diff options
Diffstat (limited to 'src/zenhttp/winsocktransport.cpp')
| -rw-r--r-- | src/zenhttp/winsocktransport.cpp | 367 |
1 files changed, 367 insertions, 0 deletions
diff --git a/src/zenhttp/winsocktransport.cpp b/src/zenhttp/winsocktransport.cpp new file mode 100644 index 000000000..e86e4822e --- /dev/null +++ b/src/zenhttp/winsocktransport.cpp @@ -0,0 +1,367 @@ +// Copyright Epic Games, Inc. All Rights Reserved. + +#include "winsocktransport.h" + +#if ZEN_WITH_PLUGINS + +# include <zencore/logging.h> +# include <zencore/scopeguard.h> +# include <zencore/workthreadpool.h> + +# if ZEN_PLATFORM_WINDOWS +# include <zencore/windows.h> +ZEN_THIRD_PARTY_INCLUDES_START +# include <WinSock2.h> +# include <WS2tcpip.h> +ZEN_THIRD_PARTY_INCLUDES_END +# endif + +# include <thread> + +namespace zen { + +class SocketTransportPluginImpl; + +class SocketTransportPlugin : public TransportPluginInterface, RefCounted +{ +public: + SocketTransportPlugin(uint16_t BasePort, unsigned int ThreadCount); + ~SocketTransportPlugin(); + + virtual uint32_t AddRef() const override; + virtual uint32_t Release() const override; + virtual void Initialize(TransportServerInterface* ServerInterface) override; + virtual void Shutdown() override; + virtual bool IsAvailable() override; + +private: + bool m_IsOk = true; + uint16_t m_BasePort = 0; + int m_ThreadCount = 0; + SocketTransportPluginImpl* m_Impl; +}; + +struct SocketTransportConnection : public TransportConnectionInterface +{ +public: + SocketTransportConnection(); + ~SocketTransportConnection(); + + void Initialize(TransportServerConnectionHandler* ServerConnection, SOCKET ClientSocket); + void HandleConnection(); + + // TransportConnectionInterface + + virtual int64_t WriteBytes(const void* Buffer, size_t DataSize) override; + virtual void Shutdown(bool Receive, bool Transmit) override; + virtual void CloseConnection() override; + +private: + Ref<TransportServerConnectionHandler> m_ConnectionHandler; + SOCKET m_ClientSocket{}; + bool m_IsTerminated = false; +}; + +SocketTransportConnection::SocketTransportConnection() +{ +} + +SocketTransportConnection::~SocketTransportConnection() +{ +} + +void +SocketTransportConnection::Initialize(TransportServerConnectionHandler* ServerConnection, SOCKET ClientSocket) +{ + ZEN_ASSERT(!m_ConnectionHandler); + + m_ConnectionHandler = ServerConnection; + m_ClientSocket = ClientSocket; +} + +void +SocketTransportConnection::HandleConnection() +{ + ZEN_ASSERT(m_ConnectionHandler); + + const int InputBufferSize = 64 * 1024; + uint8_t* InputBuffer = new uint8_t[64 * 1024]; + auto _ = MakeGuard([&] { delete[] InputBuffer; }); + + do + { + const int RecvBytes = recv(m_ClientSocket, (char*)InputBuffer, InputBufferSize, /* flags */ 0); + + if (RecvBytes == 0) + { + // Connection closed + return CloseConnection(); + } + else if (RecvBytes < 0) + { + // Error + return CloseConnection(); + } + + m_ConnectionHandler->OnBytesRead(InputBuffer, RecvBytes); + } while (m_ClientSocket); +} + +void +SocketTransportConnection::CloseConnection() +{ + if (m_IsTerminated) + { + return; + } + + ZEN_ASSERT(m_ClientSocket); + m_IsTerminated = true; + + shutdown(m_ClientSocket, SD_BOTH); // We won't be sending or receiving any more data + + closesocket(m_ClientSocket); + m_ClientSocket = 0; +} + +int64_t +SocketTransportConnection::WriteBytes(const void* Buffer, size_t DataSize) +{ + const uint8_t* BufferCursor = reinterpret_cast<const uint8_t*>(Buffer); + int64_t TotalBytesSent = 0; + + while (DataSize) + { + const int MaxBlockSize = 128 * 1024; + const int SendBlockSize = (DataSize > MaxBlockSize) ? MaxBlockSize : (int)DataSize; + const int SentBytes = send(m_ClientSocket, (const char*)BufferCursor, SendBlockSize, /* flags */ 0); + + if (SentBytes < 0) + { + // Error + return SentBytes; + } + + BufferCursor += SentBytes; + DataSize -= SentBytes; + TotalBytesSent += SentBytes; + } + + return TotalBytesSent; +} + +void +SocketTransportConnection::Shutdown(bool Receive, bool Transmit) +{ + if (Receive) + { + if (Transmit) + { + shutdown(m_ClientSocket, SD_BOTH); + } + else + { + shutdown(m_ClientSocket, SD_RECEIVE); + } + } + else if (Transmit) + { + shutdown(m_ClientSocket, SD_SEND); + } +} + +////////////////////////////////////////////////////////////////////////// + +class SocketTransportPluginImpl +{ +public: + SocketTransportPluginImpl(uint16_t BasePort, unsigned int ThreadCount); + ~SocketTransportPluginImpl(); + + uint16_t Start(uint16_t Port, TransportServerInterface* ServerInterface); + void Stop(); + +private: + TransportServerInterface* m_ServerInterface = nullptr; + uint16_t m_BasePort = 0; + int m_ThreadCount = 0; + bool m_IsOk = true; + + SOCKET m_ListenSocket{}; + std::thread m_AcceptThread; + std::atomic_flag m_KeepRunning; + std::unique_ptr<WorkerThreadPool> m_WorkerThreadpool; +}; + +SocketTransportPluginImpl::SocketTransportPluginImpl(uint16_t BasePort, unsigned int ThreadCount) +: m_BasePort(BasePort) +, m_ThreadCount(ThreadCount) +{ +# if ZEN_PLATFORM_WINDOWS + WSADATA wsaData; + if (int Result = WSAStartup(0x202, &wsaData); Result != 0) + { + m_IsOk = false; + WSACleanup(); + } +# endif + + m_WorkerThreadpool = std::make_unique<WorkerThreadPool>(m_ThreadCount, "http_conn"); +} + +SocketTransportPluginImpl::~SocketTransportPluginImpl() +{ + Stop(); + +# if ZEN_PLATFORM_WINDOWS + if (m_IsOk) + { + WSACleanup(); + } +# endif +} + +uint16_t +SocketTransportPluginImpl::Start(uint16_t Port, TransportServerInterface* ServerInterface) +{ + m_ServerInterface = ServerInterface; + m_ListenSocket = socket(AF_INET6, SOCK_STREAM, 0); + + if (m_ListenSocket == SOCKET_ERROR || m_ListenSocket == INVALID_SOCKET) + { + ZEN_ERROR("socket creation failed in HTTP plugin server init: {}", WSAGetLastError()); + + return 0; + } + + sockaddr_in6 Server{}; + Server.sin6_family = AF_INET6; + Server.sin6_port = htons(Port); + Server.sin6_addr = in6addr_any; + + if (int Result = bind(m_ListenSocket, (sockaddr*)&Server, sizeof(Server)); Result == SOCKET_ERROR) + { + ZEN_ERROR("bind call failed in HTTP plugin server init: {}", WSAGetLastError()); + + return 0; + } + + if (int Result = listen(m_ListenSocket, AF_INET6); Result == SOCKET_ERROR) + { + ZEN_ERROR("listen call failed in HTTP plugin server init: {}", WSAGetLastError()); + + return 0; + } + + m_KeepRunning.test_and_set(); + + m_AcceptThread = std::thread([&] { + SetCurrentThreadName("http_plugin_acceptor"); + + ZEN_INFO("HTTP plugin server waiting for connections"); + + do + { + if (SOCKET ClientSocket = accept(m_ListenSocket, NULL, NULL); ClientSocket != SOCKET_ERROR) + { + int Flag = 1; + setsockopt(ClientSocket, IPPROTO_TCP, TCP_NODELAY, (char*)&Flag, sizeof(Flag)); + + // Handle new connection + SocketTransportConnection* Connection = new SocketTransportConnection(); + TransportServerConnectionHandler* ConnectionInterface{m_ServerInterface->CreateConnectionHandler(Connection)}; + Connection->Initialize(ConnectionInterface, ClientSocket); + + m_WorkerThreadpool->ScheduleWork([Connection] { + try + { + Connection->HandleConnection(); + } + catch (std::exception& Ex) + { + ZEN_WARN("exception caught in connection loop: {}", Ex.what()); + } + + delete Connection; + }); + } + else + { + } + } while (!IsApplicationExitRequested() && m_KeepRunning.test()); + + ZEN_INFO("HTTP plugin server accept thread exit"); + }); + + return Port; +} + +void +SocketTransportPluginImpl::Stop() +{ + // TODO: all pending/ongoing work should be drained here as well + + m_KeepRunning.clear(); + + closesocket(m_ListenSocket); + m_ListenSocket = 0; + + if (m_AcceptThread.joinable()) + { + m_AcceptThread.join(); + } +} + +////////////////////////////////////////////////////////////////////////// + +SocketTransportPlugin::SocketTransportPlugin(uint16_t BasePort, unsigned int ThreadCount) +: m_BasePort(BasePort) +, m_ThreadCount(ThreadCount != 0 ? ThreadCount : Max(std::thread::hardware_concurrency(), 8u)) +, m_Impl(new SocketTransportPluginImpl(BasePort, m_ThreadCount)) +{ +} + +SocketTransportPlugin::~SocketTransportPlugin() +{ + delete m_Impl; +} + +uint32_t +SocketTransportPlugin::AddRef() const +{ + return RefCounted::AddRef(); +} + +uint32_t +SocketTransportPlugin::Release() const +{ + return RefCounted::Release(); +} + +void +SocketTransportPlugin::Initialize(TransportServerInterface* ServerInterface) +{ + m_Impl->Start(m_BasePort, ServerInterface); +} + +void +SocketTransportPlugin::Shutdown() +{ + m_Impl->Stop(); +} + +bool +SocketTransportPlugin::IsAvailable() +{ + return true; +} + +TransportPluginInterface* +CreateSocketTransportPlugin(uint16_t BasePort, unsigned int ThreadCount) +{ + return new SocketTransportPlugin(BasePort, ThreadCount); +} + +} // namespace zen + +#endif |