diff options
| author | Stefan Boberg <[email protected]> | 2023-10-10 13:30:07 +0200 |
|---|---|---|
| committer | GitHub <[email protected]> | 2023-10-10 13:30:07 +0200 |
| commit | 7905d0d21af95c6016768d7a8a81dd9204b34d24 (patch) | |
| tree | 076329f5fad1c33503ee611f6399853da2490567 /src | |
| parent | cache reference tracking (#455) (diff) | |
| download | zen-7905d0d21af95c6016768d7a8a81dd9204b34d24.tar.xz zen-7905d0d21af95c6016768d7a8a81dd9204b34d24.zip | |
experimental pluggable transport support (#436)
this change adds a `--http=plugin` mode where we support pluggable transports. Currently this defaults to a barebones blocking winsock implementation but there is also support for dynamic loading of transport plugins, which will be further developed in the near future.
Diffstat (limited to 'src')
| -rw-r--r-- | src/plugins/winsock/winsock.cpp | 350 | ||||
| -rw-r--r-- | src/plugins/winsock/xmake.lua | 18 | ||||
| -rw-r--r-- | src/zen/xmake.lua | 2 | ||||
| -rw-r--r-- | src/zenhttp/dlltransport.cpp | 251 | ||||
| -rw-r--r-- | src/zenhttp/dlltransport.h | 37 | ||||
| -rw-r--r-- | src/zenhttp/httpasio.cpp | 7 | ||||
| -rw-r--r-- | src/zenhttp/httpasio.h | 2 | ||||
| -rw-r--r-- | src/zenhttp/httpparser.cpp | 7 | ||||
| -rw-r--r-- | src/zenhttp/httpparser.h | 63 | ||||
| -rw-r--r-- | src/zenhttp/httpplugin.cpp | 781 | ||||
| -rw-r--r-- | src/zenhttp/httpserver.cpp | 35 | ||||
| -rw-r--r-- | src/zenhttp/include/zenhttp/httpplugin.h | 49 | ||||
| -rw-r--r-- | src/zenhttp/include/zenhttp/transportplugin.h | 97 | ||||
| -rw-r--r-- | src/zenhttp/winsocktransport.cpp | 367 | ||||
| -rw-r--r-- | src/zenhttp/winsocktransport.h | 15 |
15 files changed, 2039 insertions, 42 deletions
diff --git a/src/plugins/winsock/winsock.cpp b/src/plugins/winsock/winsock.cpp new file mode 100644 index 000000000..dca1fdbe7 --- /dev/null +++ b/src/plugins/winsock/winsock.cpp @@ -0,0 +1,350 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <inttypes.h> +#include <atomic> +#include <exception> +#include <future> +#include <memory> +#include <thread> + +#include <zencore/refcount.h> +#include <zencore/zencore.h> + +#ifndef _WIN32_WINNT +# define _WIN32_WINNT 0x0A00 +#endif + +ZEN_THIRD_PARTY_INCLUDES_START +#include <WS2tcpip.h> +#include <WinSock2.h> +#include <windows.h> +ZEN_THIRD_PARTY_INCLUDES_END + +#include <transportplugin.h> + +////////////////////////////////////////////////////////////////////////// + +class SocketTransportPlugin : public TransportPluginInterface, zen::RefCounted +{ +public: + SocketTransportPlugin(uint16_t BasePort, unsigned int ThreadCount); + ~SocketTransportPlugin(); + + // TransportPluginInterface implementation + + virtual uint32_t AddRef() const override; + virtual uint32_t Release() const override; + virtual void Initialize(TransportServerInterface* ServerInterface) override; + virtual void Shutdown() override; + virtual bool IsAvailable() override; + +private: + TransportServerInterface* m_ServerInterface = nullptr; + bool m_IsOk = true; + uint16_t m_BasePort = 0; + int m_ThreadCount = 0; + + SOCKET m_ListenSocket{}; + std::thread m_AcceptThread; + std::atomic_flag m_KeepRunning; + std::vector<std::future<void>> m_Connections; +}; + +struct SocketTransportConnection : public TransportConnectionInterface +{ +public: + SocketTransportConnection(); + ~SocketTransportConnection(); + + void Initialize(TransportServerConnectionHandler* ServerConnection, SOCKET ClientSocket); + void HandleConnection(); + + // TransportConnectionInterface implementation + + virtual int64_t WriteBytes(const void* Buffer, size_t DataSize) override; + virtual void Shutdown(bool Receive, bool Transmit) override; + virtual void CloseConnection() override; + +private: + zen::Ref<TransportServerConnectionHandler> m_ConnectionHandler; + SOCKET m_ClientSocket{}; + bool m_IsTerminated = false; +}; + +////////////////////////////////////////////////////////////////////////// + +SocketTransportConnection::SocketTransportConnection() +{ +} + +SocketTransportConnection::~SocketTransportConnection() +{ +} + +void +SocketTransportConnection::Initialize(TransportServerConnectionHandler* ServerConnection, SOCKET ClientSocket) +{ + // ZEN_ASSERT(!m_ConnectionHandler); + + m_ConnectionHandler = ServerConnection; + m_ClientSocket = ClientSocket; +} + +void +SocketTransportConnection::HandleConnection() +{ + // ZEN_ASSERT(m_ConnectionHandler); + + const int InputBufferSize = 64 * 1024; + std::unique_ptr<uint8_t[]> InputBuffer{new uint8_t[64 * 1024]}; + + do + { + const int RecvBytes = recv(m_ClientSocket, (char*)InputBuffer.get(), InputBufferSize, /* flags */ 0); + + if (RecvBytes == 0) + { + // Connection closed + return CloseConnection(); + } + else if (RecvBytes < 0) + { + // Error + return CloseConnection(); + } + + m_ConnectionHandler->OnBytesRead(InputBuffer.get(), RecvBytes); + } while (m_ClientSocket); +} + +void +SocketTransportConnection::CloseConnection() +{ + if (m_IsTerminated) + { + return; + } + + // ZEN_ASSERT(m_ClientSocket); + m_IsTerminated = true; + + shutdown(m_ClientSocket, SD_BOTH); // We won't be sending or receiving any more data + + closesocket(m_ClientSocket); + m_ClientSocket = 0; +} + +int64_t +SocketTransportConnection::WriteBytes(const void* Buffer, size_t DataSize) +{ + const uint8_t* BufferCursor = reinterpret_cast<const uint8_t*>(Buffer); + int64_t TotalBytesSent = 0; + + while (DataSize) + { + const int MaxBlockSize = 128 * 1024; + const int SendBlockSize = (DataSize > MaxBlockSize) ? MaxBlockSize : (int)DataSize; + const int SentBytes = send(m_ClientSocket, (const char*)BufferCursor, SendBlockSize, /* flags */ 0); + + if (SentBytes < 0) + { + // Error + return SentBytes; + } + + BufferCursor += SentBytes; + DataSize -= SentBytes; + TotalBytesSent += SentBytes; + } + + return TotalBytesSent; +} + +void +SocketTransportConnection::Shutdown(bool Receive, bool Transmit) +{ + if (Receive) + { + if (Transmit) + { + shutdown(m_ClientSocket, SD_BOTH); + } + else + { + shutdown(m_ClientSocket, SD_RECEIVE); + } + } + else if (Transmit) + { + shutdown(m_ClientSocket, SD_SEND); + } +} + +////////////////////////////////////////////////////////////////////////// + +SocketTransportPlugin::SocketTransportPlugin(uint16_t BasePort, unsigned int ThreadCount) +: m_BasePort(BasePort) +, m_ThreadCount(ThreadCount != 0 ? ThreadCount : std::max(std::thread::hardware_concurrency(), 8u)) +{ +#if ZEN_PLATFORM_WINDOWS + WSADATA wsaData; + if (int Result = WSAStartup(0x202, &wsaData); Result != 0) + { + m_IsOk = false; + WSACleanup(); + } +#endif +} + +SocketTransportPlugin::~SocketTransportPlugin() +{ + Shutdown(); + +#if ZEN_PLATFORM_WINDOWS + if (m_IsOk) + { + WSACleanup(); + } +#endif +} + +uint32_t +SocketTransportPlugin::AddRef() const +{ + return RefCounted::AddRef(); +} + +uint32_t +SocketTransportPlugin::Release() const +{ + return RefCounted::Release(); +} + +void +SocketTransportPlugin::Initialize(TransportServerInterface* ServerInterface) +{ + uint16_t Port = m_BasePort; + + m_ServerInterface = ServerInterface; + m_ListenSocket = socket(AF_INET6, SOCK_STREAM, 0); + + if (m_ListenSocket == SOCKET_ERROR || m_ListenSocket == INVALID_SOCKET) + { + throw std::system_error(std::error_code(WSAGetLastError(), std::system_category()), + "socket creation failed in HTTP plugin server init"); + } + + sockaddr_in6 Server{}; + Server.sin6_family = AF_INET6; + Server.sin6_port = htons(Port); + Server.sin6_addr = in6addr_any; + + if (int Result = bind(m_ListenSocket, (sockaddr*)&Server, sizeof(Server)); Result == SOCKET_ERROR) + { + throw std::system_error(std::error_code(WSAGetLastError(), std::system_category()), "bind call failed in HTTP plugin server init"); + } + + if (int Result = listen(m_ListenSocket, AF_INET6); Result == SOCKET_ERROR) + { + throw std::system_error(std::error_code(WSAGetLastError(), std::system_category()), + "listen call failed in HTTP plugin server init"); + } + + m_KeepRunning.test_and_set(); + + m_AcceptThread = std::thread([&] { + // SetCurrentThreadName("http_plugin_acceptor"); + + // ZEN_INFO("HTTP plugin server waiting for connections"); + + do + { + if (SOCKET ClientSocket = accept(m_ListenSocket, NULL, NULL); ClientSocket != SOCKET_ERROR) + { + int Flag = 1; + setsockopt(ClientSocket, IPPROTO_TCP, TCP_NODELAY, (char*)&Flag, sizeof(Flag)); + + // Handle new connection + SocketTransportConnection* Connection = new SocketTransportConnection(); + TransportServerConnectionHandler* ConnectionInterface{m_ServerInterface->CreateConnectionHandler(Connection)}; + Connection->Initialize(ConnectionInterface, ClientSocket); + + m_Connections.push_back(std::async(std::launch::async, [Connection] { + try + { + Connection->HandleConnection(); + } + catch (std::exception&) + { + // ZEN_WARN("exception caught in connection loop: {}", Ex.what()); + } + + delete Connection; + })); + } + else + { + } + } while (m_KeepRunning.test()); + + // ZEN_INFO("HTTP plugin server accept thread exit"); + }); +} + +void +SocketTransportPlugin::Shutdown() +{ + // TODO: all pending/ongoing work should be drained here as well + + m_KeepRunning.clear(); + + closesocket(m_ListenSocket); + m_ListenSocket = 0; + + if (m_AcceptThread.joinable()) + { + m_AcceptThread.join(); + } +} + +bool +SocketTransportPlugin::IsAvailable() +{ + return true; +} + +////////////////////////////////////////////////////////////////////////// + +TransportPluginInterface* +CreateTransportPluginInterface() +{ + return new SocketTransportPlugin(1337, 8); +} + +BOOL WINAPI +DllMain([[maybe_unused]] HINSTANCE hinstDLL, // handle to DLL module + DWORD fdwReason, // reason for calling function + LPVOID lpvReserved) // reserved +{ + // Perform actions based on the reason for calling. + switch (fdwReason) + { + case DLL_PROCESS_ATTACH: + break; + + case DLL_THREAD_ATTACH: + break; + + case DLL_THREAD_DETACH: + break; + + case DLL_PROCESS_DETACH: + if (lpvReserved != nullptr) + { + break; // do not do cleanup if process termination scenario + } + break; + } + + return TRUE; +} diff --git a/src/plugins/winsock/xmake.lua b/src/plugins/winsock/xmake.lua new file mode 100644 index 000000000..a4ef02a98 --- /dev/null +++ b/src/plugins/winsock/xmake.lua @@ -0,0 +1,18 @@ +-- Copyright Epic Games, Inc. All Rights Reserved. + +target("winsock") + set_kind("shared") + add_headerfiles("**.h") + add_files("**.cpp") + add_includedirs(".", "../../zenhttp/include/zenhttp", "../../zencore/include") + set_symbols("debug") + + add_cxxflags("/showIncludes") + + if is_mode("release") then + set_optimize("fastest") + end + + if is_plat("windows") then + add_links("Ws2_32") + end diff --git a/src/zen/xmake.lua b/src/zen/xmake.lua index b83999efc..fef48e7bc 100644 --- a/src/zen/xmake.lua +++ b/src/zen/xmake.lua @@ -17,7 +17,7 @@ target("zen") add_files("zen.rc") add_ldflags("/subsystem:console,5.02") add_ldflags("/LTCG") - add_ldflags("crypt32.lib", "wldap32.lib", "Ws2_32.lib") + add_links("crypt32", "wldap32", "Ws2_32") end if is_plat("macosx") then diff --git a/src/zenhttp/dlltransport.cpp b/src/zenhttp/dlltransport.cpp new file mode 100644 index 000000000..0bd5e3720 --- /dev/null +++ b/src/zenhttp/dlltransport.cpp @@ -0,0 +1,251 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "dlltransport.h" + +#include <zencore/except.h> +#include <zencore/logging.h> +#include <zencore/scopeguard.h> + +#include <exception> +#include <thread> +#include <vector> + +#if ZEN_WITH_PLUGINS + +namespace zen { + +struct DllTransportConnection : public TransportConnectionInterface +{ +public: + DllTransportConnection(); + ~DllTransportConnection(); + + void Initialize(TransportServerConnectionHandler& ServerConnection); + void HandleConnection(); + + // TransportConnectionInterface + + virtual int64_t WriteBytes(const void* Buffer, size_t DataSize) override; + virtual void Shutdown(bool Receive, bool Transmit) override; + virtual void CloseConnection() override; + +private: + Ref<TransportServerConnectionHandler> m_ConnectionHandler; + bool m_IsTerminated = false; +}; + +DllTransportConnection::DllTransportConnection() +{ +} + +DllTransportConnection::~DllTransportConnection() +{ +} + +void +DllTransportConnection::Initialize(TransportServerConnectionHandler& ServerConnection) +{ + m_ConnectionHandler = &ServerConnection; // TODO: this is awkward +} + +void +DllTransportConnection::HandleConnection() +{ +} + +void +DllTransportConnection::CloseConnection() +{ + if (m_IsTerminated) + { + return; + } + + m_IsTerminated = true; +} + +int64_t +DllTransportConnection::WriteBytes(const void* Buffer, size_t DataSize) +{ + ZEN_UNUSED(Buffer, DataSize); + return DataSize; +} + +void +DllTransportConnection::Shutdown(bool Receive, bool Transmit) +{ + ZEN_UNUSED(Receive, Transmit); +} + +////////////////////////////////////////////////////////////////////////// + +struct LoadedDll +{ + std::string Name; + std::filesystem::path LoadedFromPath; + Ref<TransportPluginInterface> Plugin; +}; + +class DllTransportPluginImpl +{ +public: + DllTransportPluginImpl(uint16_t BasePort, unsigned int ThreadCount); + ~DllTransportPluginImpl(); + + uint16_t Start(TransportServerInterface* ServerInterface); + void Stop(); + bool IsAvailable(); + void LoadDll(std::string_view Name); + +private: + TransportServerInterface* m_ServerInterface = nullptr; + RwLock m_Lock; + std::vector<LoadedDll> m_Transports; + uint16_t m_BasePort = 0; + int m_ThreadCount = 0; +}; + +DllTransportPluginImpl::DllTransportPluginImpl(uint16_t BasePort, unsigned int ThreadCount) +: m_BasePort(BasePort) +, m_ThreadCount(ThreadCount != 0 ? ThreadCount : Max(std::thread::hardware_concurrency(), 8u)) +{ +} + +DllTransportPluginImpl::~DllTransportPluginImpl() +{ +} + +uint16_t +DllTransportPluginImpl::Start(TransportServerInterface* ServerIface) +{ + m_ServerInterface = ServerIface; + + RwLock::ExclusiveLockScope _(m_Lock); + + for (LoadedDll& Transport : m_Transports) + { + try + { + Transport.Plugin->Initialize(ServerIface); + } + catch (const std::exception&) + { + // TODO: report + } + } + + return m_BasePort; +} + +void +DllTransportPluginImpl::Stop() +{ + RwLock::ExclusiveLockScope _(m_Lock); + + for (LoadedDll& Transport : m_Transports) + { + try + { + Transport.Plugin->Shutdown(); + } + catch (const std::exception&) + { + // TODO: report + } + } +} + +bool +DllTransportPluginImpl::IsAvailable() +{ + return true; +} + +void +DllTransportPluginImpl::LoadDll(std::string_view Name) +{ + ExtendableStringBuilder<128> DllPath; + DllPath << Name << ".dll"; + HMODULE DllHandle = LoadLibraryA(DllPath.c_str()); + + if (!DllHandle) + { + std::error_code Ec = MakeErrorCodeFromLastError(); + + throw std::system_error(Ec, fmt::format("failed to load transport DLL from '{}'", DllPath)); + } + + TransportPluginInterface* CreateTransportPluginInterface(); + + PfnCreateTransportPluginInterface CreatePlugin = + (PfnCreateTransportPluginInterface)GetProcAddress(DllHandle, "CreateTransportPluginInterface"); + + if (!CreatePlugin) + { + std::error_code Ec = MakeErrorCodeFromLastError(); + + FreeLibrary(DllHandle); + + throw std::system_error(Ec, fmt::format("API mismatch detected in transport DLL loaded from '{}'", DllPath)); + } + + LoadedDll NewDll; + + NewDll.Name = Name; + NewDll.LoadedFromPath = DllPath.c_str(); + NewDll.Plugin = CreatePlugin(); + + m_Transports.emplace_back(std::move(NewDll)); +} + +////////////////////////////////////////////////////////////////////////// + +DllTransportPlugin::DllTransportPlugin(uint16_t BasePort, unsigned int ThreadCount) +: m_Impl(std::make_unique<DllTransportPluginImpl>(BasePort, ThreadCount)) +{ +} + +DllTransportPlugin::~DllTransportPlugin() +{ + m_Impl->Stop(); +} + +uint32_t +DllTransportPlugin::AddRef() const +{ + return RefCounted::AddRef(); +} + +uint32_t +DllTransportPlugin::Release() const +{ + return RefCounted::Release(); +} + +void +DllTransportPlugin::Initialize(TransportServerInterface* ServerInterface) +{ + m_Impl->Start(ServerInterface); +} + +void +DllTransportPlugin::Shutdown() +{ + m_Impl->Stop(); +} + +bool +DllTransportPlugin::IsAvailable() +{ + return m_Impl->IsAvailable(); +} + +void +DllTransportPlugin::LoadDll(std::string_view Name) +{ + return m_Impl->LoadDll(Name); +} + +} // namespace zen + +#endif diff --git a/src/zenhttp/dlltransport.h b/src/zenhttp/dlltransport.h new file mode 100644 index 000000000..b13bab804 --- /dev/null +++ b/src/zenhttp/dlltransport.h @@ -0,0 +1,37 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenhttp/httpplugin.h> + +#if ZEN_WITH_PLUGINS + +namespace zen { + +class DllTransportPluginImpl; + +/** Transport plugin which supports dynamic loading of external transport + * provider modules + */ +class DllTransportPlugin : public TransportPluginInterface, RefCounted +{ +public: + DllTransportPlugin(uint16_t BasePort, unsigned int ThreadCount); + ~DllTransportPlugin(); + + virtual uint32_t AddRef() const override; + virtual uint32_t Release() const override; + + virtual void Initialize(TransportServerInterface* ServerInterface) override; + virtual void Shutdown() override; + virtual bool IsAvailable() override; + + void LoadDll(std::string_view Name); + +private: + std::unique_ptr<DllTransportPluginImpl> m_Impl; +}; + +} // namespace zen + +#endif diff --git a/src/zenhttp/httpasio.cpp b/src/zenhttp/httpasio.cpp index 702ca11fd..562f75e3d 100644 --- a/src/zenhttp/httpasio.cpp +++ b/src/zenhttp/httpasio.cpp @@ -214,7 +214,7 @@ private: ////////////////////////////////////////////////////////////////////////// -struct HttpServerConnection : public HttpConnectionBase, std::enable_shared_from_this<HttpServerConnection> +struct HttpServerConnection : public HttpRequestParserCallbacks, std::enable_shared_from_this<HttpServerConnection> { HttpServerConnection(HttpAsioServerImpl& Server, std::unique_ptr<asio::ip::tcp::socket>&& Socket); ~HttpServerConnection(); @@ -223,10 +223,11 @@ struct HttpServerConnection : public HttpConnectionBase, std::enable_shared_from // HttpConnectionBase implementation - virtual void HandleNewRequest() override; virtual void TerminateConnection() override; virtual void HandleRequest() override; + void HandleNewRequest(); + private: enum class RequestState { @@ -276,8 +277,6 @@ HttpServerConnection::~HttpServerConnection() void HttpServerConnection::HandleNewRequest() { - m_RequestData.Initialize(); - EnqueueRead(); } diff --git a/src/zenhttp/httpasio.h b/src/zenhttp/httpasio.h index 81aadfc23..e8d13a57f 100644 --- a/src/zenhttp/httpasio.h +++ b/src/zenhttp/httpasio.h @@ -10,8 +10,6 @@ namespace zen { namespace asio_http { - struct HttpServerConnection; - struct HttpAcceptor; struct HttpAsioServerImpl; } // namespace asio_http diff --git a/src/zenhttp/httpparser.cpp b/src/zenhttp/httpparser.cpp index ebfe36227..6b987151a 100644 --- a/src/zenhttp/httpparser.cpp +++ b/src/zenhttp/httpparser.cpp @@ -38,8 +38,7 @@ http_parser_settings HttpRequestParser::s_ParserSettings{ .on_chunk_header{}, .on_chunk_complete{}}; -void -HttpRequestParser::Initialize() +HttpRequestParser::HttpRequestParser(HttpRequestParserCallbacks& Connection) : m_Connection(Connection) { http_parser_init(&m_Parser, HTTP_REQUEST); m_Parser.data = this; @@ -47,6 +46,10 @@ HttpRequestParser::Initialize() ResetState(); } +HttpRequestParser::~HttpRequestParser() +{ +} + size_t HttpRequestParser::ConsumeData(const char* InputData, size_t DataSize) { diff --git a/src/zenhttp/httpparser.h b/src/zenhttp/httpparser.h index cce51fcca..219ac351d 100644 --- a/src/zenhttp/httpparser.h +++ b/src/zenhttp/httpparser.h @@ -11,20 +11,19 @@ ZEN_THIRD_PARTY_INCLUDES_END namespace zen { -class HttpConnectionBase +class HttpRequestParserCallbacks { public: - virtual ~HttpConnectionBase() = default; - virtual void HandleNewRequest() = 0; - virtual void TerminateConnection() = 0; - virtual void HandleRequest() = 0; + virtual ~HttpRequestParserCallbacks() = default; + virtual void HandleRequest() = 0; + virtual void TerminateConnection() = 0; }; struct HttpRequestParser { - explicit HttpRequestParser(HttpConnectionBase& Connection) : m_Connection(Connection) {} + explicit HttpRequestParser(HttpRequestParserCallbacks& Connection); + ~HttpRequestParser(); - void Initialize(); size_t ConsumeData(const char* InputData, size_t DataSize); void ResetState(); @@ -70,31 +69,31 @@ private: std::string_view Value; }; - HttpConnectionBase& m_Connection; - char* m_HeaderCursor = m_HeaderBuffer; - char* m_Url = nullptr; - size_t m_UrlLength = 0; - char* m_QueryString = nullptr; - size_t m_QueryLength = 0; - char* m_CurrentHeaderName = nullptr; // Used while parsing headers - size_t m_CurrentHeaderNameLength = 0; - char* m_CurrentHeaderValue = nullptr; // Used while parsing headers - size_t m_CurrentHeaderValueLength = 0; - std::vector<HeaderEntry> m_Headers; - int8_t m_ContentLengthHeaderIndex; - int8_t m_AcceptHeaderIndex; - int8_t m_ContentTypeHeaderIndex; - int8_t m_RangeHeaderIndex; - HttpVerb m_RequestVerb; - bool m_KeepAlive = false; - bool m_Expect100Continue = false; - int m_RequestId = -1; - Oid m_SessionId{}; - IoBuffer m_BodyBuffer; - uint64_t m_BodyPosition = 0; - http_parser m_Parser; - char m_HeaderBuffer[1024]; - std::string m_NormalizedUrl; + HttpRequestParserCallbacks& m_Connection; + char* m_HeaderCursor = m_HeaderBuffer; + char* m_Url = nullptr; + size_t m_UrlLength = 0; + char* m_QueryString = nullptr; + size_t m_QueryLength = 0; + char* m_CurrentHeaderName = nullptr; // Used while parsing headers + size_t m_CurrentHeaderNameLength = 0; + char* m_CurrentHeaderValue = nullptr; // Used while parsing headers + size_t m_CurrentHeaderValueLength = 0; + std::vector<HeaderEntry> m_Headers; + int8_t m_ContentLengthHeaderIndex; + int8_t m_AcceptHeaderIndex; + int8_t m_ContentTypeHeaderIndex; + int8_t m_RangeHeaderIndex; + HttpVerb m_RequestVerb; + bool m_KeepAlive = false; + bool m_Expect100Continue = false; + int m_RequestId = -1; + Oid m_SessionId{}; + IoBuffer m_BodyBuffer; + uint64_t m_BodyPosition = 0; + http_parser m_Parser; + char m_HeaderBuffer[1024]; + std::string m_NormalizedUrl; void AppendCurrentHeader(); diff --git a/src/zenhttp/httpplugin.cpp b/src/zenhttp/httpplugin.cpp new file mode 100644 index 000000000..45a6c23bd --- /dev/null +++ b/src/zenhttp/httpplugin.cpp @@ -0,0 +1,781 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include <zenhttp/httpplugin.h> + +#if ZEN_WITH_PLUGINS + +# include "httpparser.h" + +# include <zencore/except.h> +# include <zencore/logging.h> +# include <zencore/trace.h> +# include <zencore/workthreadpool.h> +# include <zenhttp/httpserver.h> + +# include <memory> +# include <string_view> + +# if ZEN_PLATFORM_WINDOWS +# include <conio.h> +# endif + +# define PLUGIN_VERBOSE_TRACE 1 + +# if PLUGIN_VERBOSE_TRACE +# define ZEN_TRACE_VERBOSE ZEN_TRACE +# else +# define ZEN_TRACE_VERBOSE(fmtstr, ...) +# endif + +namespace zen { + +struct HttpPluginServerImpl; +struct HttpPluginResponse; + +using namespace std::literals; + +////////////////////////////////////////////////////////////////////////// + +struct HttpPluginConnectionHandler : public TransportServerConnectionHandler, public HttpRequestParserCallbacks, RefCounted +{ + virtual uint32_t AddRef() const override; + virtual uint32_t Release() const override; + + virtual void OnBytesRead(const void* Buffer, size_t DataSize) override; + + // HttpRequestParserCallbacks + + virtual void HandleRequest() override; + virtual void TerminateConnection() override; + + void Initialize(TransportConnectionInterface* Transport, HttpPluginServerImpl& Server); + +private: + enum class RequestState + { + kInitialState, + kInitialRead, + kReadingMore, + kWriting, // Currently writing response, connection will be re-used + kWritingFinal, // Writing response, connection will be closed + kDone, + kTerminated + }; + + RequestState m_RequestState = RequestState::kInitialState; + HttpRequestParser m_RequestParser{*this}; + + uint32_t m_ConnectionId = 0; + Ref<IHttpPackageHandler> m_PackageHandler; + + TransportConnectionInterface* m_TransportConnection = nullptr; + HttpPluginServerImpl* m_Server = nullptr; +}; + +////////////////////////////////////////////////////////////////////////// + +struct HttpPluginServerImpl : public TransportServerInterface +{ + HttpPluginServerImpl(); + ~HttpPluginServerImpl(); + + void AddPlugin(Ref<TransportPluginInterface> Plugin); + void RemovePlugin(Ref<TransportPluginInterface> Plugin); + + void Start(); + void Stop(); + + void RegisterService(const char* InUrlPath, HttpService& Service); + HttpService* RouteRequest(std::string_view Url); + + struct ServiceEntry + { + std::string ServiceUrlPath; + HttpService* Service; + }; + + RwLock m_Lock; + std::vector<ServiceEntry> m_UriHandlers; + std::vector<Ref<TransportPluginInterface>> m_Plugins; + + // TransportServerInterface + + virtual TransportServerConnectionHandler* CreateConnectionHandler(TransportConnectionInterface* Connection) override; +}; + +/** This is the class which request handlers interface with when + generating responses + */ + +class HttpPluginServerRequest : public HttpServerRequest +{ +public: + HttpPluginServerRequest(HttpRequestParser& Request, HttpService& Service, IoBuffer PayloadBuffer); + ~HttpPluginServerRequest(); + + HttpPluginServerRequest(const HttpPluginServerRequest&) = delete; + HttpPluginServerRequest& operator=(const HttpPluginServerRequest&) = delete; + + virtual Oid ParseSessionId() const override; + virtual uint32_t ParseRequestId() const override; + + virtual IoBuffer ReadPayload() override; + virtual void WriteResponse(HttpResponseCode ResponseCode) override; + virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs) override; + virtual void WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) override; + virtual void WriteResponseAsync(std::function<void(HttpServerRequest&)>&& ContinuationHandler) override; + virtual bool TryGetRanges(HttpRanges& Ranges) override; + + using HttpServerRequest::WriteResponse; + + HttpRequestParser& m_Request; + IoBuffer m_PayloadBuffer; + std::unique_ptr<HttpPluginResponse> m_Response; +}; + +////////////////////////////////////////////////////////////////////////// + +struct HttpPluginResponse +{ +public: + HttpPluginResponse() = default; + explicit HttpPluginResponse(HttpContentType ContentType) : m_ContentType(ContentType) {} + + void InitializeForPayload(uint16_t ResponseCode, std::span<IoBuffer> BlobList); + + inline uint16_t ResponseCode() const { return m_ResponseCode; } + inline uint64_t ContentLength() const { return m_ContentLength; } + + const std::vector<IoBuffer>& ResponseBuffers() const { return m_ResponseBuffers; } + void SuppressPayload() { m_ResponseBuffers.resize(1); } + +private: + uint16_t m_ResponseCode = 0; + bool m_IsKeepAlive = true; + HttpContentType m_ContentType = HttpContentType::kBinary; + uint64_t m_ContentLength = 0; + std::vector<IoBuffer> m_ResponseBuffers; + ExtendableStringBuilder<160> m_Headers; + + std::string_view GetHeaders(); +}; + +void +HttpPluginResponse::InitializeForPayload(uint16_t ResponseCode, std::span<IoBuffer> BlobList) +{ + ZEN_TRACE_CPU("http_plugin::InitializeForPayload"); + + m_ResponseCode = ResponseCode; + const uint32_t ChunkCount = gsl::narrow<uint32_t>(BlobList.size()); + + m_ResponseBuffers.reserve(ChunkCount + 1); + m_ResponseBuffers.push_back({}); // Placeholder for header + + uint64_t TotalDataSize = 0; + + for (IoBuffer& Buffer : BlobList) + { + uint64_t BufferDataSize = Buffer.Size(); + + ZEN_ASSERT(BufferDataSize); + + TotalDataSize += BufferDataSize; + + IoBufferFileReference FileRef; + if (Buffer.GetFileReference(/* out */ FileRef)) + { + // TODO: Use direct file transfer, via TransmitFile/sendfile + + m_ResponseBuffers.emplace_back(std::move(Buffer)).MakeOwned(); + } + else + { + // Send from memory + + m_ResponseBuffers.emplace_back(std::move(Buffer)).MakeOwned(); + } + } + m_ContentLength = TotalDataSize; + + auto Headers = GetHeaders(); + m_ResponseBuffers[0] = IoBufferBuilder::MakeCloneFromMemory(Headers.data(), Headers.size()); +} + +std::string_view +HttpPluginResponse::GetHeaders() +{ + m_Headers << "HTTP/1.1 " << ResponseCode() << " " << ReasonStringForHttpResultCode(ResponseCode()) << "\r\n" + << "Content-Type: " << MapContentTypeToString(m_ContentType) << "\r\n" + << "Content-Length: " << ContentLength() << "\r\n"sv; + + if (!m_IsKeepAlive) + { + m_Headers << "Connection: close\r\n"sv; + } + + m_Headers << "\r\n"sv; + + return m_Headers; +} + +////////////////////////////////////////////////////////////////////////// + +void +HttpPluginConnectionHandler::Initialize(TransportConnectionInterface* Transport, HttpPluginServerImpl& Server) +{ + m_TransportConnection = Transport; + m_Server = &Server; +} + +uint32_t +HttpPluginConnectionHandler::AddRef() const +{ + return RefCounted::AddRef(); +} + +uint32_t +HttpPluginConnectionHandler::Release() const +{ + return RefCounted::Release(); +} + +void +HttpPluginConnectionHandler::OnBytesRead(const void* Buffer, size_t AvailableBytes) +{ + while (AvailableBytes) + { + const size_t ConsumedBytes = m_RequestParser.ConsumeData((const char*)Buffer, AvailableBytes); + + if (ConsumedBytes == ~0ull) + { + // terminate connection + + return TerminateConnection(); + } + + Buffer = reinterpret_cast<const uint8_t*>(Buffer) + ConsumedBytes; + AvailableBytes -= ConsumedBytes; + } +} + +// HttpRequestParserCallbacks + +void +HttpPluginConnectionHandler::HandleRequest() +{ + if (!m_RequestParser.IsKeepAlive()) + { + // Once response has been written, connection is done + m_RequestState = RequestState::kWritingFinal; + + // We're not going to read any more data from this socket + + const bool Receive = true; + const bool Transmit = false; + m_TransportConnection->Shutdown(Receive, Transmit); + } + else + { + m_RequestState = RequestState::kWriting; + } + + auto SendBuffer = [&](const IoBuffer& InBuffer) -> int64_t { + const char* Buffer = reinterpret_cast<const char*>(InBuffer.GetData()); + size_t Bytes = InBuffer.GetSize(); + + return m_TransportConnection->WriteBytes(Buffer, Bytes); + }; + + // Generate response + + if (HttpService* Service = m_Server->RouteRequest(m_RequestParser.Url())) + { + ZEN_TRACE_CPU("http_plugin::HandleRequest"); + + HttpPluginServerRequest Request(m_RequestParser, *Service, m_RequestParser.Body()); + + if (!HandlePackageOffers(*Service, Request, m_PackageHandler)) + { + try + { + Service->HandleRequest(Request); + } + catch (std::system_error& SystemError) + { + // Drop any partially formatted response + Request.m_Response.reset(); + + if (IsOOM(SystemError.code()) || IsOOD(SystemError.code())) + { + Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, SystemError.what()); + } + else + { + ZEN_ERROR("Caught system error exception while handling request: {}", SystemError.what()); + Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, SystemError.what()); + } + } + catch (std::bad_alloc& BadAlloc) + { + // Drop any partially formatted response + Request.m_Response.reset(); + + Request.WriteResponse(HttpResponseCode::InsufficientStorage, HttpContentType::kText, BadAlloc.what()); + } + catch (std::exception& ex) + { + // Drop any partially formatted response + Request.m_Response.reset(); + + ZEN_ERROR("Caught exception while handling request: {}", ex.what()); + Request.WriteResponse(HttpResponseCode::InternalServerError, HttpContentType::kText, ex.what()); + } + } + + if (std::unique_ptr<HttpPluginResponse> Response = std::move(Request.m_Response)) + { + // Transmit the response + + if (m_RequestParser.RequestVerb() == HttpVerb::kHead) + { + Response->SuppressPayload(); + } + + const std::vector<IoBuffer>& ResponseBuffers = Response->ResponseBuffers(); + + //// TODO: should cork/uncork for Linux? + + for (const IoBuffer& Buffer : ResponseBuffers) + { + int64_t SentBytes = SendBuffer(Buffer); + + if (SentBytes < 0) + { + TerminateConnection(); + + return; + } + } + + return; + } + } + + // No route found for request + + std::string_view Response; + + if (m_RequestParser.RequestVerb() == HttpVerb::kHead) + { + if (m_RequestParser.IsKeepAlive()) + { + Response = + "HTTP/1.1 404 NOT FOUND\r\n" + "\r\n"sv; + } + else + { + Response = + "HTTP/1.1 404 NOT FOUND\r\n" + "Connection: close\r\n" + "\r\n"sv; + } + } + else + { + if (m_RequestParser.IsKeepAlive()) + { + Response = + "HTTP/1.1 404 NOT FOUND\r\n" + "Content-Length: 23\r\n" + "Content-Type: text/plain\r\n" + "\r\n" + "No suitable route found"sv; + } + else + { + Response = + "HTTP/1.1 404 NOT FOUND\r\n" + "Content-Length: 23\r\n" + "Content-Type: text/plain\r\n" + "Connection: close\r\n" + "\r\n" + "No suitable route found"sv; + } + } + + const int64_t SentBytes = SendBuffer(IoBufferBuilder::MakeFromMemory(MakeMemoryView(Response))); + + if (SentBytes < 0) + { + TerminateConnection(); + + return; + } +} + +void +HttpPluginConnectionHandler::TerminateConnection() +{ + ZEN_ASSERT(m_TransportConnection); + m_TransportConnection->CloseConnection(); +} + +////////////////////////////////////////////////////////////////////////// + +HttpPluginServerRequest::HttpPluginServerRequest(HttpRequestParser& Request, HttpService& Service, IoBuffer PayloadBuffer) +: m_Request(Request) +, m_PayloadBuffer(std::move(PayloadBuffer)) +{ + const int PrefixLength = Service.UriPrefixLength(); + + std::string_view Uri = Request.Url(); + Uri.remove_prefix(std::min(PrefixLength, static_cast<int>(Uri.size()))); + m_Uri = Uri; + m_UriWithExtension = Uri; + m_QueryString = Request.QueryString(); + + m_Verb = Request.RequestVerb(); + m_ContentLength = Request.Body().Size(); + m_ContentType = Request.ContentType(); + + HttpContentType AcceptContentType = HttpContentType::kUnknownContentType; + + // Parse any extension, to allow requesting a particular response encoding via the URL + + { + std::string_view UriSuffix8{m_Uri}; + + const size_t LastComponentIndex = UriSuffix8.find_last_of('/'); + + if (LastComponentIndex != std::string_view::npos) + { + UriSuffix8.remove_prefix(LastComponentIndex); + } + + const size_t LastDotIndex = UriSuffix8.find_last_of('.'); + + if (LastDotIndex != std::string_view::npos) + { + UriSuffix8.remove_prefix(LastDotIndex + 1); + + AcceptContentType = ParseContentType(UriSuffix8); + + if (AcceptContentType != HttpContentType::kUnknownContentType) + { + m_Uri.remove_suffix(uint32_t(UriSuffix8.size() + 1)); + } + } + } + + // It an explicit content type extension was specified then we'll use that over any + // Accept: header value that may be present + + if (AcceptContentType != HttpContentType::kUnknownContentType) + { + m_AcceptType = AcceptContentType; + } + else + { + m_AcceptType = Request.AcceptType(); + } +} + +HttpPluginServerRequest::~HttpPluginServerRequest() +{ +} + +Oid +HttpPluginServerRequest::ParseSessionId() const +{ + return m_Request.SessionId(); +} + +uint32_t +HttpPluginServerRequest::ParseRequestId() const +{ + return m_Request.RequestId(); +} + +IoBuffer +HttpPluginServerRequest::ReadPayload() +{ + return m_PayloadBuffer; +} + +void +HttpPluginServerRequest::WriteResponse(HttpResponseCode ResponseCode) +{ + ZEN_ASSERT(!m_Response); + + m_Response.reset(new HttpPluginResponse(HttpContentType::kBinary)); + std::array<IoBuffer, 0> Empty; + + m_Response->InitializeForPayload((uint16_t)ResponseCode, Empty); +} + +void +HttpPluginServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::span<IoBuffer> Blobs) +{ + ZEN_ASSERT(!m_Response); + + m_Response.reset(new HttpPluginResponse(ContentType)); + m_Response->InitializeForPayload((uint16_t)ResponseCode, Blobs); +} + +void +HttpPluginServerRequest::WriteResponse(HttpResponseCode ResponseCode, HttpContentType ContentType, std::u8string_view ResponseString) +{ + ZEN_ASSERT(!m_Response); + m_Response.reset(new HttpPluginResponse(ContentType)); + + IoBuffer MessageBuffer(IoBuffer::Wrap, ResponseString.data(), ResponseString.size()); + std::array<IoBuffer, 1> SingleBufferList({MessageBuffer}); + + m_Response->InitializeForPayload((uint16_t)ResponseCode, SingleBufferList); +} + +void +HttpPluginServerRequest::WriteResponseAsync(std::function<void(HttpServerRequest&)>&& ContinuationHandler) +{ + ZEN_ASSERT(!m_Response); + + // Not one bit async, innit + ContinuationHandler(*this); +} + +bool +HttpPluginServerRequest::TryGetRanges(HttpRanges& Ranges) +{ + return TryParseHttpRangeHeader(m_Request.RangeHeader(), Ranges); +} + +////////////////////////////////////////////////////////////////////////// + +HttpPluginServerImpl::HttpPluginServerImpl() +{ +} + +HttpPluginServerImpl::~HttpPluginServerImpl() +{ +} + +TransportServerConnectionHandler* +HttpPluginServerImpl::CreateConnectionHandler(TransportConnectionInterface* Connection) +{ + HttpPluginConnectionHandler* Handler{new HttpPluginConnectionHandler()}; + Handler->Initialize(Connection, *this); + return Handler; +} + +void +HttpPluginServerImpl::Start() +{ + RwLock::ExclusiveLockScope _(m_Lock); + + for (auto& Plugin : m_Plugins) + { + try + { + Plugin->Initialize(this); + } + catch (std::exception& Ex) + { + ZEN_WARN("exception caught during plugin initialization: {}", Ex.what()); + } + } +} + +void +HttpPluginServerImpl::Stop() +{ + RwLock::ExclusiveLockScope _(m_Lock); + + for (auto& Plugin : m_Plugins) + { + try + { + Plugin->Shutdown(); + } + catch (std::exception& Ex) + { + ZEN_WARN("exception caught during plugin shutdown: {}", Ex.what()); + } + + Plugin = nullptr; + } + + m_Plugins.clear(); +} + +void +HttpPluginServerImpl::AddPlugin(Ref<TransportPluginInterface> Plugin) +{ + RwLock::ExclusiveLockScope _(m_Lock); + m_Plugins.emplace_back(std::move(Plugin)); +} + +void +HttpPluginServerImpl::RemovePlugin(Ref<TransportPluginInterface> Plugin) +{ + RwLock::ExclusiveLockScope _(m_Lock); + auto It = std::find(begin(m_Plugins), end(m_Plugins), Plugin); + if (It != m_Plugins.end()) + { + m_Plugins.erase(It); + } +} + +void +HttpPluginServerImpl::RegisterService(const char* InUrlPath, HttpService& Service) +{ + std::string_view UrlPath(InUrlPath); + Service.SetUriPrefixLength(UrlPath.size()); + if (!UrlPath.empty() && UrlPath.back() == '/') + { + UrlPath.remove_suffix(1); + } + + RwLock::ExclusiveLockScope _(m_Lock); + m_UriHandlers.push_back({std::string(UrlPath), &Service}); +} + +HttpService* +HttpPluginServerImpl::RouteRequest(std::string_view Url) +{ + RwLock::SharedLockScope _(m_Lock); + + HttpService* CandidateService = nullptr; + std::string::size_type CandidateMatchSize = 0; + for (const ServiceEntry& SvcEntry : m_UriHandlers) + { + const std::string& SvcUrl = SvcEntry.ServiceUrlPath; + const std::string::size_type SvcUrlSize = SvcUrl.size(); + if ((SvcUrlSize >= CandidateMatchSize) && Url.compare(0, SvcUrlSize, SvcUrl) == 0 && + ((SvcUrlSize == Url.size()) || (Url[SvcUrlSize] == '/'))) + { + CandidateMatchSize = SvcUrl.size(); + CandidateService = SvcEntry.Service; + } + } + + return CandidateService; +} + +////////////////////////////////////////////////////////////////////////// + +HttpPluginServer::HttpPluginServer(unsigned int ThreadCount) +: m_ThreadCount(ThreadCount != 0 ? ThreadCount : Max(std::thread::hardware_concurrency(), 8u)) +, m_Impl(new HttpPluginServerImpl) +{ +} + +HttpPluginServer::~HttpPluginServer() +{ + if (m_Impl) + { + ZEN_ERROR("~HttpPluginServer() called without calling Close() first"); + } +} + +int +HttpPluginServer::Initialize(int BasePort) +{ + try + { + m_Impl->Start(); + } + catch (std::exception& ex) + { + ZEN_WARN("Caught exception starting http plugin server: {}", ex.what()); + } + + return BasePort; +} + +void +HttpPluginServer::Close() +{ + try + { + m_Impl->Stop(); + } + catch (std::exception& ex) + { + ZEN_WARN("Caught exception stopping http plugin server: {}", ex.what()); + } + + delete m_Impl; + m_Impl = nullptr; +} + +void +HttpPluginServer::Run(bool IsInteractive) +{ + const bool TestMode = !IsInteractive; + + int WaitTimeout = -1; + if (!TestMode) + { + WaitTimeout = 1000; + } + +# if ZEN_PLATFORM_WINDOWS + if (TestMode == false) + { + zen::logging::ConsoleLog().info("Zen Server running (plugin HTTP). Press ESC or Q to quit"); + } + + do + { + if (!TestMode && _kbhit() != 0) + { + char c = (char)_getch(); + + if (c == 27 || c == 'Q' || c == 'q') + { + RequestApplicationExit(0); + } + } + + m_ShutdownEvent.Wait(WaitTimeout); + } while (!IsApplicationExitRequested()); +# else + if (TestMode == false) + { + zen::logging::ConsoleLog().info("Zen Server running (plugin HTTP). Ctrl-C to quit"); + } + + do + { + m_ShutdownEvent.Wait(WaitTimeout); + } while (!IsApplicationExitRequested()); +# endif +} + +void +HttpPluginServer::RequestExit() +{ + m_ShutdownEvent.Set(); +} + +void +HttpPluginServer::RegisterService(HttpService& Service) +{ + m_Impl->RegisterService(Service.BaseUri(), Service); +} + +void +HttpPluginServer::AddPlugin(Ref<TransportPluginInterface> Plugin) +{ + m_Impl->AddPlugin(Plugin); +} + +void +HttpPluginServer::RemovePlugin(Ref<TransportPluginInterface> Plugin) +{ + m_Impl->RemovePlugin(Plugin); +} + +} // namespace zen +#endif diff --git a/src/zenhttp/httpserver.cpp b/src/zenhttp/httpserver.cpp index a98a3c9bb..523befa72 100644 --- a/src/zenhttp/httpserver.cpp +++ b/src/zenhttp/httpserver.cpp @@ -6,6 +6,13 @@ #include "httpnull.h" #include "httpsys.h" +#include "zenhttp/httpplugin.h" + +#if ZEN_WITH_PLUGINS +# include "dlltransport.h" +# include "winsocktransport.h" +#endif + #include <zencore/compactbinary.h> #include <zencore/compactbinarybuilder.h> #include <zencore/compactbinarypackage.h> @@ -711,6 +718,7 @@ enum class HttpServerClass { kHttpAsio, kHttpSys, + kHttpPlugin, kHttpNull }; @@ -723,7 +731,7 @@ CreateHttpServer(const HttpServerConfig& Config) #if ZEN_WITH_HTTPSYS Class = HttpServerClass::kHttpSys; -#elif 1 +#else Class = HttpServerClass::kHttpAsio; #endif @@ -735,6 +743,10 @@ CreateHttpServer(const HttpServerConfig& Config) { Class = HttpServerClass::kHttpSys; } + else if (Config.ServerClass == "plugin"sv) + { + Class = HttpServerClass::kHttpPlugin; + } else if (Config.ServerClass == "null"sv) { Class = HttpServerClass::kHttpNull; @@ -747,6 +759,27 @@ CreateHttpServer(const HttpServerConfig& Config) ZEN_INFO("using asio HTTP server implementation"); return Ref<HttpServer>(new HttpAsioServer(Config.ThreadCount)); +#if ZEN_WITH_PLUGINS + case HttpServerClass::kHttpPlugin: + { + ZEN_INFO("using plugin HTTP server implementation"); + Ref<HttpPluginServer> Server{new HttpPluginServer(Config.ThreadCount)}; + +# if 1 + Ref<TransportPluginInterface> WinsockPlugin{CreateSocketTransportPlugin(1337, Config.ThreadCount)}; + Server->AddPlugin(WinsockPlugin); +# endif + +# if 0 + Ref<DllTransportPlugin> DllPlugin{new DllTransportPlugin(1337, Config.ThreadCount)}; + DllPlugin->LoadDll("winsock"); + Server->AddPlugin(DllPlugin); +# endif + + return Server; + } +#endif + #if ZEN_WITH_HTTPSYS case HttpServerClass::kHttpSys: ZEN_INFO("using http.sys server implementation"); diff --git a/src/zenhttp/include/zenhttp/httpplugin.h b/src/zenhttp/include/zenhttp/httpplugin.h new file mode 100644 index 000000000..409eb1b61 --- /dev/null +++ b/src/zenhttp/include/zenhttp/httpplugin.h @@ -0,0 +1,49 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zencore/refcount.h> +#include <zencore/thread.h> + +#if !defined(ZEN_WITH_PLUGINS) +# if ZEN_PLATFORM_WINDOWS +# define ZEN_WITH_PLUGINS 1 +# else +# define ZEN_WITH_PLUGINS 0 +# endif +#endif + +#if ZEN_WITH_PLUGINS +# include "transportplugin.h" +# include <zenhttp/httpserver.h> + +namespace zen { + +struct HttpPluginServerImpl; + +class HttpPluginServer : public HttpServer +{ +public: + HttpPluginServer(unsigned int ThreadCount); + ~HttpPluginServer(); + + virtual void RegisterService(HttpService& Service) override; + virtual int Initialize(int BasePort) override; + virtual void Run(bool IsInteractiveSession) override; + virtual void RequestExit() override; + virtual void Close() override; + + void AddPlugin(Ref<TransportPluginInterface> Plugin); + void RemovePlugin(Ref<TransportPluginInterface> Plugin); + +private: + Event m_ShutdownEvent; + int m_BasePort = 0; + unsigned int m_ThreadCount = 0; + + HttpPluginServerImpl* m_Impl = nullptr; +}; + +} // namespace zen + +#endif diff --git a/src/zenhttp/include/zenhttp/transportplugin.h b/src/zenhttp/include/zenhttp/transportplugin.h new file mode 100644 index 000000000..38b07d471 --- /dev/null +++ b/src/zenhttp/include/zenhttp/transportplugin.h @@ -0,0 +1,97 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <stdint.h> + +// Important note: this header is meant to compile standalone +// and should therefore not depend on anything from the Zen codebase + +class TransportConnectionInterface; +class TransportPluginInterface; +class TransportServerConnectionHandler; +class TransportServerInterface; + +/************************************************************************* + + The following interfaces are implemented on the server side, and instances + are provided to the plugins. + +*************************************************************************/ + +/** Plugin-server interface for connection + * + * This is how the transport feeds data to the connection handler + * which will parse the incoming messages and dispatch to + * appropriate request handlers and ultimately call into functions + * which write data back to the client. + */ +class TransportServerConnectionHandler +{ +public: + virtual uint32_t AddRef() const = 0; + virtual uint32_t Release() const = 0; + virtual void OnBytesRead(const void* Buffer, size_t DataSize) = 0; +}; + +/** Plugin-server interface + * + * There will be one instance of this per plugin, and the plugin + * should use this to manage lifetimes of connections and any + * other resources. + */ +class TransportServerInterface +{ +public: + virtual TransportServerConnectionHandler* CreateConnectionHandler(TransportConnectionInterface* Connection) = 0; +}; + +/************************************************************************* + + The following interfaces are to be implemented by transport plugins. + +*************************************************************************/ + +/** Interface which needs to be implemented by a transport plugin + * + * This is responsible for setting up and running the communication + * for a given transport. + */ +class TransportPluginInterface +{ +public: + virtual uint32_t AddRef() const = 0; + virtual uint32_t Release() const = 0; + virtual void Initialize(TransportServerInterface* ServerInterface) = 0; + virtual void Shutdown() = 0; + + /** Check whether this transport is usable. + */ + virtual bool IsAvailable() = 0; +}; + +/** A transport plugin provider needs to implement this interface + * + * There will be one instance of this per established connection and + * this interface is used to write response data back to the client. + */ +class TransportConnectionInterface +{ +public: + virtual int64_t WriteBytes(const void* Buffer, size_t DataSize) = 0; + virtual void Shutdown(bool Receive, bool Transmit) = 0; + virtual void CloseConnection() = 0; +}; + +#if defined(_MSC_VER) +# define DLL_TRANSPORT_API __declspec(dllexport) +#else +# define DLL_TRANSPORT_API +#endif + +extern "C" +{ + DLL_TRANSPORT_API TransportPluginInterface* CreateTransportPluginInterface(); +} + +typedef TransportPluginInterface* (*PfnCreateTransportPluginInterface)(); diff --git a/src/zenhttp/winsocktransport.cpp b/src/zenhttp/winsocktransport.cpp new file mode 100644 index 000000000..e86e4822e --- /dev/null +++ b/src/zenhttp/winsocktransport.cpp @@ -0,0 +1,367 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "winsocktransport.h" + +#if ZEN_WITH_PLUGINS + +# include <zencore/logging.h> +# include <zencore/scopeguard.h> +# include <zencore/workthreadpool.h> + +# if ZEN_PLATFORM_WINDOWS +# include <zencore/windows.h> +ZEN_THIRD_PARTY_INCLUDES_START +# include <WinSock2.h> +# include <WS2tcpip.h> +ZEN_THIRD_PARTY_INCLUDES_END +# endif + +# include <thread> + +namespace zen { + +class SocketTransportPluginImpl; + +class SocketTransportPlugin : public TransportPluginInterface, RefCounted +{ +public: + SocketTransportPlugin(uint16_t BasePort, unsigned int ThreadCount); + ~SocketTransportPlugin(); + + virtual uint32_t AddRef() const override; + virtual uint32_t Release() const override; + virtual void Initialize(TransportServerInterface* ServerInterface) override; + virtual void Shutdown() override; + virtual bool IsAvailable() override; + +private: + bool m_IsOk = true; + uint16_t m_BasePort = 0; + int m_ThreadCount = 0; + SocketTransportPluginImpl* m_Impl; +}; + +struct SocketTransportConnection : public TransportConnectionInterface +{ +public: + SocketTransportConnection(); + ~SocketTransportConnection(); + + void Initialize(TransportServerConnectionHandler* ServerConnection, SOCKET ClientSocket); + void HandleConnection(); + + // TransportConnectionInterface + + virtual int64_t WriteBytes(const void* Buffer, size_t DataSize) override; + virtual void Shutdown(bool Receive, bool Transmit) override; + virtual void CloseConnection() override; + +private: + Ref<TransportServerConnectionHandler> m_ConnectionHandler; + SOCKET m_ClientSocket{}; + bool m_IsTerminated = false; +}; + +SocketTransportConnection::SocketTransportConnection() +{ +} + +SocketTransportConnection::~SocketTransportConnection() +{ +} + +void +SocketTransportConnection::Initialize(TransportServerConnectionHandler* ServerConnection, SOCKET ClientSocket) +{ + ZEN_ASSERT(!m_ConnectionHandler); + + m_ConnectionHandler = ServerConnection; + m_ClientSocket = ClientSocket; +} + +void +SocketTransportConnection::HandleConnection() +{ + ZEN_ASSERT(m_ConnectionHandler); + + const int InputBufferSize = 64 * 1024; + uint8_t* InputBuffer = new uint8_t[64 * 1024]; + auto _ = MakeGuard([&] { delete[] InputBuffer; }); + + do + { + const int RecvBytes = recv(m_ClientSocket, (char*)InputBuffer, InputBufferSize, /* flags */ 0); + + if (RecvBytes == 0) + { + // Connection closed + return CloseConnection(); + } + else if (RecvBytes < 0) + { + // Error + return CloseConnection(); + } + + m_ConnectionHandler->OnBytesRead(InputBuffer, RecvBytes); + } while (m_ClientSocket); +} + +void +SocketTransportConnection::CloseConnection() +{ + if (m_IsTerminated) + { + return; + } + + ZEN_ASSERT(m_ClientSocket); + m_IsTerminated = true; + + shutdown(m_ClientSocket, SD_BOTH); // We won't be sending or receiving any more data + + closesocket(m_ClientSocket); + m_ClientSocket = 0; +} + +int64_t +SocketTransportConnection::WriteBytes(const void* Buffer, size_t DataSize) +{ + const uint8_t* BufferCursor = reinterpret_cast<const uint8_t*>(Buffer); + int64_t TotalBytesSent = 0; + + while (DataSize) + { + const int MaxBlockSize = 128 * 1024; + const int SendBlockSize = (DataSize > MaxBlockSize) ? MaxBlockSize : (int)DataSize; + const int SentBytes = send(m_ClientSocket, (const char*)BufferCursor, SendBlockSize, /* flags */ 0); + + if (SentBytes < 0) + { + // Error + return SentBytes; + } + + BufferCursor += SentBytes; + DataSize -= SentBytes; + TotalBytesSent += SentBytes; + } + + return TotalBytesSent; +} + +void +SocketTransportConnection::Shutdown(bool Receive, bool Transmit) +{ + if (Receive) + { + if (Transmit) + { + shutdown(m_ClientSocket, SD_BOTH); + } + else + { + shutdown(m_ClientSocket, SD_RECEIVE); + } + } + else if (Transmit) + { + shutdown(m_ClientSocket, SD_SEND); + } +} + +////////////////////////////////////////////////////////////////////////// + +class SocketTransportPluginImpl +{ +public: + SocketTransportPluginImpl(uint16_t BasePort, unsigned int ThreadCount); + ~SocketTransportPluginImpl(); + + uint16_t Start(uint16_t Port, TransportServerInterface* ServerInterface); + void Stop(); + +private: + TransportServerInterface* m_ServerInterface = nullptr; + uint16_t m_BasePort = 0; + int m_ThreadCount = 0; + bool m_IsOk = true; + + SOCKET m_ListenSocket{}; + std::thread m_AcceptThread; + std::atomic_flag m_KeepRunning; + std::unique_ptr<WorkerThreadPool> m_WorkerThreadpool; +}; + +SocketTransportPluginImpl::SocketTransportPluginImpl(uint16_t BasePort, unsigned int ThreadCount) +: m_BasePort(BasePort) +, m_ThreadCount(ThreadCount) +{ +# if ZEN_PLATFORM_WINDOWS + WSADATA wsaData; + if (int Result = WSAStartup(0x202, &wsaData); Result != 0) + { + m_IsOk = false; + WSACleanup(); + } +# endif + + m_WorkerThreadpool = std::make_unique<WorkerThreadPool>(m_ThreadCount, "http_conn"); +} + +SocketTransportPluginImpl::~SocketTransportPluginImpl() +{ + Stop(); + +# if ZEN_PLATFORM_WINDOWS + if (m_IsOk) + { + WSACleanup(); + } +# endif +} + +uint16_t +SocketTransportPluginImpl::Start(uint16_t Port, TransportServerInterface* ServerInterface) +{ + m_ServerInterface = ServerInterface; + m_ListenSocket = socket(AF_INET6, SOCK_STREAM, 0); + + if (m_ListenSocket == SOCKET_ERROR || m_ListenSocket == INVALID_SOCKET) + { + ZEN_ERROR("socket creation failed in HTTP plugin server init: {}", WSAGetLastError()); + + return 0; + } + + sockaddr_in6 Server{}; + Server.sin6_family = AF_INET6; + Server.sin6_port = htons(Port); + Server.sin6_addr = in6addr_any; + + if (int Result = bind(m_ListenSocket, (sockaddr*)&Server, sizeof(Server)); Result == SOCKET_ERROR) + { + ZEN_ERROR("bind call failed in HTTP plugin server init: {}", WSAGetLastError()); + + return 0; + } + + if (int Result = listen(m_ListenSocket, AF_INET6); Result == SOCKET_ERROR) + { + ZEN_ERROR("listen call failed in HTTP plugin server init: {}", WSAGetLastError()); + + return 0; + } + + m_KeepRunning.test_and_set(); + + m_AcceptThread = std::thread([&] { + SetCurrentThreadName("http_plugin_acceptor"); + + ZEN_INFO("HTTP plugin server waiting for connections"); + + do + { + if (SOCKET ClientSocket = accept(m_ListenSocket, NULL, NULL); ClientSocket != SOCKET_ERROR) + { + int Flag = 1; + setsockopt(ClientSocket, IPPROTO_TCP, TCP_NODELAY, (char*)&Flag, sizeof(Flag)); + + // Handle new connection + SocketTransportConnection* Connection = new SocketTransportConnection(); + TransportServerConnectionHandler* ConnectionInterface{m_ServerInterface->CreateConnectionHandler(Connection)}; + Connection->Initialize(ConnectionInterface, ClientSocket); + + m_WorkerThreadpool->ScheduleWork([Connection] { + try + { + Connection->HandleConnection(); + } + catch (std::exception& Ex) + { + ZEN_WARN("exception caught in connection loop: {}", Ex.what()); + } + + delete Connection; + }); + } + else + { + } + } while (!IsApplicationExitRequested() && m_KeepRunning.test()); + + ZEN_INFO("HTTP plugin server accept thread exit"); + }); + + return Port; +} + +void +SocketTransportPluginImpl::Stop() +{ + // TODO: all pending/ongoing work should be drained here as well + + m_KeepRunning.clear(); + + closesocket(m_ListenSocket); + m_ListenSocket = 0; + + if (m_AcceptThread.joinable()) + { + m_AcceptThread.join(); + } +} + +////////////////////////////////////////////////////////////////////////// + +SocketTransportPlugin::SocketTransportPlugin(uint16_t BasePort, unsigned int ThreadCount) +: m_BasePort(BasePort) +, m_ThreadCount(ThreadCount != 0 ? ThreadCount : Max(std::thread::hardware_concurrency(), 8u)) +, m_Impl(new SocketTransportPluginImpl(BasePort, m_ThreadCount)) +{ +} + +SocketTransportPlugin::~SocketTransportPlugin() +{ + delete m_Impl; +} + +uint32_t +SocketTransportPlugin::AddRef() const +{ + return RefCounted::AddRef(); +} + +uint32_t +SocketTransportPlugin::Release() const +{ + return RefCounted::Release(); +} + +void +SocketTransportPlugin::Initialize(TransportServerInterface* ServerInterface) +{ + m_Impl->Start(m_BasePort, ServerInterface); +} + +void +SocketTransportPlugin::Shutdown() +{ + m_Impl->Stop(); +} + +bool +SocketTransportPlugin::IsAvailable() +{ + return true; +} + +TransportPluginInterface* +CreateSocketTransportPlugin(uint16_t BasePort, unsigned int ThreadCount) +{ + return new SocketTransportPlugin(BasePort, ThreadCount); +} + +} // namespace zen + +#endif diff --git a/src/zenhttp/winsocktransport.h b/src/zenhttp/winsocktransport.h new file mode 100644 index 000000000..34809cddc --- /dev/null +++ b/src/zenhttp/winsocktransport.h @@ -0,0 +1,15 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenhttp/httpplugin.h> + +#if ZEN_WITH_PLUGINS + +namespace zen { + +TransportPluginInterface* CreateSocketTransportPlugin(uint16_t BasePort, unsigned int ThreadCount); + +} // namespace zen + +#endif |