diff options
| author | Dan Engelbrecht <[email protected]> | 2026-02-13 13:27:08 +0100 |
|---|---|---|
| committer | GitHub Enterprise <[email protected]> | 2026-02-13 13:27:08 +0100 |
| commit | 3b5b777900d9f59ff32eb7cea79e3a72a08c67a6 (patch) | |
| tree | f5ffdeaad0ca9e291085d707209938c6dfe86d20 /src/zenhttp/servers | |
| parent | bump sentry to 0.12.1 (#721) (diff) | |
| download | zen-3b5b777900d9f59ff32eb7cea79e3a72a08c67a6.tar.xz zen-3b5b777900d9f59ff32eb7cea79e3a72a08c67a6.zip | |
add IHttpRequestFilter to allow server implementation to filter/reject requests (#753)
* add IHttpRequestFilter to allow server implementation to filter/reject requests
Diffstat (limited to 'src/zenhttp/servers')
| -rw-r--r-- | src/zenhttp/servers/httpasio.cpp | 131 | ||||
| -rw-r--r-- | src/zenhttp/servers/httpmulti.cpp | 9 | ||||
| -rw-r--r-- | src/zenhttp/servers/httpmulti.h | 1 | ||||
| -rw-r--r-- | src/zenhttp/servers/httpnull.cpp | 6 | ||||
| -rw-r--r-- | src/zenhttp/servers/httpnull.h | 1 | ||||
| -rw-r--r-- | src/zenhttp/servers/httpplugin.cpp | 117 | ||||
| -rw-r--r-- | src/zenhttp/servers/httpsys.cpp | 46 |
7 files changed, 228 insertions, 83 deletions
diff --git a/src/zenhttp/servers/httpasio.cpp b/src/zenhttp/servers/httpasio.cpp index 8bfbd8b37..230aac6a8 100644 --- a/src/zenhttp/servers/httpasio.cpp +++ b/src/zenhttp/servers/httpasio.cpp @@ -498,16 +498,19 @@ public: HttpAsioServerImpl(); ~HttpAsioServerImpl(); - void Initialize(std::filesystem::path DataDir); - int Start(uint16_t Port, const AsioConfig& Config); - void Stop(); - void RegisterService(const char* UrlPath, HttpService& Service); - HttpService* RouteRequest(std::string_view Url); + void Initialize(std::filesystem::path DataDir); + int Start(uint16_t Port, const AsioConfig& Config); + void Stop(); + void RegisterService(const char* UrlPath, HttpService& Service); + void SetHttpRequestFilter(IHttpRequestFilter* RequestFilter); + HttpService* RouteRequest(std::string_view Url); + IHttpRequestFilter::Result FilterRequest(HttpServerRequest& Request); asio::io_service m_IoService; asio::io_service::work m_Work{m_IoService}; std::unique_ptr<asio_http::HttpAcceptor> m_Acceptor; std::vector<std::thread> m_ThreadPool; + std::atomic<IHttpRequestFilter*> m_HttpRequestFilter = nullptr; LoggerRef m_RequestLog; HttpServerTracer m_RequestTracer; @@ -1199,53 +1202,65 @@ HttpServerConnection::HandleRequest() std::vector<IoBuffer>{Request.ReadPayload()}); } - if (!HandlePackageOffers(*Service, Request, m_PackageHandler)) + IHttpRequestFilter::Result FilterResult = m_Server.FilterRequest(Request); + if (FilterResult == IHttpRequestFilter::Result::Accepted) { - try - { - Service->HandleRequest(Request); - } - catch (const AssertException& AssertEx) - { - // Drop any partially formatted response - Request.m_Response.reset(); - - ZEN_ERROR("Caught assert exception while handling request: {}", AssertEx.FullDescription()); - Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, AssertEx.FullDescription()); - } - catch (const std::system_error& SystemError) + if (!HandlePackageOffers(*Service, Request, m_PackageHandler)) { - // Drop any partially formatted response - Request.m_Response.reset(); - - if (IsOOM(SystemError.code()) || IsOOD(SystemError.code())) + try { - Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, SystemError.what()); + Service->HandleRequest(Request); } - else + catch (const AssertException& AssertEx) { - ZEN_WARN("Caught system error exception while handling request: {}. ({})", - SystemError.what(), - SystemError.code().value()); - Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, SystemError.what()); + // Drop any partially formatted response + Request.m_Response.reset(); + + ZEN_ERROR("Caught assert exception while handling request: {}", AssertEx.FullDescription()); + Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, AssertEx.FullDescription()); } - } - catch (const std::bad_alloc& BadAlloc) - { - // Drop any partially formatted response - Request.m_Response.reset(); + catch (const std::system_error& SystemError) + { + // Drop any partially formatted response + Request.m_Response.reset(); - Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, BadAlloc.what()); - } - catch (const std::exception& ex) - { - // Drop any partially formatted response - Request.m_Response.reset(); + if (IsOOM(SystemError.code()) || IsOOD(SystemError.code())) + { + Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, SystemError.what()); + } + else + { + ZEN_WARN("Caught system error exception while handling request: {}. ({})", + SystemError.what(), + SystemError.code().value()); + Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, SystemError.what()); + } + } + catch (const std::bad_alloc& BadAlloc) + { + // Drop any partially formatted response + Request.m_Response.reset(); + + Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, BadAlloc.what()); + } + catch (const std::exception& ex) + { + // Drop any partially formatted response + Request.m_Response.reset(); - ZEN_WARN("Caught exception while handling request: {}", ex.what()); - Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, ex.what()); + ZEN_WARN("Caught exception while handling request: {}", ex.what()); + Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, ex.what()); + } } } + else if (FilterResult == IHttpRequestFilter::Result::Forbidden) + { + Request.WriteResponse(HttpResponseCode::Forbidden); + } + else + { + ZEN_ASSERT(FilterResult == IHttpRequestFilter::Result::ResponseSent); + } if (std::unique_ptr<HttpResponse> Response = std::move(Request.m_Response)) { @@ -1923,6 +1938,31 @@ HttpAsioServerImpl::RouteRequest(std::string_view Url) return CandidateService; } +void +HttpAsioServerImpl::SetHttpRequestFilter(IHttpRequestFilter* RequestFilter) +{ + ZEN_MEMSCOPE(GetHttpasioTag()); + RwLock::ExclusiveLockScope _(m_Lock); + m_HttpRequestFilter.store(RequestFilter); +} + +IHttpRequestFilter::Result +HttpAsioServerImpl::FilterRequest(HttpServerRequest& Request) +{ + if (!m_HttpRequestFilter.load()) + { + return IHttpRequestFilter::Result::Accepted; + } + RwLock::SharedLockScope _(m_Lock); + IHttpRequestFilter* RequestFilter = m_HttpRequestFilter.load(); + if (!RequestFilter) + { + return IHttpRequestFilter::Result::Accepted; + } + IHttpRequestFilter::Result FilterResult = RequestFilter->FilterRequest(Request); + return FilterResult; +} + } // namespace zen::asio_http ////////////////////////////////////////////////////////////////////////// @@ -1937,6 +1977,7 @@ public: virtual void OnRegisterService(HttpService& Service) override; virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) override; + virtual void OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) override; virtual void OnRun(bool IsInteractiveSession) override; virtual void OnRequestExit() override; virtual void OnClose() override; @@ -1984,6 +2025,12 @@ HttpAsioServer::OnRegisterService(HttpService& Service) m_Impl->RegisterService(Service.BaseUri(), Service); } +void +HttpAsioServer::OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) +{ + m_Impl->SetHttpRequestFilter(RequestFilter); +} + int HttpAsioServer::OnInitialize(int BasePort, std::filesystem::path DataDir) { diff --git a/src/zenhttp/servers/httpmulti.cpp b/src/zenhttp/servers/httpmulti.cpp index 95624245f..850d7d6b9 100644 --- a/src/zenhttp/servers/httpmulti.cpp +++ b/src/zenhttp/servers/httpmulti.cpp @@ -54,6 +54,15 @@ HttpMultiServer::OnInitialize(int BasePort, std::filesystem::path DataDir) } void +HttpMultiServer::OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) +{ + for (auto& Server : m_Servers) + { + Server->SetHttpRequestFilter(RequestFilter); + } +} + +void HttpMultiServer::OnRun(bool IsInteractiveSession) { const int WaitTimeout = 1000; diff --git a/src/zenhttp/servers/httpmulti.h b/src/zenhttp/servers/httpmulti.h index ae0ed74cf..1897587a9 100644 --- a/src/zenhttp/servers/httpmulti.h +++ b/src/zenhttp/servers/httpmulti.h @@ -16,6 +16,7 @@ public: ~HttpMultiServer(); virtual void OnRegisterService(HttpService& Service) override; + virtual void OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) override; virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) override; virtual void OnRun(bool IsInteractiveSession) override; virtual void OnRequestExit() override; diff --git a/src/zenhttp/servers/httpnull.cpp b/src/zenhttp/servers/httpnull.cpp index b770b97db..db360c5fb 100644 --- a/src/zenhttp/servers/httpnull.cpp +++ b/src/zenhttp/servers/httpnull.cpp @@ -24,6 +24,12 @@ HttpNullServer::OnRegisterService(HttpService& Service) ZEN_UNUSED(Service); } +void +HttpNullServer::OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) +{ + ZEN_UNUSED(RequestFilter); +} + int HttpNullServer::OnInitialize(int BasePort, std::filesystem::path DataDir) { diff --git a/src/zenhttp/servers/httpnull.h b/src/zenhttp/servers/httpnull.h index ce7230938..52838f012 100644 --- a/src/zenhttp/servers/httpnull.h +++ b/src/zenhttp/servers/httpnull.h @@ -18,6 +18,7 @@ public: ~HttpNullServer(); virtual void OnRegisterService(HttpService& Service) override; + virtual void OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) override; virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) override; virtual void OnRun(bool IsInteractiveSession) override; virtual void OnRequestExit() override; diff --git a/src/zenhttp/servers/httpplugin.cpp b/src/zenhttp/servers/httpplugin.cpp index 1a630c16f..4219dc292 100644 --- a/src/zenhttp/servers/httpplugin.cpp +++ b/src/zenhttp/servers/httpplugin.cpp @@ -96,6 +96,7 @@ struct HttpPluginServerImpl : public HttpPluginServer, TransportServer // HttpPluginServer virtual void OnRegisterService(HttpService& Service) override; + virtual void OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) override; virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) override; virtual void OnRun(bool IsInteractiveSession) override; virtual void OnRequestExit() override; @@ -104,7 +105,8 @@ struct HttpPluginServerImpl : public HttpPluginServer, TransportServer virtual void AddPlugin(Ref<TransportPlugin> Plugin) override; virtual void RemovePlugin(Ref<TransportPlugin> Plugin) override; - HttpService* RouteRequest(std::string_view Url); + HttpService* RouteRequest(std::string_view Url); + IHttpRequestFilter::Result FilterRequest(HttpServerRequest& Request); struct ServiceEntry { @@ -112,7 +114,8 @@ struct HttpPluginServerImpl : public HttpPluginServer, TransportServer HttpService* Service; }; - bool m_IsInitialized = false; + std::atomic<IHttpRequestFilter*> m_HttpRequestFilter = nullptr; + bool m_IsInitialized = false; RwLock m_Lock; std::vector<ServiceEntry> m_UriHandlers; std::vector<Ref<TransportPlugin>> m_Plugins; @@ -395,53 +398,65 @@ HttpPluginConnectionHandler::HandleRequest() std::vector<IoBuffer>{Request.ReadPayload()}); } - if (!HandlePackageOffers(*Service, Request, m_PackageHandler)) + IHttpRequestFilter::Result FilterResult = m_Server->FilterRequest(Request); + if (FilterResult == IHttpRequestFilter::Result::Accepted) { - try - { - Service->HandleRequest(Request); - } - catch (const AssertException& AssertEx) + if (!HandlePackageOffers(*Service, Request, m_PackageHandler)) { - // Drop any partially formatted response - Request.m_Response.reset(); - - ZEN_ERROR("Caught assert exception while handling request: {}", AssertEx.FullDescription()); - Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, AssertEx.FullDescription()); - } - catch (const std::system_error& SystemError) - { - // Drop any partially formatted response - Request.m_Response.reset(); - - if (IsOOM(SystemError.code()) || IsOOD(SystemError.code())) + try { - Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, SystemError.what()); + Service->HandleRequest(Request); } - else + catch (const AssertException& AssertEx) { - ZEN_WARN("Caught system error exception while handling request: {}. ({})", - SystemError.what(), - SystemError.code().value()); - Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, SystemError.what()); + // Drop any partially formatted response + Request.m_Response.reset(); + + ZEN_ERROR("Caught assert exception while handling request: {}", AssertEx.FullDescription()); + Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, AssertEx.FullDescription()); } - } - catch (const std::bad_alloc& BadAlloc) - { - // Drop any partially formatted response - Request.m_Response.reset(); + catch (const std::system_error& SystemError) + { + // Drop any partially formatted response + Request.m_Response.reset(); + + if (IsOOM(SystemError.code()) || IsOOD(SystemError.code())) + { + Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, SystemError.what()); + } + else + { + ZEN_WARN("Caught system error exception while handling request: {}. ({})", + SystemError.what(), + SystemError.code().value()); + Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, SystemError.what()); + } + } + catch (const std::bad_alloc& BadAlloc) + { + // Drop any partially formatted response + Request.m_Response.reset(); - Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, BadAlloc.what()); - } - catch (const std::exception& ex) - { - // Drop any partially formatted response - Request.m_Response.reset(); + Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, BadAlloc.what()); + } + catch (const std::exception& ex) + { + // Drop any partially formatted response + Request.m_Response.reset(); - ZEN_WARN("Caught exception while handling request: {}", ex.what()); - Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, ex.what()); + ZEN_WARN("Caught exception while handling request: {}", ex.what()); + Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, ex.what()); + } } } + else if (FilterResult == IHttpRequestFilter::Result::Forbidden) + { + Request.WriteResponse(HttpResponseCode::Forbidden); + } + else + { + ZEN_ASSERT(FilterResult == IHttpRequestFilter::Result::ResponseSent); + } if (std::unique_ptr<HttpPluginResponse> Response = std::move(Request.m_Response)) { @@ -753,6 +768,13 @@ HttpPluginServerImpl::OnInitialize(int InBasePort, std::filesystem::path DataDir } void +HttpPluginServerImpl::OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) +{ + RwLock::ExclusiveLockScope _(m_Lock); + m_HttpRequestFilter.store(RequestFilter); +} + +void HttpPluginServerImpl::OnClose() { if (!m_IsInitialized) @@ -897,6 +919,23 @@ HttpPluginServerImpl::RouteRequest(std::string_view Url) return CandidateService; } +IHttpRequestFilter::Result +HttpPluginServerImpl::FilterRequest(HttpServerRequest& Request) +{ + if (!m_HttpRequestFilter.load()) + { + return IHttpRequestFilter::Result::Accepted; + } + RwLock::SharedLockScope _(m_Lock); + IHttpRequestFilter* RequestFilter = m_HttpRequestFilter.load(); + if (!RequestFilter) + { + return IHttpRequestFilter::Result::Accepted; + } + IHttpRequestFilter::Result FilterResult = RequestFilter->FilterRequest(Request); + return FilterResult; +} + ////////////////////////////////////////////////////////////////////////// struct HttpPluginServerImpl; diff --git a/src/zenhttp/servers/httpsys.cpp b/src/zenhttp/servers/httpsys.cpp index 01c4559a1..4df4cd079 100644 --- a/src/zenhttp/servers/httpsys.cpp +++ b/src/zenhttp/servers/httpsys.cpp @@ -94,6 +94,7 @@ public: virtual void OnRun(bool TestMode) override; virtual void OnRequestExit() override; virtual void OnRegisterService(HttpService& Service) override; + virtual void OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) override; virtual void OnClose() override; WorkerThreadPool& WorkPool(); @@ -101,6 +102,8 @@ public: inline bool IsOk() const { return m_IsOk; } inline bool IsAsyncResponseEnabled() const { return m_IsAsyncResponseEnabled; } + IHttpRequestFilter::Result FilterRequest(HttpServerRequest& Request); + private: int InitializeServer(int BasePort); void Cleanup(); @@ -137,6 +140,9 @@ private: int32_t m_MaxPendingRequests = 128; Event m_ShutdownEvent; HttpSysConfig m_InitialConfig; + + RwLock m_RequestFilterLock; + std::atomic<IHttpRequestFilter*> m_HttpRequestFilter = nullptr; }; } // namespace zen @@ -1672,9 +1678,21 @@ HttpSysTransaction::InvokeRequestHandler(HttpService& Service, IoBuffer Payload) otel::ScopedSpan HttpSpan(SpanNamer, SpanAnnotator); # endif - if (!HandlePackageOffers(Service, ThisRequest, m_PackageHandler)) + IHttpRequestFilter::Result FilterResult = m_HttpServer.FilterRequest(ThisRequest); + if (FilterResult == IHttpRequestFilter::Result::Accepted) + { + if (!HandlePackageOffers(Service, ThisRequest, m_PackageHandler)) + { + Service.HandleRequest(ThisRequest); + } + } + else if (FilterResult == IHttpRequestFilter::Result::Forbidden) { - Service.HandleRequest(ThisRequest); + ThisRequest.WriteResponse(HttpResponseCode::Forbidden); + } + else + { + ZEN_ASSERT(FilterResult == IHttpRequestFilter::Result::ResponseSent); } return ThisRequest; @@ -2244,6 +2262,30 @@ HttpSysServer::OnRegisterService(HttpService& Service) RegisterService(Service.BaseUri(), Service); } +void +HttpSysServer::OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) +{ + RwLock::ExclusiveLockScope _(m_RequestFilterLock); + m_HttpRequestFilter.store(RequestFilter); +} + +IHttpRequestFilter::Result +HttpSysServer::FilterRequest(HttpServerRequest& Request) +{ + if (!m_HttpRequestFilter.load()) + { + return IHttpRequestFilter::Result::Accepted; + } + RwLock::SharedLockScope _(m_RequestFilterLock); + IHttpRequestFilter* RequestFilter = m_HttpRequestFilter.load(); + if (!RequestFilter) + { + return IHttpRequestFilter::Result::Accepted; + } + IHttpRequestFilter::Result FilterResult = RequestFilter->FilterRequest(Request); + return FilterResult; +} + Ref<HttpServer> CreateHttpSysServer(HttpSysConfig Config) { |