// Copyright Epic Games, Inc. All Rights Reserved. #include "winsocktransport.h" #if ZEN_WITH_PLUGINS # include # include # include ZEN_THIRD_PARTY_INCLUDES_START # include ZEN_THIRD_PARTY_INCLUDES_END # if ZEN_PLATFORM_WINDOWS # include ZEN_THIRD_PARTY_INCLUDES_START # include ZEN_THIRD_PARTY_INCLUDES_END # endif # include # include # include namespace zen { struct AsioTransportAcceptor; class AsioTransportPlugin : public TransportPlugin, RefCounted { public: AsioTransportPlugin(); ~AsioTransportPlugin(); virtual uint32_t AddRef() const override; virtual uint32_t Release() const override; virtual void Configure(const char* OptionTag, const char* OptionValue) override; virtual void Initialize(TransportServer* ServerInterface) override; virtual void Shutdown() override; virtual const char* GetDebugName() override { return nullptr; } virtual bool IsAvailable() override; private: bool m_IsOk = true; uint16_t m_BasePort = 8558; int m_ThreadCount = 0; asio::io_service m_IoService; asio::io_service::work m_Work{m_IoService}; std::unique_ptr m_Acceptor; std::vector m_ThreadPool; }; struct AsioTransportConnection : public TransportConnection, std::enable_shared_from_this { AsioTransportConnection(std::unique_ptr&& Socket); ~AsioTransportConnection(); void Initialize(TransportServerConnection* ConnectionHandler); std::shared_ptr AsSharedPtr() { return shared_from_this(); } // TransportConnectionInterface virtual int64_t WriteBytes(const void* Buffer, size_t DataSize) override; virtual void Shutdown(bool Receive, bool Transmit) override; virtual void CloseConnection() override; virtual const char* GetDebugName() override { return nullptr; } private: void EnqueueRead(); void OnDataReceived(const asio::error_code& Ec, std::size_t ByteCount); void OnResponseDataSent(const asio::error_code& Ec, std::size_t ByteCount); Ref m_ConnectionHandler; asio::streambuf m_RequestBuffer; std::unique_ptr m_Socket; uint32_t m_ConnectionId = 0; std::atomic_flag m_IsTerminated{}; }; ////////////////////////////////////////////////////////////////////////// struct AsioTransportAcceptor { AsioTransportAcceptor(TransportServer* ServerInterface, asio::io_service& IoService, uint16_t BasePort); ~AsioTransportAcceptor(); void Start(); void RequestStop(); inline int GetAcceptPort() { return m_Acceptor.local_endpoint().port(); } private: TransportServer* m_ServerInterface = nullptr; asio::io_service& m_IoService; asio::ip::tcp::acceptor m_Acceptor; std::atomic m_IsStopped{false}; void EnqueueAccept(); }; ////////////////////////////////////////////////////////////////////////// AsioTransportAcceptor::AsioTransportAcceptor(TransportServer* ServerInterface, asio::io_service& IoService, uint16_t BasePort) : m_ServerInterface(ServerInterface) , m_IoService(IoService) , m_Acceptor(m_IoService, asio::ip::tcp::v6()) { m_Acceptor.set_option(asio::ip::v6_only(false)); # if ZEN_PLATFORM_WINDOWS // Special option for Windows settings as !asio::socket_base::reuse_address is not the same as exclusive access on Windows platforms typedef asio::detail::socket_option::boolean exclusive_address; m_Acceptor.set_option(exclusive_address(true)); # else m_Acceptor.set_option(asio::socket_base::reuse_address(false)); # endif m_Acceptor.set_option(asio::ip::tcp::no_delay(true)); m_Acceptor.set_option(asio::socket_base::receive_buffer_size(128 * 1024)); m_Acceptor.set_option(asio::socket_base::send_buffer_size(256 * 1024)); uint16_t EffectivePort = BasePort; asio::error_code BindErrorCode; m_Acceptor.bind(asio::ip::tcp::endpoint(asio::ip::address_v6::any(), EffectivePort), BindErrorCode); // Sharing violation implies the port is being used by another process for (uint16_t PortOffset = 1; (BindErrorCode == asio::error::address_in_use) && (PortOffset < 10); ++PortOffset) { EffectivePort = BasePort + (PortOffset * 100); m_Acceptor.bind(asio::ip::tcp::endpoint(asio::ip::address_v6::any(), EffectivePort), BindErrorCode); } if (BindErrorCode == asio::error::access_denied) { EffectivePort = 0; m_Acceptor.bind(asio::ip::tcp::endpoint(asio::ip::address_v6::any(), EffectivePort), BindErrorCode); } if (BindErrorCode) { ZEN_ERROR("Unable open asio service, error '{}'", BindErrorCode.message()); } # if ZEN_PLATFORM_WINDOWS // On Windows, loopback connections can take advantage of a faster code path optionally with this flag. // This must be used by both the client and server side, and is only effective in the absence of // Windows Filtering Platform (WFP) callouts which can be installed by security software. // https://docs.microsoft.com/en-us/windows/win32/winsock/sio-loopback-fast-path SOCKET NativeSocket = m_Acceptor.native_handle(); int LoopbackOptionValue = 1; DWORD OptionNumberOfBytesReturned = 0; WSAIoctl(NativeSocket, SIO_LOOPBACK_FAST_PATH, &LoopbackOptionValue, sizeof(LoopbackOptionValue), NULL, 0, &OptionNumberOfBytesReturned, 0, 0); # endif ZEN_INFO("started asio transport at port: {}", EffectivePort); } AsioTransportAcceptor::~AsioTransportAcceptor() { } void AsioTransportAcceptor::Start() { m_Acceptor.listen(); EnqueueAccept(); } void AsioTransportAcceptor::RequestStop() { m_IsStopped = true; } void AsioTransportAcceptor::EnqueueAccept() { auto SocketPtr = std::make_unique(m_IoService); asio::ip::tcp::socket& SocketRef = *SocketPtr.get(); m_Acceptor.async_accept(SocketRef, [this, Socket = std::move(SocketPtr)](const asio::error_code& Ec) mutable { if (Ec) { ZEN_WARN("asio async_accept error ({}:{}): {}", m_Acceptor.local_endpoint().address().to_string(), m_Acceptor.local_endpoint().port(), Ec.message()); } else { // New connection established, pass socket ownership into connection object // and initiate request handling loop. The connection lifetime is // managed by the async read/write loop by passing the shared // reference to the callbacks. Socket->set_option(asio::ip::tcp::no_delay(true)); Socket->set_option(asio::socket_base::receive_buffer_size(128 * 1024)); Socket->set_option(asio::socket_base::send_buffer_size(256 * 1024)); auto Conn = std::make_shared(std::move(Socket)); Conn->Initialize(m_ServerInterface->CreateConnectionHandler(Conn.get())); } if (!m_IsStopped.load()) { EnqueueAccept(); } else { std::error_code CloseEc; m_Acceptor.close(CloseEc); if (CloseEc) { ZEN_WARN("acceptor close ERROR, reason '{}'", CloseEc.message()); } } }); } ////////////////////////////////////////////////////////////////////////// AsioTransportConnection::AsioTransportConnection(std::unique_ptr&& Socket) : m_Socket(std::move(Socket)) { } AsioTransportConnection::~AsioTransportConnection() { } void AsioTransportConnection::Initialize(TransportServerConnection* ConnectionHandler) { m_ConnectionHandler = ConnectionHandler; EnqueueRead(); } int64_t AsioTransportConnection::WriteBytes(const void* Buffer, size_t DataSize) { size_t WrittenBytes = asio::write(*m_Socket.get(), asio::const_buffer(Buffer, DataSize), asio::transfer_exactly(DataSize)); return WrittenBytes; } void AsioTransportConnection::Shutdown(bool Receive, bool Transmit) { std::error_code Ec; if (Receive) { if (Transmit) { m_Socket->shutdown(asio::socket_base::shutdown_both, Ec); } else { m_Socket->shutdown(asio::socket_base::shutdown_receive, Ec); } } else if (Transmit) { m_Socket->shutdown(asio::socket_base::shutdown_send, Ec); } } void AsioTransportConnection::CloseConnection() { if (m_IsTerminated.test()) { return; } if (m_IsTerminated.test_and_set() == false) { Shutdown(true, true); std::error_code Ec; m_Socket->close(Ec); } } void AsioTransportConnection::EnqueueRead() { if (m_IsTerminated.test() == false) { m_RequestBuffer.prepare(64 * 1024); asio::async_read( *m_Socket.get(), m_RequestBuffer, asio::transfer_at_least(1), [Conn = AsSharedPtr()](const asio::error_code& Ec, std::size_t ByteCount) { Conn->OnDataReceived(Ec, ByteCount); }); } } void AsioTransportConnection::OnDataReceived(const asio::error_code& Ec, std::size_t ByteCount) { ZEN_UNUSED(ByteCount); if (Ec) { if (!m_IsTerminated.test()) { ZEN_WARN("on data received ERROR, connection: {}, reason '{}'", m_ConnectionId, Ec.message()); } const bool Receive = true; const bool Transmit = true; return Shutdown(Receive, Transmit); } while (m_RequestBuffer.size()) { const asio::const_buffer& InputBuffer = m_RequestBuffer.data(); m_ConnectionHandler->OnBytesRead(InputBuffer.data(), InputBuffer.size()); m_RequestBuffer.consume(InputBuffer.size()); } EnqueueRead(); } void AsioTransportConnection::OnResponseDataSent(const asio::error_code& Ec, std::size_t ByteCount) { ZEN_UNUSED(ByteCount); if (Ec) { ZEN_WARN("on data sent ERROR, connection: {}, reason '{}'", m_ConnectionId, Ec.message()); const bool Receive = true; const bool Transmit = true; return Shutdown(Receive, Transmit); } } ////////////////////////////////////////////////////////////////////////// AsioTransportPlugin::AsioTransportPlugin() : m_ThreadCount(Max(GetHardwareConcurrency(), 8u)) { } AsioTransportPlugin::~AsioTransportPlugin() { } uint32_t AsioTransportPlugin::AddRef() const { return RefCounted::AddRef(); } uint32_t AsioTransportPlugin::Release() const { return RefCounted::Release(); } void AsioTransportPlugin::Configure(const char* OptionTag, const char* OptionValue) { using namespace std::literals; if (OptionTag == "port"sv) { if (auto PortNum = ParseInt(OptionValue)) { m_BasePort = *PortNum; } } else if (OptionTag == "threads"sv) { if (auto ThreadCount = ParseInt(OptionValue)) { m_ThreadCount = *ThreadCount; } } else { // Unknown configuration option } } void AsioTransportPlugin::Initialize(TransportServer* ServerInterface) { ZEN_ASSERT(m_ThreadCount > 0); ZEN_ASSERT(ServerInterface); ZEN_INFO("starting asio http with {} service threads", m_ThreadCount); m_Acceptor.reset(new AsioTransportAcceptor(ServerInterface, m_IoService, m_BasePort)); m_Acceptor->Start(); // This should consist of a set of minimum threads and grow on demand to // meet concurrency needs? Right now we end up allocating a large number // of threads even if we never end up using all of them, which seems // wasteful. It's also not clear how the demand for concurrency should // be balanced with the engine side - ideally we'd have some kind of // global scheduling to prevent one side from starving the other side // and thus preventing progress. Or at the very least, thread priorities // should be considered for (int i = 0; i < m_ThreadCount; ++i) { m_ThreadPool.emplace_back([this, ThreadNumber = i + 1] { SetCurrentThreadName(fmt::format("asio_thr_{}", ThreadNumber)); try { m_IoService.run(); } catch (const std::exception& e) { ZEN_ERROR("exception caught in asio event loop: {}", e.what()); } }); } ZEN_INFO("asio http transport started (port {})", m_Acceptor->GetAcceptPort()); } void AsioTransportPlugin::Shutdown() { m_Acceptor->RequestStop(); m_IoService.stop(); for (auto& Thread : m_ThreadPool) { Thread.join(); } } bool AsioTransportPlugin::IsAvailable() { return true; } TransportPlugin* CreateAsioTransportPlugin() { return new AsioTransportPlugin(); } } // namespace zen #endif