diff options
| author | FluorescentCIAAfricanAmerican <[email protected]> | 2020-04-22 12:56:21 -0400 |
|---|---|---|
| committer | FluorescentCIAAfricanAmerican <[email protected]> | 2020-04-22 12:56:21 -0400 |
| commit | 3bf9df6b2785fa6d951086978a3e66f49427166a (patch) | |
| tree | 2c0f1f0c63c4832882bc93814ebd2c2b1c6224e5 /utils/vmpi/ThreadedTCPSocket.cpp | |
| download | archived-source-engine-2018-hl2-src-3bf9df6b2785fa6d951086978a3e66f49427166a.tar.xz archived-source-engine-2018-hl2-src-3bf9df6b2785fa6d951086978a3e66f49427166a.zip | |
Diffstat (limited to 'utils/vmpi/ThreadedTCPSocket.cpp')
| -rw-r--r-- | utils/vmpi/ThreadedTCPSocket.cpp | 1085 |
1 files changed, 1085 insertions, 0 deletions
diff --git a/utils/vmpi/ThreadedTCPSocket.cpp b/utils/vmpi/ThreadedTCPSocket.cpp new file mode 100644 index 0000000..444e22b --- /dev/null +++ b/utils/vmpi/ThreadedTCPSocket.cpp @@ -0,0 +1,1085 @@ +//========= Copyright Valve Corporation, All rights reserved. ============// +// +// Purpose: +// +// $NoKeywords: $ +// +//=============================================================================// +// ThreadedTCPSocket.cpp : Defines the entry point for the console application. +// + +#include <winsock2.h> +#include <mswsock.h> +#include "IThreadedTCPSocket.h" +#include "utllinkedlist.h" +#include "threadhelpers.h" +#include "iphelpers.h" +#include "tier1/strtools.h" + + +#define SEND_KEEPALIVE_INTERVAL 3000 +#define KEEPALIVE_TIMEOUT 25000 // The sockets timeout after this long. + +#define KEEPALIVE_SENTINEL -12345 // When first 4 bytes of a packet = this, then it's just a keepalive. + + +static int g_KeepaliveSentinel = KEEPALIVE_SENTINEL; +bool g_bHandleTimeouts = true; + +// If true, it'll set the socket thread priorities lower than normal. +bool g_bSetTCPSocketThreadPriorities = true; + +// We get crashes at runtime if they don't link in the multithreaded runtime libraries, +// so raise a ruckus if they're using singlethreaded libraries. +#ifndef _MT + #pragma message( "**** WARNING **** ThreadedTCPSocket requires multithreaded runtime libraries to be used.\n" ) + class MTChecker + { + public: + MTChecker() { Assert( false ); } + } g_MTChecker; +#endif + + + +// ------------------------------------------------------------------------------------------------ // +// Static helpers. +// ------------------------------------------------------------------------------------------------ // +static SOCKET TCPBind( const CIPAddr *pAddr ) +{ + // Create a socket to send and receive through. + SOCKET sock = WSASocket( AF_INET, SOCK_STREAM, IPPROTO_TCP, NULL, 0, WSA_FLAG_OVERLAPPED ); + if ( sock == INVALID_SOCKET ) + { + Assert( false ); + return INVALID_SOCKET; + } + + // bind to it! + sockaddr_in addr; + IPAddrToSockAddr( pAddr, &addr ); + + int status = bind( sock, (sockaddr*)&addr, sizeof(addr) ); + if ( status == 0 ) + { + return sock; + } + else + { + closesocket( sock ); + return INVALID_SOCKET; + } +} + + + +// ------------------------------------------------------------------------------------------------ // +// CTCPPacket. +// ------------------------------------------------------------------------------------------------ // + +int CTCPPacket::GetUserData() const +{ + return m_UserData; +} + +void CTCPPacket::SetUserData( int userData ) +{ + m_UserData = userData; +} + +void CTCPPacket::Release() +{ + free( this ); +} + + +// ------------------------------------------------------------------------------------------------ // +// CThreadedTCPSocket. +// ------------------------------------------------------------------------------------------------ // +class CThreadedTCPSocket : public IThreadedTCPSocket +{ +public: + + static IThreadedTCPSocket* Create( SOCKET iSocket, CIPAddr remoteAddr, ITCPSocketHandler *pHandler ) + { + CThreadedTCPSocket *pRet = new CThreadedTCPSocket; + if ( pRet->Init( iSocket, remoteAddr, pHandler ) ) + { + return pRet; + } + else + { + pRet->Release(); + return NULL; + } + } + + +// IThreadedTCPSocket implementation. +public: + + virtual void Release() + { + delete this; + } + + virtual CIPAddr GetRemoteAddr() const + { + return m_RemoteAddr; + } + + virtual bool IsValid() + { + return !CheckErrorSignal(); + } + + virtual bool Send( const void *pData, int len ) + { + const void *pChunks[1] = { pData }; + return SendChunks( pChunks, &len, 1 ); + } + + virtual bool SendChunks( void const * const *pChunks, const int *pChunkLengths, int nChunks ) + { + if ( CheckErrorSignal() ) + return false; + + return InternalSend( pChunks, pChunkLengths, nChunks, true ); + } + + +// Initialization. +private: + + CThreadedTCPSocket() + { + m_Socket = INVALID_SOCKET; + m_pHandler = NULL; + memset( &m_SendOverlapped, 0, sizeof( m_SendOverlapped ) ); + memset( &m_RecvOverlapped, 0, sizeof( m_RecvOverlapped ) ); + m_bWaitingForSendCompletion = false; + m_nBytesToReceive = -1; + m_bWaitingForSize = false; + m_bErrorSignal = false; + m_pRecvBuffer = NULL; + } + + virtual ~CThreadedTCPSocket() + { + Term(); + } + + bool Init( SOCKET iSocket, CIPAddr remoteAddr, ITCPSocketHandler *pHandler ) + { + m_Socket = iSocket; + m_RemoteAddr = remoteAddr; + m_pHandler = pHandler; + + SetInitialSocketOptions(); + + // Create all the event objects we'll use to communicate. + m_hExitThreadsEvent.Init( true, false ); + m_hSendCompletionEvent.Init( false, false ); + m_hReadyToSendEvent.Init( false, false ); + m_hRecvEvent.Init( false, false ); + + m_SendOverlapped.hEvent = m_hSendCompletionEvent.GetEventHandle(); + m_RecvOverlapped.hEvent = m_hRecvEvent.GetEventHandle(); + + // Create our threads. + DWORD dwSendThreadID, dwRecvThreadID; + m_hSendThread = CreateThread( NULL, 0, &CThreadedTCPSocket::StaticSendThreadFn, this, CREATE_SUSPENDED, &dwSendThreadID ); + m_hRecvThread = CreateThread( NULL, 0, &CThreadedTCPSocket::StaticRecvThreadFn, this, CREATE_SUSPENDED, &dwRecvThreadID ); + if ( !m_hSendThread || !m_hRecvThread ) + { + return false; + } + + if ( g_bSetTCPSocketThreadPriorities ) + { + SetThreadPriority( m_hSendThread, THREAD_PRIORITY_LOWEST ); + SetThreadPriority( m_hRecvThread, THREAD_PRIORITY_LOWEST ); + } + + ThreadSetDebugName( (ThreadId_t)dwSendThreadID, "TCPSend" ); + ThreadSetDebugName( (ThreadId_t)dwRecvThreadID, "TCPRecv" ); + + // Make sure to init the handler before the threads actually run, so it isn't handed data before initializing. + m_pHandler->Init( this ); + + ResumeThread( m_hSendThread ); + ResumeThread( m_hRecvThread ); + + return true; + } + + void Term() + { + // Signal our threads to exit. + m_hExitThreadsEvent.SetEvent(); + if ( m_hSendThread ) + { + WaitForSingleObject( m_hSendThread, INFINITE ); + CloseHandle( m_hSendThread ); + m_hSendThread = NULL; + } + + if ( m_hRecvThread ) + { + WaitForSingleObject( m_hRecvThread, INFINITE ); + CloseHandle( m_hRecvThread ); + m_hRecvThread = NULL; + } + m_hExitThreadsEvent.ResetEvent(); + + + if ( m_Socket != INVALID_SOCKET ) + { + closesocket( m_Socket ); + m_Socket = INVALID_SOCKET; + } + } + + // Set the initial socket options that we want. + void SetInitialSocketOptions() + { + // Set nodelay to improve latency. + BOOL val = TRUE; + setsockopt( m_Socket, IPPROTO_TCP, TCP_NODELAY, (const char FAR *)&val, sizeof(BOOL) ); + + // Make it linger for 3 seconds when it exits. + LINGER linger; + linger.l_onoff = 1; + linger.l_linger = 3; + setsockopt( m_Socket, SOL_SOCKET, SO_LINGER, (char*)&linger, sizeof( linger ) ); + } + + +// Send thread functionality. +private: + + // This function copies off the payload and adds a SendChunk_t to the list of chunks to be sent. + // It also fires the ReadyToSend event so the thread will pick it up. + bool InternalSend( void const * const *pChunks, const int *pChunkLengths, int nChunks, bool bPrependLength ) + { + int totalLength = 0; + for ( int i=0; i < nChunks; i++ ) + totalLength += pChunkLengths[i]; + + if ( bPrependLength ) + { + if ( totalLength == 0 ) + return true; + + totalLength += 4; + } + + // Copy all the data into a SendData_t. + SendData_t *pSendData = (SendData_t*)malloc( sizeof( SendData_t ) - 1 + totalLength ); + pSendData->m_Len = totalLength; + + char *pOut = pSendData->m_Payload; + if ( bPrependLength ) + { + *((int*)pOut) = totalLength - 4; // The length we prepend is the size of the data, not data size + integer for length. + pOut += 4; + } + for ( int i=0; i < nChunks; i++ ) + { + memcpy( pOut, pChunks[i], pChunkLengths[i] ); + pOut += pChunkLengths[i]; + } + + CCriticalSectionLock csLock( &m_SendCS ); + csLock.Lock(); + + m_SendDatas.AddToTail( pSendData ); + m_hReadyToSendEvent.SetEvent(); // Notify the thread that there is data to send. + + csLock.Unlock(); + + return true; + } + + void SendThread_HandleTimeout() + { + // Timeout.. send a keepalive. + // But only if we're not already sending something. + CCriticalSectionLock csLock( &m_SendCS ); + csLock.Lock(); + int count = m_SendDatas.Count(); + csLock.Unlock(); + + if ( count == 0 ) + { + void *pBuf[1] = { &g_KeepaliveSentinel }; + int len[1] = { sizeof( g_KeepaliveSentinel ) }; + InternalSend( pBuf, len, 1, false ); + } + } + + bool SendThread_HandleSendCompletionEvent() + { + Assert( m_bWaitingForSendCompletion ); + m_bWaitingForSendCompletion = false; + + // A send operation just completed. Now do the next one. + DWORD cbTransfer, flags; + if ( !WSAGetOverlappedResult( m_Socket, &m_SendOverlapped, &cbTransfer, TRUE, &flags ) ) + { + HandleError( WSAGetLastError() ); + return false; + } + + if ( cbTransfer != m_nBytesToTransfer ) + { + char str[512]; + Q_snprintf( str, sizeof( str ), "Invalid # bytes transferred (%d) in send thread (should be %d)", cbTransfer, m_nBytesToTransfer ); + HandleError( ITCPSocketHandler::SocketError, str ); + return false; + } + + // Remove the block we just sent. + CCriticalSectionLock csLock( &m_SendCS ); + csLock.Lock(); + + SendData_t *pSendData = m_SendDatas[ m_SendDatas.Head() ]; + free( pSendData ); + m_SendDatas.Remove( m_SendDatas.Head() ); + + m_bWaitingForSendCompletion = false; + + // Set our send event if there's anything else to send. + if ( m_SendDatas.Count() > 0 ) + m_hReadyToSendEvent.SetEvent(); + + csLock.Unlock(); + return true; + } + + bool SendThread_HandleReadyToSendEvent() + { + // We've got at least one buffer that's ready to be sent. + // NOTE: don't send anything until our current send is completed. + CCriticalSectionLock csLock( &m_SendCS ); + csLock.Lock(); + + Assert( !m_bWaitingForSendCompletion ); + + // Send it off! + SendData_t *pSendData = m_SendDatas[ m_SendDatas.Head() ]; + WSABUF buf = { pSendData->m_Len, pSendData->m_Payload }; + + m_nBytesToTransfer = pSendData->m_Len; + m_bWaitingForSendCompletion = true; + + csLock.Unlock(); + + DWORD dwNumBytesSent = 0; + DWORD ret = WSASend( m_Socket, &buf, 1, &dwNumBytesSent, 0, &m_SendOverlapped, NULL ); + DWORD err = WSAGetLastError(); + if ( ret == 0 || ( ret == SOCKET_ERROR && err == WSA_IO_PENDING ) ) + { + // Either way, the operation completed successfully, and m_hSendCompletionEvent is now set. + return true; + } + else + { + HandleError( err ); + return false; + } + + return true; + } + + DWORD SendThreadFn() + { + while ( 1 ) + { + HANDLE handles[] = + { + m_hExitThreadsEvent.GetEventHandle(), + m_hSendCompletionEvent.GetEventHandle(), + m_hReadyToSendEvent.GetEventHandle() + }; + int nHandles = ARRAYSIZE( handles ); + + // While waiting for send completion, don't handle "ready to send" events. + if ( m_bWaitingForSendCompletion ) + --nHandles; + + DWORD waitValue = WaitForMultipleObjects( nHandles, handles, FALSE, SEND_KEEPALIVE_INTERVAL ); + switch ( waitValue ) + { + case WAIT_TIMEOUT: + { + if ( g_bHandleTimeouts ) + { + // We haven't sent anything in a bit. Send out a keepalive. + SendThread_HandleTimeout(); + } + } + break; + + case WAIT_OBJECT_0: + { + // The main thread is signaling us to exit. + return 0; + } + + case WAIT_OBJECT_0 + 1: + { + if ( !SendThread_HandleSendCompletionEvent() ) + return 1; + } + break; + + case WAIT_OBJECT_0 + 2: + { + if ( !SendThread_HandleReadyToSendEvent() ) + return 1; + } + break; + + case WAIT_FAILED: + { + // Uh oh. We're dead. Cleanup and signal an error. + HandleError( GetLastError() ); + return 1; + } + + default: + { + char str[512]; + Q_snprintf( str, sizeof( str ), "Unknown return value (%lu) from WaitForMultipleObjects", waitValue ); + HandleError( ITCPSocketHandler::SocketError, str ); + return 0; + } + } + } + + return 0; + } + + static DWORD WINAPI StaticSendThreadFn( LPVOID pParameter ) + { + return ((CThreadedTCPSocket*)pParameter)->SendThreadFn(); + } + + +// Receive thread functionality. +private: + + bool RecvThread_WaitToReceiveSize() + { + return RecvThread_InternalRecv( &m_NextPacketLen, sizeof( m_NextPacketLen ), false, true ); + } + + + bool RecvThread_InternalHandleRecvCompletion( DWORD dwTransfer ) + { + int cbTransfer = (int)dwTransfer; + int nBytesWanted = m_nBytesToReceive - m_nBytesReceivedSoFar; + if ( cbTransfer > nBytesWanted ) + { + char str[512]; + Q_snprintf( str, sizeof( str ), "Invalid # bytes received (%d) in recv thread (should be %d)", cbTransfer, m_nBytesToReceive ); + HandleError( ITCPSocketHandler::SocketError, str ); + return false; + } + else if ( cbTransfer < nBytesWanted ) + { + // We have to reissue the receive command because it didn't receive all the data. + m_nBytesReceivedSoFar += cbTransfer; + + char *pDest = (char*)&m_NextPacketLen; + if ( !m_bWaitingForSize ) + { + Assert( m_pRecvBuffer ); + pDest = m_pRecvBuffer->m_Data; + } + + return RecvThread_InternalRecv( &pDest[m_nBytesReceivedSoFar], m_nBytesToReceive - m_nBytesReceivedSoFar, true ); + } + + if ( m_bWaitingForSize ) + { + // If we were waiting for size, now wait for the data. + if ( m_NextPacketLen == KEEPALIVE_SENTINEL ) + { + // Ok, it was just a keepalive. Wait for size again. + return RecvThread_WaitToReceiveSize(); + } + else + { + if ( m_NextPacketLen < 1 || m_NextPacketLen > 1024*1024*75 ) + { + char str[512]; + Q_snprintf( str, sizeof( str ), "Invalid packet size in RecvThread (size = %d)", m_NextPacketLen ); + HandleError( ITCPSocketHandler::SocketError, str ); + return false; + } + else + { + Assert( !m_pRecvBuffer ); + m_pRecvBuffer = (CTCPPacket*)malloc( sizeof( CTCPPacket ) - 1 + m_NextPacketLen ); + m_pRecvBuffer->m_UserData = 0; + m_pRecvBuffer->m_Len = m_NextPacketLen; + + return RecvThread_InternalRecv( m_pRecvBuffer->m_Data, m_pRecvBuffer->m_Len, false, false ); + } + } + } + else + { + // Got a packet! Give it to the app. + m_pHandler->OnPacketReceived( m_pRecvBuffer ); + m_pRecvBuffer = NULL; + + return RecvThread_WaitToReceiveSize(); + } + } + + bool RecvThread_HandleRecvCompletionEvent() + { + // A send operation just completed. Now do the next one. + DWORD cbTransfer, flags; + if ( !WSAGetOverlappedResult( m_Socket, &m_RecvOverlapped, &cbTransfer, TRUE, &flags ) ) + { + HandleError( WSAGetLastError() ); + return false; + } + + return RecvThread_InternalHandleRecvCompletion( cbTransfer ); + } + + bool RecvThread_InternalRecv( void *pDest, int destSize, bool bContinuation, bool bWaitingForSize = false ) + { + WSABUF buf = { destSize, (char*)pDest }; + + if ( !bContinuation ) + { + // If this is not a continuation of whatever we were receiving before, then + m_bWaitingForSize = bWaitingForSize; + m_nBytesToReceive = destSize; + m_nBytesReceivedSoFar = 0; + } + + DWORD dwFlags = 0; + DWORD nBytesReceived = 0; + DWORD ret = WSARecv( m_Socket, &buf, 1, &nBytesReceived, &dwFlags, &m_RecvOverlapped, NULL ); + DWORD dwLastError = WSAGetLastError(); + if ( ret == 0 || ( ret == SOCKET_ERROR && dwLastError == WSA_IO_PENDING ) ) + { + // Note: m_hRecvEvent is in a signaled state, so the RecvThread will pick up the results next time around. + return true; + } + else + { + HandleError( dwLastError ); + return false; + } + } + + + + DWORD RecvThreadFn() + { + // Start us off by setting up to receive the first packet size. + if ( !RecvThread_WaitToReceiveSize() ) + return 1; + + HANDLE handles[] = + { + m_hExitThreadsEvent.GetEventHandle(), + m_hRecvEvent.GetEventHandle() + }; + + while ( 1 ) + { + DWORD waitValue = WaitForMultipleObjects( ARRAYSIZE( handles ), handles, FALSE, KEEPALIVE_TIMEOUT ); + switch ( waitValue ) + { + case WAIT_TIMEOUT: + { + if ( g_bHandleTimeouts ) + { + HandleError( ITCPSocketHandler::ConnectionTimedOut, "Connection timed out" ); + return 1; + } + } + break; + + case WAIT_OBJECT_0: + { + // We're being told to exit. + return 0; + } + + case WAIT_OBJECT_0 + 1: + { + // Just finished receiving something. + if ( !RecvThread_HandleRecvCompletionEvent() ) + return 1; + } + break; + + case WAIT_FAILED: + { + // Uh oh. We're dead. Cleanup and signal an error. + HandleError( GetLastError() ); + return 1; + } + + default: + { + char str[512]; + Q_snprintf( str, sizeof( str ), "Unknown return value (%lu) from WaitForMultipleObjects", waitValue ); + HandleError( ITCPSocketHandler::SocketError, str ); + return 1; + } + } + } + + return 0; + } + + static DWORD WINAPI StaticRecvThreadFn( LPVOID pParameter ) + { + return ((CThreadedTCPSocket*)pParameter)->RecvThreadFn(); + } + + + +// Error handling. +private: + + // This checks to see if either thread has signaled an error. If so, it shuts down the socket and returns true. + bool CheckErrorSignal() + { + return m_bErrorSignal; + } + + // This is called from any of the threads and signals that something went awry. It shuts down the object + // and makes it return false from all of its functions. + void HandleError( DWORD errorValue ) + { + char *lpMsgBuf; + + FormatMessage( + FORMAT_MESSAGE_ALLOCATE_BUFFER | + FORMAT_MESSAGE_FROM_SYSTEM | + FORMAT_MESSAGE_IGNORE_INSERTS, + NULL, + GetLastError(), + MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), // Default language + (char*)&lpMsgBuf, + 0, + NULL + ); + + // Windows likes to stick a carriage return in there and we don't want it so get rid of it. + int len = strlen( lpMsgBuf ); + while ( len > 0 && ( lpMsgBuf[len-1] == '\n' || lpMsgBuf[len-1] == '\r' ) ) + { + --len; + lpMsgBuf[len] = 0; + } + + HandleError( ITCPSocketHandler::SocketError, lpMsgBuf ); + + LocalFree( lpMsgBuf ); + } + + + void HandleError( int errorCode, const char *pErrorString ) + { + //Assert( false ); + + // Tell the app. + m_pHandler->OnError( errorCode, pErrorString ); + + // Tell the threads to exit. + m_hExitThreadsEvent.SetEvent(); + + // Notify the main thread so it can call Term() when it gets a chance. + m_bErrorSignal = true; + } + + +private: + + // Data for the send thread. + typedef struct + { + //WSAOVERLAPPED m_Overlapped; + int m_Len; + char m_Payload[1]; + } SendData_t; + + HANDLE m_hSendThread; + WSAOVERLAPPED m_SendOverlapped; + CEvent m_hReadyToSendEvent; + CEvent m_hSendCompletionEvent; + + CCriticalSection m_SendCS; + DWORD m_nBytesToTransfer; + bool m_bWaitingForSendCompletion; + CUtlLinkedList<SendData_t*, int> m_SendDatas; // Added to the tail, popped off the head for sending. + + + // Data for the recv thread. + HANDLE m_hRecvThread; + int m_nBytesToReceive; // This stores how many bytes we want to receive for the next packet. + int m_nBytesReceivedSoFar; // This stores how many bytes we've received so far. + bool m_bWaitingForSize; // This tells if we're trying to receive the next packet length or the next packet's data. + + int m_NextPacketLen; // Data is received INTO here before it receives each packet. This + // holds the length of each incoming packet. + + WSAOVERLAPPED m_RecvOverlapped; + CEvent m_hRecvEvent; + CTCPPacket *m_pRecvBuffer; // This is allocated for each packet we're receiving and given to the + // app when the packet is done being received. + + + volatile bool m_bErrorSignal; + + + CEvent m_hExitThreadsEvent; + + ITCPSocketHandler *m_pHandler; + + SOCKET m_Socket; + CIPAddr m_RemoteAddr; +}; + + +// ------------------------------------------------------------------------------------------------ // +// CTCPConnectSocket_Listener +// ------------------------------------------------------------------------------------------------ // +class CTCPConnectSocket_Listener : public ITCPConnectSocket +{ +public: + CTCPConnectSocket_Listener() + { + m_Socket = INVALID_SOCKET; + } + + + virtual ~CTCPConnectSocket_Listener() + { + if ( m_Socket != INVALID_SOCKET ) + { + closesocket( m_Socket ); + } + } + + + // The main function to create one of these suckers. + static ITCPConnectSocket* Create( + IHandlerCreator *pHandlerCreator, + const unsigned short port, + int nQueueLength + ) + { + CTCPConnectSocket_Listener *pRet = new CTCPConnectSocket_Listener; + if ( !pRet ) + return NULL; + + if ( nQueueLength < 0 ) + { + Error( "CTCPConnectSocket_Listener::Create - SOMAXCONN not allowed - causes some XP SP2 systems to stop receiving any network data (systemwide)." ); + } + + // Bind it to a socket and start listening. + CIPAddr addr( 0, 0, 0, 0, port ); // INADDR_ANY + pRet->m_Socket = TCPBind( &addr ); + if ( pRet->m_Socket == INVALID_SOCKET || + listen( pRet->m_Socket, nQueueLength == -1 ? SOMAXCONN : nQueueLength ) != 0 ) + { + pRet->Release(); + return false; + } + + pRet->m_pHandler = pHandlerCreator; + return pRet; + } + + +// ITCPConnectSocket implementation. +public: + + virtual void Release() + { + delete this; + } + + virtual bool Update( IThreadedTCPSocket **pSocket, unsigned long milliseconds ) + { + *pSocket = NULL; + if ( m_Socket == INVALID_SOCKET ) + return false; + + // We're still ok.. just wait until the socket becomes writable (is connected) or we timeout. + fd_set readSet; + readSet.fd_count = 1; + readSet.fd_array[0] = m_Socket; + TIMEVAL timeVal = {0, milliseconds*1000}; + + // Wait until it connects. + int status = select( 0, &readSet, NULL, NULL, &timeVal ); + if ( status > 0 ) + { + sockaddr_in addr; + int addrSize = sizeof( addr ); + + // Now accept the final connection. + SOCKET newSock = accept( m_Socket, (struct sockaddr*)&addr, &addrSize ); + if ( newSock == INVALID_SOCKET ) + { + Assert( false ); + return true; + } + else + { + CIPAddr connectedAddr; + SockAddrToIPAddr( &addr, &connectedAddr ); + + IThreadedTCPSocket *pRet = CThreadedTCPSocket::Create( newSock, connectedAddr, m_pHandler->CreateNewHandler() ); + if ( !pRet ) + { + Assert( false ); + closesocket( m_Socket ); + m_Socket = INVALID_SOCKET; + return false; + } + + *pSocket = pRet; + return true; + } + } + else if ( status == SOCKET_ERROR ) + { + closesocket( m_Socket ); + m_Socket = INVALID_SOCKET; + return false; + } + else + { + return true; + } + } + + +private: + SOCKET m_Socket; + + IHandlerCreator *m_pHandler; +}; + + + +ITCPConnectSocket* ThreadedTCP_CreateListener( + IHandlerCreator *pHandlerCreator, + const unsigned short port, + int nQueueLength + ) +{ + return CTCPConnectSocket_Listener::Create( pHandlerCreator, port, nQueueLength ); +} + + + +// ------------------------------------------------------------------------------------------------ // +// CTCPConnectSocket_Connector +// ------------------------------------------------------------------------------------------------ // +class CTCPConnectSocket_Connector : public ITCPConnectSocket +{ +public: + + CTCPConnectSocket_Connector() + { + m_bConnected = false; + m_Socket = INVALID_SOCKET; + m_bError = false; + } + + virtual ~CTCPConnectSocket_Connector() + { + if ( m_Socket != INVALID_SOCKET ) + { + closesocket( m_Socket ); + } + } + + static ITCPConnectSocket* Create( + const CIPAddr &connectAddr, + const CIPAddr &localAddr, + IHandlerCreator *pHandlerCreator + ) + { + CTCPConnectSocket_Connector *pRet = new CTCPConnectSocket_Connector; + + pRet->m_Socket = TCPBind( &localAddr ); + if ( pRet->m_Socket == INVALID_SOCKET ) + { + pRet->Release(); + return NULL; + } + + sockaddr_in addr; + IPAddrToSockAddr( &connectAddr, &addr ); + + // We don't want the connect() call to block. + DWORD val = 1; + int status = ioctlsocket( pRet->m_Socket, FIONBIO, &val ); + if ( status != 0 ) + { + Assert( false ); + pRet->Release(); + return NULL; + } + + pRet->m_RemoteAddr = connectAddr; + pRet->m_pHandlerCreator = pHandlerCreator; + + int ret = connect( pRet->m_Socket, (struct sockaddr*)&addr, sizeof( addr ) ); + if ( ret == 0 ) + { + pRet->m_bConnected = true; + return pRet; + } + else if ( ret == SOCKET_ERROR && WSAGetLastError() == WSAEWOULDBLOCK ) + { + return pRet; + } + else + { + Assert( false ); + pRet->Release(); + return NULL; + } + } + + +// ITCPConnectSocket implementation. +public: + + virtual void Release() + { + delete this; + } + + + virtual bool Update( IThreadedTCPSocket **pSocket, unsigned long milliseconds ) + { + *pSocket = NULL; + + // If we got an error previously, keep returning false. + if ( m_bError ) + return false; + + // If this condition holds, then we already returned a valid socket and we're just waiting to be released. + if ( m_Socket == INVALID_SOCKET ) + return true; + + // Ok, see if we're connected now. + if ( !m_bConnected ) + { + TIMEVAL timeVal = { 0, milliseconds*1000 }; + + fd_set writeSet; + writeSet.fd_count = 1; + writeSet.fd_array[0] = m_Socket; + + int ret = select( 0, NULL, &writeSet, NULL, &timeVal ); + if ( ret > 0 ) + { + m_bConnected = true; + } + else if ( ret == SOCKET_ERROR ) + { + return EnterErrorMode(); + } + } + + if ( m_bConnected ) + { + // Ok, return a connected socket for them. + + // Make our socket blocking again. + DWORD val = 0; + int status = ioctlsocket( m_Socket, FIONBIO, &val ); + if ( status != 0 ) + { + Assert( false ); + m_bError = true; + closesocket( m_Socket ); + m_Socket = INVALID_SOCKET; + return false; + } + + IThreadedTCPSocket *pRet = CThreadedTCPSocket::Create( m_Socket, m_RemoteAddr, m_pHandlerCreator->CreateNewHandler() ); + if ( pRet ) + { + m_Socket = INVALID_SOCKET; + *pSocket = pRet; + return true; + } + else + { + return EnterErrorMode(); + } + } + else + { + // Still waiting.. + return true; + } + } + + // Shutdown the socket and start returning false from Update(). + bool EnterErrorMode() + { + Assert( false ); + m_bError = true; + closesocket( m_Socket ); + m_Socket = INVALID_SOCKET; + return false; + } + + +private: + + bool m_bError; + bool m_bConnected; + + SOCKET m_Socket; + CIPAddr m_RemoteAddr; + + IHandlerCreator *m_pHandlerCreator; +}; + + +ITCPConnectSocket* ThreadedTCP_CreateConnector( + const CIPAddr &addr, + const CIPAddr &localAddr, + IHandlerCreator *pHandlerCreator + ) +{ + return CTCPConnectSocket_Connector::Create( addr, localAddr, pHandlerCreator ); +} + + +void ThreadedTCP_EnableTimeouts( bool bEnable ) +{ + g_bHandleTimeouts = bEnable; +} + + +void ThreadedTCP_SetTCPSocketThreadPriorities( bool bSetTCPSocketThreadPriorities ) +{ + g_bSetTCPSocketThreadPriorities = bSetTCPSocketThreadPriorities; +} + |