diff options
| author | Stefan Boberg <[email protected]> | 2025-11-24 10:32:52 +0100 |
|---|---|---|
| committer | GitHub Enterprise <[email protected]> | 2025-11-24 10:32:52 +0100 |
| commit | d6a2b1c247778e1bf2a847adba230ad20f44d21d (patch) | |
| tree | a6eb8c33798456d9d68a6918a693bdc54158f562 /src/zenhttp | |
| parent | changelog spelling (diff) | |
| download | zen-d6a2b1c247778e1bf2a847adba230ad20f44d21d.tar.xz zen-d6a2b1c247778e1bf2a847adba230ad20f44d21d.zip | |
Add regex-free route matching support (#662)
This change adds support for non-regex matching of routes. Instead of using regex patterns you can associate matcher functions with pattern names and string literal components are identified and matched directly.
Also implemented tests for `HttpRequestRouter` class.
Diffstat (limited to 'src/zenhttp')
| -rw-r--r-- | src/zenhttp/httpserver.cpp | 410 | ||||
| -rw-r--r-- | src/zenhttp/include/zenhttp/httpserver.h | 114 |
2 files changed, 474 insertions, 50 deletions
diff --git a/src/zenhttp/httpserver.cpp b/src/zenhttp/httpserver.cpp index f48c22367..b28682375 100644 --- a/src/zenhttp/httpserver.cpp +++ b/src/zenhttp/httpserver.cpp @@ -698,24 +698,124 @@ void HttpRequestRouter::AddPattern(const char* Id, const char* Regex) { ZEN_ASSERT(m_PatternMap.find(Id) == m_PatternMap.end()); + ZEN_ASSERT(!m_IsFinalized); m_PatternMap.insert({Id, Regex}); } void -HttpRequestRouter::RegisterRoute(const char* Regex, HttpRequestRouter::HandlerFunc_t&& HandlerFunc, HttpVerb SupportedVerbs) +HttpRequestRouter::AddMatcher(const char* Id, std::function<bool(std::string_view)>&& Matcher) { - ExtendableStringBuilder<128> ExpandedRegex; - ProcessRegexSubstitutions(Regex, ExpandedRegex); + ZEN_ASSERT(m_MatcherNameMap.find(Id) == m_MatcherNameMap.end()); + ZEN_ASSERT(!m_IsFinalized); - m_Handlers.emplace_back(ExpandedRegex.c_str(), SupportedVerbs, std::move(HandlerFunc), Regex); + const int MatcherIndex = gsl::narrow_cast<int>(m_MatcherFunctions.size()); + m_MatcherFunctions.push_back(Matcher); + m_MatcherNameMap.insert({Id, MatcherIndex}); } void +HttpRequestRouter::RegisterRoute(const char* UriPattern, HttpRequestRouter::HandlerFunc_t&& HandlerFunc, HttpVerb SupportedVerbs) +{ + ZEN_ASSERT(!m_IsFinalized); + + if (ExtendableStringBuilder<128> ExpandedRegex; ProcessRegexSubstitutions(UriPattern, ExpandedRegex)) + { + // Regex route + m_RegexHandlers.emplace_back(ExpandedRegex.c_str(), SupportedVerbs, std::move(HandlerFunc), UriPattern); + } + else + { + // New-style regex-free route. More efficient and should be used for everything eventually + + int RegexLen = gsl::narrow_cast<int>(strlen(UriPattern)); + + int i = 0; + + std::vector<int> MatcherIndices; + + while (i < RegexLen) + { + if (UriPattern[i] == '{') + { + bool IsComplete = false; + int PatternStart = i + 1; + while (++i < RegexLen) + { + if (UriPattern[i] == '}') + { + std::string_view Pattern(&UriPattern[PatternStart], i - PatternStart); + if (auto it = m_MatcherNameMap.find(std::string(Pattern)); it != m_MatcherNameMap.end()) + { + // It's a match + MatcherIndices.push_back(it->second); + IsComplete = true; + ++i; + break; + } + else + { + throw std::runtime_error(fmt::format("unknown matcher pattern '{}' in URI pattern '{}'", Pattern, UriPattern)); + } + } + } + if (!IsComplete) + { + throw std::runtime_error(fmt::format("unterminated matcher pattern in URI pattern '{}'", UriPattern)); + } + } + else + { + if (UriPattern[i] == '/') + { + throw std::runtime_error(fmt::format("unexpected '/' in literal segment of URI pattern '{}'", UriPattern)); + } + + int SegmentStart = i; + while (++i < RegexLen && UriPattern[i] != '/') + ; + + std::string_view Segment(&UriPattern[SegmentStart], (i - SegmentStart)); + int LiteralIndex = gsl::narrow_cast<int>(m_Literals.size()); + m_Literals.push_back(std::string(Segment)); + MatcherIndices.push_back(-1 - LiteralIndex); + } + + if (i < RegexLen && UriPattern[i] == '/') + { + ++i; // skip slash + } + } + + m_MatcherEndpoints.emplace_back(std::move(MatcherIndices), SupportedVerbs, std::move(HandlerFunc), UriPattern); + } +} + +std::string_view +HttpRouterRequest::GetCapture(uint32_t Index) const +{ + if (!m_CapturedSegments.empty()) + { + ZEN_ASSERT(Index < m_CapturedSegments.size()); + return m_CapturedSegments[Index]; + } + + ZEN_ASSERT(Index < m_Match.size()); + + const auto& Match = m_Match[Index]; + + return std::string_view(&*Match.first, Match.second - Match.first); +} + +bool HttpRequestRouter::ProcessRegexSubstitutions(const char* Regex, StringBuilderBase& OutExpandedRegex) { size_t RegexLen = strlen(Regex); + bool HasRegex = false; + + std::vector<std::string> UnknownPatterns; + for (size_t i = 0; i < RegexLen;) { bool matched = false; @@ -733,12 +833,11 @@ HttpRequestRouter::ProcessRegexSubstitutions(const char* Regex, StringBuilderBas if (auto it = m_PatternMap.find(Pattern); it != m_PatternMap.end()) { OutExpandedRegex.Append(it->second.c_str()); + HasRegex = true; } else { - // Default to anything goes (or should this just be an error?) - - OutExpandedRegex.Append("(.+?)"); + UnknownPatterns.push_back(Pattern); } // skip ahead @@ -756,17 +855,127 @@ HttpRequestRouter::ProcessRegexSubstitutions(const char* Regex, StringBuilderBas OutExpandedRegex.Append(Regex[i++]); } } + + if (HasRegex) + { + if (UnknownPatterns.size() > 0) + { + std::string UnknownList; + for (const auto& Pattern : UnknownPatterns) + { + if (!UnknownList.empty()) + { + UnknownList += ", "; + } + UnknownList += "'"; + UnknownList += Pattern; + UnknownList += "'"; + } + + throw std::runtime_error(fmt::format("unknown pattern(s) {} in regex route '{}'", UnknownList, Regex)); + } + + return true; + } + + return false; } bool HttpRequestRouter::HandleRequest(zen::HttpServerRequest& Request) { + if (!m_IsFinalized) + { + m_IsFinalized = true; + } + const HttpVerb Verb = Request.RequestVerb(); std::string_view Uri = Request.RelativeUri(); HttpRouterRequest RouterRequest(Request); - for (const auto& Handler : m_Handlers) + // First try new-style matcher routes + + for (const auto& Handler : m_MatcherEndpoints) + { + if ((Handler.Verbs & Verb) == Verb) + { + size_t UriPos = 0; + const size_t UriLen = Uri.length(); + const std::vector<int>& Matchers = Handler.ComponentIndices; + bool IsMatch = true; + + std::vector<std::string_view> CapturedSegments; + + CapturedSegments.emplace_back(Uri); + + for (int MatcherIndex : Matchers) + { + if (UriPos >= UriLen) + { + IsMatch = false; + break; + } + + if (MatcherIndex < 0) + { + // Literal match + int LitIndex = -MatcherIndex - 1; + const std::string& LitStr = m_Literals[LitIndex]; + size_t LitLen = LitStr.length(); + + if (Uri.substr(UriPos, LitLen) == LitStr) + { + UriPos += LitLen; + } + else + { + IsMatch = false; + break; + } + } + else + { + // Matcher function + size_t SegmentStart = UriPos; + while (UriPos < UriLen && Uri[UriPos] != '/') + { + ++UriPos; + } + + std::string_view Segment = Uri.substr(SegmentStart, UriPos - SegmentStart); + + if (m_MatcherFunctions[MatcherIndex](Segment)) + { + CapturedSegments.push_back(Segment); + } + else + { + IsMatch = false; + break; + } + } + + // Skip slash + if (UriPos < UriLen && Uri[UriPos] == '/') + { + ++UriPos; + } + } + + if (IsMatch && UriPos == UriLen) + { + RouterRequest.m_CapturedSegments = std::move(CapturedSegments); + Handler.Handler(RouterRequest); + + return true; // Route matched + } + } + } + + // Old-style regex routes + + for (const auto& Handler : m_RegexHandlers) { if ((Handler.Verbs & Verb) == Verb && regex_match(begin(Uri), end(Uri), RouterRequest.m_Match, Handler.RegEx)) { @@ -1023,22 +1232,189 @@ TEST_CASE("http.common") { using namespace std::literals; - SUBCASE("router") + struct TestHttpServerRequest : public HttpServerRequest + { + TestHttpServerRequest(std::string_view Uri) { m_Uri = Uri; } + virtual IoBuffer ReadPayload() override { return IoBuffer(); } + virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs) override + { + ZEN_UNUSED(ResponseCode, ContentType, Blobs); + } + virtual void WriteResponse(HttpResponseCode ResponseCode) override { ZEN_UNUSED(ResponseCode); } + virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) override + { + ZEN_UNUSED(ResponseCode, ContentType, ResponseString); + } + virtual void WriteResponseAsync(std::function<void(HttpServerRequest&)>&& ContinuationHandler) override + { + ZEN_UNUSED(ContinuationHandler); + } + virtual Oid ParseSessionId() const override { return Oid(); } + virtual uint32_t ParseRequestId() const override { return 0; } + }; + + SUBCASE("router-regex") + { + bool HandledA = false; + bool HandledAA = false; + std::vector<std::string> Captures; + auto Reset = [&] { + Captures.clear(); + HandledA = HandledAA = false; + }; + + HttpRequestRouter r; + r.AddPattern("a", "([[:alpha:]]+)"); + r.RegisterRoute( + "{a}", + [&](auto& Req) { + HandledA = true; + Captures = {std::string(Req.GetCapture(1))}; + }, + HttpVerb::kGet); + + r.RegisterRoute( + "{a}/{a}", + [&](auto& Req) { + HandledAA = true; + Captures = {std::string(Req.GetCapture(1)), std::string(Req.GetCapture(2))}; + }, + HttpVerb::kGet); + + { + Reset(); + TestHttpServerRequest req{"abc"sv}; + r.HandleRequest(req); + CHECK(HandledA); + CHECK(!HandledAA); + REQUIRE_EQ(Captures.size(), 1); + CHECK_EQ(Captures[0], "abc"sv); + } + + { + Reset(); + TestHttpServerRequest req{"abc/def"sv}; + r.HandleRequest(req); + CHECK(!HandledA); + CHECK(HandledAA); + REQUIRE_EQ(Captures.size(), 2); + CHECK_EQ(Captures[0], "abc"sv); + CHECK_EQ(Captures[1], "def"sv); + } + + { + Reset(); + TestHttpServerRequest req{"123"sv}; + r.HandleRequest(req); + CHECK(!HandledA); + } + + { + Reset(); + TestHttpServerRequest req{"a123"sv}; + r.HandleRequest(req); + CHECK(!HandledA); + } + } + + SUBCASE("router-matcher") { + bool HandledA = false; + bool HandledAA = false; + bool HandledAB = false; + bool HandledAandB = false; + std::vector<std::string> Captures; + auto Reset = [&] { + HandledA = HandledAA = HandledAB = HandledAandB = false; + Captures.clear(); + }; + HttpRequestRouter r; - r.AddPattern("a", "[[:alpha:]]+"); + r.AddMatcher("a", [](std::string_view In) -> bool { return In.length() % 2 == 0; }); + r.AddMatcher("b", [](std::string_view In) -> bool { return In.length() % 3 == 0; }); r.RegisterRoute( "{a}", - [&](auto) {}, + [&](auto& Req) { + HandledA = true; + Captures = {std::string(Req.GetCapture(1))}; + }, + HttpVerb::kGet); + r.RegisterRoute( + "{a}/{a}", + [&](auto& Req) { + HandledAA = true; + Captures = {std::string(Req.GetCapture(1)), std::string(Req.GetCapture(2))}; + }, + HttpVerb::kGet); + r.RegisterRoute( + "{a}/{b}", + [&](auto& Req) { + HandledAB = true; + Captures = {std::string(Req.GetCapture(1)), std::string(Req.GetCapture(2))}; + }, HttpVerb::kGet); + r.RegisterRoute( + "{a}/and/{b}", + [&](auto& Req) { + HandledAandB = true; + Captures = {std::string(Req.GetCapture(1)), std::string(Req.GetCapture(2))}; + }, + HttpVerb::kGet); + + { + Reset(); + TestHttpServerRequest req{"ab"sv}; + r.HandleRequest(req); + CHECK(HandledA); + CHECK(!HandledAA); + CHECK(!HandledAB); + + REQUIRE_EQ(Captures.size(), 1); + CHECK_EQ(Captures[0], "ab"sv); + } - // struct TestHttpServerRequest : public HttpServerRequest - //{ - // TestHttpServerRequest(std::string_view Uri) : m_uri{Uri} {} - //}; + { + Reset(); + TestHttpServerRequest req{"ab/def"sv}; + r.HandleRequest(req); + CHECK(!HandledA); + CHECK(!HandledAA); + CHECK(HandledAB); + REQUIRE_EQ(Captures.size(), 2); + CHECK_EQ(Captures[0], "ab"sv); + CHECK_EQ(Captures[1], "def"sv); + } + + { + Reset(); + TestHttpServerRequest req{"ab/and/def"sv}; + r.HandleRequest(req); + CHECK(!HandledA); + CHECK(!HandledAA); + CHECK(!HandledAB); + CHECK(HandledAandB); + REQUIRE_EQ(Captures.size(), 2); + CHECK_EQ(Captures[0], "ab"sv); + CHECK_EQ(Captures[1], "def"sv); + } - // TestHttpServerRequest req{}; - // r.HandleRequest(req); + { + Reset(); + TestHttpServerRequest req{"123"sv}; + r.HandleRequest(req); + CHECK(!HandledA); + CHECK(!HandledAA); + CHECK(!HandledAB); + } + + { + Reset(); + TestHttpServerRequest req{"a123"sv}; + r.HandleRequest(req); + CHECK(HandledA); + CHECK(!HandledAA); + CHECK(!HandledAB); + } } SUBCASE("content-type") diff --git a/src/zenhttp/include/zenhttp/httpserver.h b/src/zenhttp/include/zenhttp/httpserver.h index 03e547bf3..f95ec51d2 100644 --- a/src/zenhttp/include/zenhttp/httpserver.h +++ b/src/zenhttp/include/zenhttp/httpserver.h @@ -45,15 +45,13 @@ public: std::string_view GetValue(std::string_view ParamName) const { - for (const auto& Kv : KvPairs) + for (const auto& [Key, Value] : KvPairs) { - const std::string_view& Key = Kv.first; - if (Key.size() == ParamName.size()) { if (0 == StrCaseCompare(Key.data(), ParamName.data(), Key.size())) { - return Kv.second; + return Value; } } } @@ -213,43 +211,53 @@ Ref<HttpServer> CreateHttpServer(const HttpServerConfig& Config); class HttpRouterRequest { public: - HttpRouterRequest(HttpServerRequest& Request) : m_HttpRequest(Request) {} - + /** Get captured segment from matched URL + * + * @param Index Index of captured segment to retrieve. Note that due to + * backwards compatibility with regex-based routes, this index is 1-based + * and index=0 is the full matched URL + * @return Returns string view of captured segment + */ std::string_view GetCapture(uint32_t Index) const; inline HttpServerRequest& ServerRequest() { return m_HttpRequest; } private: + HttpRouterRequest(HttpServerRequest& Request) : m_HttpRequest(Request) {} + ~HttpRouterRequest() = default; + + HttpRouterRequest(const HttpRouterRequest&) = delete; + HttpRouterRequest& operator=(const HttpRouterRequest&) = delete; + using MatchResults_t = std::match_results<std::string_view::const_iterator>; - HttpServerRequest& m_HttpRequest; - MatchResults_t m_Match; + HttpServerRequest& m_HttpRequest; + MatchResults_t m_Match; + std::vector<std::string_view> m_CapturedSegments; // for matcher-based routes friend class HttpRequestRouter; }; -inline std::string_view -HttpRouterRequest::GetCapture(uint32_t Index) const -{ - ZEN_ASSERT(Index < m_Match.size()); - - const auto& Match = m_Match[Index]; - - return std::string_view(&*Match.first, Match.second - Match.first); -} - /** HTTP request router helper * * This helper class allows a service implementer to register one or more - * endpoints using pattern matching (currently using regex matching) + * endpoints using pattern matching. We currently support a legacy regex-based + * matching system, but also a new matcher-function based system which is more + * efficient and should be used whenever possible. * * This is intended to be initialized once only, there is no thread * safety so you can absolutely not add or remove endpoints once the handler - * goes live + * goes live. */ class HttpRequestRouter { public: + HttpRequestRouter() = default; + ~HttpRequestRouter() = default; + + HttpRequestRouter(const HttpRequestRouter&) = delete; + HttpRequestRouter& operator=(const HttpRequestRouter&) = delete; + typedef std::function<void(HttpRouterRequest&)> HandlerFunc_t; /** @@ -260,15 +268,21 @@ public: void AddPattern(const char* Id, const char* Regex); /** - * @brief Register a an endpoint handler for the given route - * @param Regex Regular expression used to match the handler to a request. This may - * contain pattern aliases registered via AddPattern + * @brief Add matcher function which can be referenced by name, used for URL components + * @param Id String used to identify matchers in endpoint specifications + * @param Matcher Function which will be called to match the component + */ + void AddMatcher(const char* Id, std::function<bool(std::string_view)>&& Matcher); + + /** + * @brief Register an endpoint handler for the given route + * @param Pattern Pattern used to match the handler to a request. This should + * only contain literal URI segments and pattern aliases registered + via AddPattern() or AddMatcher() * @param HandlerFunc Handler function to call for any matching request * @param SupportedVerbs Supported HTTP verbs for this handler */ - void RegisterRoute(const char* Regex, HandlerFunc_t&& HandlerFunc, HttpVerb SupportedVerbs); - - void ProcessRegexSubstitutions(const char* Regex, StringBuilderBase& ExpandedRegex); + void RegisterRoute(const char* Pattern, HandlerFunc_t&& HandlerFunc, HttpVerb SupportedVerbs); /** * @brief HTTP request handling function - this should be called to route the @@ -279,9 +293,11 @@ public: bool HandleRequest(zen::HttpServerRequest& Request); private: - struct HandlerEntry + bool ProcessRegexSubstitutions(const char* Regex, StringBuilderBase& ExpandedRegex); + + struct RegexEndpoint { - HandlerEntry(const char* Regex, HttpVerb SupportedVerbs, HandlerFunc_t&& Handler, const char* Pattern) + RegexEndpoint(const char* Regex, HttpVerb SupportedVerbs, HandlerFunc_t&& Handler, const char* Pattern) : RegEx(Regex, std::regex::icase | std::regex::ECMAScript) , Verbs(SupportedVerbs) , Handler(std::move(Handler)) @@ -289,7 +305,7 @@ private: { } - ~HandlerEntry() = default; + ~RegexEndpoint() = default; std::regex RegEx; HttpVerb Verbs; @@ -297,12 +313,44 @@ private: const char* Pattern; private: - HandlerEntry& operator=(const HandlerEntry&) = delete; - HandlerEntry(const HandlerEntry&) = delete; + RegexEndpoint& operator=(const RegexEndpoint&) = delete; + RegexEndpoint(const RegexEndpoint&) = delete; }; - std::list<HandlerEntry> m_Handlers; + std::list<RegexEndpoint> m_RegexHandlers; std::unordered_map<std::string, std::string> m_PatternMap; + + // New-style matcher endpoints. Should be preferred over regex endpoints where possible + // as it is considerably more efficient + + struct MatcherEndpoint + { + MatcherEndpoint(std::vector<int>&& ComponentIndices, HttpVerb SupportedVerbs, HandlerFunc_t&& Handler, const char* Pattern) + : ComponentIndices(std::move(ComponentIndices)) + , Verbs(SupportedVerbs) + , Handler(std::move(Handler)) + , Pattern(Pattern) + { + } + + ~MatcherEndpoint() = default; + + // Negative indexes are literals, non-negative are matcher function indexes + std::vector<int> ComponentIndices; + HttpVerb Verbs; + HandlerFunc_t Handler; + const char* Pattern; + + private: + MatcherEndpoint& operator=(const MatcherEndpoint&) = delete; + MatcherEndpoint(const MatcherEndpoint&) = delete; + }; + + std::unordered_map<std::string, int> m_MatcherNameMap; + std::vector<std::function<bool(std::string_view)>> m_MatcherFunctions; + std::vector<std::string> m_Literals; + std::list<MatcherEndpoint> m_MatcherEndpoints; + bool m_IsFinalized = false; }; /** HTTP RPC request helper @@ -310,7 +358,7 @@ private: class RpcResult { - RpcResult(CbObject Result) : m_Result(std::move(Result)) {} + explicit RpcResult(CbObject Result) : m_Result(std::move(Result)) {} private: CbObject m_Result; |