diff options
| author | Stefan Boberg <[email protected]> | 2023-10-13 09:55:27 +0200 |
|---|---|---|
| committer | GitHub <[email protected]> | 2023-10-13 09:55:27 +0200 |
| commit | 74d104d4eb3735e0881f0e1fccc2df8aa4d3f57d (patch) | |
| tree | acae59dac67b4d051403f35e580201c214ec4fda /src/zenhttp/servers/httpplugin.cpp | |
| parent | faster oplog iteration (#471) (diff) | |
| download | zen-74d104d4eb3735e0881f0e1fccc2df8aa4d3f57d.tar.xz zen-74d104d4eb3735e0881f0e1fccc2df8aa4d3f57d.zip | |
restructured zenhttp (#472)
separating the http server implementations into a directory and moved diagsvcs into zenserver since it's somewhat hard-coded for it
Diffstat (limited to 'src/zenhttp/servers/httpplugin.cpp')
| -rw-r--r-- | src/zenhttp/servers/httpplugin.cpp | 781 |
1 files changed, 781 insertions, 0 deletions
diff --git a/src/zenhttp/servers/httpplugin.cpp b/src/zenhttp/servers/httpplugin.cpp new file mode 100644 index 000000000..2e934473e --- /dev/null +++ b/src/zenhttp/servers/httpplugin.cpp @@ -0,0 +1,781 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenhttp/httpplugin.h> + +#if ZEN_WITH_PLUGINS + +# include "httpparser.h" + +# include <zencore/except.h> +# include <zencore/logging.h> +# include <zencore/trace.h> +# include <zencore/workthreadpool.h> +# include <zenhttp/httpserver.h> + +# include <memory> +# include <string_view> + +# if ZEN_PLATFORM_WINDOWS +# include <conio.h> +# endif + +# define PLUGIN_VERBOSE_TRACE 1 + +# if PLUGIN_VERBOSE_TRACE +# define ZEN_TRACE_VERBOSE ZEN_TRACE +# else +# define ZEN_TRACE_VERBOSE(fmtstr, ...) +# endif + +namespace zen { + +struct HttpPluginServerImpl; +struct HttpPluginResponse; + +using namespace std::literals; + +////////////////////////////////////////////////////////////////////////// + +struct HttpPluginConnectionHandler : public TransportServerConnection, public HttpRequestParserCallbacks, RefCounted +{ + virtual uint32_t AddRef() const override; + virtual uint32_t Release() const override; + + virtual void OnBytesRead(const void* Buffer, size_t DataSize) override; + + // HttpRequestParserCallbacks + + virtual void HandleRequest() override; + virtual void TerminateConnection() override; + + void Initialize(TransportConnection* Transport, HttpPluginServerImpl& Server); + +private: + enum class RequestState + { + kInitialState, + kInitialRead, + kReadingMore, + kWriting, // Currently writing response, connection will be re-used + kWritingFinal, // Writing response, connection will be closed + kDone, + kTerminated + }; + + RequestState m_RequestState = RequestState::kInitialState; + HttpRequestParser m_RequestParser{*this}; + + uint32_t m_ConnectionId = 0; + Ref<IHttpPackageHandler> m_PackageHandler; + + TransportConnection* m_TransportConnection = nullptr; + HttpPluginServerImpl* m_Server = nullptr; +}; + +////////////////////////////////////////////////////////////////////////// + +struct HttpPluginServerImpl : public TransportServer +{ + HttpPluginServerImpl(); + ~HttpPluginServerImpl(); + + void AddPlugin(Ref<TransportPlugin> Plugin); + void RemovePlugin(Ref<TransportPlugin> Plugin); + + void Start(); + void Stop(); + + void RegisterService(const char* InUrlPath, HttpService& Service); + HttpService* RouteRequest(std::string_view Url); + + struct ServiceEntry + { + std::string ServiceUrlPath; + HttpService* Service; + }; + + RwLock m_Lock; + std::vector<ServiceEntry> m_UriHandlers; + std::vector<Ref<TransportPlugin>> m_Plugins; + + // TransportServer + + virtual TransportServerConnection* CreateConnectionHandler(TransportConnection* Connection) override; +}; + +/** This is the class which request handlers interface with when + generating responses + */ + +class HttpPluginServerRequest : public HttpServerRequest +{ +public: + HttpPluginServerRequest(HttpRequestParser& Request, HttpService& Service, IoBuffer PayloadBuffer); + ~HttpPluginServerRequest(); + + HttpPluginServerRequest(const HttpPluginServerRequest&) = delete; + HttpPluginServerRequest& operator=(const HttpPluginServerRequest&) = delete; + + virtual Oid ParseSessionId() const override; + virtual uint32_t ParseRequestId() const override; + + virtual IoBuffer ReadPayload() override; + virtual void WriteResponse(HttpResponseCode ResponseCode) override; + virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs) override; + virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) override; + virtual void WriteResponseAsync(std::function<void(HttpServerRequest&)>&& ContinuationHandler) override; + virtual bool TryGetRanges(HttpRanges& Ranges) override; + + using HttpServerRequest::WriteResponse; + + HttpRequestParser& m_Request; + IoBuffer m_PayloadBuffer; + std::unique_ptr<HttpPluginResponse> m_Response; +}; + +////////////////////////////////////////////////////////////////////////// + +struct HttpPluginResponse +{ +public: + HttpPluginResponse() = default; + explicit HttpPluginResponse(HttpContentType ContentType) : m_ContentType(ContentType) {} + + void InitializeForPayload(uint16_t ResponseCode, std::span<IoBuffer> BlobList); + + inline uint16_t ResponseCode() const { return m_ResponseCode; } + inline uint64_t ContentLength() const { return m_ContentLength; } + + const std::vector<IoBuffer>& ResponseBuffers() const { return m_ResponseBuffers; } + void SuppressPayload() { m_ResponseBuffers.resize(1); } + +private: + uint16_t m_ResponseCode = 0; + bool m_IsKeepAlive = true; + HttpContentType m_ContentType = HttpContentType::kBinary; + uint64_t m_ContentLength = 0; + std::vector<IoBuffer> m_ResponseBuffers; + ExtendableStringBuilder<160> m_Headers; + + std::string_view GetHeaders(); +}; + +void +HttpPluginResponse::InitializeForPayload(uint16_t ResponseCode, std::span<IoBuffer> BlobList) +{ + ZEN_TRACE_CPU("http_plugin::InitializeForPayload"); + + m_ResponseCode = ResponseCode; + const uint32_t ChunkCount = gsl::narrow<uint32_t>(BlobList.size()); + + m_ResponseBuffers.reserve(ChunkCount + 1); + m_ResponseBuffers.push_back({}); // Placeholder for header + + uint64_t TotalDataSize = 0; + + for (IoBuffer& Buffer : BlobList) + { + uint64_t BufferDataSize = Buffer.Size(); + + ZEN_ASSERT(BufferDataSize); + + TotalDataSize += BufferDataSize; + + IoBufferFileReference FileRef; + if (Buffer.GetFileReference(/* out */ FileRef)) + { + // TODO: Use direct file transfer, via TransmitFile/sendfile + + m_ResponseBuffers.emplace_back(std::move(Buffer)).MakeOwned(); + } + else + { + // Send from memory + + m_ResponseBuffers.emplace_back(std::move(Buffer)).MakeOwned(); + } + } + m_ContentLength = TotalDataSize; + + auto Headers = GetHeaders(); + m_ResponseBuffers[0] = IoBufferBuilder::MakeCloneFromMemory(Headers.data(), Headers.size()); +} + +std::string_view +HttpPluginResponse::GetHeaders() +{ + m_Headers << "HTTP/1.1 " << ResponseCode() << " " << ReasonStringForHttpResultCode(ResponseCode()) << "\r\n" + << "Content-Type: " << MapContentTypeToString(m_ContentType) << "\r\n" + << "Content-Length: " << ContentLength() << "\r\n"sv; + + if (!m_IsKeepAlive) + { + m_Headers << "Connection: close\r\n"sv; + } + + m_Headers << "\r\n"sv; + + return m_Headers; +} + +////////////////////////////////////////////////////////////////////////// + +void +HttpPluginConnectionHandler::Initialize(TransportConnection* Transport, HttpPluginServerImpl& Server) +{ + m_TransportConnection = Transport; + m_Server = &Server; +} + +uint32_t +HttpPluginConnectionHandler::AddRef() const +{ + return RefCounted::AddRef(); +} + +uint32_t +HttpPluginConnectionHandler::Release() const +{ + return RefCounted::Release(); +} + +void +HttpPluginConnectionHandler::OnBytesRead(const void* Buffer, size_t AvailableBytes) +{ + while (AvailableBytes) + { + const size_t ConsumedBytes = m_RequestParser.ConsumeData((const char*)Buffer, AvailableBytes); + + if (ConsumedBytes == ~0ull) + { + // terminate connection + + return TerminateConnection(); + } + + Buffer = reinterpret_cast<const uint8_t*>(Buffer) + ConsumedBytes; + AvailableBytes -= ConsumedBytes; + } +} + +// HttpRequestParserCallbacks + +void +HttpPluginConnectionHandler::HandleRequest() +{ + if (!m_RequestParser.IsKeepAlive()) + { + // Once response has been written, connection is done + m_RequestState = RequestState::kWritingFinal; + + // We're not going to read any more data from this socket + + const bool Receive = true; + const bool Transmit = false; + m_TransportConnection->Shutdown(Receive, Transmit); + } + else + { + m_RequestState = RequestState::kWriting; + } + + auto SendBuffer = [&](const IoBuffer& InBuffer) -> int64_t { + const char* Buffer = reinterpret_cast<const char*>(InBuffer.GetData()); + size_t Bytes = InBuffer.GetSize(); + + return m_TransportConnection->WriteBytes(Buffer, Bytes); + }; + + // Generate response + + if (HttpService* Service = m_Server->RouteRequest(m_RequestParser.Url())) + { + ZEN_TRACE_CPU("http_plugin::HandleRequest"); + + HttpPluginServerRequest Request(m_RequestParser, *Service, m_RequestParser.Body()); + + if (!HandlePackageOffers(*Service, Request, m_PackageHandler)) + { + try + { + Service->HandleRequest(Request); + } + catch (std::system_error& SystemError) + { + // Drop any partially formatted response + Request.m_Response.reset(); + + if (IsOOM(SystemError.code()) || IsOOD(SystemError.code())) + { + Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, SystemError.what()); + } + else + { + ZEN_ERROR("Caught system error exception while handling request: {}", SystemError.what()); + Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, SystemError.what()); + } + } + catch (std::bad_alloc& BadAlloc) + { + // Drop any partially formatted response + Request.m_Response.reset(); + + Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, BadAlloc.what()); + } + catch (std::exception& ex) + { + // Drop any partially formatted response + Request.m_Response.reset(); + + ZEN_ERROR("Caught exception while handling request: {}", ex.what()); + Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, ex.what()); + } + } + + if (std::unique_ptr<HttpPluginResponse> Response = std::move(Request.m_Response)) + { + // Transmit the response + + if (m_RequestParser.RequestVerb() == HttpVerb::kHead) + { + Response->SuppressPayload(); + } + + const std::vector<IoBuffer>& ResponseBuffers = Response->ResponseBuffers(); + + //// TODO: should cork/uncork for Linux? + + for (const IoBuffer& Buffer : ResponseBuffers) + { + int64_t SentBytes = SendBuffer(Buffer); + + if (SentBytes < 0) + { + TerminateConnection(); + + return; + } + } + + return; + } + } + + // No route found for request + + std::string_view Response; + + if (m_RequestParser.RequestVerb() == HttpVerb::kHead) + { + if (m_RequestParser.IsKeepAlive()) + { + Response = + "HTTP/1.1 404 NOT FOUND\r\n" + "\r\n"sv; + } + else + { + Response = + "HTTP/1.1 404 NOT FOUND\r\n" + "Connection: close\r\n" + "\r\n"sv; + } + } + else + { + if (m_RequestParser.IsKeepAlive()) + { + Response = + "HTTP/1.1 404 NOT FOUND\r\n" + "Content-Length: 23\r\n" + "Content-Type: text/plain\r\n" + "\r\n" + "No suitable route found"sv; + } + else + { + Response = + "HTTP/1.1 404 NOT FOUND\r\n" + "Content-Length: 23\r\n" + "Content-Type: text/plain\r\n" + "Connection: close\r\n" + "\r\n" + "No suitable route found"sv; + } + } + + const int64_t SentBytes = SendBuffer(IoBufferBuilder::MakeFromMemory(MakeMemoryView(Response))); + + if (SentBytes < 0) + { + TerminateConnection(); + + return; + } +} + +void +HttpPluginConnectionHandler::TerminateConnection() +{ + ZEN_ASSERT(m_TransportConnection); + m_TransportConnection->CloseConnection(); +} + +////////////////////////////////////////////////////////////////////////// + +HttpPluginServerRequest::HttpPluginServerRequest(HttpRequestParser& Request, HttpService& Service, IoBuffer PayloadBuffer) +: m_Request(Request) +, m_PayloadBuffer(std::move(PayloadBuffer)) +{ + const int PrefixLength = Service.UriPrefixLength(); + + std::string_view Uri = Request.Url(); + Uri.remove_prefix(std::min(PrefixLength, static_cast<int>(Uri.size()))); + m_Uri = Uri; + m_UriWithExtension = Uri; + m_QueryString = Request.QueryString(); + + m_Verb = Request.RequestVerb(); + m_ContentLength = Request.Body().Size(); + m_ContentType = Request.ContentType(); + + HttpContentType AcceptContentType = HttpContentType::kUnknownContentType; + + // Parse any extension, to allow requesting a particular response encoding via the URL + + { + std::string_view UriSuffix8{m_Uri}; + + const size_t LastComponentIndex = UriSuffix8.find_last_of('/'); + + if (LastComponentIndex != std::string_view::npos) + { + UriSuffix8.remove_prefix(LastComponentIndex); + } + + const size_t LastDotIndex = UriSuffix8.find_last_of('.'); + + if (LastDotIndex != std::string_view::npos) + { + UriSuffix8.remove_prefix(LastDotIndex + 1); + + AcceptContentType = ParseContentType(UriSuffix8); + + if (AcceptContentType != HttpContentType::kUnknownContentType) + { + m_Uri.remove_suffix(uint32_t(UriSuffix8.size() + 1)); + } + } + } + + // It an explicit content type extension was specified then we'll use that over any + // Accept: header value that may be present + + if (AcceptContentType != HttpContentType::kUnknownContentType) + { + m_AcceptType = AcceptContentType; + } + else + { + m_AcceptType = Request.AcceptType(); + } +} + +HttpPluginServerRequest::~HttpPluginServerRequest() +{ +} + +Oid +HttpPluginServerRequest::ParseSessionId() const +{ + return m_Request.SessionId(); +} + +uint32_t +HttpPluginServerRequest::ParseRequestId() const +{ + return m_Request.RequestId(); +} + +IoBuffer +HttpPluginServerRequest::ReadPayload() +{ + return m_PayloadBuffer; +} + +void +HttpPluginServerRequest::WriteResponse(HttpResponseCode ResponseCode) +{ + ZEN_ASSERT(!m_Response); + + m_Response.reset(new HttpPluginResponse(HttpContentType::kBinary)); + std::array<IoBuffer, 0> Empty; + + m_Response->InitializeForPayload((uint16_t)ResponseCode, Empty); +} + +void +HttpPluginServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs) +{ + ZEN_ASSERT(!m_Response); + + m_Response.reset(new HttpPluginResponse(ContentType)); + m_Response->InitializeForPayload((uint16_t)ResponseCode, Blobs); +} + +void +HttpPluginServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) +{ + ZEN_ASSERT(!m_Response); + m_Response.reset(new HttpPluginResponse(ContentType)); + + IoBuffer MessageBuffer(IoBuffer::Wrap, ResponseString.data(), ResponseString.size()); + std::array<IoBuffer, 1> SingleBufferList({MessageBuffer}); + + m_Response->InitializeForPayload((uint16_t)ResponseCode, SingleBufferList); +} + +void +HttpPluginServerRequest::WriteResponseAsync(std::function<void(HttpServerRequest&)>&& ContinuationHandler) +{ + ZEN_ASSERT(!m_Response); + + // Not one bit async, innit + ContinuationHandler(*this); +} + +bool +HttpPluginServerRequest::TryGetRanges(HttpRanges& Ranges) +{ + return TryParseHttpRangeHeader(m_Request.RangeHeader(), Ranges); +} + +////////////////////////////////////////////////////////////////////////// + +HttpPluginServerImpl::HttpPluginServerImpl() +{ +} + +HttpPluginServerImpl::~HttpPluginServerImpl() +{ +} + +TransportServerConnection* +HttpPluginServerImpl::CreateConnectionHandler(TransportConnection* Connection) +{ + HttpPluginConnectionHandler* Handler{new HttpPluginConnectionHandler()}; + Handler->Initialize(Connection, *this); + return Handler; +} + +void +HttpPluginServerImpl::Start() +{ + RwLock::ExclusiveLockScope _(m_Lock); + + for (auto& Plugin : m_Plugins) + { + try + { + Plugin->Initialize(this); + } + catch (std::exception& Ex) + { + ZEN_WARN("exception caught during plugin initialization: {}", Ex.what()); + } + } +} + +void +HttpPluginServerImpl::Stop() +{ + RwLock::ExclusiveLockScope _(m_Lock); + + for (auto& Plugin : m_Plugins) + { + try + { + Plugin->Shutdown(); + } + catch (std::exception& Ex) + { + ZEN_WARN("exception caught during plugin shutdown: {}", Ex.what()); + } + + Plugin = nullptr; + } + + m_Plugins.clear(); +} + +void +HttpPluginServerImpl::AddPlugin(Ref<TransportPlugin> Plugin) +{ + RwLock::ExclusiveLockScope _(m_Lock); + m_Plugins.emplace_back(std::move(Plugin)); +} + +void +HttpPluginServerImpl::RemovePlugin(Ref<TransportPlugin> Plugin) +{ + RwLock::ExclusiveLockScope _(m_Lock); + auto It = std::find(begin(m_Plugins), end(m_Plugins), Plugin); + if (It != m_Plugins.end()) + { + m_Plugins.erase(It); + } +} + +void +HttpPluginServerImpl::RegisterService(const char* InUrlPath, HttpService& Service) +{ + std::string_view UrlPath(InUrlPath); + Service.SetUriPrefixLength(UrlPath.size()); + if (!UrlPath.empty() && UrlPath.back() == '/') + { + UrlPath.remove_suffix(1); + } + + RwLock::ExclusiveLockScope _(m_Lock); + m_UriHandlers.push_back({std::string(UrlPath), &Service}); +} + +HttpService* +HttpPluginServerImpl::RouteRequest(std::string_view Url) +{ + RwLock::SharedLockScope _(m_Lock); + + HttpService* CandidateService = nullptr; + std::string::size_type CandidateMatchSize = 0; + for (const ServiceEntry& SvcEntry : m_UriHandlers) + { + const std::string& SvcUrl = SvcEntry.ServiceUrlPath; + const std::string::size_type SvcUrlSize = SvcUrl.size(); + if ((SvcUrlSize >= CandidateMatchSize) && Url.compare(0, SvcUrlSize, SvcUrl) == 0 && + ((SvcUrlSize == Url.size()) || (Url[SvcUrlSize] == '/'))) + { + CandidateMatchSize = SvcUrl.size(); + CandidateService = SvcEntry.Service; + } + } + + return CandidateService; +} + +////////////////////////////////////////////////////////////////////////// + +HttpPluginServer::HttpPluginServer(unsigned int ThreadCount) +: m_ThreadCount(ThreadCount != 0 ? ThreadCount : Max(std::thread::hardware_concurrency(), 8u)) +, m_Impl(new HttpPluginServerImpl) +{ +} + +HttpPluginServer::~HttpPluginServer() +{ + if (m_Impl) + { + ZEN_ERROR("~HttpPluginServer() called without calling Close() first"); + } +} + +int +HttpPluginServer::Initialize(int BasePort) +{ + try + { + m_Impl->Start(); + } + catch (std::exception& ex) + { + ZEN_WARN("Caught exception starting http plugin server: {}", ex.what()); + } + + return BasePort; +} + +void +HttpPluginServer::Close() +{ + try + { + m_Impl->Stop(); + } + catch (std::exception& ex) + { + ZEN_WARN("Caught exception stopping http plugin server: {}", ex.what()); + } + + delete m_Impl; + m_Impl = nullptr; +} + +void +HttpPluginServer::Run(bool IsInteractive) +{ + const bool TestMode = !IsInteractive; + + int WaitTimeout = -1; + if (!TestMode) + { + WaitTimeout = 1000; + } + +# if ZEN_PLATFORM_WINDOWS + if (TestMode == false) + { + zen::logging::ConsoleLog().info("Zen Server running (plugin HTTP). Press ESC or Q to quit"); + } + + do + { + if (!TestMode && _kbhit() != 0) + { + char c = (char)_getch(); + + if (c == 27 || c == 'Q' || c == 'q') + { + RequestApplicationExit(0); + } + } + + m_ShutdownEvent.Wait(WaitTimeout); + } while (!IsApplicationExitRequested()); +# else + if (TestMode == false) + { + zen::logging::ConsoleLog().info("Zen Server running (plugin HTTP). Ctrl-C to quit"); + } + + do + { + m_ShutdownEvent.Wait(WaitTimeout); + } while (!IsApplicationExitRequested()); +# endif +} + +void +HttpPluginServer::RequestExit() +{ + m_ShutdownEvent.Set(); +} + +void +HttpPluginServer::RegisterService(HttpService& Service) +{ + m_Impl->RegisterService(Service.BaseUri(), Service); +} + +void +HttpPluginServer::AddPlugin(Ref<TransportPlugin> Plugin) +{ + m_Impl->AddPlugin(Plugin); +} + +void +HttpPluginServer::RemovePlugin(Ref<TransportPlugin> Plugin) +{ + m_Impl->RemovePlugin(Plugin); +} + +} // namespace zen +#endif |