From 11f7f70b825c5b6784f5e2609463a1a9d1a0dabc Mon Sep 17 00:00:00 2001 From: Stefan Boberg Date: Wed, 11 Oct 2023 14:59:25 +0200 Subject: pluggable asio transport (#460) added pluggable transport based on asio. This is in an experimental state and is not yet a replacement for httpasio even though that is the ultimate goal also moved plugin API header into dedicated part of the tree to clarify that it is meant to be usable in isolation, without any dependency on zencore et al moved transport implementations into dedicated source directory in zenhttp note that this adds code to the build but nothing should change at runtime since the instantiation of the new code is conditional and is inactive by default --- src/plugins/include/transportplugin.h | 111 +++++++ src/plugins/winsock/winsock.cpp | 2 + src/plugins/winsock/xmake.lua | 10 +- src/plugins/xmake.lua | 7 + src/zenhttp/dlltransport.cpp | 250 --------------- src/zenhttp/dlltransport.h | 37 --- src/zenhttp/httpasio.cpp | 1 + src/zenhttp/httpasio.h | 3 - src/zenhttp/httpserver.cpp | 18 +- src/zenhttp/httpsys.cpp | 88 +++++- src/zenhttp/httpsys.h | 87 +---- src/zenhttp/include/zenhttp/transportplugin.h | 97 ------ src/zenhttp/transports/asiotransport.cpp | 439 ++++++++++++++++++++++++++ src/zenhttp/transports/asiotransport.h | 15 + src/zenhttp/transports/dlltransport.cpp | 250 +++++++++++++++ src/zenhttp/transports/dlltransport.h | 37 +++ src/zenhttp/transports/winsocktransport.cpp | 367 +++++++++++++++++++++ src/zenhttp/transports/winsocktransport.h | 15 + src/zenhttp/winsocktransport.cpp | 367 --------------------- src/zenhttp/winsocktransport.h | 15 - src/zenhttp/xmake.lua | 2 +- 21 files changed, 1354 insertions(+), 864 deletions(-) create mode 100644 src/plugins/include/transportplugin.h create mode 100644 src/plugins/xmake.lua delete mode 100644 src/zenhttp/dlltransport.cpp delete mode 100644 src/zenhttp/dlltransport.h delete mode 100644 src/zenhttp/include/zenhttp/transportplugin.h create mode 100644 src/zenhttp/transports/asiotransport.cpp create mode 100644 src/zenhttp/transports/asiotransport.h create mode 100644 src/zenhttp/transports/dlltransport.cpp create mode 100644 src/zenhttp/transports/dlltransport.h create mode 100644 src/zenhttp/transports/winsocktransport.cpp create mode 100644 src/zenhttp/transports/winsocktransport.h delete mode 100644 src/zenhttp/winsocktransport.cpp delete mode 100644 src/zenhttp/winsocktransport.h (limited to 'src') diff --git a/src/plugins/include/transportplugin.h b/src/plugins/include/transportplugin.h new file mode 100644 index 000000000..aee5b2e7a --- /dev/null +++ b/src/plugins/include/transportplugin.h @@ -0,0 +1,111 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include + +// Important note: this header is meant to compile standalone +// and should therefore not depend on anything from the Zen codebase + +namespace zen { + +class TransportConnection; +class TransportPlugin; +class TransportServerConnection; +class TransportServer; + +/************************************************************************* + + The following interfaces are implemented on the server side, and instances + are provided to the plugins. + +*************************************************************************/ + +/** Plugin-server interface for connection + + This is returned by a call to TransportServer::CreateConnectionHandler + and there should be one instance created per established connection + + The plugin uses this interface to feed data into the server side + protocol implementation which will parse the incoming messages and + dispatch to appropriate request handlers and ultimately call into + TransportConnection functions which write data back to the client + */ +class TransportServerConnection +{ +public: + virtual uint32_t AddRef() const = 0; + virtual uint32_t Release() const = 0; + virtual void OnBytesRead(const void* Buffer, size_t DataSize) = 0; +}; + +/** Server interface + + There will be one instance of this provided by the system to the transport plugin + + The plugin can use this to register new connections + + */ +class TransportServer +{ +public: + virtual TransportServerConnection* CreateConnectionHandler(TransportConnection* Connection) = 0; +}; + +/************************************************************************* + + The following interfaces are to be implemented by transport plugins. + +*************************************************************************/ + +/** Interface which needs to be implemented by a transport plugin + + This is responsible for setting up and running the communication + for a given transport. + + Once initialized, the plugin should be ready to accept connections + using its own execution resources (threads, thread pools etc) + */ +class TransportPlugin +{ +public: + virtual uint32_t AddRef() const = 0; + virtual uint32_t Release() const = 0; + virtual void Initialize(TransportServer* ServerInterface) = 0; + virtual void Shutdown() = 0; + + /** Check whether this transport is usable. + */ + virtual bool IsAvailable() = 0; +}; + +/** A transport plugin provider needs to implement this interface + + The plugin should create one instance of this per established + connection and register it with the TransportServer instance + CreateConnectionHandler() function. The server will subsequently + use this interface to write response data back to the client and + to manage the connection life cycle in general +*/ +class TransportConnection +{ +public: + virtual int64_t WriteBytes(const void* Buffer, size_t DataSize) = 0; + virtual void Shutdown(bool Receive, bool Transmit) = 0; + virtual void CloseConnection() = 0; +}; + +} // namespace zen + +#if defined(_MSC_VER) +# define DLL_TRANSPORT_API __declspec(dllexport) +#else +# define DLL_TRANSPORT_API +#endif + +extern "C" +{ + DLL_TRANSPORT_API zen::TransportPlugin* CreateTransportPlugin(); +} + +typedef zen::TransportPlugin* (*PfnCreateTransportPlugin)(); diff --git a/src/plugins/winsock/winsock.cpp b/src/plugins/winsock/winsock.cpp index 3ee3f0ccd..a6cfed1e3 100644 --- a/src/plugins/winsock/winsock.cpp +++ b/src/plugins/winsock/winsock.cpp @@ -24,6 +24,8 @@ ZEN_THIRD_PARTY_INCLUDES_END ////////////////////////////////////////////////////////////////////////// +using namespace zen; + class SocketTransportPlugin : public TransportPlugin, zen::RefCounted { public: diff --git a/src/plugins/winsock/xmake.lua b/src/plugins/winsock/xmake.lua index a4ef02a98..408a248b1 100644 --- a/src/plugins/winsock/xmake.lua +++ b/src/plugins/winsock/xmake.lua @@ -4,15 +4,11 @@ target("winsock") set_kind("shared") add_headerfiles("**.h") add_files("**.cpp") - add_includedirs(".", "../../zenhttp/include/zenhttp", "../../zencore/include") + add_links("Ws2_32") + add_includedirs(".", "../../zencore/include") set_symbols("debug") - - add_cxxflags("/showIncludes") + add_deps("plugins") if is_mode("release") then set_optimize("fastest") end - - if is_plat("windows") then - add_links("Ws2_32") - end diff --git a/src/plugins/xmake.lua b/src/plugins/xmake.lua new file mode 100644 index 000000000..9e4d49685 --- /dev/null +++ b/src/plugins/xmake.lua @@ -0,0 +1,7 @@ +-- Copyright Epic Games, Inc. All Rights Reserved. + +target('plugins') + set_kind("headeronly") + set_group("plugins") + add_headerfiles("**.h") + add_includedirs("include", {public=true}) diff --git a/src/zenhttp/dlltransport.cpp b/src/zenhttp/dlltransport.cpp deleted file mode 100644 index 04fb6caaa..000000000 --- a/src/zenhttp/dlltransport.cpp +++ /dev/null @@ -1,250 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#include "dlltransport.h" - -#include -#include -#include - -#include -#include -#include - -#if ZEN_WITH_PLUGINS - -namespace zen { - -struct DllTransportConnection : public TransportConnection -{ -public: - DllTransportConnection(); - ~DllTransportConnection(); - - void Initialize(TransportServerConnection& ServerConnection); - void HandleConnection(); - - // TransportConnection - - virtual int64_t WriteBytes(const void* Buffer, size_t DataSize) override; - virtual void Shutdown(bool Receive, bool Transmit) override; - virtual void CloseConnection() override; - -private: - Ref m_ConnectionHandler; - bool m_IsTerminated = false; -}; - -DllTransportConnection::DllTransportConnection() -{ -} - -DllTransportConnection::~DllTransportConnection() -{ -} - -void -DllTransportConnection::Initialize(TransportServerConnection& ServerConnection) -{ - m_ConnectionHandler = &ServerConnection; // TODO: this is awkward -} - -void -DllTransportConnection::HandleConnection() -{ -} - -void -DllTransportConnection::CloseConnection() -{ - if (m_IsTerminated) - { - return; - } - - m_IsTerminated = true; -} - -int64_t -DllTransportConnection::WriteBytes(const void* Buffer, size_t DataSize) -{ - ZEN_UNUSED(Buffer, DataSize); - return DataSize; -} - -void -DllTransportConnection::Shutdown(bool Receive, bool Transmit) -{ - ZEN_UNUSED(Receive, Transmit); -} - -////////////////////////////////////////////////////////////////////////// - -struct LoadedDll -{ - std::string Name; - std::filesystem::path LoadedFromPath; - Ref Plugin; -}; - -class DllTransportPluginImpl -{ -public: - DllTransportPluginImpl(uint16_t BasePort, unsigned int ThreadCount); - ~DllTransportPluginImpl(); - - uint16_t Start(TransportServer* ServerInterface); - void Stop(); - bool IsAvailable(); - void LoadDll(std::string_view Name); - -private: - TransportServer* m_ServerInterface = nullptr; - RwLock m_Lock; - std::vector m_Transports; - uint16_t m_BasePort = 0; - int m_ThreadCount = 0; -}; - -DllTransportPluginImpl::DllTransportPluginImpl(uint16_t BasePort, unsigned int ThreadCount) -: m_BasePort(BasePort) -, m_ThreadCount(ThreadCount != 0 ? ThreadCount : Max(std::thread::hardware_concurrency(), 8u)) -{ -} - -DllTransportPluginImpl::~DllTransportPluginImpl() -{ -} - -uint16_t -DllTransportPluginImpl::Start(TransportServer* ServerIface) -{ - m_ServerInterface = ServerIface; - - RwLock::ExclusiveLockScope _(m_Lock); - - for (LoadedDll& Transport : m_Transports) - { - try - { - Transport.Plugin->Initialize(ServerIface); - } - catch (const std::exception&) - { - // TODO: report - } - } - - return m_BasePort; -} - -void -DllTransportPluginImpl::Stop() -{ - RwLock::ExclusiveLockScope _(m_Lock); - - for (LoadedDll& Transport : m_Transports) - { - try - { - Transport.Plugin->Shutdown(); - } - catch (const std::exception&) - { - // TODO: report - } - } -} - -bool -DllTransportPluginImpl::IsAvailable() -{ - return true; -} - -void -DllTransportPluginImpl::LoadDll(std::string_view Name) -{ - ExtendableStringBuilder<128> DllPath; - DllPath << Name << ".dll"; - HMODULE DllHandle = LoadLibraryA(DllPath.c_str()); - - if (!DllHandle) - { - std::error_code Ec = MakeErrorCodeFromLastError(); - - throw std::system_error(Ec, fmt::format("failed to load transport DLL from '{}'", DllPath)); - } - - TransportPlugin* CreateTransportPlugin(); - - PfnCreateTransportPlugin CreatePlugin = (PfnCreateTransportPlugin)GetProcAddress(DllHandle, "CreateTransportPlugin"); - - if (!CreatePlugin) - { - std::error_code Ec = MakeErrorCodeFromLastError(); - - FreeLibrary(DllHandle); - - throw std::system_error(Ec, fmt::format("API mismatch detected in transport DLL loaded from '{}'", DllPath)); - } - - LoadedDll NewDll; - - NewDll.Name = Name; - NewDll.LoadedFromPath = DllPath.c_str(); - NewDll.Plugin = CreatePlugin(); - - m_Transports.emplace_back(std::move(NewDll)); -} - -////////////////////////////////////////////////////////////////////////// - -DllTransportPlugin::DllTransportPlugin(uint16_t BasePort, unsigned int ThreadCount) -: m_Impl(std::make_unique(BasePort, ThreadCount)) -{ -} - -DllTransportPlugin::~DllTransportPlugin() -{ - m_Impl->Stop(); -} - -uint32_t -DllTransportPlugin::AddRef() const -{ - return RefCounted::AddRef(); -} - -uint32_t -DllTransportPlugin::Release() const -{ - return RefCounted::Release(); -} - -void -DllTransportPlugin::Initialize(TransportServer* ServerInterface) -{ - m_Impl->Start(ServerInterface); -} - -void -DllTransportPlugin::Shutdown() -{ - m_Impl->Stop(); -} - -bool -DllTransportPlugin::IsAvailable() -{ - return m_Impl->IsAvailable(); -} - -void -DllTransportPlugin::LoadDll(std::string_view Name) -{ - return m_Impl->LoadDll(Name); -} - -} // namespace zen - -#endif diff --git a/src/zenhttp/dlltransport.h b/src/zenhttp/dlltransport.h deleted file mode 100644 index 2dccdd0f9..000000000 --- a/src/zenhttp/dlltransport.h +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#pragma once - -#include - -#if ZEN_WITH_PLUGINS - -namespace zen { - -class DllTransportPluginImpl; - -/** Transport plugin which supports dynamic loading of external transport - * provider modules - */ -class DllTransportPlugin : public TransportPlugin, RefCounted -{ -public: - DllTransportPlugin(uint16_t BasePort, unsigned int ThreadCount); - ~DllTransportPlugin(); - - virtual uint32_t AddRef() const override; - virtual uint32_t Release() const override; - - virtual void Initialize(TransportServer* ServerInterface) override; - virtual void Shutdown() override; - virtual bool IsAvailable() override; - - void LoadDll(std::string_view Name); - -private: - std::unique_ptr m_Impl; -}; - -} // namespace zen - -#endif diff --git a/src/zenhttp/httpasio.cpp b/src/zenhttp/httpasio.cpp index f23e0edb1..0c6b189f9 100644 --- a/src/zenhttp/httpasio.cpp +++ b/src/zenhttp/httpasio.cpp @@ -4,6 +4,7 @@ #include #include +#include #include #include diff --git a/src/zenhttp/httpasio.h b/src/zenhttp/httpasio.h index 57068f7c5..2366f3437 100644 --- a/src/zenhttp/httpasio.h +++ b/src/zenhttp/httpasio.h @@ -2,11 +2,8 @@ #pragma once -#include #include -#include - namespace zen { Ref CreateHttpAsioServer(unsigned int ThreadCount); diff --git a/src/zenhttp/httpserver.cpp b/src/zenhttp/httpserver.cpp index f22438a58..a2ea4cff8 100644 --- a/src/zenhttp/httpserver.cpp +++ b/src/zenhttp/httpserver.cpp @@ -9,8 +9,9 @@ #include "zenhttp/httpplugin.h" #if ZEN_WITH_PLUGINS -# include "dlltransport.h" -# include "winsocktransport.h" +# include "transports/asiotransport.h" +# include "transports/dlltransport.h" +# include "transports/winsocktransport.h" #endif #include @@ -770,6 +771,11 @@ CreateHttpServer(const HttpServerConfig& Config) Server->AddPlugin(WinsockPlugin); # endif +# if 0 + Ref AsioPlugin{CreateAsioTransportPlugin(1337, Config.ThreadCount)}; + Server->AddPlugin(AsioPlugin); +# endif + # if 0 Ref DllPlugin{new DllTransportPlugin(1337, Config.ThreadCount)}; DllPlugin->LoadDll("winsock"); @@ -783,10 +789,10 @@ CreateHttpServer(const HttpServerConfig& Config) #if ZEN_WITH_HTTPSYS case HttpServerClass::kHttpSys: ZEN_INFO("using http.sys server implementation"); - return Ref(new HttpSysServer({.ThreadCount = Config.ThreadCount, - .AsyncWorkThreadCount = Config.HttpSys.AsyncWorkThreadCount, - .IsAsyncResponseEnabled = Config.HttpSys.IsAsyncResponseEnabled, - .IsRequestLoggingEnabled = Config.HttpSys.IsRequestLoggingEnabled})); + return Ref(CreateHttpSysServer({.ThreadCount = Config.ThreadCount, + .AsyncWorkThreadCount = Config.HttpSys.AsyncWorkThreadCount, + .IsAsyncResponseEnabled = Config.HttpSys.IsAsyncResponseEnabled, + .IsRequestLoggingEnabled = Config.HttpSys.IsRequestLoggingEnabled})); #endif case HttpServerClass::kHttpNull: diff --git a/src/zenhttp/httpsys.cpp b/src/zenhttp/httpsys.cpp index c7ed0bb2f..8401dcf83 100644 --- a/src/zenhttp/httpsys.cpp +++ b/src/zenhttp/httpsys.cpp @@ -15,6 +15,86 @@ #include #include +#if ZEN_WITH_HTTPSYS +# define _WINSOCKAPI_ +# include +# include +# include "iothreadpool.h" + +# include + +namespace spdlog { +class logger; +} + +namespace zen { + +/** + * @brief Windows implementation of HTTP server based on http.sys + * + * This requires elevation to function + */ +class HttpSysServer : public HttpServer +{ + friend class HttpSysTransaction; + +public: + explicit HttpSysServer(const HttpSysConfig& Config); + ~HttpSysServer(); + + // HttpServer interface implementation + + virtual int Initialize(int BasePort) override; + virtual void Run(bool TestMode) override; + virtual void RequestExit() override; + virtual void RegisterService(HttpService& Service) override; + virtual void Close() override; + + WorkerThreadPool& WorkPool(); + + inline bool IsOk() const { return m_IsOk; } + inline bool IsAsyncResponseEnabled() const { return m_IsAsyncResponseEnabled; } + +private: + int InitializeServer(int BasePort); + void Cleanup(); + + void StartServer(); + void OnHandlingNewRequest(); + void IssueNewRequestMaybe(); + + void RegisterService(const char* Endpoint, HttpService& Service); + void UnregisterService(const char* Endpoint, HttpService& Service); + +private: + spdlog::logger& m_Log; + spdlog::logger& m_RequestLog; + spdlog::logger& Log() { return m_Log; } + + bool m_IsOk = false; + bool m_IsHttpInitialized = false; + bool m_IsRequestLoggingEnabled = false; + bool m_IsAsyncResponseEnabled = true; + + WinIoThreadPool m_ThreadPool; + RwLock m_AsyncWorkPoolInitLock; + WorkerThreadPool* m_AsyncWorkPool = nullptr; + + std::vector m_BaseUris; // eg: http://*:nnnn/ + HTTP_SERVER_SESSION_ID m_HttpSessionId = 0; + HTTP_URL_GROUP_ID m_HttpUrlGroupId = 0; + HANDLE m_RequestQueueHandle = 0; + std::atomic_int32_t m_PendingRequests{0}; + std::atomic_int32_t m_IsShuttingDown{0}; + int32_t m_MinPendingRequests = 16; + int32_t m_MaxPendingRequests = 128; + Event m_ShutdownEvent; + const HttpSysConfig m_InitialConfig; +}; + +} // namespace zen +#endif + #if ZEN_WITH_HTTPSYS # include @@ -809,7 +889,7 @@ HttpAsyncWorkRequest::AsyncWorkItem::Execute() \/ \/ \/ */ -HttpSysServer::HttpSysServer(const Config& Config) +HttpSysServer::HttpSysServer(const HttpSysConfig& Config) : m_Log(logging::Get("http")) , m_RequestLog(logging::Get("http_requests")) , m_IsRequestLoggingEnabled(Config.IsRequestLoggingEnabled) @@ -1868,5 +1948,11 @@ HttpSysServer::RegisterService(HttpService& Service) RegisterService(Service.BaseUri(), Service); } +Ref +CreateHttpSysServer(HttpSysConfig Config) +{ + return Ref(new HttpSysServer(Config)); +} + } // namespace zen #endif diff --git a/src/zenhttp/httpsys.h b/src/zenhttp/httpsys.h index 65239bae7..1553d56ef 100644 --- a/src/zenhttp/httpsys.h +++ b/src/zenhttp/httpsys.h @@ -12,89 +12,16 @@ # endif #endif -#if ZEN_WITH_HTTPSYS -# define _WINSOCKAPI_ -# include -# include -# include "iothreadpool.h" - -# include - -namespace spdlog { -class logger; -} - namespace zen { -/** - * @brief Windows implementation of HTTP server based on http.sys - * - * This requires elevation to function - */ -class HttpSysServer : public HttpServer +struct HttpSysConfig { - friend class HttpSysTransaction; - -public: - struct Config - { - unsigned int ThreadCount = 0; - unsigned int AsyncWorkThreadCount = 0; - bool IsAsyncResponseEnabled = true; - bool IsRequestLoggingEnabled = false; - }; - explicit HttpSysServer(const Config& Config); - ~HttpSysServer(); - - // HttpServer interface implementation - - virtual int Initialize(int BasePort) override; - virtual void Run(bool TestMode) override; - virtual void RequestExit() override; - virtual void RegisterService(HttpService& Service) override; - virtual void Close() override; - - WorkerThreadPool& WorkPool(); - - inline bool IsOk() const { return m_IsOk; } - inline bool IsAsyncResponseEnabled() const { return m_IsAsyncResponseEnabled; } - -private: - int InitializeServer(int BasePort); - void Cleanup(); - - void StartServer(); - void OnHandlingNewRequest(); - void IssueNewRequestMaybe(); - - void RegisterService(const char* Endpoint, HttpService& Service); - void UnregisterService(const char* Endpoint, HttpService& Service); - -private: - spdlog::logger& m_Log; - spdlog::logger& m_RequestLog; - spdlog::logger& Log() { return m_Log; } - - bool m_IsOk = false; - bool m_IsHttpInitialized = false; - bool m_IsRequestLoggingEnabled = false; - bool m_IsAsyncResponseEnabled = true; - - WinIoThreadPool m_ThreadPool; - RwLock m_AsyncWorkPoolInitLock; - WorkerThreadPool* m_AsyncWorkPool = nullptr; - - std::vector m_BaseUris; // eg: http://*:nnnn/ - HTTP_SERVER_SESSION_ID m_HttpSessionId = 0; - HTTP_URL_GROUP_ID m_HttpUrlGroupId = 0; - HANDLE m_RequestQueueHandle = 0; - std::atomic_int32_t m_PendingRequests{0}; - std::atomic_int32_t m_IsShuttingDown{0}; - int32_t m_MinPendingRequests = 16; - int32_t m_MaxPendingRequests = 128; - Event m_ShutdownEvent; - const Config m_InitialConfig; + unsigned int ThreadCount = 0; + unsigned int AsyncWorkThreadCount = 0; + bool IsAsyncResponseEnabled = true; + bool IsRequestLoggingEnabled = false; }; +Ref CreateHttpSysServer(HttpSysConfig Config); + } // namespace zen -#endif diff --git a/src/zenhttp/include/zenhttp/transportplugin.h b/src/zenhttp/include/zenhttp/transportplugin.h deleted file mode 100644 index fe17680de..000000000 --- a/src/zenhttp/include/zenhttp/transportplugin.h +++ /dev/null @@ -1,97 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#pragma once - -#include - -// Important note: this header is meant to compile standalone -// and should therefore not depend on anything from the Zen codebase - -class TransportConnection; -class TransportPlugin; -class TransportServerConnection; -class TransportServer; - -/************************************************************************* - - The following interfaces are implemented on the server side, and instances - are provided to the plugins. - -*************************************************************************/ - -/** Plugin-server interface for connection - * - * This is how the transport feeds data to the connection handler - * which will parse the incoming messages and dispatch to - * appropriate request handlers and ultimately call into functions - * which write data back to the client. - */ -class TransportServerConnection -{ -public: - virtual uint32_t AddRef() const = 0; - virtual uint32_t Release() const = 0; - virtual void OnBytesRead(const void* Buffer, size_t DataSize) = 0; -}; - -/** Plugin-server interface - * - * There will be one instance of this per plugin, and the plugin - * should use this to manage lifetimes of connections and any - * other resources. - */ -class TransportServer -{ -public: - virtual TransportServerConnection* CreateConnectionHandler(TransportConnection* Connection) = 0; -}; - -/************************************************************************* - - The following interfaces are to be implemented by transport plugins. - -*************************************************************************/ - -/** Interface which needs to be implemented by a transport plugin - * - * This is responsible for setting up and running the communication - * for a given transport. - */ -class TransportPlugin -{ -public: - virtual uint32_t AddRef() const = 0; - virtual uint32_t Release() const = 0; - virtual void Initialize(TransportServer* ServerInterface) = 0; - virtual void Shutdown() = 0; - - /** Check whether this transport is usable. - */ - virtual bool IsAvailable() = 0; -}; - -/** A transport plugin provider needs to implement this interface - * - * There will be one instance of this per established connection and - * this interface is used to write response data back to the client. - */ -class TransportConnection -{ -public: - virtual int64_t WriteBytes(const void* Buffer, size_t DataSize) = 0; - virtual void Shutdown(bool Receive, bool Transmit) = 0; - virtual void CloseConnection() = 0; -}; - -#if defined(_MSC_VER) -# define DLL_TRANSPORT_API __declspec(dllexport) -#else -# define DLL_TRANSPORT_API -#endif - -extern "C" -{ - DLL_TRANSPORT_API TransportPlugin* CreateTransportPlugin(); -} - -typedef TransportPlugin* (*PfnCreateTransportPlugin)(); diff --git a/src/zenhttp/transports/asiotransport.cpp b/src/zenhttp/transports/asiotransport.cpp new file mode 100644 index 000000000..b8fef8f5f --- /dev/null +++ b/src/zenhttp/transports/asiotransport.cpp @@ -0,0 +1,439 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "winsocktransport.h" + +#if ZEN_WITH_PLUGINS + +# include +# include +# include + +ZEN_THIRD_PARTY_INCLUDES_START +# include +ZEN_THIRD_PARTY_INCLUDES_END + +# if ZEN_PLATFORM_WINDOWS +# include +ZEN_THIRD_PARTY_INCLUDES_START +# include +ZEN_THIRD_PARTY_INCLUDES_END +# endif + +# include + +# include +# include + +namespace zen { + +struct AsioTransportAcceptor; + +class AsioTransportPlugin : public TransportPlugin, RefCounted +{ +public: + AsioTransportPlugin(uint16_t BasePort, unsigned int ThreadCount); + ~AsioTransportPlugin(); + + virtual uint32_t AddRef() const override; + virtual uint32_t Release() const override; + virtual void Initialize(TransportServer* ServerInterface) override; + virtual void Shutdown() override; + virtual bool IsAvailable() override; + +private: + bool m_IsOk = true; + uint16_t m_BasePort = 0; + int m_ThreadCount = 0; + + asio::io_service m_IoService; + asio::io_service::work m_Work{m_IoService}; + std::unique_ptr m_Acceptor; + std::vector m_ThreadPool; +}; + +struct AsioTransportConnection : public TransportConnection, std::enable_shared_from_this +{ + AsioTransportConnection(std::unique_ptr&& Socket); + ~AsioTransportConnection(); + + void Initialize(TransportServerConnection* ConnectionHandler); + + std::shared_ptr AsSharedPtr() { return shared_from_this(); } + + // TransportConnectionInterface + + virtual int64_t WriteBytes(const void* Buffer, size_t DataSize) override; + virtual void Shutdown(bool Receive, bool Transmit) override; + virtual void CloseConnection() override; + +private: + void EnqueueRead(); + void OnDataReceived(const asio::error_code& Ec, std::size_t ByteCount); + void OnResponseDataSent(const asio::error_code& Ec, std::size_t ByteCount); + + Ref m_ConnectionHandler; + asio::streambuf m_RequestBuffer; + std::unique_ptr m_Socket; + uint32_t m_ConnectionId = 0; + std::atomic_flag m_IsTerminated{}; +}; + +////////////////////////////////////////////////////////////////////////// + +struct AsioTransportAcceptor +{ + AsioTransportAcceptor(TransportServer* ServerInterface, asio::io_service& IoService, uint16_t BasePort); + ~AsioTransportAcceptor(); + + void Start(); + void RequestStop(); + + inline int GetAcceptPort() { return m_Acceptor.local_endpoint().port(); } + +private: + TransportServer* m_ServerInterface = nullptr; + asio::io_service& m_IoService; + asio::ip::tcp::acceptor m_Acceptor; + std::atomic m_IsStopped{false}; + + void EnqueueAccept(); +}; + +////////////////////////////////////////////////////////////////////////// + +AsioTransportAcceptor::AsioTransportAcceptor(TransportServer* ServerInterface, asio::io_service& IoService, uint16_t BasePort) +: m_ServerInterface(ServerInterface) +, m_IoService(IoService) +, m_Acceptor(m_IoService, asio::ip::tcp::v6()) +{ + m_Acceptor.set_option(asio::ip::v6_only(false)); + +# if ZEN_PLATFORM_WINDOWS + // Special option for Windows settings as !asio::socket_base::reuse_address is not the same as exclusive access on Windows platforms + typedef asio::detail::socket_option::boolean exclusive_address; + m_Acceptor.set_option(exclusive_address(true)); +# else + m_Acceptor.set_option(asio::socket_base::reuse_address(false)); +# endif + + m_Acceptor.set_option(asio::ip::tcp::no_delay(true)); + m_Acceptor.set_option(asio::socket_base::receive_buffer_size(128 * 1024)); + m_Acceptor.set_option(asio::socket_base::send_buffer_size(256 * 1024)); + + uint16_t EffectivePort = BasePort; + + asio::error_code BindErrorCode; + m_Acceptor.bind(asio::ip::tcp::endpoint(asio::ip::address_v6::any(), EffectivePort), BindErrorCode); + // Sharing violation implies the port is being used by another process + for (uint16_t PortOffset = 1; (BindErrorCode == asio::error::address_in_use) && (PortOffset < 10); ++PortOffset) + { + EffectivePort = BasePort + (PortOffset * 100); + m_Acceptor.bind(asio::ip::tcp::endpoint(asio::ip::address_v6::any(), EffectivePort), BindErrorCode); + } + if (BindErrorCode == asio::error::access_denied) + { + EffectivePort = 0; + m_Acceptor.bind(asio::ip::tcp::endpoint(asio::ip::address_v6::any(), EffectivePort), BindErrorCode); + } + if (BindErrorCode) + { + ZEN_ERROR("Unable open asio service, error '{}'", BindErrorCode.message()); + } + +# if ZEN_PLATFORM_WINDOWS + // On Windows, loopback connections can take advantage of a faster code path optionally with this flag. + // This must be used by both the client and server side, and is only effective in the absence of + // Windows Filtering Platform (WFP) callouts which can be installed by security software. + // https://docs.microsoft.com/en-us/windows/win32/winsock/sio-loopback-fast-path + SOCKET NativeSocket = m_Acceptor.native_handle(); + int LoopbackOptionValue = 1; + DWORD OptionNumberOfBytesReturned = 0; + WSAIoctl(NativeSocket, + SIO_LOOPBACK_FAST_PATH, + &LoopbackOptionValue, + sizeof(LoopbackOptionValue), + NULL, + 0, + &OptionNumberOfBytesReturned, + 0, + 0); +# endif + + ZEN_INFO("started asio transport at port: {}", EffectivePort); +} + +AsioTransportAcceptor::~AsioTransportAcceptor() +{ +} + +void +AsioTransportAcceptor::Start() +{ + m_Acceptor.listen(); + + EnqueueAccept(); +} + +void +AsioTransportAcceptor::RequestStop() +{ + m_IsStopped = true; +} + +void +AsioTransportAcceptor::EnqueueAccept() +{ + auto SocketPtr = std::make_unique(m_IoService); + asio::ip::tcp::socket& SocketRef = *SocketPtr.get(); + + m_Acceptor.async_accept(SocketRef, [this, Socket = std::move(SocketPtr)](const asio::error_code& Ec) mutable { + if (Ec) + { + ZEN_WARN("asio async_accept error ({}:{}): {}", + m_Acceptor.local_endpoint().address().to_string(), + m_Acceptor.local_endpoint().port(), + Ec.message()); + } + else + { + // New connection established, pass socket ownership into connection object + // and initiate request handling loop. The connection lifetime is + // managed by the async read/write loop by passing the shared + // reference to the callbacks. + + Socket->set_option(asio::ip::tcp::no_delay(true)); + Socket->set_option(asio::socket_base::receive_buffer_size(128 * 1024)); + Socket->set_option(asio::socket_base::send_buffer_size(256 * 1024)); + + auto Conn = std::make_shared(std::move(Socket)); + Conn->Initialize(m_ServerInterface->CreateConnectionHandler(Conn.get())); + } + + if (!m_IsStopped.load()) + { + EnqueueAccept(); + } + else + { + std::error_code CloseEc; + m_Acceptor.close(CloseEc); + + if (CloseEc) + { + ZEN_WARN("acceptor close ERROR, reason '{}'", CloseEc.message()); + } + } + }); +} + +////////////////////////////////////////////////////////////////////////// + +AsioTransportConnection::AsioTransportConnection(std::unique_ptr&& Socket) : m_Socket(std::move(Socket)) +{ +} + +AsioTransportConnection::~AsioTransportConnection() +{ +} + +void +AsioTransportConnection::Initialize(TransportServerConnection* ConnectionHandler) +{ + m_ConnectionHandler = ConnectionHandler; + + EnqueueRead(); +} + +int64_t +AsioTransportConnection::WriteBytes(const void* Buffer, size_t DataSize) +{ + size_t WrittenBytes = asio::write(*m_Socket.get(), asio::const_buffer(Buffer, DataSize), asio::transfer_exactly(DataSize)); + + return WrittenBytes; +} + +void +AsioTransportConnection::Shutdown(bool Receive, bool Transmit) +{ + std::error_code Ec; + if (Receive) + { + if (Transmit) + { + m_Socket->shutdown(asio::socket_base::shutdown_both, Ec); + } + else + { + m_Socket->shutdown(asio::socket_base::shutdown_receive, Ec); + } + } + else if (Transmit) + { + m_Socket->shutdown(asio::socket_base::shutdown_send, Ec); + } +} + +void +AsioTransportConnection::CloseConnection() +{ + if (m_IsTerminated.test()) + { + return; + } + + if (m_IsTerminated.test_and_set() == false) + { + Shutdown(true, true); + + std::error_code Ec; + m_Socket->close(Ec); + } +} + +void +AsioTransportConnection::EnqueueRead() +{ + if (m_IsTerminated.test() == false) + { + m_RequestBuffer.prepare(64 * 1024); + + asio::async_read( + *m_Socket.get(), + m_RequestBuffer, + asio::transfer_at_least(1), + [Conn = AsSharedPtr()](const asio::error_code& Ec, std::size_t ByteCount) { Conn->OnDataReceived(Ec, ByteCount); }); + } +} + +void +AsioTransportConnection::OnDataReceived(const asio::error_code& Ec, std::size_t ByteCount) +{ + ZEN_UNUSED(ByteCount); + + if (Ec) + { + if (!m_IsTerminated.test()) + { + ZEN_WARN("on data received ERROR, connection: {}, reason '{}'", m_ConnectionId, Ec.message()); + } + + const bool Receive = true; + const bool Transmit = true; + return Shutdown(Receive, Transmit); + } + + while (m_RequestBuffer.size()) + { + const asio::const_buffer& InputBuffer = m_RequestBuffer.data(); + m_ConnectionHandler->OnBytesRead(InputBuffer.data(), InputBuffer.size()); + m_RequestBuffer.consume(InputBuffer.size()); + } + + EnqueueRead(); +} + +void +AsioTransportConnection::OnResponseDataSent(const asio::error_code& Ec, std::size_t ByteCount) +{ + ZEN_UNUSED(ByteCount); + + if (Ec) + { + ZEN_WARN("on data sent ERROR, connection: {}, reason '{}'", m_ConnectionId, Ec.message()); + + const bool Receive = true; + const bool Transmit = true; + return Shutdown(Receive, Transmit); + } +} + +////////////////////////////////////////////////////////////////////////// + +AsioTransportPlugin::AsioTransportPlugin(uint16_t BasePort, unsigned int ThreadCount) +: m_BasePort(BasePort) +, m_ThreadCount(ThreadCount != 0 ? ThreadCount : Max(std::thread::hardware_concurrency(), 8u)) +{ +} + +AsioTransportPlugin::~AsioTransportPlugin() +{ +} + +uint32_t +AsioTransportPlugin::AddRef() const +{ + return RefCounted::AddRef(); +} + +uint32_t +AsioTransportPlugin::Release() const +{ + return RefCounted::Release(); +} + +void +AsioTransportPlugin::Initialize(TransportServer* ServerInterface) +{ + ZEN_ASSERT(m_ThreadCount > 0); + ZEN_ASSERT(ServerInterface); + + ZEN_INFO("starting asio http with {} service threads", m_ThreadCount); + + m_Acceptor.reset(new AsioTransportAcceptor(ServerInterface, m_IoService, m_BasePort)); + m_Acceptor->Start(); + + // This should consist of a set of minimum threads and grow on demand to + // meet concurrency needs? Right now we end up allocating a large number + // of threads even if we never end up using all of them, which seems + // wasteful. It's also not clear how the demand for concurrency should + // be balanced with the engine side - ideally we'd have some kind of + // global scheduling to prevent one side from starving the other side + // and thus preventing progress. Or at the very least, thread priorities + // should be considered + + for (int i = 0; i < m_ThreadCount; ++i) + { + m_ThreadPool.emplace_back([this, ThreadNumber = i + 1] { + SetCurrentThreadName(fmt::format("asio_thr_{}", ThreadNumber)); + + try + { + m_IoService.run(); + } + catch (std::exception& e) + { + ZEN_ERROR("exception caught in asio event loop: {}", e.what()); + } + }); + } + + ZEN_INFO("asio http transport started (port {})", m_Acceptor->GetAcceptPort()); +} + +void +AsioTransportPlugin::Shutdown() +{ + m_Acceptor->RequestStop(); + m_IoService.stop(); + + for (auto& Thread : m_ThreadPool) + { + Thread.join(); + } +} + +bool +AsioTransportPlugin::IsAvailable() +{ + return true; +} + +TransportPlugin* +CreateAsioTransportPlugin(uint16_t BasePort, unsigned int ThreadCount) +{ + return new AsioTransportPlugin(BasePort, ThreadCount); +} + +} // namespace zen + +#endif diff --git a/src/zenhttp/transports/asiotransport.h b/src/zenhttp/transports/asiotransport.h new file mode 100644 index 000000000..b10174b85 --- /dev/null +++ b/src/zenhttp/transports/asiotransport.h @@ -0,0 +1,15 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include + +#if ZEN_WITH_PLUGINS + +namespace zen { + +TransportPlugin* CreateAsioTransportPlugin(uint16_t BasePort, unsigned int ThreadCount); + +} // namespace zen + +#endif diff --git a/src/zenhttp/transports/dlltransport.cpp b/src/zenhttp/transports/dlltransport.cpp new file mode 100644 index 000000000..04fb6caaa --- /dev/null +++ b/src/zenhttp/transports/dlltransport.cpp @@ -0,0 +1,250 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "dlltransport.h" + +#include +#include +#include + +#include +#include +#include + +#if ZEN_WITH_PLUGINS + +namespace zen { + +struct DllTransportConnection : public TransportConnection +{ +public: + DllTransportConnection(); + ~DllTransportConnection(); + + void Initialize(TransportServerConnection& ServerConnection); + void HandleConnection(); + + // TransportConnection + + virtual int64_t WriteBytes(const void* Buffer, size_t DataSize) override; + virtual void Shutdown(bool Receive, bool Transmit) override; + virtual void CloseConnection() override; + +private: + Ref m_ConnectionHandler; + bool m_IsTerminated = false; +}; + +DllTransportConnection::DllTransportConnection() +{ +} + +DllTransportConnection::~DllTransportConnection() +{ +} + +void +DllTransportConnection::Initialize(TransportServerConnection& ServerConnection) +{ + m_ConnectionHandler = &ServerConnection; // TODO: this is awkward +} + +void +DllTransportConnection::HandleConnection() +{ +} + +void +DllTransportConnection::CloseConnection() +{ + if (m_IsTerminated) + { + return; + } + + m_IsTerminated = true; +} + +int64_t +DllTransportConnection::WriteBytes(const void* Buffer, size_t DataSize) +{ + ZEN_UNUSED(Buffer, DataSize); + return DataSize; +} + +void +DllTransportConnection::Shutdown(bool Receive, bool Transmit) +{ + ZEN_UNUSED(Receive, Transmit); +} + +////////////////////////////////////////////////////////////////////////// + +struct LoadedDll +{ + std::string Name; + std::filesystem::path LoadedFromPath; + Ref Plugin; +}; + +class DllTransportPluginImpl +{ +public: + DllTransportPluginImpl(uint16_t BasePort, unsigned int ThreadCount); + ~DllTransportPluginImpl(); + + uint16_t Start(TransportServer* ServerInterface); + void Stop(); + bool IsAvailable(); + void LoadDll(std::string_view Name); + +private: + TransportServer* m_ServerInterface = nullptr; + RwLock m_Lock; + std::vector m_Transports; + uint16_t m_BasePort = 0; + int m_ThreadCount = 0; +}; + +DllTransportPluginImpl::DllTransportPluginImpl(uint16_t BasePort, unsigned int ThreadCount) +: m_BasePort(BasePort) +, m_ThreadCount(ThreadCount != 0 ? ThreadCount : Max(std::thread::hardware_concurrency(), 8u)) +{ +} + +DllTransportPluginImpl::~DllTransportPluginImpl() +{ +} + +uint16_t +DllTransportPluginImpl::Start(TransportServer* ServerIface) +{ + m_ServerInterface = ServerIface; + + RwLock::ExclusiveLockScope _(m_Lock); + + for (LoadedDll& Transport : m_Transports) + { + try + { + Transport.Plugin->Initialize(ServerIface); + } + catch (const std::exception&) + { + // TODO: report + } + } + + return m_BasePort; +} + +void +DllTransportPluginImpl::Stop() +{ + RwLock::ExclusiveLockScope _(m_Lock); + + for (LoadedDll& Transport : m_Transports) + { + try + { + Transport.Plugin->Shutdown(); + } + catch (const std::exception&) + { + // TODO: report + } + } +} + +bool +DllTransportPluginImpl::IsAvailable() +{ + return true; +} + +void +DllTransportPluginImpl::LoadDll(std::string_view Name) +{ + ExtendableStringBuilder<128> DllPath; + DllPath << Name << ".dll"; + HMODULE DllHandle = LoadLibraryA(DllPath.c_str()); + + if (!DllHandle) + { + std::error_code Ec = MakeErrorCodeFromLastError(); + + throw std::system_error(Ec, fmt::format("failed to load transport DLL from '{}'", DllPath)); + } + + TransportPlugin* CreateTransportPlugin(); + + PfnCreateTransportPlugin CreatePlugin = (PfnCreateTransportPlugin)GetProcAddress(DllHandle, "CreateTransportPlugin"); + + if (!CreatePlugin) + { + std::error_code Ec = MakeErrorCodeFromLastError(); + + FreeLibrary(DllHandle); + + throw std::system_error(Ec, fmt::format("API mismatch detected in transport DLL loaded from '{}'", DllPath)); + } + + LoadedDll NewDll; + + NewDll.Name = Name; + NewDll.LoadedFromPath = DllPath.c_str(); + NewDll.Plugin = CreatePlugin(); + + m_Transports.emplace_back(std::move(NewDll)); +} + +////////////////////////////////////////////////////////////////////////// + +DllTransportPlugin::DllTransportPlugin(uint16_t BasePort, unsigned int ThreadCount) +: m_Impl(std::make_unique(BasePort, ThreadCount)) +{ +} + +DllTransportPlugin::~DllTransportPlugin() +{ + m_Impl->Stop(); +} + +uint32_t +DllTransportPlugin::AddRef() const +{ + return RefCounted::AddRef(); +} + +uint32_t +DllTransportPlugin::Release() const +{ + return RefCounted::Release(); +} + +void +DllTransportPlugin::Initialize(TransportServer* ServerInterface) +{ + m_Impl->Start(ServerInterface); +} + +void +DllTransportPlugin::Shutdown() +{ + m_Impl->Stop(); +} + +bool +DllTransportPlugin::IsAvailable() +{ + return m_Impl->IsAvailable(); +} + +void +DllTransportPlugin::LoadDll(std::string_view Name) +{ + return m_Impl->LoadDll(Name); +} + +} // namespace zen + +#endif diff --git a/src/zenhttp/transports/dlltransport.h b/src/zenhttp/transports/dlltransport.h new file mode 100644 index 000000000..2dccdd0f9 --- /dev/null +++ b/src/zenhttp/transports/dlltransport.h @@ -0,0 +1,37 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include + +#if ZEN_WITH_PLUGINS + +namespace zen { + +class DllTransportPluginImpl; + +/** Transport plugin which supports dynamic loading of external transport + * provider modules + */ +class DllTransportPlugin : public TransportPlugin, RefCounted +{ +public: + DllTransportPlugin(uint16_t BasePort, unsigned int ThreadCount); + ~DllTransportPlugin(); + + virtual uint32_t AddRef() const override; + virtual uint32_t Release() const override; + + virtual void Initialize(TransportServer* ServerInterface) override; + virtual void Shutdown() override; + virtual bool IsAvailable() override; + + void LoadDll(std::string_view Name); + +private: + std::unique_ptr m_Impl; +}; + +} // namespace zen + +#endif diff --git a/src/zenhttp/transports/winsocktransport.cpp b/src/zenhttp/transports/winsocktransport.cpp new file mode 100644 index 000000000..ad3302550 --- /dev/null +++ b/src/zenhttp/transports/winsocktransport.cpp @@ -0,0 +1,367 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "winsocktransport.h" + +#if ZEN_WITH_PLUGINS + +# include +# include +# include + +# if ZEN_PLATFORM_WINDOWS +# include +ZEN_THIRD_PARTY_INCLUDES_START +# include +# include +ZEN_THIRD_PARTY_INCLUDES_END +# endif + +# include + +namespace zen { + +class SocketTransportPluginImpl; + +class SocketTransportPlugin : public TransportPlugin, RefCounted +{ +public: + SocketTransportPlugin(uint16_t BasePort, unsigned int ThreadCount); + ~SocketTransportPlugin(); + + virtual uint32_t AddRef() const override; + virtual uint32_t Release() const override; + virtual void Initialize(TransportServer* ServerInterface) override; + virtual void Shutdown() override; + virtual bool IsAvailable() override; + +private: + bool m_IsOk = true; + uint16_t m_BasePort = 0; + int m_ThreadCount = 0; + SocketTransportPluginImpl* m_Impl; +}; + +struct SocketTransportConnection : public TransportConnection +{ +public: + SocketTransportConnection(); + ~SocketTransportConnection(); + + void Initialize(TransportServerConnection* ServerConnection, SOCKET ClientSocket); + void HandleConnection(); + + // TransportConnection + + virtual int64_t WriteBytes(const void* Buffer, size_t DataSize) override; + virtual void Shutdown(bool Receive, bool Transmit) override; + virtual void CloseConnection() override; + +private: + Ref m_ConnectionHandler; + SOCKET m_ClientSocket{}; + bool m_IsTerminated = false; +}; + +SocketTransportConnection::SocketTransportConnection() +{ +} + +SocketTransportConnection::~SocketTransportConnection() +{ +} + +void +SocketTransportConnection::Initialize(TransportServerConnection* ServerConnection, SOCKET ClientSocket) +{ + ZEN_ASSERT(!m_ConnectionHandler); + + m_ConnectionHandler = ServerConnection; + m_ClientSocket = ClientSocket; +} + +void +SocketTransportConnection::HandleConnection() +{ + ZEN_ASSERT(m_ConnectionHandler); + + const int InputBufferSize = 64 * 1024; + uint8_t* InputBuffer = new uint8_t[64 * 1024]; + auto _ = MakeGuard([&] { delete[] InputBuffer; }); + + do + { + const int RecvBytes = recv(m_ClientSocket, (char*)InputBuffer, InputBufferSize, /* flags */ 0); + + if (RecvBytes == 0) + { + // Connection closed + return CloseConnection(); + } + else if (RecvBytes < 0) + { + // Error + return CloseConnection(); + } + + m_ConnectionHandler->OnBytesRead(InputBuffer, RecvBytes); + } while (m_ClientSocket); +} + +void +SocketTransportConnection::CloseConnection() +{ + if (m_IsTerminated) + { + return; + } + + ZEN_ASSERT(m_ClientSocket); + m_IsTerminated = true; + + shutdown(m_ClientSocket, SD_BOTH); // We won't be sending or receiving any more data + + closesocket(m_ClientSocket); + m_ClientSocket = 0; +} + +int64_t +SocketTransportConnection::WriteBytes(const void* Buffer, size_t DataSize) +{ + const uint8_t* BufferCursor = reinterpret_cast(Buffer); + int64_t TotalBytesSent = 0; + + while (DataSize) + { + const int MaxBlockSize = 128 * 1024; + const int SendBlockSize = (DataSize > MaxBlockSize) ? MaxBlockSize : (int)DataSize; + const int SentBytes = send(m_ClientSocket, (const char*)BufferCursor, SendBlockSize, /* flags */ 0); + + if (SentBytes < 0) + { + // Error + return SentBytes; + } + + BufferCursor += SentBytes; + DataSize -= SentBytes; + TotalBytesSent += SentBytes; + } + + return TotalBytesSent; +} + +void +SocketTransportConnection::Shutdown(bool Receive, bool Transmit) +{ + if (Receive) + { + if (Transmit) + { + shutdown(m_ClientSocket, SD_BOTH); + } + else + { + shutdown(m_ClientSocket, SD_RECEIVE); + } + } + else if (Transmit) + { + shutdown(m_ClientSocket, SD_SEND); + } +} + +////////////////////////////////////////////////////////////////////////// + +class SocketTransportPluginImpl +{ +public: + SocketTransportPluginImpl(uint16_t BasePort, unsigned int ThreadCount); + ~SocketTransportPluginImpl(); + + uint16_t Start(uint16_t Port, TransportServer* ServerInterface); + void Stop(); + +private: + TransportServer* m_ServerInterface = nullptr; + uint16_t m_BasePort = 0; + int m_ThreadCount = 0; + bool m_IsOk = true; + + SOCKET m_ListenSocket{}; + std::thread m_AcceptThread; + std::atomic_flag m_KeepRunning; + std::unique_ptr m_WorkerThreadpool; +}; + +SocketTransportPluginImpl::SocketTransportPluginImpl(uint16_t BasePort, unsigned int ThreadCount) +: m_BasePort(BasePort) +, m_ThreadCount(ThreadCount) +{ +# if ZEN_PLATFORM_WINDOWS + WSADATA wsaData; + if (int Result = WSAStartup(0x202, &wsaData); Result != 0) + { + m_IsOk = false; + WSACleanup(); + } +# endif + + m_WorkerThreadpool = std::make_unique(m_ThreadCount, "http_conn"); +} + +SocketTransportPluginImpl::~SocketTransportPluginImpl() +{ + Stop(); + +# if ZEN_PLATFORM_WINDOWS + if (m_IsOk) + { + WSACleanup(); + } +# endif +} + +uint16_t +SocketTransportPluginImpl::Start(uint16_t Port, TransportServer* ServerInterface) +{ + m_ServerInterface = ServerInterface; + m_ListenSocket = socket(AF_INET6, SOCK_STREAM, 0); + + if (m_ListenSocket == SOCKET_ERROR || m_ListenSocket == INVALID_SOCKET) + { + ZEN_ERROR("socket creation failed in HTTP plugin server init: {}", WSAGetLastError()); + + return 0; + } + + sockaddr_in6 Server{}; + Server.sin6_family = AF_INET6; + Server.sin6_port = htons(Port); + Server.sin6_addr = in6addr_any; + + if (int Result = bind(m_ListenSocket, (sockaddr*)&Server, sizeof(Server)); Result == SOCKET_ERROR) + { + ZEN_ERROR("bind call failed in HTTP plugin server init: {}", WSAGetLastError()); + + return 0; + } + + if (int Result = listen(m_ListenSocket, AF_INET6); Result == SOCKET_ERROR) + { + ZEN_ERROR("listen call failed in HTTP plugin server init: {}", WSAGetLastError()); + + return 0; + } + + m_KeepRunning.test_and_set(); + + m_AcceptThread = std::thread([&] { + SetCurrentThreadName("http_plugin_acceptor"); + + ZEN_INFO("HTTP plugin server waiting for connections"); + + do + { + if (SOCKET ClientSocket = accept(m_ListenSocket, NULL, NULL); ClientSocket != SOCKET_ERROR) + { + int Flag = 1; + setsockopt(ClientSocket, IPPROTO_TCP, TCP_NODELAY, (char*)&Flag, sizeof(Flag)); + + // Handle new connection + SocketTransportConnection* Connection = new SocketTransportConnection(); + TransportServerConnection* ConnectionInterface{m_ServerInterface->CreateConnectionHandler(Connection)}; + Connection->Initialize(ConnectionInterface, ClientSocket); + + m_WorkerThreadpool->ScheduleWork([Connection] { + try + { + Connection->HandleConnection(); + } + catch (std::exception& Ex) + { + ZEN_WARN("exception caught in connection loop: {}", Ex.what()); + } + + delete Connection; + }); + } + else + { + } + } while (!IsApplicationExitRequested() && m_KeepRunning.test()); + + ZEN_INFO("HTTP plugin server accept thread exit"); + }); + + return Port; +} + +void +SocketTransportPluginImpl::Stop() +{ + // TODO: all pending/ongoing work should be drained here as well + + m_KeepRunning.clear(); + + closesocket(m_ListenSocket); + m_ListenSocket = 0; + + if (m_AcceptThread.joinable()) + { + m_AcceptThread.join(); + } +} + +////////////////////////////////////////////////////////////////////////// + +SocketTransportPlugin::SocketTransportPlugin(uint16_t BasePort, unsigned int ThreadCount) +: m_BasePort(BasePort) +, m_ThreadCount(ThreadCount != 0 ? ThreadCount : Max(std::thread::hardware_concurrency(), 8u)) +, m_Impl(new SocketTransportPluginImpl(BasePort, m_ThreadCount)) +{ +} + +SocketTransportPlugin::~SocketTransportPlugin() +{ + delete m_Impl; +} + +uint32_t +SocketTransportPlugin::AddRef() const +{ + return RefCounted::AddRef(); +} + +uint32_t +SocketTransportPlugin::Release() const +{ + return RefCounted::Release(); +} + +void +SocketTransportPlugin::Initialize(TransportServer* ServerInterface) +{ + m_Impl->Start(m_BasePort, ServerInterface); +} + +void +SocketTransportPlugin::Shutdown() +{ + m_Impl->Stop(); +} + +bool +SocketTransportPlugin::IsAvailable() +{ + return true; +} + +TransportPlugin* +CreateSocketTransportPlugin(uint16_t BasePort, unsigned int ThreadCount) +{ + return new SocketTransportPlugin(BasePort, ThreadCount); +} + +} // namespace zen + +#endif diff --git a/src/zenhttp/transports/winsocktransport.h b/src/zenhttp/transports/winsocktransport.h new file mode 100644 index 000000000..2b2a55aef --- /dev/null +++ b/src/zenhttp/transports/winsocktransport.h @@ -0,0 +1,15 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include + +#if ZEN_WITH_PLUGINS + +namespace zen { + +TransportPlugin* CreateSocketTransportPlugin(uint16_t BasePort, unsigned int ThreadCount); + +} // namespace zen + +#endif diff --git a/src/zenhttp/winsocktransport.cpp b/src/zenhttp/winsocktransport.cpp deleted file mode 100644 index ad3302550..000000000 --- a/src/zenhttp/winsocktransport.cpp +++ /dev/null @@ -1,367 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#include "winsocktransport.h" - -#if ZEN_WITH_PLUGINS - -# include -# include -# include - -# if ZEN_PLATFORM_WINDOWS -# include -ZEN_THIRD_PARTY_INCLUDES_START -# include -# include -ZEN_THIRD_PARTY_INCLUDES_END -# endif - -# include - -namespace zen { - -class SocketTransportPluginImpl; - -class SocketTransportPlugin : public TransportPlugin, RefCounted -{ -public: - SocketTransportPlugin(uint16_t BasePort, unsigned int ThreadCount); - ~SocketTransportPlugin(); - - virtual uint32_t AddRef() const override; - virtual uint32_t Release() const override; - virtual void Initialize(TransportServer* ServerInterface) override; - virtual void Shutdown() override; - virtual bool IsAvailable() override; - -private: - bool m_IsOk = true; - uint16_t m_BasePort = 0; - int m_ThreadCount = 0; - SocketTransportPluginImpl* m_Impl; -}; - -struct SocketTransportConnection : public TransportConnection -{ -public: - SocketTransportConnection(); - ~SocketTransportConnection(); - - void Initialize(TransportServerConnection* ServerConnection, SOCKET ClientSocket); - void HandleConnection(); - - // TransportConnection - - virtual int64_t WriteBytes(const void* Buffer, size_t DataSize) override; - virtual void Shutdown(bool Receive, bool Transmit) override; - virtual void CloseConnection() override; - -private: - Ref m_ConnectionHandler; - SOCKET m_ClientSocket{}; - bool m_IsTerminated = false; -}; - -SocketTransportConnection::SocketTransportConnection() -{ -} - -SocketTransportConnection::~SocketTransportConnection() -{ -} - -void -SocketTransportConnection::Initialize(TransportServerConnection* ServerConnection, SOCKET ClientSocket) -{ - ZEN_ASSERT(!m_ConnectionHandler); - - m_ConnectionHandler = ServerConnection; - m_ClientSocket = ClientSocket; -} - -void -SocketTransportConnection::HandleConnection() -{ - ZEN_ASSERT(m_ConnectionHandler); - - const int InputBufferSize = 64 * 1024; - uint8_t* InputBuffer = new uint8_t[64 * 1024]; - auto _ = MakeGuard([&] { delete[] InputBuffer; }); - - do - { - const int RecvBytes = recv(m_ClientSocket, (char*)InputBuffer, InputBufferSize, /* flags */ 0); - - if (RecvBytes == 0) - { - // Connection closed - return CloseConnection(); - } - else if (RecvBytes < 0) - { - // Error - return CloseConnection(); - } - - m_ConnectionHandler->OnBytesRead(InputBuffer, RecvBytes); - } while (m_ClientSocket); -} - -void -SocketTransportConnection::CloseConnection() -{ - if (m_IsTerminated) - { - return; - } - - ZEN_ASSERT(m_ClientSocket); - m_IsTerminated = true; - - shutdown(m_ClientSocket, SD_BOTH); // We won't be sending or receiving any more data - - closesocket(m_ClientSocket); - m_ClientSocket = 0; -} - -int64_t -SocketTransportConnection::WriteBytes(const void* Buffer, size_t DataSize) -{ - const uint8_t* BufferCursor = reinterpret_cast(Buffer); - int64_t TotalBytesSent = 0; - - while (DataSize) - { - const int MaxBlockSize = 128 * 1024; - const int SendBlockSize = (DataSize > MaxBlockSize) ? MaxBlockSize : (int)DataSize; - const int SentBytes = send(m_ClientSocket, (const char*)BufferCursor, SendBlockSize, /* flags */ 0); - - if (SentBytes < 0) - { - // Error - return SentBytes; - } - - BufferCursor += SentBytes; - DataSize -= SentBytes; - TotalBytesSent += SentBytes; - } - - return TotalBytesSent; -} - -void -SocketTransportConnection::Shutdown(bool Receive, bool Transmit) -{ - if (Receive) - { - if (Transmit) - { - shutdown(m_ClientSocket, SD_BOTH); - } - else - { - shutdown(m_ClientSocket, SD_RECEIVE); - } - } - else if (Transmit) - { - shutdown(m_ClientSocket, SD_SEND); - } -} - -////////////////////////////////////////////////////////////////////////// - -class SocketTransportPluginImpl -{ -public: - SocketTransportPluginImpl(uint16_t BasePort, unsigned int ThreadCount); - ~SocketTransportPluginImpl(); - - uint16_t Start(uint16_t Port, TransportServer* ServerInterface); - void Stop(); - -private: - TransportServer* m_ServerInterface = nullptr; - uint16_t m_BasePort = 0; - int m_ThreadCount = 0; - bool m_IsOk = true; - - SOCKET m_ListenSocket{}; - std::thread m_AcceptThread; - std::atomic_flag m_KeepRunning; - std::unique_ptr m_WorkerThreadpool; -}; - -SocketTransportPluginImpl::SocketTransportPluginImpl(uint16_t BasePort, unsigned int ThreadCount) -: m_BasePort(BasePort) -, m_ThreadCount(ThreadCount) -{ -# if ZEN_PLATFORM_WINDOWS - WSADATA wsaData; - if (int Result = WSAStartup(0x202, &wsaData); Result != 0) - { - m_IsOk = false; - WSACleanup(); - } -# endif - - m_WorkerThreadpool = std::make_unique(m_ThreadCount, "http_conn"); -} - -SocketTransportPluginImpl::~SocketTransportPluginImpl() -{ - Stop(); - -# if ZEN_PLATFORM_WINDOWS - if (m_IsOk) - { - WSACleanup(); - } -# endif -} - -uint16_t -SocketTransportPluginImpl::Start(uint16_t Port, TransportServer* ServerInterface) -{ - m_ServerInterface = ServerInterface; - m_ListenSocket = socket(AF_INET6, SOCK_STREAM, 0); - - if (m_ListenSocket == SOCKET_ERROR || m_ListenSocket == INVALID_SOCKET) - { - ZEN_ERROR("socket creation failed in HTTP plugin server init: {}", WSAGetLastError()); - - return 0; - } - - sockaddr_in6 Server{}; - Server.sin6_family = AF_INET6; - Server.sin6_port = htons(Port); - Server.sin6_addr = in6addr_any; - - if (int Result = bind(m_ListenSocket, (sockaddr*)&Server, sizeof(Server)); Result == SOCKET_ERROR) - { - ZEN_ERROR("bind call failed in HTTP plugin server init: {}", WSAGetLastError()); - - return 0; - } - - if (int Result = listen(m_ListenSocket, AF_INET6); Result == SOCKET_ERROR) - { - ZEN_ERROR("listen call failed in HTTP plugin server init: {}", WSAGetLastError()); - - return 0; - } - - m_KeepRunning.test_and_set(); - - m_AcceptThread = std::thread([&] { - SetCurrentThreadName("http_plugin_acceptor"); - - ZEN_INFO("HTTP plugin server waiting for connections"); - - do - { - if (SOCKET ClientSocket = accept(m_ListenSocket, NULL, NULL); ClientSocket != SOCKET_ERROR) - { - int Flag = 1; - setsockopt(ClientSocket, IPPROTO_TCP, TCP_NODELAY, (char*)&Flag, sizeof(Flag)); - - // Handle new connection - SocketTransportConnection* Connection = new SocketTransportConnection(); - TransportServerConnection* ConnectionInterface{m_ServerInterface->CreateConnectionHandler(Connection)}; - Connection->Initialize(ConnectionInterface, ClientSocket); - - m_WorkerThreadpool->ScheduleWork([Connection] { - try - { - Connection->HandleConnection(); - } - catch (std::exception& Ex) - { - ZEN_WARN("exception caught in connection loop: {}", Ex.what()); - } - - delete Connection; - }); - } - else - { - } - } while (!IsApplicationExitRequested() && m_KeepRunning.test()); - - ZEN_INFO("HTTP plugin server accept thread exit"); - }); - - return Port; -} - -void -SocketTransportPluginImpl::Stop() -{ - // TODO: all pending/ongoing work should be drained here as well - - m_KeepRunning.clear(); - - closesocket(m_ListenSocket); - m_ListenSocket = 0; - - if (m_AcceptThread.joinable()) - { - m_AcceptThread.join(); - } -} - -////////////////////////////////////////////////////////////////////////// - -SocketTransportPlugin::SocketTransportPlugin(uint16_t BasePort, unsigned int ThreadCount) -: m_BasePort(BasePort) -, m_ThreadCount(ThreadCount != 0 ? ThreadCount : Max(std::thread::hardware_concurrency(), 8u)) -, m_Impl(new SocketTransportPluginImpl(BasePort, m_ThreadCount)) -{ -} - -SocketTransportPlugin::~SocketTransportPlugin() -{ - delete m_Impl; -} - -uint32_t -SocketTransportPlugin::AddRef() const -{ - return RefCounted::AddRef(); -} - -uint32_t -SocketTransportPlugin::Release() const -{ - return RefCounted::Release(); -} - -void -SocketTransportPlugin::Initialize(TransportServer* ServerInterface) -{ - m_Impl->Start(m_BasePort, ServerInterface); -} - -void -SocketTransportPlugin::Shutdown() -{ - m_Impl->Stop(); -} - -bool -SocketTransportPlugin::IsAvailable() -{ - return true; -} - -TransportPlugin* -CreateSocketTransportPlugin(uint16_t BasePort, unsigned int ThreadCount) -{ - return new SocketTransportPlugin(BasePort, ThreadCount); -} - -} // namespace zen - -#endif diff --git a/src/zenhttp/winsocktransport.h b/src/zenhttp/winsocktransport.h deleted file mode 100644 index 2b2a55aef..000000000 --- a/src/zenhttp/winsocktransport.h +++ /dev/null @@ -1,15 +0,0 @@ -// Copyright Epic Games, Inc. All Rights Reserved. - -#pragma once - -#include - -#if ZEN_WITH_PLUGINS - -namespace zen { - -TransportPlugin* CreateSocketTransportPlugin(uint16_t BasePort, unsigned int ThreadCount); - -} // namespace zen - -#endif diff --git a/src/zenhttp/xmake.lua b/src/zenhttp/xmake.lua index 411436b16..9c3869911 100644 --- a/src/zenhttp/xmake.lua +++ b/src/zenhttp/xmake.lua @@ -7,7 +7,7 @@ target('zenhttp') add_files("**.cpp") add_files("httpsys.cpp", {unity_ignored=true}) add_includedirs("include", {public=true}) - add_deps("zencore") + add_deps("zencore", "plugins") add_packages( "vcpkg::cpr", "vcpkg::curl", -- required by cpr -- cgit v1.2.3