aboutsummaryrefslogtreecommitdiff
path: root/src/zenhttp/servers/httpplugin.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/zenhttp/servers/httpplugin.cpp')
-rw-r--r--src/zenhttp/servers/httpplugin.cpp140
1 files changed, 95 insertions, 45 deletions
diff --git a/src/zenhttp/servers/httpplugin.cpp b/src/zenhttp/servers/httpplugin.cpp
index b9217ed87..4bf8c61bb 100644
--- a/src/zenhttp/servers/httpplugin.cpp
+++ b/src/zenhttp/servers/httpplugin.cpp
@@ -96,6 +96,7 @@ struct HttpPluginServerImpl : public HttpPluginServer, TransportServer
// HttpPluginServer
virtual void OnRegisterService(HttpService& Service) override;
+ virtual void OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter) override;
virtual int OnInitialize(int BasePort, std::filesystem::path DataDir) override;
virtual void OnRun(bool IsInteractiveSession) override;
virtual void OnRequestExit() override;
@@ -104,7 +105,8 @@ struct HttpPluginServerImpl : public HttpPluginServer, TransportServer
virtual void AddPlugin(Ref<TransportPlugin> Plugin) override;
virtual void RemovePlugin(Ref<TransportPlugin> Plugin) override;
- HttpService* RouteRequest(std::string_view Url);
+ HttpService* RouteRequest(std::string_view Url);
+ IHttpRequestFilter::Result FilterRequest(HttpServerRequest& Request);
struct ServiceEntry
{
@@ -112,7 +114,8 @@ struct HttpPluginServerImpl : public HttpPluginServer, TransportServer
HttpService* Service;
};
- bool m_IsInitialized = false;
+ std::atomic<IHttpRequestFilter*> m_HttpRequestFilter = nullptr;
+ bool m_IsInitialized = false;
RwLock m_Lock;
std::vector<ServiceEntry> m_UriHandlers;
std::vector<Ref<TransportPlugin>> m_Plugins;
@@ -120,7 +123,7 @@ struct HttpPluginServerImpl : public HttpPluginServer, TransportServer
bool m_IsRequestLoggingEnabled = false;
LoggerRef m_RequestLog;
std::atomic_uint32_t m_ConnectionIdCounter{0};
- int m_BasePort;
+ int m_BasePort = 0;
HttpServerTracer m_RequestTracer;
@@ -143,8 +146,11 @@ public:
HttpPluginServerRequest(const HttpPluginServerRequest&) = delete;
HttpPluginServerRequest& operator=(const HttpPluginServerRequest&) = delete;
- virtual Oid ParseSessionId() const override;
- virtual uint32_t ParseRequestId() const override;
+ // As this is plugin transport connection used for specialized connections we assume it is not a machine local connection
+ virtual bool IsLocalMachineRequest() const /* override*/ { return false; }
+ virtual std::string_view GetAuthorizationHeader() const override;
+ virtual Oid ParseSessionId() const override;
+ virtual uint32_t ParseRequestId() const override;
virtual IoBuffer ReadPayload() override;
virtual void WriteResponse(HttpResponseCode ResponseCode) override;
@@ -288,7 +294,7 @@ HttpPluginConnectionHandler::Initialize(TransportConnection* Transport, HttpPlug
ConnectionName = "anonymous";
}
- ZEN_LOG_TRACE(m_Server->m_RequestLog, "NEW connection #{} ('')", m_ConnectionId, ConnectionName);
+ ZEN_LOG_TRACE(m_Server->m_RequestLog, "NEW connection #{} ('{}')", m_ConnectionId, ConnectionName);
}
uint32_t
@@ -372,12 +378,14 @@ HttpPluginConnectionHandler::HandleRequest()
{
ZEN_TRACE_CPU("http_plugin::HandleRequest");
+ m_Server->MarkRequest();
+
HttpPluginServerRequest Request(m_RequestParser, *Service, m_RequestParser.Body());
const HttpVerb RequestVerb = Request.RequestVerb();
const std::string_view Uri = Request.RelativeUri();
- if (m_Server->m_RequestLog.ShouldLog(logging::level::Trace))
+ if (m_Server->m_RequestLog.ShouldLog(logging::Trace))
{
ZEN_LOG_TRACE(m_Server->m_RequestLog,
"connection #{} Handling Request: {} {} ({} bytes ({}), accept: {})",
@@ -392,53 +400,65 @@ HttpPluginConnectionHandler::HandleRequest()
std::vector<IoBuffer>{Request.ReadPayload()});
}
- if (!HandlePackageOffers(*Service, Request, m_PackageHandler))
+ IHttpRequestFilter::Result FilterResult = m_Server->FilterRequest(Request);
+ if (FilterResult == IHttpRequestFilter::Result::Accepted)
{
- try
- {
- Service->HandleRequest(Request);
- }
- catch (const AssertException& AssertEx)
+ if (!HandlePackageOffers(*Service, Request, m_PackageHandler))
{
- // Drop any partially formatted response
- Request.m_Response.reset();
-
- ZEN_ERROR("Caught assert exception while handling request: {}", AssertEx.FullDescription());
- Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, AssertEx.FullDescription());
- }
- catch (const std::system_error& SystemError)
- {
- // Drop any partially formatted response
- Request.m_Response.reset();
-
- if (IsOOM(SystemError.code()) || IsOOD(SystemError.code()))
+ try
{
- Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, SystemError.what());
+ Service->HandleRequest(Request);
}
- else
+ catch (const AssertException& AssertEx)
{
- ZEN_WARN("Caught system error exception while handling request: {}. ({})",
- SystemError.what(),
- SystemError.code().value());
- Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, SystemError.what());
+ // Drop any partially formatted response
+ Request.m_Response.reset();
+
+ ZEN_ERROR("Caught assert exception while handling request: {}", AssertEx.FullDescription());
+ Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, AssertEx.FullDescription());
}
- }
- catch (const std::bad_alloc& BadAlloc)
- {
- // Drop any partially formatted response
- Request.m_Response.reset();
+ catch (const std::system_error& SystemError)
+ {
+ // Drop any partially formatted response
+ Request.m_Response.reset();
+
+ if (IsOOM(SystemError.code()) || IsOOD(SystemError.code()))
+ {
+ Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, SystemError.what());
+ }
+ else
+ {
+ ZEN_WARN("Caught system error exception while handling request: {}. ({})",
+ SystemError.what(),
+ SystemError.code().value());
+ Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, SystemError.what());
+ }
+ }
+ catch (const std::bad_alloc& BadAlloc)
+ {
+ // Drop any partially formatted response
+ Request.m_Response.reset();
- Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, BadAlloc.what());
- }
- catch (const std::exception& ex)
- {
- // Drop any partially formatted response
- Request.m_Response.reset();
+ Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, BadAlloc.what());
+ }
+ catch (const std::exception& ex)
+ {
+ // Drop any partially formatted response
+ Request.m_Response.reset();
- ZEN_WARN("Caught exception while handling request: {}", ex.what());
- Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, ex.what());
+ ZEN_WARN("Caught exception while handling request: {}", ex.what());
+ Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, ex.what());
+ }
}
}
+ else if (FilterResult == IHttpRequestFilter::Result::Forbidden)
+ {
+ Request.WriteResponse(HttpResponseCode::Forbidden);
+ }
+ else
+ {
+ ZEN_ASSERT(FilterResult == IHttpRequestFilter::Result::ResponseSent);
+ }
if (std::unique_ptr<HttpPluginResponse> Response = std::move(Request.m_Response))
{
@@ -462,7 +482,7 @@ HttpPluginConnectionHandler::HandleRequest()
const std::vector<IoBuffer>& ResponseBuffers = Response->ResponseBuffers();
- if (m_Server->m_RequestLog.ShouldLog(logging::level::Trace))
+ if (m_Server->m_RequestLog.ShouldLog(logging::Trace))
{
m_Server->m_RequestTracer.WriteDebugPayload(fmt::format("response_{}_{}.bin", m_ConnectionId, RequestNumber),
ResponseBuffers);
@@ -618,6 +638,12 @@ HttpPluginServerRequest::~HttpPluginServerRequest()
{
}
+std::string_view
+HttpPluginServerRequest::GetAuthorizationHeader() const
+{
+ return m_Request.AuthorizationHeader();
+}
+
Oid
HttpPluginServerRequest::ParseSessionId() const
{
@@ -750,6 +776,13 @@ HttpPluginServerImpl::OnInitialize(int InBasePort, std::filesystem::path DataDir
}
void
+HttpPluginServerImpl::OnSetHttpRequestFilter(IHttpRequestFilter* RequestFilter)
+{
+ RwLock::ExclusiveLockScope _(m_Lock);
+ m_HttpRequestFilter.store(RequestFilter);
+}
+
+void
HttpPluginServerImpl::OnClose()
{
if (!m_IsInitialized)
@@ -806,6 +839,7 @@ HttpPluginServerImpl::OnRun(bool IsInteractive)
if (c == 27 || c == 'Q' || c == 'q')
{
+ m_ShutdownEvent.Set();
RequestApplicationExit(0);
}
}
@@ -894,6 +928,22 @@ HttpPluginServerImpl::RouteRequest(std::string_view Url)
return CandidateService;
}
+IHttpRequestFilter::Result
+HttpPluginServerImpl::FilterRequest(HttpServerRequest& Request)
+{
+ if (!m_HttpRequestFilter.load())
+ {
+ return IHttpRequestFilter::Result::Accepted;
+ }
+ RwLock::SharedLockScope _(m_Lock);
+ IHttpRequestFilter* RequestFilter = m_HttpRequestFilter.load();
+ if (!RequestFilter)
+ {
+ return IHttpRequestFilter::Result::Accepted;
+ }
+ return RequestFilter->FilterRequest(Request);
+}
+
//////////////////////////////////////////////////////////////////////////
struct HttpPluginServerImpl;