aboutsummaryrefslogtreecommitdiff
path: root/src/zenhttp/winsocktransport.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/zenhttp/winsocktransport.cpp')
-rw-r--r--src/zenhttp/winsocktransport.cpp367
1 files changed, 367 insertions, 0 deletions
diff --git a/src/zenhttp/winsocktransport.cpp b/src/zenhttp/winsocktransport.cpp
new file mode 100644
index 000000000..e86e4822e
--- /dev/null
+++ b/src/zenhttp/winsocktransport.cpp
@@ -0,0 +1,367 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "winsocktransport.h"
+
+#if ZEN_WITH_PLUGINS
+
+# include <zencore/logging.h>
+# include <zencore/scopeguard.h>
+# include <zencore/workthreadpool.h>
+
+# if ZEN_PLATFORM_WINDOWS
+# include <zencore/windows.h>
+ZEN_THIRD_PARTY_INCLUDES_START
+# include <WinSock2.h>
+# include <WS2tcpip.h>
+ZEN_THIRD_PARTY_INCLUDES_END
+# endif
+
+# include <thread>
+
+namespace zen {
+
+class SocketTransportPluginImpl;
+
+class SocketTransportPlugin : public TransportPluginInterface, RefCounted
+{
+public:
+ SocketTransportPlugin(uint16_t BasePort, unsigned int ThreadCount);
+ ~SocketTransportPlugin();
+
+ virtual uint32_t AddRef() const override;
+ virtual uint32_t Release() const override;
+ virtual void Initialize(TransportServerInterface* ServerInterface) override;
+ virtual void Shutdown() override;
+ virtual bool IsAvailable() override;
+
+private:
+ bool m_IsOk = true;
+ uint16_t m_BasePort = 0;
+ int m_ThreadCount = 0;
+ SocketTransportPluginImpl* m_Impl;
+};
+
+struct SocketTransportConnection : public TransportConnectionInterface
+{
+public:
+ SocketTransportConnection();
+ ~SocketTransportConnection();
+
+ void Initialize(TransportServerConnectionHandler* ServerConnection, SOCKET ClientSocket);
+ void HandleConnection();
+
+ // TransportConnectionInterface
+
+ virtual int64_t WriteBytes(const void* Buffer, size_t DataSize) override;
+ virtual void Shutdown(bool Receive, bool Transmit) override;
+ virtual void CloseConnection() override;
+
+private:
+ Ref<TransportServerConnectionHandler> m_ConnectionHandler;
+ SOCKET m_ClientSocket{};
+ bool m_IsTerminated = false;
+};
+
+SocketTransportConnection::SocketTransportConnection()
+{
+}
+
+SocketTransportConnection::~SocketTransportConnection()
+{
+}
+
+void
+SocketTransportConnection::Initialize(TransportServerConnectionHandler* ServerConnection, SOCKET ClientSocket)
+{
+ ZEN_ASSERT(!m_ConnectionHandler);
+
+ m_ConnectionHandler = ServerConnection;
+ m_ClientSocket = ClientSocket;
+}
+
+void
+SocketTransportConnection::HandleConnection()
+{
+ ZEN_ASSERT(m_ConnectionHandler);
+
+ const int InputBufferSize = 64 * 1024;
+ uint8_t* InputBuffer = new uint8_t[64 * 1024];
+ auto _ = MakeGuard([&] { delete[] InputBuffer; });
+
+ do
+ {
+ const int RecvBytes = recv(m_ClientSocket, (char*)InputBuffer, InputBufferSize, /* flags */ 0);
+
+ if (RecvBytes == 0)
+ {
+ // Connection closed
+ return CloseConnection();
+ }
+ else if (RecvBytes < 0)
+ {
+ // Error
+ return CloseConnection();
+ }
+
+ m_ConnectionHandler->OnBytesRead(InputBuffer, RecvBytes);
+ } while (m_ClientSocket);
+}
+
+void
+SocketTransportConnection::CloseConnection()
+{
+ if (m_IsTerminated)
+ {
+ return;
+ }
+
+ ZEN_ASSERT(m_ClientSocket);
+ m_IsTerminated = true;
+
+ shutdown(m_ClientSocket, SD_BOTH); // We won't be sending or receiving any more data
+
+ closesocket(m_ClientSocket);
+ m_ClientSocket = 0;
+}
+
+int64_t
+SocketTransportConnection::WriteBytes(const void* Buffer, size_t DataSize)
+{
+ const uint8_t* BufferCursor = reinterpret_cast<const uint8_t*>(Buffer);
+ int64_t TotalBytesSent = 0;
+
+ while (DataSize)
+ {
+ const int MaxBlockSize = 128 * 1024;
+ const int SendBlockSize = (DataSize > MaxBlockSize) ? MaxBlockSize : (int)DataSize;
+ const int SentBytes = send(m_ClientSocket, (const char*)BufferCursor, SendBlockSize, /* flags */ 0);
+
+ if (SentBytes < 0)
+ {
+ // Error
+ return SentBytes;
+ }
+
+ BufferCursor += SentBytes;
+ DataSize -= SentBytes;
+ TotalBytesSent += SentBytes;
+ }
+
+ return TotalBytesSent;
+}
+
+void
+SocketTransportConnection::Shutdown(bool Receive, bool Transmit)
+{
+ if (Receive)
+ {
+ if (Transmit)
+ {
+ shutdown(m_ClientSocket, SD_BOTH);
+ }
+ else
+ {
+ shutdown(m_ClientSocket, SD_RECEIVE);
+ }
+ }
+ else if (Transmit)
+ {
+ shutdown(m_ClientSocket, SD_SEND);
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+class SocketTransportPluginImpl
+{
+public:
+ SocketTransportPluginImpl(uint16_t BasePort, unsigned int ThreadCount);
+ ~SocketTransportPluginImpl();
+
+ uint16_t Start(uint16_t Port, TransportServerInterface* ServerInterface);
+ void Stop();
+
+private:
+ TransportServerInterface* m_ServerInterface = nullptr;
+ uint16_t m_BasePort = 0;
+ int m_ThreadCount = 0;
+ bool m_IsOk = true;
+
+ SOCKET m_ListenSocket{};
+ std::thread m_AcceptThread;
+ std::atomic_flag m_KeepRunning;
+ std::unique_ptr<WorkerThreadPool> m_WorkerThreadpool;
+};
+
+SocketTransportPluginImpl::SocketTransportPluginImpl(uint16_t BasePort, unsigned int ThreadCount)
+: m_BasePort(BasePort)
+, m_ThreadCount(ThreadCount)
+{
+# if ZEN_PLATFORM_WINDOWS
+ WSADATA wsaData;
+ if (int Result = WSAStartup(0x202, &wsaData); Result != 0)
+ {
+ m_IsOk = false;
+ WSACleanup();
+ }
+# endif
+
+ m_WorkerThreadpool = std::make_unique<WorkerThreadPool>(m_ThreadCount, "http_conn");
+}
+
+SocketTransportPluginImpl::~SocketTransportPluginImpl()
+{
+ Stop();
+
+# if ZEN_PLATFORM_WINDOWS
+ if (m_IsOk)
+ {
+ WSACleanup();
+ }
+# endif
+}
+
+uint16_t
+SocketTransportPluginImpl::Start(uint16_t Port, TransportServerInterface* ServerInterface)
+{
+ m_ServerInterface = ServerInterface;
+ m_ListenSocket = socket(AF_INET6, SOCK_STREAM, 0);
+
+ if (m_ListenSocket == SOCKET_ERROR || m_ListenSocket == INVALID_SOCKET)
+ {
+ ZEN_ERROR("socket creation failed in HTTP plugin server init: {}", WSAGetLastError());
+
+ return 0;
+ }
+
+ sockaddr_in6 Server{};
+ Server.sin6_family = AF_INET6;
+ Server.sin6_port = htons(Port);
+ Server.sin6_addr = in6addr_any;
+
+ if (int Result = bind(m_ListenSocket, (sockaddr*)&Server, sizeof(Server)); Result == SOCKET_ERROR)
+ {
+ ZEN_ERROR("bind call failed in HTTP plugin server init: {}", WSAGetLastError());
+
+ return 0;
+ }
+
+ if (int Result = listen(m_ListenSocket, AF_INET6); Result == SOCKET_ERROR)
+ {
+ ZEN_ERROR("listen call failed in HTTP plugin server init: {}", WSAGetLastError());
+
+ return 0;
+ }
+
+ m_KeepRunning.test_and_set();
+
+ m_AcceptThread = std::thread([&] {
+ SetCurrentThreadName("http_plugin_acceptor");
+
+ ZEN_INFO("HTTP plugin server waiting for connections");
+
+ do
+ {
+ if (SOCKET ClientSocket = accept(m_ListenSocket, NULL, NULL); ClientSocket != SOCKET_ERROR)
+ {
+ int Flag = 1;
+ setsockopt(ClientSocket, IPPROTO_TCP, TCP_NODELAY, (char*)&Flag, sizeof(Flag));
+
+ // Handle new connection
+ SocketTransportConnection* Connection = new SocketTransportConnection();
+ TransportServerConnectionHandler* ConnectionInterface{m_ServerInterface->CreateConnectionHandler(Connection)};
+ Connection->Initialize(ConnectionInterface, ClientSocket);
+
+ m_WorkerThreadpool->ScheduleWork([Connection] {
+ try
+ {
+ Connection->HandleConnection();
+ }
+ catch (std::exception& Ex)
+ {
+ ZEN_WARN("exception caught in connection loop: {}", Ex.what());
+ }
+
+ delete Connection;
+ });
+ }
+ else
+ {
+ }
+ } while (!IsApplicationExitRequested() && m_KeepRunning.test());
+
+ ZEN_INFO("HTTP plugin server accept thread exit");
+ });
+
+ return Port;
+}
+
+void
+SocketTransportPluginImpl::Stop()
+{
+ // TODO: all pending/ongoing work should be drained here as well
+
+ m_KeepRunning.clear();
+
+ closesocket(m_ListenSocket);
+ m_ListenSocket = 0;
+
+ if (m_AcceptThread.joinable())
+ {
+ m_AcceptThread.join();
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+
+SocketTransportPlugin::SocketTransportPlugin(uint16_t BasePort, unsigned int ThreadCount)
+: m_BasePort(BasePort)
+, m_ThreadCount(ThreadCount != 0 ? ThreadCount : Max(std::thread::hardware_concurrency(), 8u))
+, m_Impl(new SocketTransportPluginImpl(BasePort, m_ThreadCount))
+{
+}
+
+SocketTransportPlugin::~SocketTransportPlugin()
+{
+ delete m_Impl;
+}
+
+uint32_t
+SocketTransportPlugin::AddRef() const
+{
+ return RefCounted::AddRef();
+}
+
+uint32_t
+SocketTransportPlugin::Release() const
+{
+ return RefCounted::Release();
+}
+
+void
+SocketTransportPlugin::Initialize(TransportServerInterface* ServerInterface)
+{
+ m_Impl->Start(m_BasePort, ServerInterface);
+}
+
+void
+SocketTransportPlugin::Shutdown()
+{
+ m_Impl->Stop();
+}
+
+bool
+SocketTransportPlugin::IsAvailable()
+{
+ return true;
+}
+
+TransportPluginInterface*
+CreateSocketTransportPlugin(uint16_t BasePort, unsigned int ThreadCount)
+{
+ return new SocketTransportPlugin(BasePort, ThreadCount);
+}
+
+} // namespace zen
+
+#endif