aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDan Engelbrecht <[email protected]>2026-02-13 13:27:08 +0100
committerGitHub Enterprise <[email protected]>2026-02-13 13:27:08 +0100
commit3b5b777900d9f59ff32eb7cea79e3a72a08c67a6 (patch)
treef5ffdeaad0ca9e291085d707209938c6dfe86d20
parentbump sentry to 0.12.1 (#721) (diff)
downloadzen-3b5b777900d9f59ff32eb7cea79e3a72a08c67a6.tar.xz
zen-3b5b777900d9f59ff32eb7cea79e3a72a08c67a6.zip
add IHttpRequestFilter to allow server implementation to filter/reject requests (#753)
* add IHttpRequestFilter to allow server implementation to filter/reject requests
-rw-r--r--src/zenhttp/httpclient.cpp104
-rw-r--r--src/zenhttp/httpserver.cpp12
-rw-r--r--src/zenhttp/include/zenhttp/httpserver.h22
-rw-r--r--src/zenhttp/servers/httpasio.cpp131
-rw-r--r--src/zenhttp/servers/httpmulti.cpp9
-rw-r--r--src/zenhttp/servers/httpmulti.h1
-rw-r--r--src/zenhttp/servers/httpnull.cpp6
-rw-r--r--src/zenhttp/servers/httpnull.h1
-rw-r--r--src/zenhttp/servers/httpplugin.cpp117
-rw-r--r--src/zenhttp/servers/httpsys.cpp46
10 files changed, 353 insertions, 96 deletions
diff --git a/src/zenhttp/httpclient.cpp b/src/zenhttp/httpclient.cpp
index c77be8624..16729ce38 100644
--- a/src/zenhttp/httpclient.cpp
+++ b/src/zenhttp/httpclient.cpp
@@ -427,13 +427,13 @@ TEST_CASE("httpclient")
AsioServer->RegisterService(TestService);
- std::thread SeverThread([&]() { AsioServer->Run(false); });
+ std::thread ServerThread([&]() { AsioServer->Run(false); });
{
auto _ = MakeGuard([&]() {
- if (SeverThread.joinable())
+ if (ServerThread.joinable())
{
- SeverThread.join();
+ ServerThread.join();
}
AsioServer->Close();
});
@@ -502,13 +502,13 @@ TEST_CASE("httpclient")
HttpSysServer->RegisterService(TestService);
- std::thread SeverThread([&]() { HttpSysServer->Run(false); });
+ std::thread ServerThread([&]() { HttpSysServer->Run(false); });
{
auto _ = MakeGuard([&]() {
- if (SeverThread.joinable())
+ if (ServerThread.joinable())
{
- SeverThread.join();
+ ServerThread.join();
}
HttpSysServer->Close();
});
@@ -570,6 +570,98 @@ TEST_CASE("httpclient")
# endif // ZEN_PLATFORM_WINDOWS
}
+TEST_CASE("httpclient.requestfilter")
+{
+ using namespace std::literals;
+
+ struct TestHttpService : public HttpService
+ {
+ TestHttpService() = default;
+
+ virtual const char* BaseUri() const override { return "/test/"; }
+ virtual void HandleRequest(HttpServerRequest& HttpServiceRequest) override
+ {
+ if (HttpServiceRequest.RelativeUri() == "yo")
+ {
+ return HttpServiceRequest.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "hey family");
+ }
+
+ {
+ CHECK(HttpServiceRequest.RelativeUri() != "should_filter");
+ return HttpServiceRequest.WriteResponse(HttpResponseCode::InternalServerError);
+ }
+
+ {
+ CHECK(HttpServiceRequest.RelativeUri() != "should_forbid");
+ return HttpServiceRequest.WriteResponse(HttpResponseCode::InternalServerError);
+ }
+ }
+ };
+
+ TestHttpService TestService;
+ ScopedTemporaryDirectory TmpDir;
+
+ class MyFilterImpl : public IHttpRequestFilter
+ {
+ public:
+ virtual Result FilterRequest(HttpServerRequest& Request)
+ {
+ if (Request.RelativeUri() == "should_filter")
+ {
+ Request.WriteResponse(HttpResponseCode::MethodNotAllowed, HttpContentType::kText, "no thank you");
+ return Result::ResponseSent;
+ }
+ else if (Request.RelativeUri() == "should_forbid")
+ {
+ return Result::Forbidden;
+ }
+ return Result::Accepted;
+ }
+ };
+
+ MyFilterImpl MyFilter;
+
+ Ref<HttpServer> AsioServer = CreateHttpAsioServer(AsioConfig{});
+
+ AsioServer->SetHttpRequestFilter(&MyFilter);
+
+ int Port = AsioServer->Initialize(7575, TmpDir.Path());
+ REQUIRE(Port != -1);
+
+ AsioServer->RegisterService(TestService);
+
+ std::thread ServerThread([&]() { AsioServer->Run(false); });
+
+ {
+ auto _ = MakeGuard([&]() {
+ if (ServerThread.joinable())
+ {
+ ServerThread.join();
+ }
+ AsioServer->Close();
+ });
+
+ HttpClient Client(fmt::format("localhost:{}", Port),
+ HttpClientSettings{},
+ /*CheckIfAbortFunction*/ {});
+
+ ZEN_INFO("Request using {}", Client.GetBaseUri());
+
+ HttpClient::Response YoResponse = Client.Get("/test/yo");
+ CHECK(YoResponse.IsSuccess());
+ CHECK_EQ(YoResponse.AsText(), "hey family");
+
+ HttpClient::Response ShouldFilterResponse = Client.Get("/test/should_filter");
+ CHECK_EQ(ShouldFilterResponse.StatusCode, HttpResponseCode::MethodNotAllowed);
+ CHECK_EQ(ShouldFilterResponse.AsText(), "no thank you");
+
+ HttpClient::Response ShouldForbitResponse = Client.Get("/test/should_forbid");
+ CHECK_EQ(ShouldForbitResponse.StatusCode, HttpResponseCode::Forbidden);
+
+ AsioServer->RequestExit();
+ }
+}
+
void
httpclient_forcelink()
{
diff --git a/src/zenhttp/httpserver.cpp b/src/zenhttp/httpserver.cpp
index 8985120b0..d8367fcb2 100644
--- a/src/zenhttp/httpserver.cpp
+++ b/src/zenhttp/httpserver.cpp
@@ -463,7 +463,7 @@ HttpService::HandlePackageRequest(HttpServerRequest& HttpServiceRequest)
//////////////////////////////////////////////////////////////////////////
-HttpServerRequest::HttpServerRequest(HttpService& Service) : m_BaseUri(Service.BaseUri())
+HttpServerRequest::HttpServerRequest(HttpService& Service) : m_Service(Service)
{
}
@@ -970,7 +970,7 @@ HttpRequestRouter::HandleRequest(zen::HttpServerRequest& Request)
if (otel::Span* ActiveSpan = otel::Span::GetCurrentSpan())
{
ExtendableStringBuilder<128> RoutePath;
- RoutePath.Append(Request.BaseUri());
+ RoutePath.Append(Request.Service().BaseUri());
RoutePath.Append(Handler.Pattern);
ActiveSpan->AddAttribute("http.route"sv, RoutePath.ToView());
}
@@ -994,7 +994,7 @@ HttpRequestRouter::HandleRequest(zen::HttpServerRequest& Request)
if (otel::Span* ActiveSpan = otel::Span::GetCurrentSpan())
{
ExtendableStringBuilder<128> RoutePath;
- RoutePath.Append(Request.BaseUri());
+ RoutePath.Append(Request.Service().BaseUri());
RoutePath.Append(Handler.Pattern);
ActiveSpan->AddAttribute("http.route"sv, RoutePath.ToView());
}
@@ -1052,6 +1052,12 @@ HttpServer::EnumerateServices(std::function<void(HttpService& Service)>&& Callba
}
}
+void
+HttpServer::SetHttpRequestFilter(IHttpRequestFilter* RequestFilter)
+{
+ OnSetHttpRequestFilter(RequestFilter);
+}
+
//////////////////////////////////////////////////////////////////////////
HttpRpcHandler::HttpRpcHandler()
diff --git a/src/zenhttp/include/zenhttp/httpserver.h b/src/zenhttp/include/zenhttp/httpserver.h
index f0a667686..60f6bc9f2 100644
--- a/src/zenhttp/include/zenhttp/httpserver.h
+++ b/src/zenhttp/include/zenhttp/httpserver.h
@@ -41,7 +41,7 @@ public:
[[nodiscard]] inline std::string_view RelativeUri() const { return m_Uri; } // Returns URI without service prefix
[[nodiscard]] std::string_view RelativeUriWithExtension() const { return m_UriWithExtension; }
[[nodiscard]] inline std::string_view QueryString() const { return m_QueryString; }
- [[nodiscard]] inline std::string_view BaseUri() const { return m_BaseUri; } // Service prefix
+ [[nodiscard]] inline HttpService& Service() const { return m_Service; }
struct QueryParams
{
@@ -121,13 +121,14 @@ protected:
kHaveSessionId = 1 << 3,
};
- mutable uint32_t m_Flags = 0;
+ mutable uint32_t m_Flags = 0;
+
+ HttpService& m_Service; // Service handling this request
HttpVerb m_Verb = HttpVerb::kGet;
HttpContentType m_ContentType = HttpContentType::kBinary;
HttpContentType m_AcceptType = HttpContentType::kUnknownContentType;
uint64_t m_ContentLength = ~0ull;
- std::string_view m_BaseUri; // Base URI path of the service handling this request
- std::string_view m_Uri; // URI without service prefix
+ std::string_view m_Uri; // URI without service prefix
std::string_view m_UriWithExtension;
std::string_view m_QueryString;
mutable uint32_t m_RequestId = ~uint32_t(0);
@@ -148,6 +149,17 @@ public:
virtual void OnRequestComplete() = 0;
};
+struct IHttpRequestFilter
+{
+ enum class Result
+ {
+ Forbidden,
+ ResponseSent,
+ Accepted
+ };
+ virtual Result FilterRequest(HttpServerRequest& Request) = 0;
+};
+
/**
* Base class for implementing an HTTP "service"
*
@@ -184,6 +196,7 @@ class HttpServer : public RefCounted
public:
void RegisterService(HttpService& Service);
void EnumerateServices(std::function<void(HttpService&)>&& Callback);
+ void SetHttpRequestFilter(IHttpRequestFilter* RequestFilter);
int Initialize(int BasePort, std::filesystem::path DataDir);
void Run(bool IsInteractiveSession);
@@ -195,6 +208,7 @@ private:
virtual void OnRegisterService(HttpService& Service) = 0;
virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) = 0;
+ virtual void OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) = 0;
virtual void OnRun(bool IsInteractiveSession) = 0;
virtual void OnRequestExit() = 0;
virtual void OnClose() = 0;
diff --git a/src/zenhttp/servers/httpasio.cpp b/src/zenhttp/servers/httpasio.cpp
index 8bfbd8b37..230aac6a8 100644
--- a/src/zenhttp/servers/httpasio.cpp
+++ b/src/zenhttp/servers/httpasio.cpp
@@ -498,16 +498,19 @@ public:
HttpAsioServerImpl();
~HttpAsioServerImpl();
- void Initialize(std::filesystem::path DataDir);
- int Start(uint16_t Port, const AsioConfig& Config);
- void Stop();
- void RegisterService(const char* UrlPath, HttpService& Service);
- HttpService* RouteRequest(std::string_view Url);
+ void Initialize(std::filesystem::path DataDir);
+ int Start(uint16_t Port, const AsioConfig& Config);
+ void Stop();
+ void RegisterService(const char* UrlPath, HttpService& Service);
+ void SetHttpRequestFilter(IHttpRequestFilter* RequestFilter);
+ HttpService* RouteRequest(std::string_view Url);
+ IHttpRequestFilter::Result FilterRequest(HttpServerRequest& Request);
asio::io_service m_IoService;
asio::io_service::work m_Work{m_IoService};
std::unique_ptr<asio_http::HttpAcceptor> m_Acceptor;
std::vector<std::thread> m_ThreadPool;
+ std::atomic<IHttpRequestFilter*> m_HttpRequestFilter = nullptr;
LoggerRef m_RequestLog;
HttpServerTracer m_RequestTracer;
@@ -1199,53 +1202,65 @@ HttpServerConnection::HandleRequest()
std::vector<IoBuffer>{Request.ReadPayload()});
}
- if (!HandlePackageOffers(*Service, Request, m_PackageHandler))
+ IHttpRequestFilter::Result FilterResult = m_Server.FilterRequest(Request);
+ if (FilterResult == IHttpRequestFilter::Result::Accepted)
{
- try
- {
- Service->HandleRequest(Request);
- }
- catch (const AssertException& AssertEx)
- {
- // Drop any partially formatted response
- Request.m_Response.reset();
-
- ZEN_ERROR("Caught assert exception while handling request: {}", AssertEx.FullDescription());
- Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, AssertEx.FullDescription());
- }
- catch (const std::system_error& SystemError)
+ if (!HandlePackageOffers(*Service, Request, m_PackageHandler))
{
- // Drop any partially formatted response
- Request.m_Response.reset();
-
- if (IsOOM(SystemError.code()) || IsOOD(SystemError.code()))
+ try
{
- Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, SystemError.what());
+ Service->HandleRequest(Request);
}
- else
+ catch (const AssertException& AssertEx)
{
- ZEN_WARN("Caught system error exception while handling request: {}. ({})",
- SystemError.what(),
- SystemError.code().value());
- Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, SystemError.what());
+ // Drop any partially formatted response
+ Request.m_Response.reset();
+
+ ZEN_ERROR("Caught assert exception while handling request: {}", AssertEx.FullDescription());
+ Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, AssertEx.FullDescription());
}
- }
- catch (const std::bad_alloc& BadAlloc)
- {
- // Drop any partially formatted response
- Request.m_Response.reset();
+ catch (const std::system_error& SystemError)
+ {
+ // Drop any partially formatted response
+ Request.m_Response.reset();
- Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, BadAlloc.what());
- }
- catch (const std::exception& ex)
- {
- // Drop any partially formatted response
- Request.m_Response.reset();
+ if (IsOOM(SystemError.code()) || IsOOD(SystemError.code()))
+ {
+ Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, SystemError.what());
+ }
+ else
+ {
+ ZEN_WARN("Caught system error exception while handling request: {}. ({})",
+ SystemError.what(),
+ SystemError.code().value());
+ Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, SystemError.what());
+ }
+ }
+ catch (const std::bad_alloc& BadAlloc)
+ {
+ // Drop any partially formatted response
+ Request.m_Response.reset();
+
+ Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, BadAlloc.what());
+ }
+ catch (const std::exception& ex)
+ {
+ // Drop any partially formatted response
+ Request.m_Response.reset();
- ZEN_WARN("Caught exception while handling request: {}", ex.what());
- Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, ex.what());
+ ZEN_WARN("Caught exception while handling request: {}", ex.what());
+ Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, ex.what());
+ }
}
}
+ else if (FilterResult == IHttpRequestFilter::Result::Forbidden)
+ {
+ Request.WriteResponse(HttpResponseCode::Forbidden);
+ }
+ else
+ {
+ ZEN_ASSERT(FilterResult == IHttpRequestFilter::Result::ResponseSent);
+ }
if (std::unique_ptr<HttpResponse> Response = std::move(Request.m_Response))
{
@@ -1923,6 +1938,31 @@ HttpAsioServerImpl::RouteRequest(std::string_view Url)
return CandidateService;
}
+void
+HttpAsioServerImpl::SetHttpRequestFilter(IHttpRequestFilter* RequestFilter)
+{
+ ZEN_MEMSCOPE(GetHttpasioTag());
+ RwLock::ExclusiveLockScope _(m_Lock);
+ m_HttpRequestFilter.store(RequestFilter);
+}
+
+IHttpRequestFilter::Result
+HttpAsioServerImpl::FilterRequest(HttpServerRequest& Request)
+{
+ if (!m_HttpRequestFilter.load())
+ {
+ return IHttpRequestFilter::Result::Accepted;
+ }
+ RwLock::SharedLockScope _(m_Lock);
+ IHttpRequestFilter* RequestFilter = m_HttpRequestFilter.load();
+ if (!RequestFilter)
+ {
+ return IHttpRequestFilter::Result::Accepted;
+ }
+ IHttpRequestFilter::Result FilterResult = RequestFilter->FilterRequest(Request);
+ return FilterResult;
+}
+
} // namespace zen::asio_http
//////////////////////////////////////////////////////////////////////////
@@ -1937,6 +1977,7 @@ public:
virtual void OnRegisterService(HttpService& Service) override;
virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) override;
+ virtual void OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) override;
virtual void OnRun(bool IsInteractiveSession) override;
virtual void OnRequestExit() override;
virtual void OnClose() override;
@@ -1984,6 +2025,12 @@ HttpAsioServer::OnRegisterService(HttpService& Service)
m_Impl->RegisterService(Service.BaseUri(), Service);
}
+void
+HttpAsioServer::OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter)
+{
+ m_Impl->SetHttpRequestFilter(RequestFilter);
+}
+
int
HttpAsioServer::OnInitialize(int BasePort, std::filesystem::path DataDir)
{
diff --git a/src/zenhttp/servers/httpmulti.cpp b/src/zenhttp/servers/httpmulti.cpp
index 95624245f..850d7d6b9 100644
--- a/src/zenhttp/servers/httpmulti.cpp
+++ b/src/zenhttp/servers/httpmulti.cpp
@@ -54,6 +54,15 @@ HttpMultiServer::OnInitialize(int BasePort, std::filesystem::path DataDir)
}
void
+HttpMultiServer::OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter)
+{
+ for (auto& Server : m_Servers)
+ {
+ Server->SetHttpRequestFilter(RequestFilter);
+ }
+}
+
+void
HttpMultiServer::OnRun(bool IsInteractiveSession)
{
const int WaitTimeout = 1000;
diff --git a/src/zenhttp/servers/httpmulti.h b/src/zenhttp/servers/httpmulti.h
index ae0ed74cf..1897587a9 100644
--- a/src/zenhttp/servers/httpmulti.h
+++ b/src/zenhttp/servers/httpmulti.h
@@ -16,6 +16,7 @@ public:
~HttpMultiServer();
virtual void OnRegisterService(HttpService& Service) override;
+ virtual void OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) override;
virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) override;
virtual void OnRun(bool IsInteractiveSession) override;
virtual void OnRequestExit() override;
diff --git a/src/zenhttp/servers/httpnull.cpp b/src/zenhttp/servers/httpnull.cpp
index b770b97db..db360c5fb 100644
--- a/src/zenhttp/servers/httpnull.cpp
+++ b/src/zenhttp/servers/httpnull.cpp
@@ -24,6 +24,12 @@ HttpNullServer::OnRegisterService(HttpService& Service)
ZEN_UNUSED(Service);
}
+void
+HttpNullServer::OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter)
+{
+ ZEN_UNUSED(RequestFilter);
+}
+
int
HttpNullServer::OnInitialize(int BasePort, std::filesystem::path DataDir)
{
diff --git a/src/zenhttp/servers/httpnull.h b/src/zenhttp/servers/httpnull.h
index ce7230938..52838f012 100644
--- a/src/zenhttp/servers/httpnull.h
+++ b/src/zenhttp/servers/httpnull.h
@@ -18,6 +18,7 @@ public:
~HttpNullServer();
virtual void OnRegisterService(HttpService& Service) override;
+ virtual void OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) override;
virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) override;
virtual void OnRun(bool IsInteractiveSession) override;
virtual void OnRequestExit() override;
diff --git a/src/zenhttp/servers/httpplugin.cpp b/src/zenhttp/servers/httpplugin.cpp
index 1a630c16f..4219dc292 100644
--- a/src/zenhttp/servers/httpplugin.cpp
+++ b/src/zenhttp/servers/httpplugin.cpp
@@ -96,6 +96,7 @@ struct HttpPluginServerImpl : public HttpPluginServer, TransportServer
// HttpPluginServer
virtual void OnRegisterService(HttpService& Service) override;
+ virtual void OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) override;
virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) override;
virtual void OnRun(bool IsInteractiveSession) override;
virtual void OnRequestExit() override;
@@ -104,7 +105,8 @@ struct HttpPluginServerImpl : public HttpPluginServer, TransportServer
virtual void AddPlugin(Ref<TransportPlugin> Plugin) override;
virtual void RemovePlugin(Ref<TransportPlugin> Plugin) override;
- HttpService* RouteRequest(std::string_view Url);
+ HttpService* RouteRequest(std::string_view Url);
+ IHttpRequestFilter::Result FilterRequest(HttpServerRequest& Request);
struct ServiceEntry
{
@@ -112,7 +114,8 @@ struct HttpPluginServerImpl : public HttpPluginServer, TransportServer
HttpService* Service;
};
- bool m_IsInitialized = false;
+ std::atomic<IHttpRequestFilter*> m_HttpRequestFilter = nullptr;
+ bool m_IsInitialized = false;
RwLock m_Lock;
std::vector<ServiceEntry> m_UriHandlers;
std::vector<Ref<TransportPlugin>> m_Plugins;
@@ -395,53 +398,65 @@ HttpPluginConnectionHandler::HandleRequest()
std::vector<IoBuffer>{Request.ReadPayload()});
}
- if (!HandlePackageOffers(*Service, Request, m_PackageHandler))
+ IHttpRequestFilter::Result FilterResult = m_Server->FilterRequest(Request);
+ if (FilterResult == IHttpRequestFilter::Result::Accepted)
{
- try
- {
- Service->HandleRequest(Request);
- }
- catch (const AssertException& AssertEx)
+ if (!HandlePackageOffers(*Service, Request, m_PackageHandler))
{
- // Drop any partially formatted response
- Request.m_Response.reset();
-
- ZEN_ERROR("Caught assert exception while handling request: {}", AssertEx.FullDescription());
- Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, AssertEx.FullDescription());
- }
- catch (const std::system_error& SystemError)
- {
- // Drop any partially formatted response
- Request.m_Response.reset();
-
- if (IsOOM(SystemError.code()) || IsOOD(SystemError.code()))
+ try
{
- Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, SystemError.what());
+ Service->HandleRequest(Request);
}
- else
+ catch (const AssertException& AssertEx)
{
- ZEN_WARN("Caught system error exception while handling request: {}. ({})",
- SystemError.what(),
- SystemError.code().value());
- Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, SystemError.what());
+ // Drop any partially formatted response
+ Request.m_Response.reset();
+
+ ZEN_ERROR("Caught assert exception while handling request: {}", AssertEx.FullDescription());
+ Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, AssertEx.FullDescription());
}
- }
- catch (const std::bad_alloc& BadAlloc)
- {
- // Drop any partially formatted response
- Request.m_Response.reset();
+ catch (const std::system_error& SystemError)
+ {
+ // Drop any partially formatted response
+ Request.m_Response.reset();
+
+ if (IsOOM(SystemError.code()) || IsOOD(SystemError.code()))
+ {
+ Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, SystemError.what());
+ }
+ else
+ {
+ ZEN_WARN("Caught system error exception while handling request: {}. ({})",
+ SystemError.what(),
+ SystemError.code().value());
+ Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, SystemError.what());
+ }
+ }
+ catch (const std::bad_alloc& BadAlloc)
+ {
+ // Drop any partially formatted response
+ Request.m_Response.reset();
- Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, BadAlloc.what());
- }
- catch (const std::exception& ex)
- {
- // Drop any partially formatted response
- Request.m_Response.reset();
+ Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, BadAlloc.what());
+ }
+ catch (const std::exception& ex)
+ {
+ // Drop any partially formatted response
+ Request.m_Response.reset();
- ZEN_WARN("Caught exception while handling request: {}", ex.what());
- Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, ex.what());
+ ZEN_WARN("Caught exception while handling request: {}", ex.what());
+ Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, ex.what());
+ }
}
}
+ else if (FilterResult == IHttpRequestFilter::Result::Forbidden)
+ {
+ Request.WriteResponse(HttpResponseCode::Forbidden);
+ }
+ else
+ {
+ ZEN_ASSERT(FilterResult == IHttpRequestFilter::Result::ResponseSent);
+ }
if (std::unique_ptr<HttpPluginResponse> Response = std::move(Request.m_Response))
{
@@ -753,6 +768,13 @@ HttpPluginServerImpl::OnInitialize(int InBasePort, std::filesystem::path DataDir
}
void
+HttpPluginServerImpl::OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter)
+{
+ RwLock::ExclusiveLockScope _(m_Lock);
+ m_HttpRequestFilter.store(RequestFilter);
+}
+
+void
HttpPluginServerImpl::OnClose()
{
if (!m_IsInitialized)
@@ -897,6 +919,23 @@ HttpPluginServerImpl::RouteRequest(std::string_view Url)
return CandidateService;
}
+IHttpRequestFilter::Result
+HttpPluginServerImpl::FilterRequest(HttpServerRequest& Request)
+{
+ if (!m_HttpRequestFilter.load())
+ {
+ return IHttpRequestFilter::Result::Accepted;
+ }
+ RwLock::SharedLockScope _(m_Lock);
+ IHttpRequestFilter* RequestFilter = m_HttpRequestFilter.load();
+ if (!RequestFilter)
+ {
+ return IHttpRequestFilter::Result::Accepted;
+ }
+ IHttpRequestFilter::Result FilterResult = RequestFilter->FilterRequest(Request);
+ return FilterResult;
+}
+
//////////////////////////////////////////////////////////////////////////
struct HttpPluginServerImpl;
diff --git a/src/zenhttp/servers/httpsys.cpp b/src/zenhttp/servers/httpsys.cpp
index 01c4559a1..4df4cd079 100644
--- a/src/zenhttp/servers/httpsys.cpp
+++ b/src/zenhttp/servers/httpsys.cpp
@@ -94,6 +94,7 @@ public:
virtual void OnRun(bool TestMode) override;
virtual void OnRequestExit() override;
virtual void OnRegisterService(HttpService& Service) override;
+ virtual void OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) override;
virtual void OnClose() override;
WorkerThreadPool& WorkPool();
@@ -101,6 +102,8 @@ public:
inline bool IsOk() const { return m_IsOk; }
inline bool IsAsyncResponseEnabled() const { return m_IsAsyncResponseEnabled; }
+ IHttpRequestFilter::Result FilterRequest(HttpServerRequest& Request);
+
private:
int InitializeServer(int BasePort);
void Cleanup();
@@ -137,6 +140,9 @@ private:
int32_t m_MaxPendingRequests = 128;
Event m_ShutdownEvent;
HttpSysConfig m_InitialConfig;
+
+ RwLock m_RequestFilterLock;
+ std::atomic<IHttpRequestFilter*> m_HttpRequestFilter = nullptr;
};
} // namespace zen
@@ -1672,9 +1678,21 @@ HttpSysTransaction::InvokeRequestHandler(HttpService& Service, IoBuffer Payload)
otel::ScopedSpan HttpSpan(SpanNamer, SpanAnnotator);
# endif
- if (!HandlePackageOffers(Service, ThisRequest, m_PackageHandler))
+ IHttpRequestFilter::Result FilterResult = m_HttpServer.FilterRequest(ThisRequest);
+ if (FilterResult == IHttpRequestFilter::Result::Accepted)
+ {
+ if (!HandlePackageOffers(Service, ThisRequest, m_PackageHandler))
+ {
+ Service.HandleRequest(ThisRequest);
+ }
+ }
+ else if (FilterResult == IHttpRequestFilter::Result::Forbidden)
{
- Service.HandleRequest(ThisRequest);
+ ThisRequest.WriteResponse(HttpResponseCode::Forbidden);
+ }
+ else
+ {
+ ZEN_ASSERT(FilterResult == IHttpRequestFilter::Result::ResponseSent);
}
return ThisRequest;
@@ -2244,6 +2262,30 @@ HttpSysServer::OnRegisterService(HttpService& Service)
RegisterService(Service.BaseUri(), Service);
}
+void
+HttpSysServer::OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter)
+{
+ RwLock::ExclusiveLockScope _(m_RequestFilterLock);
+ m_HttpRequestFilter.store(RequestFilter);
+}
+
+IHttpRequestFilter::Result
+HttpSysServer::FilterRequest(HttpServerRequest& Request)
+{
+ if (!m_HttpRequestFilter.load())
+ {
+ return IHttpRequestFilter::Result::Accepted;
+ }
+ RwLock::SharedLockScope _(m_RequestFilterLock);
+ IHttpRequestFilter* RequestFilter = m_HttpRequestFilter.load();
+ if (!RequestFilter)
+ {
+ return IHttpRequestFilter::Result::Accepted;
+ }
+ IHttpRequestFilter::Result FilterResult = RequestFilter->FilterRequest(Request);
+ return FilterResult;
+}
+
Ref<HttpServer>
CreateHttpSysServer(HttpSysConfig Config)
{