diff options
| author | Stefan Boberg <[email protected]> | 2023-10-11 14:59:25 +0200 |
|---|---|---|
| committer | GitHub <[email protected]> | 2023-10-11 14:59:25 +0200 |
| commit | 11f7f70b825c5b6784f5e2609463a1a9d1a0dabc (patch) | |
| tree | 98f65537f52327c354193afa98a29f9f838b42ff /src/zenhttp/transports | |
| parent | hide HttpAsioServer interface behind factory function (#463) (diff) | |
| download | zen-11f7f70b825c5b6784f5e2609463a1a9d1a0dabc.tar.xz zen-11f7f70b825c5b6784f5e2609463a1a9d1a0dabc.zip | |
pluggable asio transport (#460)
added pluggable transport based on asio. This is in an experimental state and is not yet a replacement for httpasio even though that is the ultimate goal
also moved plugin API header into dedicated part of the tree to clarify that it is meant to be usable in isolation, without any dependency on zencore et al
moved transport implementations into dedicated source directory in zenhttp
note that this adds code to the build but nothing should change at runtime since the instantiation of the new code is conditional and is inactive by default
Diffstat (limited to 'src/zenhttp/transports')
| -rw-r--r-- | src/zenhttp/transports/asiotransport.cpp | 439 | ||||
| -rw-r--r-- | src/zenhttp/transports/asiotransport.h | 15 | ||||
| -rw-r--r-- | src/zenhttp/transports/dlltransport.cpp | 250 | ||||
| -rw-r--r-- | src/zenhttp/transports/dlltransport.h | 37 | ||||
| -rw-r--r-- | src/zenhttp/transports/winsocktransport.cpp | 367 | ||||
| -rw-r--r-- | src/zenhttp/transports/winsocktransport.h | 15 |
6 files changed, 1123 insertions, 0 deletions
diff --git a/src/zenhttp/transports/asiotransport.cpp b/src/zenhttp/transports/asiotransport.cpp new file mode 100644 index 000000000..b8fef8f5f --- /dev/null +++ b/src/zenhttp/transports/asiotransport.cpp @@ -0,0 +1,439 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "winsocktransport.h" + +#if ZEN_WITH_PLUGINS + +# include <zencore/logging.h> +# include <zencore/scopeguard.h> +# include <zencore/workthreadpool.h> + +ZEN_THIRD_PARTY_INCLUDES_START +# include <asio.hpp> +ZEN_THIRD_PARTY_INCLUDES_END + +# if ZEN_PLATFORM_WINDOWS +# include <zencore/windows.h> +ZEN_THIRD_PARTY_INCLUDES_START +# include <mstcpip.h> +ZEN_THIRD_PARTY_INCLUDES_END +# endif + +# include <fmt/format.h> + +# include <memory> +# include <thread> + +namespace zen { + +struct AsioTransportAcceptor; + +class AsioTransportPlugin : public TransportPlugin, RefCounted +{ +public: + AsioTransportPlugin(uint16_t BasePort, unsigned int ThreadCount); + ~AsioTransportPlugin(); + + virtual uint32_t AddRef() const override; + virtual uint32_t Release() const override; + virtual void Initialize(TransportServer* ServerInterface) override; + virtual void Shutdown() override; + virtual bool IsAvailable() override; + +private: + bool m_IsOk = true; + uint16_t m_BasePort = 0; + int m_ThreadCount = 0; + + asio::io_service m_IoService; + asio::io_service::work m_Work{m_IoService}; + std::unique_ptr<AsioTransportAcceptor> m_Acceptor; + std::vector<std::thread> m_ThreadPool; +}; + +struct AsioTransportConnection : public TransportConnection, std::enable_shared_from_this<AsioTransportConnection> +{ + AsioTransportConnection(std::unique_ptr<asio::ip::tcp::socket>&& Socket); + ~AsioTransportConnection(); + + void Initialize(TransportServerConnection* ConnectionHandler); + + std::shared_ptr<AsioTransportConnection> AsSharedPtr() { return shared_from_this(); } + + // TransportConnectionInterface + + virtual int64_t WriteBytes(const void* Buffer, size_t DataSize) override; + virtual void Shutdown(bool Receive, bool Transmit) override; + virtual void CloseConnection() override; + +private: + void EnqueueRead(); + void OnDataReceived(const asio::error_code& Ec, std::size_t ByteCount); + void OnResponseDataSent(const asio::error_code& Ec, std::size_t ByteCount); + + Ref<TransportServerConnection> m_ConnectionHandler; + asio::streambuf m_RequestBuffer; + std::unique_ptr<asio::ip::tcp::socket> m_Socket; + uint32_t m_ConnectionId = 0; + std::atomic_flag m_IsTerminated{}; +}; + +////////////////////////////////////////////////////////////////////////// + +struct AsioTransportAcceptor +{ + AsioTransportAcceptor(TransportServer* ServerInterface, asio::io_service& IoService, uint16_t BasePort); + ~AsioTransportAcceptor(); + + void Start(); + void RequestStop(); + + inline int GetAcceptPort() { return m_Acceptor.local_endpoint().port(); } + +private: + TransportServer* m_ServerInterface = nullptr; + asio::io_service& m_IoService; + asio::ip::tcp::acceptor m_Acceptor; + std::atomic<bool> m_IsStopped{false}; + + void EnqueueAccept(); +}; + +////////////////////////////////////////////////////////////////////////// + +AsioTransportAcceptor::AsioTransportAcceptor(TransportServer* ServerInterface, asio::io_service& IoService, uint16_t BasePort) +: m_ServerInterface(ServerInterface) +, m_IoService(IoService) +, m_Acceptor(m_IoService, asio::ip::tcp::v6()) +{ + m_Acceptor.set_option(asio::ip::v6_only(false)); + +# if ZEN_PLATFORM_WINDOWS + // Special option for Windows settings as !asio::socket_base::reuse_address is not the same as exclusive access on Windows platforms + typedef asio::detail::socket_option::boolean<ASIO_OS_DEF(SOL_SOCKET), SO_EXCLUSIVEADDRUSE> exclusive_address; + m_Acceptor.set_option(exclusive_address(true)); +# else + m_Acceptor.set_option(asio::socket_base::reuse_address(false)); +# endif + + m_Acceptor.set_option(asio::ip::tcp::no_delay(true)); + m_Acceptor.set_option(asio::socket_base::receive_buffer_size(128 * 1024)); + m_Acceptor.set_option(asio::socket_base::send_buffer_size(256 * 1024)); + + uint16_t EffectivePort = BasePort; + + asio::error_code BindErrorCode; + m_Acceptor.bind(asio::ip::tcp::endpoint(asio::ip::address_v6::any(), EffectivePort), BindErrorCode); + // Sharing violation implies the port is being used by another process + for (uint16_t PortOffset = 1; (BindErrorCode == asio::error::address_in_use) && (PortOffset < 10); ++PortOffset) + { + EffectivePort = BasePort + (PortOffset * 100); + m_Acceptor.bind(asio::ip::tcp::endpoint(asio::ip::address_v6::any(), EffectivePort), BindErrorCode); + } + if (BindErrorCode == asio::error::access_denied) + { + EffectivePort = 0; + m_Acceptor.bind(asio::ip::tcp::endpoint(asio::ip::address_v6::any(), EffectivePort), BindErrorCode); + } + if (BindErrorCode) + { + ZEN_ERROR("Unable open asio service, error '{}'", BindErrorCode.message()); + } + +# if ZEN_PLATFORM_WINDOWS + // On Windows, loopback connections can take advantage of a faster code path optionally with this flag. + // This must be used by both the client and server side, and is only effective in the absence of + // Windows Filtering Platform (WFP) callouts which can be installed by security software. + // https://docs.microsoft.com/en-us/windows/win32/winsock/sio-loopback-fast-path + SOCKET NativeSocket = m_Acceptor.native_handle(); + int LoopbackOptionValue = 1; + DWORD OptionNumberOfBytesReturned = 0; + WSAIoctl(NativeSocket, + SIO_LOOPBACK_FAST_PATH, + &LoopbackOptionValue, + sizeof(LoopbackOptionValue), + NULL, + 0, + &OptionNumberOfBytesReturned, + 0, + 0); +# endif + + ZEN_INFO("started asio transport at port: {}", EffectivePort); +} + +AsioTransportAcceptor::~AsioTransportAcceptor() +{ +} + +void +AsioTransportAcceptor::Start() +{ + m_Acceptor.listen(); + + EnqueueAccept(); +} + +void +AsioTransportAcceptor::RequestStop() +{ + m_IsStopped = true; +} + +void +AsioTransportAcceptor::EnqueueAccept() +{ + auto SocketPtr = std::make_unique<asio::ip::tcp::socket>(m_IoService); + asio::ip::tcp::socket& SocketRef = *SocketPtr.get(); + + m_Acceptor.async_accept(SocketRef, [this, Socket = std::move(SocketPtr)](const asio::error_code& Ec) mutable { + if (Ec) + { + ZEN_WARN("asio async_accept error ({}:{}): {}", + m_Acceptor.local_endpoint().address().to_string(), + m_Acceptor.local_endpoint().port(), + Ec.message()); + } + else + { + // New connection established, pass socket ownership into connection object + // and initiate request handling loop. The connection lifetime is + // managed by the async read/write loop by passing the shared + // reference to the callbacks. + + Socket->set_option(asio::ip::tcp::no_delay(true)); + Socket->set_option(asio::socket_base::receive_buffer_size(128 * 1024)); + Socket->set_option(asio::socket_base::send_buffer_size(256 * 1024)); + + auto Conn = std::make_shared<AsioTransportConnection>(std::move(Socket)); + Conn->Initialize(m_ServerInterface->CreateConnectionHandler(Conn.get())); + } + + if (!m_IsStopped.load()) + { + EnqueueAccept(); + } + else + { + std::error_code CloseEc; + m_Acceptor.close(CloseEc); + + if (CloseEc) + { + ZEN_WARN("acceptor close ERROR, reason '{}'", CloseEc.message()); + } + } + }); +} + +////////////////////////////////////////////////////////////////////////// + +AsioTransportConnection::AsioTransportConnection(std::unique_ptr<asio::ip::tcp::socket>&& Socket) : m_Socket(std::move(Socket)) +{ +} + +AsioTransportConnection::~AsioTransportConnection() +{ +} + +void +AsioTransportConnection::Initialize(TransportServerConnection* ConnectionHandler) +{ + m_ConnectionHandler = ConnectionHandler; + + EnqueueRead(); +} + +int64_t +AsioTransportConnection::WriteBytes(const void* Buffer, size_t DataSize) +{ + size_t WrittenBytes = asio::write(*m_Socket.get(), asio::const_buffer(Buffer, DataSize), asio::transfer_exactly(DataSize)); + + return WrittenBytes; +} + +void +AsioTransportConnection::Shutdown(bool Receive, bool Transmit) +{ + std::error_code Ec; + if (Receive) + { + if (Transmit) + { + m_Socket->shutdown(asio::socket_base::shutdown_both, Ec); + } + else + { + m_Socket->shutdown(asio::socket_base::shutdown_receive, Ec); + } + } + else if (Transmit) + { + m_Socket->shutdown(asio::socket_base::shutdown_send, Ec); + } +} + +void +AsioTransportConnection::CloseConnection() +{ + if (m_IsTerminated.test()) + { + return; + } + + if (m_IsTerminated.test_and_set() == false) + { + Shutdown(true, true); + + std::error_code Ec; + m_Socket->close(Ec); + } +} + +void +AsioTransportConnection::EnqueueRead() +{ + if (m_IsTerminated.test() == false) + { + m_RequestBuffer.prepare(64 * 1024); + + asio::async_read( + *m_Socket.get(), + m_RequestBuffer, + asio::transfer_at_least(1), + [Conn = AsSharedPtr()](const asio::error_code& Ec, std::size_t ByteCount) { Conn->OnDataReceived(Ec, ByteCount); }); + } +} + +void +AsioTransportConnection::OnDataReceived(const asio::error_code& Ec, std::size_t ByteCount) +{ + ZEN_UNUSED(ByteCount); + + if (Ec) + { + if (!m_IsTerminated.test()) + { + ZEN_WARN("on data received ERROR, connection: {}, reason '{}'", m_ConnectionId, Ec.message()); + } + + const bool Receive = true; + const bool Transmit = true; + return Shutdown(Receive, Transmit); + } + + while (m_RequestBuffer.size()) + { + const asio::const_buffer& InputBuffer = m_RequestBuffer.data(); + m_ConnectionHandler->OnBytesRead(InputBuffer.data(), InputBuffer.size()); + m_RequestBuffer.consume(InputBuffer.size()); + } + + EnqueueRead(); +} + +void +AsioTransportConnection::OnResponseDataSent(const asio::error_code& Ec, std::size_t ByteCount) +{ + ZEN_UNUSED(ByteCount); + + if (Ec) + { + ZEN_WARN("on data sent ERROR, connection: {}, reason '{}'", m_ConnectionId, Ec.message()); + + const bool Receive = true; + const bool Transmit = true; + return Shutdown(Receive, Transmit); + } +} + +////////////////////////////////////////////////////////////////////////// + +AsioTransportPlugin::AsioTransportPlugin(uint16_t BasePort, unsigned int ThreadCount) +: m_BasePort(BasePort) +, m_ThreadCount(ThreadCount != 0 ? ThreadCount : Max(std::thread::hardware_concurrency(), 8u)) +{ +} + +AsioTransportPlugin::~AsioTransportPlugin() +{ +} + +uint32_t +AsioTransportPlugin::AddRef() const +{ + return RefCounted::AddRef(); +} + +uint32_t +AsioTransportPlugin::Release() const +{ + return RefCounted::Release(); +} + +void +AsioTransportPlugin::Initialize(TransportServer* ServerInterface) +{ + ZEN_ASSERT(m_ThreadCount > 0); + ZEN_ASSERT(ServerInterface); + + ZEN_INFO("starting asio http with {} service threads", m_ThreadCount); + + m_Acceptor.reset(new AsioTransportAcceptor(ServerInterface, m_IoService, m_BasePort)); + m_Acceptor->Start(); + + // This should consist of a set of minimum threads and grow on demand to + // meet concurrency needs? Right now we end up allocating a large number + // of threads even if we never end up using all of them, which seems + // wasteful. It's also not clear how the demand for concurrency should + // be balanced with the engine side - ideally we'd have some kind of + // global scheduling to prevent one side from starving the other side + // and thus preventing progress. Or at the very least, thread priorities + // should be considered + + for (int i = 0; i < m_ThreadCount; ++i) + { + m_ThreadPool.emplace_back([this, ThreadNumber = i + 1] { + SetCurrentThreadName(fmt::format("asio_thr_{}", ThreadNumber)); + + try + { + m_IoService.run(); + } + catch (std::exception& e) + { + ZEN_ERROR("exception caught in asio event loop: {}", e.what()); + } + }); + } + + ZEN_INFO("asio http transport started (port {})", m_Acceptor->GetAcceptPort()); +} + +void +AsioTransportPlugin::Shutdown() +{ + m_Acceptor->RequestStop(); + m_IoService.stop(); + + for (auto& Thread : m_ThreadPool) + { + Thread.join(); + } +} + +bool +AsioTransportPlugin::IsAvailable() +{ + return true; +} + +TransportPlugin* +CreateAsioTransportPlugin(uint16_t BasePort, unsigned int ThreadCount) +{ + return new AsioTransportPlugin(BasePort, ThreadCount); +} + +} // namespace zen + +#endif diff --git a/src/zenhttp/transports/asiotransport.h b/src/zenhttp/transports/asiotransport.h new file mode 100644 index 000000000..b10174b85 --- /dev/null +++ b/src/zenhttp/transports/asiotransport.h @@ -0,0 +1,15 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenhttp/httpplugin.h> + +#if ZEN_WITH_PLUGINS + +namespace zen { + +TransportPlugin* CreateAsioTransportPlugin(uint16_t BasePort, unsigned int ThreadCount); + +} // namespace zen + +#endif diff --git a/src/zenhttp/transports/dlltransport.cpp b/src/zenhttp/transports/dlltransport.cpp new file mode 100644 index 000000000..04fb6caaa --- /dev/null +++ b/src/zenhttp/transports/dlltransport.cpp @@ -0,0 +1,250 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "dlltransport.h" + +#include <zencore/except.h> +#include <zencore/logging.h> +#include <zencore/scopeguard.h> + +#include <exception> +#include <thread> +#include <vector> + +#if ZEN_WITH_PLUGINS + +namespace zen { + +struct DllTransportConnection : public TransportConnection +{ +public: + DllTransportConnection(); + ~DllTransportConnection(); + + void Initialize(TransportServerConnection& ServerConnection); + void HandleConnection(); + + // TransportConnection + + virtual int64_t WriteBytes(const void* Buffer, size_t DataSize) override; + virtual void Shutdown(bool Receive, bool Transmit) override; + virtual void CloseConnection() override; + +private: + Ref<TransportServerConnection> m_ConnectionHandler; + bool m_IsTerminated = false; +}; + +DllTransportConnection::DllTransportConnection() +{ +} + +DllTransportConnection::~DllTransportConnection() +{ +} + +void +DllTransportConnection::Initialize(TransportServerConnection& ServerConnection) +{ + m_ConnectionHandler = &ServerConnection; // TODO: this is awkward +} + +void +DllTransportConnection::HandleConnection() +{ +} + +void +DllTransportConnection::CloseConnection() +{ + if (m_IsTerminated) + { + return; + } + + m_IsTerminated = true; +} + +int64_t +DllTransportConnection::WriteBytes(const void* Buffer, size_t DataSize) +{ + ZEN_UNUSED(Buffer, DataSize); + return DataSize; +} + +void +DllTransportConnection::Shutdown(bool Receive, bool Transmit) +{ + ZEN_UNUSED(Receive, Transmit); +} + +////////////////////////////////////////////////////////////////////////// + +struct LoadedDll +{ + std::string Name; + std::filesystem::path LoadedFromPath; + Ref<TransportPlugin> Plugin; +}; + +class DllTransportPluginImpl +{ +public: + DllTransportPluginImpl(uint16_t BasePort, unsigned int ThreadCount); + ~DllTransportPluginImpl(); + + uint16_t Start(TransportServer* ServerInterface); + void Stop(); + bool IsAvailable(); + void LoadDll(std::string_view Name); + +private: + TransportServer* m_ServerInterface = nullptr; + RwLock m_Lock; + std::vector<LoadedDll> m_Transports; + uint16_t m_BasePort = 0; + int m_ThreadCount = 0; +}; + +DllTransportPluginImpl::DllTransportPluginImpl(uint16_t BasePort, unsigned int ThreadCount) +: m_BasePort(BasePort) +, m_ThreadCount(ThreadCount != 0 ? ThreadCount : Max(std::thread::hardware_concurrency(), 8u)) +{ +} + +DllTransportPluginImpl::~DllTransportPluginImpl() +{ +} + +uint16_t +DllTransportPluginImpl::Start(TransportServer* ServerIface) +{ + m_ServerInterface = ServerIface; + + RwLock::ExclusiveLockScope _(m_Lock); + + for (LoadedDll& Transport : m_Transports) + { + try + { + Transport.Plugin->Initialize(ServerIface); + } + catch (const std::exception&) + { + // TODO: report + } + } + + return m_BasePort; +} + +void +DllTransportPluginImpl::Stop() +{ + RwLock::ExclusiveLockScope _(m_Lock); + + for (LoadedDll& Transport : m_Transports) + { + try + { + Transport.Plugin->Shutdown(); + } + catch (const std::exception&) + { + // TODO: report + } + } +} + +bool +DllTransportPluginImpl::IsAvailable() +{ + return true; +} + +void +DllTransportPluginImpl::LoadDll(std::string_view Name) +{ + ExtendableStringBuilder<128> DllPath; + DllPath << Name << ".dll"; + HMODULE DllHandle = LoadLibraryA(DllPath.c_str()); + + if (!DllHandle) + { + std::error_code Ec = MakeErrorCodeFromLastError(); + + throw std::system_error(Ec, fmt::format("failed to load transport DLL from '{}'", DllPath)); + } + + TransportPlugin* CreateTransportPlugin(); + + PfnCreateTransportPlugin CreatePlugin = (PfnCreateTransportPlugin)GetProcAddress(DllHandle, "CreateTransportPlugin"); + + if (!CreatePlugin) + { + std::error_code Ec = MakeErrorCodeFromLastError(); + + FreeLibrary(DllHandle); + + throw std::system_error(Ec, fmt::format("API mismatch detected in transport DLL loaded from '{}'", DllPath)); + } + + LoadedDll NewDll; + + NewDll.Name = Name; + NewDll.LoadedFromPath = DllPath.c_str(); + NewDll.Plugin = CreatePlugin(); + + m_Transports.emplace_back(std::move(NewDll)); +} + +////////////////////////////////////////////////////////////////////////// + +DllTransportPlugin::DllTransportPlugin(uint16_t BasePort, unsigned int ThreadCount) +: m_Impl(std::make_unique<DllTransportPluginImpl>(BasePort, ThreadCount)) +{ +} + +DllTransportPlugin::~DllTransportPlugin() +{ + m_Impl->Stop(); +} + +uint32_t +DllTransportPlugin::AddRef() const +{ + return RefCounted::AddRef(); +} + +uint32_t +DllTransportPlugin::Release() const +{ + return RefCounted::Release(); +} + +void +DllTransportPlugin::Initialize(TransportServer* ServerInterface) +{ + m_Impl->Start(ServerInterface); +} + +void +DllTransportPlugin::Shutdown() +{ + m_Impl->Stop(); +} + +bool +DllTransportPlugin::IsAvailable() +{ + return m_Impl->IsAvailable(); +} + +void +DllTransportPlugin::LoadDll(std::string_view Name) +{ + return m_Impl->LoadDll(Name); +} + +} // namespace zen + +#endif diff --git a/src/zenhttp/transports/dlltransport.h b/src/zenhttp/transports/dlltransport.h new file mode 100644 index 000000000..2dccdd0f9 --- /dev/null +++ b/src/zenhttp/transports/dlltransport.h @@ -0,0 +1,37 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenhttp/httpplugin.h> + +#if ZEN_WITH_PLUGINS + +namespace zen { + +class DllTransportPluginImpl; + +/** Transport plugin which supports dynamic loading of external transport + * provider modules + */ +class DllTransportPlugin : public TransportPlugin, RefCounted +{ +public: + DllTransportPlugin(uint16_t BasePort, unsigned int ThreadCount); + ~DllTransportPlugin(); + + virtual uint32_t AddRef() const override; + virtual uint32_t Release() const override; + + virtual void Initialize(TransportServer* ServerInterface) override; + virtual void Shutdown() override; + virtual bool IsAvailable() override; + + void LoadDll(std::string_view Name); + +private: + std::unique_ptr<DllTransportPluginImpl> m_Impl; +}; + +} // namespace zen + +#endif diff --git a/src/zenhttp/transports/winsocktransport.cpp b/src/zenhttp/transports/winsocktransport.cpp new file mode 100644 index 000000000..ad3302550 --- /dev/null +++ b/src/zenhttp/transports/winsocktransport.cpp @@ -0,0 +1,367 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "winsocktransport.h" + +#if ZEN_WITH_PLUGINS + +# include <zencore/logging.h> +# include <zencore/scopeguard.h> +# include <zencore/workthreadpool.h> + +# if ZEN_PLATFORM_WINDOWS +# include <zencore/windows.h> +ZEN_THIRD_PARTY_INCLUDES_START +# include <WinSock2.h> +# include <WS2tcpip.h> +ZEN_THIRD_PARTY_INCLUDES_END +# endif + +# include <thread> + +namespace zen { + +class SocketTransportPluginImpl; + +class SocketTransportPlugin : public TransportPlugin, RefCounted +{ +public: + SocketTransportPlugin(uint16_t BasePort, unsigned int ThreadCount); + ~SocketTransportPlugin(); + + virtual uint32_t AddRef() const override; + virtual uint32_t Release() const override; + virtual void Initialize(TransportServer* ServerInterface) override; + virtual void Shutdown() override; + virtual bool IsAvailable() override; + +private: + bool m_IsOk = true; + uint16_t m_BasePort = 0; + int m_ThreadCount = 0; + SocketTransportPluginImpl* m_Impl; +}; + +struct SocketTransportConnection : public TransportConnection +{ +public: + SocketTransportConnection(); + ~SocketTransportConnection(); + + void Initialize(TransportServerConnection* ServerConnection, SOCKET ClientSocket); + void HandleConnection(); + + // TransportConnection + + virtual int64_t WriteBytes(const void* Buffer, size_t DataSize) override; + virtual void Shutdown(bool Receive, bool Transmit) override; + virtual void CloseConnection() override; + +private: + Ref<TransportServerConnection> m_ConnectionHandler; + SOCKET m_ClientSocket{}; + bool m_IsTerminated = false; +}; + +SocketTransportConnection::SocketTransportConnection() +{ +} + +SocketTransportConnection::~SocketTransportConnection() +{ +} + +void +SocketTransportConnection::Initialize(TransportServerConnection* ServerConnection, SOCKET ClientSocket) +{ + ZEN_ASSERT(!m_ConnectionHandler); + + m_ConnectionHandler = ServerConnection; + m_ClientSocket = ClientSocket; +} + +void +SocketTransportConnection::HandleConnection() +{ + ZEN_ASSERT(m_ConnectionHandler); + + const int InputBufferSize = 64 * 1024; + uint8_t* InputBuffer = new uint8_t[64 * 1024]; + auto _ = MakeGuard([&] { delete[] InputBuffer; }); + + do + { + const int RecvBytes = recv(m_ClientSocket, (char*)InputBuffer, InputBufferSize, /* flags */ 0); + + if (RecvBytes == 0) + { + // Connection closed + return CloseConnection(); + } + else if (RecvBytes < 0) + { + // Error + return CloseConnection(); + } + + m_ConnectionHandler->OnBytesRead(InputBuffer, RecvBytes); + } while (m_ClientSocket); +} + +void +SocketTransportConnection::CloseConnection() +{ + if (m_IsTerminated) + { + return; + } + + ZEN_ASSERT(m_ClientSocket); + m_IsTerminated = true; + + shutdown(m_ClientSocket, SD_BOTH); // We won't be sending or receiving any more data + + closesocket(m_ClientSocket); + m_ClientSocket = 0; +} + +int64_t +SocketTransportConnection::WriteBytes(const void* Buffer, size_t DataSize) +{ + const uint8_t* BufferCursor = reinterpret_cast<const uint8_t*>(Buffer); + int64_t TotalBytesSent = 0; + + while (DataSize) + { + const int MaxBlockSize = 128 * 1024; + const int SendBlockSize = (DataSize > MaxBlockSize) ? MaxBlockSize : (int)DataSize; + const int SentBytes = send(m_ClientSocket, (const char*)BufferCursor, SendBlockSize, /* flags */ 0); + + if (SentBytes < 0) + { + // Error + return SentBytes; + } + + BufferCursor += SentBytes; + DataSize -= SentBytes; + TotalBytesSent += SentBytes; + } + + return TotalBytesSent; +} + +void +SocketTransportConnection::Shutdown(bool Receive, bool Transmit) +{ + if (Receive) + { + if (Transmit) + { + shutdown(m_ClientSocket, SD_BOTH); + } + else + { + shutdown(m_ClientSocket, SD_RECEIVE); + } + } + else if (Transmit) + { + shutdown(m_ClientSocket, SD_SEND); + } +} + +////////////////////////////////////////////////////////////////////////// + +class SocketTransportPluginImpl +{ +public: + SocketTransportPluginImpl(uint16_t BasePort, unsigned int ThreadCount); + ~SocketTransportPluginImpl(); + + uint16_t Start(uint16_t Port, TransportServer* ServerInterface); + void Stop(); + +private: + TransportServer* m_ServerInterface = nullptr; + uint16_t m_BasePort = 0; + int m_ThreadCount = 0; + bool m_IsOk = true; + + SOCKET m_ListenSocket{}; + std::thread m_AcceptThread; + std::atomic_flag m_KeepRunning; + std::unique_ptr<WorkerThreadPool> m_WorkerThreadpool; +}; + +SocketTransportPluginImpl::SocketTransportPluginImpl(uint16_t BasePort, unsigned int ThreadCount) +: m_BasePort(BasePort) +, m_ThreadCount(ThreadCount) +{ +# if ZEN_PLATFORM_WINDOWS + WSADATA wsaData; + if (int Result = WSAStartup(0x202, &wsaData); Result != 0) + { + m_IsOk = false; + WSACleanup(); + } +# endif + + m_WorkerThreadpool = std::make_unique<WorkerThreadPool>(m_ThreadCount, "http_conn"); +} + +SocketTransportPluginImpl::~SocketTransportPluginImpl() +{ + Stop(); + +# if ZEN_PLATFORM_WINDOWS + if (m_IsOk) + { + WSACleanup(); + } +# endif +} + +uint16_t +SocketTransportPluginImpl::Start(uint16_t Port, TransportServer* ServerInterface) +{ + m_ServerInterface = ServerInterface; + m_ListenSocket = socket(AF_INET6, SOCK_STREAM, 0); + + if (m_ListenSocket == SOCKET_ERROR || m_ListenSocket == INVALID_SOCKET) + { + ZEN_ERROR("socket creation failed in HTTP plugin server init: {}", WSAGetLastError()); + + return 0; + } + + sockaddr_in6 Server{}; + Server.sin6_family = AF_INET6; + Server.sin6_port = htons(Port); + Server.sin6_addr = in6addr_any; + + if (int Result = bind(m_ListenSocket, (sockaddr*)&Server, sizeof(Server)); Result == SOCKET_ERROR) + { + ZEN_ERROR("bind call failed in HTTP plugin server init: {}", WSAGetLastError()); + + return 0; + } + + if (int Result = listen(m_ListenSocket, AF_INET6); Result == SOCKET_ERROR) + { + ZEN_ERROR("listen call failed in HTTP plugin server init: {}", WSAGetLastError()); + + return 0; + } + + m_KeepRunning.test_and_set(); + + m_AcceptThread = std::thread([&] { + SetCurrentThreadName("http_plugin_acceptor"); + + ZEN_INFO("HTTP plugin server waiting for connections"); + + do + { + if (SOCKET ClientSocket = accept(m_ListenSocket, NULL, NULL); ClientSocket != SOCKET_ERROR) + { + int Flag = 1; + setsockopt(ClientSocket, IPPROTO_TCP, TCP_NODELAY, (char*)&Flag, sizeof(Flag)); + + // Handle new connection + SocketTransportConnection* Connection = new SocketTransportConnection(); + TransportServerConnection* ConnectionInterface{m_ServerInterface->CreateConnectionHandler(Connection)}; + Connection->Initialize(ConnectionInterface, ClientSocket); + + m_WorkerThreadpool->ScheduleWork([Connection] { + try + { + Connection->HandleConnection(); + } + catch (std::exception& Ex) + { + ZEN_WARN("exception caught in connection loop: {}", Ex.what()); + } + + delete Connection; + }); + } + else + { + } + } while (!IsApplicationExitRequested() && m_KeepRunning.test()); + + ZEN_INFO("HTTP plugin server accept thread exit"); + }); + + return Port; +} + +void +SocketTransportPluginImpl::Stop() +{ + // TODO: all pending/ongoing work should be drained here as well + + m_KeepRunning.clear(); + + closesocket(m_ListenSocket); + m_ListenSocket = 0; + + if (m_AcceptThread.joinable()) + { + m_AcceptThread.join(); + } +} + +////////////////////////////////////////////////////////////////////////// + +SocketTransportPlugin::SocketTransportPlugin(uint16_t BasePort, unsigned int ThreadCount) +: m_BasePort(BasePort) +, m_ThreadCount(ThreadCount != 0 ? ThreadCount : Max(std::thread::hardware_concurrency(), 8u)) +, m_Impl(new SocketTransportPluginImpl(BasePort, m_ThreadCount)) +{ +} + +SocketTransportPlugin::~SocketTransportPlugin() +{ + delete m_Impl; +} + +uint32_t +SocketTransportPlugin::AddRef() const +{ + return RefCounted::AddRef(); +} + +uint32_t +SocketTransportPlugin::Release() const +{ + return RefCounted::Release(); +} + +void +SocketTransportPlugin::Initialize(TransportServer* ServerInterface) +{ + m_Impl->Start(m_BasePort, ServerInterface); +} + +void +SocketTransportPlugin::Shutdown() +{ + m_Impl->Stop(); +} + +bool +SocketTransportPlugin::IsAvailable() +{ + return true; +} + +TransportPlugin* +CreateSocketTransportPlugin(uint16_t BasePort, unsigned int ThreadCount) +{ + return new SocketTransportPlugin(BasePort, ThreadCount); +} + +} // namespace zen + +#endif diff --git a/src/zenhttp/transports/winsocktransport.h b/src/zenhttp/transports/winsocktransport.h new file mode 100644 index 000000000..2b2a55aef --- /dev/null +++ b/src/zenhttp/transports/winsocktransport.h @@ -0,0 +1,15 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#pragma once + +#include <zenhttp/httpplugin.h> + +#if ZEN_WITH_PLUGINS + +namespace zen { + +TransportPlugin* CreateSocketTransportPlugin(uint16_t BasePort, unsigned int ThreadCount); + +} // namespace zen + +#endif |