diff options
| author | Dan Engelbrecht <[email protected]> | 2026-02-13 13:27:08 +0100 |
|---|---|---|
| committer | GitHub Enterprise <[email protected]> | 2026-02-13 13:27:08 +0100 |
| commit | 3b5b777900d9f59ff32eb7cea79e3a72a08c67a6 (patch) | |
| tree | f5ffdeaad0ca9e291085d707209938c6dfe86d20 | |
| parent | bump sentry to 0.12.1 (#721) (diff) | |
| download | zen-3b5b777900d9f59ff32eb7cea79e3a72a08c67a6.tar.xz zen-3b5b777900d9f59ff32eb7cea79e3a72a08c67a6.zip | |
add IHttpRequestFilter to allow server implementation to filter/reject requests (#753)
* add IHttpRequestFilter to allow server implementation to filter/reject requests
| -rw-r--r-- | src/zenhttp/httpclient.cpp | 104 | ||||
| -rw-r--r-- | src/zenhttp/httpserver.cpp | 12 | ||||
| -rw-r--r-- | src/zenhttp/include/zenhttp/httpserver.h | 22 | ||||
| -rw-r--r-- | src/zenhttp/servers/httpasio.cpp | 131 | ||||
| -rw-r--r-- | src/zenhttp/servers/httpmulti.cpp | 9 | ||||
| -rw-r--r-- | src/zenhttp/servers/httpmulti.h | 1 | ||||
| -rw-r--r-- | src/zenhttp/servers/httpnull.cpp | 6 | ||||
| -rw-r--r-- | src/zenhttp/servers/httpnull.h | 1 | ||||
| -rw-r--r-- | src/zenhttp/servers/httpplugin.cpp | 117 | ||||
| -rw-r--r-- | src/zenhttp/servers/httpsys.cpp | 46 |
10 files changed, 353 insertions, 96 deletions
diff --git a/src/zenhttp/httpclient.cpp b/src/zenhttp/httpclient.cpp index c77be8624..16729ce38 100644 --- a/src/zenhttp/httpclient.cpp +++ b/src/zenhttp/httpclient.cpp @@ -427,13 +427,13 @@ TEST_CASE("httpclient") AsioServer->RegisterService(TestService); - std::thread SeverThread([&]() { AsioServer->Run(false); }); + std::thread ServerThread([&]() { AsioServer->Run(false); }); { auto _ = MakeGuard([&]() { - if (SeverThread.joinable()) + if (ServerThread.joinable()) { - SeverThread.join(); + ServerThread.join(); } AsioServer->Close(); }); @@ -502,13 +502,13 @@ TEST_CASE("httpclient") HttpSysServer->RegisterService(TestService); - std::thread SeverThread([&]() { HttpSysServer->Run(false); }); + std::thread ServerThread([&]() { HttpSysServer->Run(false); }); { auto _ = MakeGuard([&]() { - if (SeverThread.joinable()) + if (ServerThread.joinable()) { - SeverThread.join(); + ServerThread.join(); } HttpSysServer->Close(); }); @@ -570,6 +570,98 @@ TEST_CASE("httpclient") # endif // ZEN_PLATFORM_WINDOWS } +TEST_CASE("httpclient.requestfilter") +{ + using namespace std::literals; + + struct TestHttpService : public HttpService + { + TestHttpService() = default; + + virtual const char* BaseUri() const override { return "/test/"; } + virtual void HandleRequest(HttpServerRequest& HttpServiceRequest) override + { + if (HttpServiceRequest.RelativeUri() == "yo") + { + return HttpServiceRequest.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "hey family"); + } + + { + CHECK(HttpServiceRequest.RelativeUri() != "should_filter"); + return HttpServiceRequest.WriteResponse(HttpResponseCode::InternalServerError); + } + + { + CHECK(HttpServiceRequest.RelativeUri() != "should_forbid"); + return HttpServiceRequest.WriteResponse(HttpResponseCode::InternalServerError); + } + } + }; + + TestHttpService TestService; + ScopedTemporaryDirectory TmpDir; + + class MyFilterImpl : public IHttpRequestFilter + { + public: + virtual Result FilterRequest(HttpServerRequest& Request) + { + if (Request.RelativeUri() == "should_filter") + { + Request.WriteResponse(HttpResponseCode::MethodNotAllowed, HttpContentType::kText, "no thank you"); + return Result::ResponseSent; + } + else if (Request.RelativeUri() == "should_forbid") + { + return Result::Forbidden; + } + return Result::Accepted; + } + }; + + MyFilterImpl MyFilter; + + Ref<HttpServer> AsioServer = CreateHttpAsioServer(AsioConfig{}); + + AsioServer->SetHttpRequestFilter(&MyFilter); + + int Port = AsioServer->Initialize(7575, TmpDir.Path()); + REQUIRE(Port != -1); + + AsioServer->RegisterService(TestService); + + std::thread ServerThread([&]() { AsioServer->Run(false); }); + + { + auto _ = MakeGuard([&]() { + if (ServerThread.joinable()) + { + ServerThread.join(); + } + AsioServer->Close(); + }); + + HttpClient Client(fmt::format("localhost:{}", Port), + HttpClientSettings{}, + /*CheckIfAbortFunction*/ {}); + + ZEN_INFO("Request using {}", Client.GetBaseUri()); + + HttpClient::Response YoResponse = Client.Get("/test/yo"); + CHECK(YoResponse.IsSuccess()); + CHECK_EQ(YoResponse.AsText(), "hey family"); + + HttpClient::Response ShouldFilterResponse = Client.Get("/test/should_filter"); + CHECK_EQ(ShouldFilterResponse.StatusCode, HttpResponseCode::MethodNotAllowed); + CHECK_EQ(ShouldFilterResponse.AsText(), "no thank you"); + + HttpClient::Response ShouldForbitResponse = Client.Get("/test/should_forbid"); + CHECK_EQ(ShouldForbitResponse.StatusCode, HttpResponseCode::Forbidden); + + AsioServer->RequestExit(); + } +} + void httpclient_forcelink() { diff --git a/src/zenhttp/httpserver.cpp b/src/zenhttp/httpserver.cpp index 8985120b0..d8367fcb2 100644 --- a/src/zenhttp/httpserver.cpp +++ b/src/zenhttp/httpserver.cpp @@ -463,7 +463,7 @@ HttpService::HandlePackageRequest(HttpServerRequest& HttpServiceRequest) ////////////////////////////////////////////////////////////////////////// -HttpServerRequest::HttpServerRequest(HttpService& Service) : m_BaseUri(Service.BaseUri()) +HttpServerRequest::HttpServerRequest(HttpService& Service) : m_Service(Service) { } @@ -970,7 +970,7 @@ HttpRequestRouter::HandleRequest(zen::HttpServerRequest& Request) if (otel::Span* ActiveSpan = otel::Span::GetCurrentSpan()) { ExtendableStringBuilder<128> RoutePath; - RoutePath.Append(Request.BaseUri()); + RoutePath.Append(Request.Service().BaseUri()); RoutePath.Append(Handler.Pattern); ActiveSpan->AddAttribute("http.route"sv, RoutePath.ToView()); } @@ -994,7 +994,7 @@ HttpRequestRouter::HandleRequest(zen::HttpServerRequest& Request) if (otel::Span* ActiveSpan = otel::Span::GetCurrentSpan()) { ExtendableStringBuilder<128> RoutePath; - RoutePath.Append(Request.BaseUri()); + RoutePath.Append(Request.Service().BaseUri()); RoutePath.Append(Handler.Pattern); ActiveSpan->AddAttribute("http.route"sv, RoutePath.ToView()); } @@ -1052,6 +1052,12 @@ HttpServer::EnumerateServices(std::function<void(HttpService& Service)>&& Callba } } +void +HttpServer::SetHttpRequestFilter(IHttpRequestFilter* RequestFilter) +{ + OnSetHttpRequestFilter(RequestFilter); +} + ////////////////////////////////////////////////////////////////////////// HttpRpcHandler::HttpRpcHandler() diff --git a/src/zenhttp/include/zenhttp/httpserver.h b/src/zenhttp/include/zenhttp/httpserver.h index f0a667686..60f6bc9f2 100644 --- a/src/zenhttp/include/zenhttp/httpserver.h +++ b/src/zenhttp/include/zenhttp/httpserver.h @@ -41,7 +41,7 @@ public: [[nodiscard]] inline std::string_view RelativeUri() const { return m_Uri; } // Returns URI without service prefix [[nodiscard]] std::string_view RelativeUriWithExtension() const { return m_UriWithExtension; } [[nodiscard]] inline std::string_view QueryString() const { return m_QueryString; } - [[nodiscard]] inline std::string_view BaseUri() const { return m_BaseUri; } // Service prefix + [[nodiscard]] inline HttpService& Service() const { return m_Service; } struct QueryParams { @@ -121,13 +121,14 @@ protected: kHaveSessionId = 1 << 3, }; - mutable uint32_t m_Flags = 0; + mutable uint32_t m_Flags = 0; + + HttpService& m_Service; // Service handling this request HttpVerb m_Verb = HttpVerb::kGet; HttpContentType m_ContentType = HttpContentType::kBinary; HttpContentType m_AcceptType = HttpContentType::kUnknownContentType; uint64_t m_ContentLength = ~0ull; - std::string_view m_BaseUri; // Base URI path of the service handling this request - std::string_view m_Uri; // URI without service prefix + std::string_view m_Uri; // URI without service prefix std::string_view m_UriWithExtension; std::string_view m_QueryString; mutable uint32_t m_RequestId = ~uint32_t(0); @@ -148,6 +149,17 @@ public: virtual void OnRequestComplete() = 0; }; +struct IHttpRequestFilter +{ + enum class Result + { + Forbidden, + ResponseSent, + Accepted + }; + virtual Result FilterRequest(HttpServerRequest& Request) = 0; +}; + /** * Base class for implementing an HTTP "service" * @@ -184,6 +196,7 @@ class HttpServer : public RefCounted public: void RegisterService(HttpService& Service); void EnumerateServices(std::function<void(HttpService&)>&& Callback); + void SetHttpRequestFilter(IHttpRequestFilter* RequestFilter); int Initialize(int BasePort, std::filesystem::path DataDir); void Run(bool IsInteractiveSession); @@ -195,6 +208,7 @@ private: virtual void OnRegisterService(HttpService& Service) = 0; virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) = 0; + virtual void OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) = 0; virtual void OnRun(bool IsInteractiveSession) = 0; virtual void OnRequestExit() = 0; virtual void OnClose() = 0; diff --git a/src/zenhttp/servers/httpasio.cpp b/src/zenhttp/servers/httpasio.cpp index 8bfbd8b37..230aac6a8 100644 --- a/src/zenhttp/servers/httpasio.cpp +++ b/src/zenhttp/servers/httpasio.cpp @@ -498,16 +498,19 @@ public: HttpAsioServerImpl(); ~HttpAsioServerImpl(); - void Initialize(std::filesystem::path DataDir); - int Start(uint16_t Port, const AsioConfig& Config); - void Stop(); - void RegisterService(const char* UrlPath, HttpService& Service); - HttpService* RouteRequest(std::string_view Url); + void Initialize(std::filesystem::path DataDir); + int Start(uint16_t Port, const AsioConfig& Config); + void Stop(); + void RegisterService(const char* UrlPath, HttpService& Service); + void SetHttpRequestFilter(IHttpRequestFilter* RequestFilter); + HttpService* RouteRequest(std::string_view Url); + IHttpRequestFilter::Result FilterRequest(HttpServerRequest& Request); asio::io_service m_IoService; asio::io_service::work m_Work{m_IoService}; std::unique_ptr<asio_http::HttpAcceptor> m_Acceptor; std::vector<std::thread> m_ThreadPool; + std::atomic<IHttpRequestFilter*> m_HttpRequestFilter = nullptr; LoggerRef m_RequestLog; HttpServerTracer m_RequestTracer; @@ -1199,53 +1202,65 @@ HttpServerConnection::HandleRequest() std::vector<IoBuffer>{Request.ReadPayload()}); } - if (!HandlePackageOffers(*Service, Request, m_PackageHandler)) + IHttpRequestFilter::Result FilterResult = m_Server.FilterRequest(Request); + if (FilterResult == IHttpRequestFilter::Result::Accepted) { - try - { - Service->HandleRequest(Request); - } - catch (const AssertException& AssertEx) - { - // Drop any partially formatted response - Request.m_Response.reset(); - - ZEN_ERROR("Caught assert exception while handling request: {}", AssertEx.FullDescription()); - Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, AssertEx.FullDescription()); - } - catch (const std::system_error& SystemError) + if (!HandlePackageOffers(*Service, Request, m_PackageHandler)) { - // Drop any partially formatted response - Request.m_Response.reset(); - - if (IsOOM(SystemError.code()) || IsOOD(SystemError.code())) + try { - Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, SystemError.what()); + Service->HandleRequest(Request); } - else + catch (const AssertException& AssertEx) { - ZEN_WARN("Caught system error exception while handling request: {}. ({})", - SystemError.what(), - SystemError.code().value()); - Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, SystemError.what()); + // Drop any partially formatted response + Request.m_Response.reset(); + + ZEN_ERROR("Caught assert exception while handling request: {}", AssertEx.FullDescription()); + Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, AssertEx.FullDescription()); } - } - catch (const std::bad_alloc& BadAlloc) - { - // Drop any partially formatted response - Request.m_Response.reset(); + catch (const std::system_error& SystemError) + { + // Drop any partially formatted response + Request.m_Response.reset(); - Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, BadAlloc.what()); - } - catch (const std::exception& ex) - { - // Drop any partially formatted response - Request.m_Response.reset(); + if (IsOOM(SystemError.code()) || IsOOD(SystemError.code())) + { + Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, SystemError.what()); + } + else + { + ZEN_WARN("Caught system error exception while handling request: {}. ({})", + SystemError.what(), + SystemError.code().value()); + Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, SystemError.what()); + } + } + catch (const std::bad_alloc& BadAlloc) + { + // Drop any partially formatted response + Request.m_Response.reset(); + + Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, BadAlloc.what()); + } + catch (const std::exception& ex) + { + // Drop any partially formatted response + Request.m_Response.reset(); - ZEN_WARN("Caught exception while handling request: {}", ex.what()); - Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, ex.what()); + ZEN_WARN("Caught exception while handling request: {}", ex.what()); + Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, ex.what()); + } } } + else if (FilterResult == IHttpRequestFilter::Result::Forbidden) + { + Request.WriteResponse(HttpResponseCode::Forbidden); + } + else + { + ZEN_ASSERT(FilterResult == IHttpRequestFilter::Result::ResponseSent); + } if (std::unique_ptr<HttpResponse> Response = std::move(Request.m_Response)) { @@ -1923,6 +1938,31 @@ HttpAsioServerImpl::RouteRequest(std::string_view Url) return CandidateService; } +void +HttpAsioServerImpl::SetHttpRequestFilter(IHttpRequestFilter* RequestFilter) +{ + ZEN_MEMSCOPE(GetHttpasioTag()); + RwLock::ExclusiveLockScope _(m_Lock); + m_HttpRequestFilter.store(RequestFilter); +} + +IHttpRequestFilter::Result +HttpAsioServerImpl::FilterRequest(HttpServerRequest& Request) +{ + if (!m_HttpRequestFilter.load()) + { + return IHttpRequestFilter::Result::Accepted; + } + RwLock::SharedLockScope _(m_Lock); + IHttpRequestFilter* RequestFilter = m_HttpRequestFilter.load(); + if (!RequestFilter) + { + return IHttpRequestFilter::Result::Accepted; + } + IHttpRequestFilter::Result FilterResult = RequestFilter->FilterRequest(Request); + return FilterResult; +} + } // namespace zen::asio_http ////////////////////////////////////////////////////////////////////////// @@ -1937,6 +1977,7 @@ public: virtual void OnRegisterService(HttpService& Service) override; virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) override; + virtual void OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) override; virtual void OnRun(bool IsInteractiveSession) override; virtual void OnRequestExit() override; virtual void OnClose() override; @@ -1984,6 +2025,12 @@ HttpAsioServer::OnRegisterService(HttpService& Service) m_Impl->RegisterService(Service.BaseUri(), Service); } +void +HttpAsioServer::OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) +{ + m_Impl->SetHttpRequestFilter(RequestFilter); +} + int HttpAsioServer::OnInitialize(int BasePort, std::filesystem::path DataDir) { diff --git a/src/zenhttp/servers/httpmulti.cpp b/src/zenhttp/servers/httpmulti.cpp index 95624245f..850d7d6b9 100644 --- a/src/zenhttp/servers/httpmulti.cpp +++ b/src/zenhttp/servers/httpmulti.cpp @@ -54,6 +54,15 @@ HttpMultiServer::OnInitialize(int BasePort, std::filesystem::path DataDir) } void +HttpMultiServer::OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) +{ + for (auto& Server : m_Servers) + { + Server->SetHttpRequestFilter(RequestFilter); + } +} + +void HttpMultiServer::OnRun(bool IsInteractiveSession) { const int WaitTimeout = 1000; diff --git a/src/zenhttp/servers/httpmulti.h b/src/zenhttp/servers/httpmulti.h index ae0ed74cf..1897587a9 100644 --- a/src/zenhttp/servers/httpmulti.h +++ b/src/zenhttp/servers/httpmulti.h @@ -16,6 +16,7 @@ public: ~HttpMultiServer(); virtual void OnRegisterService(HttpService& Service) override; + virtual void OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) override; virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) override; virtual void OnRun(bool IsInteractiveSession) override; virtual void OnRequestExit() override; diff --git a/src/zenhttp/servers/httpnull.cpp b/src/zenhttp/servers/httpnull.cpp index b770b97db..db360c5fb 100644 --- a/src/zenhttp/servers/httpnull.cpp +++ b/src/zenhttp/servers/httpnull.cpp @@ -24,6 +24,12 @@ HttpNullServer::OnRegisterService(HttpService& Service) ZEN_UNUSED(Service); } +void +HttpNullServer::OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) +{ + ZEN_UNUSED(RequestFilter); +} + int HttpNullServer::OnInitialize(int BasePort, std::filesystem::path DataDir) { diff --git a/src/zenhttp/servers/httpnull.h b/src/zenhttp/servers/httpnull.h index ce7230938..52838f012 100644 --- a/src/zenhttp/servers/httpnull.h +++ b/src/zenhttp/servers/httpnull.h @@ -18,6 +18,7 @@ public: ~HttpNullServer(); virtual void OnRegisterService(HttpService& Service) override; + virtual void OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) override; virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) override; virtual void OnRun(bool IsInteractiveSession) override; virtual void OnRequestExit() override; diff --git a/src/zenhttp/servers/httpplugin.cpp b/src/zenhttp/servers/httpplugin.cpp index 1a630c16f..4219dc292 100644 --- a/src/zenhttp/servers/httpplugin.cpp +++ b/src/zenhttp/servers/httpplugin.cpp @@ -96,6 +96,7 @@ struct HttpPluginServerImpl : public HttpPluginServer, TransportServer // HttpPluginServer virtual void OnRegisterService(HttpService& Service) override; + virtual void OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) override; virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) override; virtual void OnRun(bool IsInteractiveSession) override; virtual void OnRequestExit() override; @@ -104,7 +105,8 @@ struct HttpPluginServerImpl : public HttpPluginServer, TransportServer virtual void AddPlugin(Ref<TransportPlugin> Plugin) override; virtual void RemovePlugin(Ref<TransportPlugin> Plugin) override; - HttpService* RouteRequest(std::string_view Url); + HttpService* RouteRequest(std::string_view Url); + IHttpRequestFilter::Result FilterRequest(HttpServerRequest& Request); struct ServiceEntry { @@ -112,7 +114,8 @@ struct HttpPluginServerImpl : public HttpPluginServer, TransportServer HttpService* Service; }; - bool m_IsInitialized = false; + std::atomic<IHttpRequestFilter*> m_HttpRequestFilter = nullptr; + bool m_IsInitialized = false; RwLock m_Lock; std::vector<ServiceEntry> m_UriHandlers; std::vector<Ref<TransportPlugin>> m_Plugins; @@ -395,53 +398,65 @@ HttpPluginConnectionHandler::HandleRequest() std::vector<IoBuffer>{Request.ReadPayload()}); } - if (!HandlePackageOffers(*Service, Request, m_PackageHandler)) + IHttpRequestFilter::Result FilterResult = m_Server->FilterRequest(Request); + if (FilterResult == IHttpRequestFilter::Result::Accepted) { - try - { - Service->HandleRequest(Request); - } - catch (const AssertException& AssertEx) + if (!HandlePackageOffers(*Service, Request, m_PackageHandler)) { - // Drop any partially formatted response - Request.m_Response.reset(); - - ZEN_ERROR("Caught assert exception while handling request: {}", AssertEx.FullDescription()); - Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, AssertEx.FullDescription()); - } - catch (const std::system_error& SystemError) - { - // Drop any partially formatted response - Request.m_Response.reset(); - - if (IsOOM(SystemError.code()) || IsOOD(SystemError.code())) + try { - Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, SystemError.what()); + Service->HandleRequest(Request); } - else + catch (const AssertException& AssertEx) { - ZEN_WARN("Caught system error exception while handling request: {}. ({})", - SystemError.what(), - SystemError.code().value()); - Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, SystemError.what()); + // Drop any partially formatted response + Request.m_Response.reset(); + + ZEN_ERROR("Caught assert exception while handling request: {}", AssertEx.FullDescription()); + Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, AssertEx.FullDescription()); } - } - catch (const std::bad_alloc& BadAlloc) - { - // Drop any partially formatted response - Request.m_Response.reset(); + catch (const std::system_error& SystemError) + { + // Drop any partially formatted response + Request.m_Response.reset(); + + if (IsOOM(SystemError.code()) || IsOOD(SystemError.code())) + { + Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, SystemError.what()); + } + else + { + ZEN_WARN("Caught system error exception while handling request: {}. ({})", + SystemError.what(), + SystemError.code().value()); + Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, SystemError.what()); + } + } + catch (const std::bad_alloc& BadAlloc) + { + // Drop any partially formatted response + Request.m_Response.reset(); - Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, BadAlloc.what()); - } - catch (const std::exception& ex) - { - // Drop any partially formatted response - Request.m_Response.reset(); + Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, BadAlloc.what()); + } + catch (const std::exception& ex) + { + // Drop any partially formatted response + Request.m_Response.reset(); - ZEN_WARN("Caught exception while handling request: {}", ex.what()); - Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, ex.what()); + ZEN_WARN("Caught exception while handling request: {}", ex.what()); + Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, ex.what()); + } } } + else if (FilterResult == IHttpRequestFilter::Result::Forbidden) + { + Request.WriteResponse(HttpResponseCode::Forbidden); + } + else + { + ZEN_ASSERT(FilterResult == IHttpRequestFilter::Result::ResponseSent); + } if (std::unique_ptr<HttpPluginResponse> Response = std::move(Request.m_Response)) { @@ -753,6 +768,13 @@ HttpPluginServerImpl::OnInitialize(int InBasePort, std::filesystem::path DataDir } void +HttpPluginServerImpl::OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) +{ + RwLock::ExclusiveLockScope _(m_Lock); + m_HttpRequestFilter.store(RequestFilter); +} + +void HttpPluginServerImpl::OnClose() { if (!m_IsInitialized) @@ -897,6 +919,23 @@ HttpPluginServerImpl::RouteRequest(std::string_view Url) return CandidateService; } +IHttpRequestFilter::Result +HttpPluginServerImpl::FilterRequest(HttpServerRequest& Request) +{ + if (!m_HttpRequestFilter.load()) + { + return IHttpRequestFilter::Result::Accepted; + } + RwLock::SharedLockScope _(m_Lock); + IHttpRequestFilter* RequestFilter = m_HttpRequestFilter.load(); + if (!RequestFilter) + { + return IHttpRequestFilter::Result::Accepted; + } + IHttpRequestFilter::Result FilterResult = RequestFilter->FilterRequest(Request); + return FilterResult; +} + ////////////////////////////////////////////////////////////////////////// struct HttpPluginServerImpl; diff --git a/src/zenhttp/servers/httpsys.cpp b/src/zenhttp/servers/httpsys.cpp index 01c4559a1..4df4cd079 100644 --- a/src/zenhttp/servers/httpsys.cpp +++ b/src/zenhttp/servers/httpsys.cpp @@ -94,6 +94,7 @@ public: virtual void OnRun(bool TestMode) override; virtual void OnRequestExit() override; virtual void OnRegisterService(HttpService& Service) override; + virtual void OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) override; virtual void OnClose() override; WorkerThreadPool& WorkPool(); @@ -101,6 +102,8 @@ public: inline bool IsOk() const { return m_IsOk; } inline bool IsAsyncResponseEnabled() const { return m_IsAsyncResponseEnabled; } + IHttpRequestFilter::Result FilterRequest(HttpServerRequest& Request); + private: int InitializeServer(int BasePort); void Cleanup(); @@ -137,6 +140,9 @@ private: int32_t m_MaxPendingRequests = 128; Event m_ShutdownEvent; HttpSysConfig m_InitialConfig; + + RwLock m_RequestFilterLock; + std::atomic<IHttpRequestFilter*> m_HttpRequestFilter = nullptr; }; } // namespace zen @@ -1672,9 +1678,21 @@ HttpSysTransaction::InvokeRequestHandler(HttpService& Service, IoBuffer Payload) otel::ScopedSpan HttpSpan(SpanNamer, SpanAnnotator); # endif - if (!HandlePackageOffers(Service, ThisRequest, m_PackageHandler)) + IHttpRequestFilter::Result FilterResult = m_HttpServer.FilterRequest(ThisRequest); + if (FilterResult == IHttpRequestFilter::Result::Accepted) + { + if (!HandlePackageOffers(Service, ThisRequest, m_PackageHandler)) + { + Service.HandleRequest(ThisRequest); + } + } + else if (FilterResult == IHttpRequestFilter::Result::Forbidden) { - Service.HandleRequest(ThisRequest); + ThisRequest.WriteResponse(HttpResponseCode::Forbidden); + } + else + { + ZEN_ASSERT(FilterResult == IHttpRequestFilter::Result::ResponseSent); } return ThisRequest; @@ -2244,6 +2262,30 @@ HttpSysServer::OnRegisterService(HttpService& Service) RegisterService(Service.BaseUri(), Service); } +void +HttpSysServer::OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) +{ + RwLock::ExclusiveLockScope _(m_RequestFilterLock); + m_HttpRequestFilter.store(RequestFilter); +} + +IHttpRequestFilter::Result +HttpSysServer::FilterRequest(HttpServerRequest& Request) +{ + if (!m_HttpRequestFilter.load()) + { + return IHttpRequestFilter::Result::Accepted; + } + RwLock::SharedLockScope _(m_RequestFilterLock); + IHttpRequestFilter* RequestFilter = m_HttpRequestFilter.load(); + if (!RequestFilter) + { + return IHttpRequestFilter::Result::Accepted; + } + IHttpRequestFilter::Result FilterResult = RequestFilter->FilterRequest(Request); + return FilterResult; +} + Ref<HttpServer> CreateHttpSysServer(HttpSysConfig Config) { |