// Copyright Epic Games, Inc. All Rights Reserved. #include "winsocktransport.h" #if ZEN_WITH_PLUGINS # include # include # include # if ZEN_PLATFORM_WINDOWS # include ZEN_THIRD_PARTY_INCLUDES_START # include # include ZEN_THIRD_PARTY_INCLUDES_END # endif # include namespace zen { struct SocketTransportConnection : public TransportConnection { public: SocketTransportConnection(); ~SocketTransportConnection(); void Initialize(TransportServerConnection* ServerConnection, SOCKET ClientSocket); void HandleConnection(); // TransportConnection virtual int64_t WriteBytes(const void* Buffer, size_t DataSize) override; virtual void Shutdown(bool Receive, bool Transmit) override; virtual void CloseConnection() override; virtual const char* GetDebugName() override; private: Ref m_ConnectionHandler; SOCKET m_ClientSocket{}; bool m_IsTerminated = false; }; SocketTransportConnection::SocketTransportConnection() { } SocketTransportConnection::~SocketTransportConnection() { } void SocketTransportConnection::Initialize(TransportServerConnection* ServerConnection, SOCKET ClientSocket) { ZEN_ASSERT(!m_ConnectionHandler); m_ConnectionHandler = ServerConnection; m_ClientSocket = ClientSocket; } void SocketTransportConnection::HandleConnection() { ZEN_ASSERT(m_ConnectionHandler); const int InputBufferSize = 64 * 1024; uint8_t* InputBuffer = new uint8_t[64 * 1024]; auto _ = MakeGuard([&] { delete[] InputBuffer; }); do { const int RecvBytes = recv(m_ClientSocket, (char*)InputBuffer, InputBufferSize, /* flags */ 0); if (RecvBytes == 0) { // Connection closed return CloseConnection(); } else if (RecvBytes < 0) { // Error return CloseConnection(); } m_ConnectionHandler->OnBytesRead(InputBuffer, RecvBytes); } while (m_ClientSocket); } void SocketTransportConnection::CloseConnection() { if (m_IsTerminated) { return; } ZEN_ASSERT(m_ClientSocket); m_IsTerminated = true; shutdown(m_ClientSocket, SD_BOTH); // We won't be sending or receiving any more data closesocket(m_ClientSocket); m_ClientSocket = 0; } const char* SocketTransportConnection::GetDebugName() { return nullptr; } int64_t SocketTransportConnection::WriteBytes(const void* Buffer, size_t DataSize) { const uint8_t* BufferCursor = reinterpret_cast(Buffer); int64_t TotalBytesSent = 0; while (DataSize) { const int MaxBlockSize = 128 * 1024; const int SendBlockSize = (DataSize > MaxBlockSize) ? MaxBlockSize : (int)DataSize; const int SentBytes = send(m_ClientSocket, (const char*)BufferCursor, SendBlockSize, /* flags */ 0); if (SentBytes < 0) { // Error return SentBytes; } BufferCursor += SentBytes; DataSize -= SentBytes; TotalBytesSent += SentBytes; } return TotalBytesSent; } void SocketTransportConnection::Shutdown(bool Receive, bool Transmit) { if (Receive) { if (Transmit) { shutdown(m_ClientSocket, SD_BOTH); } else { shutdown(m_ClientSocket, SD_RECEIVE); } } else if (Transmit) { shutdown(m_ClientSocket, SD_SEND); } } ////////////////////////////////////////////////////////////////////////// class SocketTransportPluginImpl : public TransportPlugin, RefCounted { public: SocketTransportPluginImpl(); ~SocketTransportPluginImpl(); virtual uint32_t AddRef() const override; virtual uint32_t Release() const override; virtual void Configure(const char* OptionTag, const char* OptionValue) override; virtual void Initialize(TransportServer* ServerInterface) override; virtual void Shutdown() override; virtual const char* GetDebugName() override; virtual bool IsAvailable() override; private: TransportServer* m_ServerInterface = nullptr; uint16_t m_BasePort = 8558; int m_ThreadCount = 8; bool m_IsOk = true; SOCKET m_ListenSocket{}; std::thread m_AcceptThread; std::atomic_flag m_KeepRunning; std::unique_ptr m_WorkerThreadpool; }; SocketTransportPluginImpl::SocketTransportPluginImpl() { # if ZEN_PLATFORM_WINDOWS WSADATA wsaData; if (int Result = WSAStartup(0x202, &wsaData); Result != 0) { m_IsOk = false; WSACleanup(); } # endif } SocketTransportPluginImpl::~SocketTransportPluginImpl() { Shutdown(); # if ZEN_PLATFORM_WINDOWS if (m_IsOk) { WSACleanup(); } # endif } uint32_t SocketTransportPluginImpl::AddRef() const { return RefCounted::AddRef(); } uint32_t SocketTransportPluginImpl::Release() const { return RefCounted::Release(); } void SocketTransportPluginImpl::Configure(const char* OptionTag, const char* OptionValue) { using namespace std::literals; if (OptionTag == "port"sv) { if (auto PortNum = ParseInt(OptionValue)) { m_BasePort = *PortNum; } } else if (OptionTag == "threads"sv) { if (auto ThreadCount = ParseInt(OptionValue)) { m_ThreadCount = *ThreadCount; } } else { // Unknown configuration option } } bool SocketTransportPluginImpl::IsAvailable() { return true; } void SocketTransportPluginImpl::Initialize(TransportServer* ServerInterface) { m_ServerInterface = ServerInterface; m_WorkerThreadpool = std::make_unique(m_ThreadCount, "http_conn"); m_ListenSocket = socket(AF_INET6, SOCK_STREAM, 0); if (m_ListenSocket == SOCKET_ERROR || m_ListenSocket == INVALID_SOCKET) { ZEN_ERROR("socket creation failed in HTTP plugin server init: {}", WSAGetLastError()); return; } sockaddr_in6 Server{}; Server.sin6_family = AF_INET6; Server.sin6_port = htons(m_BasePort); Server.sin6_addr = in6addr_any; if (int Result = bind(m_ListenSocket, (sockaddr*)&Server, sizeof(Server)); Result == SOCKET_ERROR) { ZEN_ERROR("bind call failed in HTTP plugin server init: {}", WSAGetLastError()); return; } if (int Result = listen(m_ListenSocket, AF_INET6); Result == SOCKET_ERROR) { ZEN_ERROR("listen call failed in HTTP plugin server init: {}", WSAGetLastError()); return; } m_KeepRunning.test_and_set(); m_AcceptThread = std::thread([&] { SetCurrentThreadName("http_plugin_acceptor"); ZEN_INFO("HTTP plugin server waiting for connections"); do { if (SOCKET ClientSocket = accept(m_ListenSocket, NULL, NULL); ClientSocket != SOCKET_ERROR) { int Flag = 1; setsockopt(ClientSocket, IPPROTO_TCP, TCP_NODELAY, (char*)&Flag, sizeof(Flag)); // Handle new connection SocketTransportConnection* Connection = new SocketTransportConnection(); TransportServerConnection* ConnectionInterface{m_ServerInterface->CreateConnectionHandler(Connection)}; Connection->Initialize(ConnectionInterface, ClientSocket); m_WorkerThreadpool->ScheduleWork( [Connection] { try { Connection->HandleConnection(); } catch (const std::exception& Ex) { ZEN_WARN("exception caught in connection loop: {}", Ex.what()); } delete Connection; }, WorkerThreadPool::EMode::EnableBacklog); } else { } } while (m_KeepRunning.test()); ZEN_INFO("HTTP plugin server accept thread exit"); }); } void SocketTransportPluginImpl::Shutdown() { // TODO: all pending/ongoing work should be drained here as well m_KeepRunning.clear(); if (m_ListenSocket) { closesocket(m_ListenSocket); m_ListenSocket = 0; } if (m_AcceptThread.joinable()) { m_AcceptThread.join(); } } const char* SocketTransportPluginImpl::GetDebugName() { return nullptr; } ////////////////////////////////////////////////////////////////////////// TransportPlugin* CreateSocketTransportPlugin() { return new SocketTransportPluginImpl; } } // namespace zen #endif