summaryrefslogtreecommitdiff
path: root/utils/vmpi/tcpsocket.cpp
diff options
context:
space:
mode:
authorFluorescentCIAAfricanAmerican <[email protected]>2020-04-22 12:56:21 -0400
committerFluorescentCIAAfricanAmerican <[email protected]>2020-04-22 12:56:21 -0400
commit3bf9df6b2785fa6d951086978a3e66f49427166a (patch)
tree2c0f1f0c63c4832882bc93814ebd2c2b1c6224e5 /utils/vmpi/tcpsocket.cpp
downloadarchived-source-engine-2018-hl2-src-3bf9df6b2785fa6d951086978a3e66f49427166a.tar.xz
archived-source-engine-2018-hl2-src-3bf9df6b2785fa6d951086978a3e66f49427166a.zip
Diffstat (limited to 'utils/vmpi/tcpsocket.cpp')
-rw-r--r--utils/vmpi/tcpsocket.cpp1178
1 files changed, 1178 insertions, 0 deletions
diff --git a/utils/vmpi/tcpsocket.cpp b/utils/vmpi/tcpsocket.cpp
new file mode 100644
index 0000000..8ab67cf
--- /dev/null
+++ b/utils/vmpi/tcpsocket.cpp
@@ -0,0 +1,1178 @@
+//========= Copyright Valve Corporation, All rights reserved. ============//
+//
+// Purpose:
+//
+// $NoKeywords: $
+//=============================================================================//
+
+
+//#define PARANOID
+
+#if defined( PARANOID )
+ #include <stdlib.h>
+ #include <crtdbg.h>
+#endif
+
+#include <winsock2.h>
+#include <mswsock.h>
+#include "tcpsocket.h"
+#include "tier1/utllinkedlist.h"
+#include <stdio.h>
+#include "threadhelpers.h"
+#include "tier0/dbg.h"
+
+
+
+#error "I am TCPSocket and I suck. Use IThreadedTCPSocket or ThreadedTCPSocketEmu instead."
+
+
+extern TIMEVAL SetupTimeVal( double flTimeout );
+extern void IPAddrToSockAddr( const CIPAddr *pIn, sockaddr_in *pOut );
+extern void SockAddrToIPAddr( const sockaddr_in *pIn, CIPAddr *pOut );
+
+
+#define SENTINEL_DISCONNECT -1
+#define SENTINEL_KEEPALIVE -2
+
+
+#define KEEPALIVE_INTERVAL_MS 3000 // keepalives are sent every N MS
+#define KEEPALIVE_TIMEOUT_SECONDS 15.0 // connections timeout after this long
+
+
+static bool g_bEnableTCPTimeout = true;
+
+
+class CRecvData
+{
+public:
+ int m_Count;
+ unsigned char m_Data[1];
+};
+
+
+
+SOCKET TCPBind( const CIPAddr *pAddr )
+{
+ // Create a socket to send and receive through.
+ SOCKET sock = WSASocket( AF_INET, SOCK_STREAM, IPPROTO_TCP, NULL, 0, WSA_FLAG_OVERLAPPED );
+ if ( sock == INVALID_SOCKET )
+ {
+ Assert( false );
+ return INVALID_SOCKET;
+ }
+
+ // bind to it!
+ sockaddr_in addr;
+ IPAddrToSockAddr( pAddr, &addr );
+
+ int status = bind( sock, (sockaddr*)&addr, sizeof(addr) );
+ if ( status == 0 )
+ {
+ return sock;
+ }
+ else
+ {
+ closesocket( sock );
+ return INVALID_SOCKET;
+ }
+}
+
+
+
+// ---------------------------------------------------------------------------------------- //
+// TCP sockets.
+// ---------------------------------------------------------------------------------------- //
+
+enum
+{
+ OP_RECV=111,
+ OP_SEND
+};
+
+// We use this for all OVERLAPPED structures.
+class COverlappedPlus : public WSAOVERLAPPED
+{
+public:
+ COverlappedPlus()
+ {
+ memset( this, 0, sizeof( WSAOVERLAPPED ) );
+ }
+
+ int m_OPType; // One of the OP_ defines.
+};
+
+typedef struct SendBuf_t
+{
+ COverlappedPlus m_Overlapped;
+ int m_Index; // Index into m_SendBufs.
+ int m_DataLength;
+ char m_Data[1];
+} SendBuf_s;
+
+
+// These manage a thread that calls SendKeepalive() on all TCPSockets.
+// AddGlobalTCPSocket shouldn't be called until you're ready for SendKeepalive() to be called.
+class CTCPSocket;
+void AddGlobalTCPSocket( CTCPSocket *pSocket );
+void RemoveGlobalTCPSocket( CTCPSocket *pSocket );
+
+
+
+// ------------------------------------------------------------------------------------------ //
+// CTCPSocket implementation.
+// ------------------------------------------------------------------------------------------ //
+
+class CTCPSocket : public ITCPSocket
+{
+friend class CTCPListenSocket;
+
+public:
+
+ CTCPSocket()
+ {
+ m_Socket = INVALID_SOCKET;
+ m_bConnected = false;
+
+ m_hIOCP = NULL;
+
+ m_bShouldExitThreads = false;
+ m_bConnectionLost = false;
+ m_nSizeBytesReceived = 0;
+
+ m_pIncomingData = NULL;
+
+ memset( &m_RecvOverlapped, 0, sizeof( m_RecvOverlapped ) );
+ m_RecvOverlapped.m_OPType = OP_RECV;
+
+ m_hRecvSignal = CreateEvent( NULL, FALSE, FALSE, NULL );
+ m_RecvStage = -1;
+
+ m_MainThreadID = GetCurrentThreadId();
+ }
+
+ virtual ~CTCPSocket()
+ {
+ Term();
+ CloseHandle( m_hRecvSignal );
+ }
+
+ void Term()
+ {
+ Assert( GetCurrentThreadId() == m_MainThreadID );
+
+ RemoveGlobalTCPSocket( this );
+
+ if ( m_Socket != SOCKET_ERROR && !m_bConnectionLost )
+ {
+ SendDisconnectSentinel();
+
+ // Give the sends a second to complete. SO_LINGER is having trouble for some reason.
+ WaitForSendsToComplete( 1 );
+ }
+
+
+ StopThreads();
+
+ if ( m_Socket != INVALID_SOCKET )
+ {
+ closesocket( m_Socket );
+ m_Socket = INVALID_SOCKET;
+ }
+
+ if ( m_hIOCP )
+ {
+ CloseHandle( m_hIOCP );
+ m_hIOCP = NULL;
+ }
+
+ m_bConnected = false;
+ m_bConnectionLost = true;
+ m_RecvStage = -1;
+
+ FOR_EACH_LL( m_SendBufs, i )
+ {
+ SendBuf_t *pSendBuf = m_SendBufs[i];
+ ParanoidMemoryCheck( pSendBuf );
+ free( pSendBuf );
+ }
+ m_SendBufs.Purge();
+
+ FOR_EACH_LL( m_RecvDatas, j )
+ {
+ CRecvData *pRecvData = m_RecvDatas[j];
+ ParanoidMemoryCheck( pRecvData );
+ free( pRecvData );
+ }
+ m_RecvDatas.Purge();
+
+ if ( m_pIncomingData )
+ {
+ ParanoidMemoryCheck( m_pIncomingData );
+ free( m_pIncomingData );
+ m_pIncomingData = 0;
+ }
+ }
+
+ virtual void Release()
+ {
+ delete this;
+ }
+
+
+ void ParanoidMemoryCheck( void *ptr = NULL )
+ {
+#if defined( PARANOID )
+ Assert( _CrtIsValidHeapPointer( this ) );
+
+ if ( ptr )
+ {
+ Assert( _CrtIsValidHeapPointer( ptr ) );
+ }
+
+ Assert( _CrtCheckMemory() == TRUE );
+#endif
+ }
+
+
+ virtual bool BindToAny( const unsigned short port )
+ {
+ Term();
+
+ CIPAddr addr( 0, 0, 0, 0, port ); // INADDR_ANY
+ m_Socket = TCPBind( &addr );
+ if ( m_Socket == INVALID_SOCKET )
+ {
+ return false;
+ }
+ else
+ {
+ SetInitialSocketOptions();
+ return true;
+ }
+ }
+
+
+ // Set the initial socket options that we want.
+ void SetInitialSocketOptions()
+ {
+ // Set nodelay to improve latency.
+ BOOL val = TRUE;
+ setsockopt( m_Socket, IPPROTO_TCP, TCP_NODELAY, (const char FAR *)&val, sizeof(BOOL) );
+
+ // Make it linger for 3 seconds when it exits.
+ LINGER linger;
+ linger.l_onoff = 1;
+ linger.l_linger = 3;
+ setsockopt( m_Socket, SOL_SOCKET, SO_LINGER, (char*)&linger, sizeof( linger ) );
+ }
+
+
+ // Called only by main thread interface functions.
+ // Returns true if the connection is lost.
+ bool CheckConnectionLost()
+ {
+ Assert( GetCurrentThreadId() == m_MainThreadID );
+
+ if ( m_Socket == SOCKET_ERROR )
+ return true;
+
+ // Have we timed out?
+ if ( g_bEnableTCPTimeout && (Plat_FloatTime() - m_LastRecvTime > KEEPALIVE_TIMEOUT_SECONDS) )
+ {
+ SetConnectionLost( "Connection timed out." );
+ }
+
+ // Has any thread posted that the connection has been lost?
+ CCriticalSectionLock postLock( &m_ConnectionLostCS );
+ postLock.Lock();
+ if ( m_bConnectionLost )
+ {
+ Term();
+ return true;
+ }
+ else
+ {
+ return false;
+ }
+ }
+
+ // Called by any thread. All interface functions call CheckConnectionLost() and return errors if it's lost.
+ void SetConnectionLost( const char *pErrorString, int err = -1 )
+ {
+ CCriticalSectionLock postLock( &m_ConnectionLostCS );
+ postLock.Lock();
+ m_bConnectionLost = true;
+ postLock.Unlock();
+
+ // Handle it right away if we're in the main thread. If we're in an IO thread,
+ // it has to wait until the next interface function calls CheckConnectionLost().
+ if ( GetCurrentThreadId() == m_MainThreadID )
+ {
+ Term();
+ }
+
+ if ( pErrorString )
+ {
+ m_ErrorString.CopyArray( pErrorString, strlen( pErrorString ) + 1 );
+ }
+ else
+ {
+ char *lpMsgBuf;
+ FormatMessage(
+ FORMAT_MESSAGE_ALLOCATE_BUFFER |
+ FORMAT_MESSAGE_FROM_SYSTEM |
+ FORMAT_MESSAGE_IGNORE_INSERTS,
+ NULL,
+ err,
+ MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), // Default language
+ (LPTSTR) &lpMsgBuf,
+ 0,
+ NULL
+ );
+
+ m_ErrorString.CopyArray( lpMsgBuf, strlen( lpMsgBuf ) + 1 );
+ LocalFree( lpMsgBuf );
+ }
+ }
+
+
+ // -------------------------------------------------------------------------------------------------- //
+ // The receive code.
+ // -------------------------------------------------------------------------------------------------- //
+
+ virtual bool StartWaitingForSize( bool bFresh )
+ {
+ Assert( m_Socket != INVALID_SOCKET );
+ Assert( m_bConnected );
+
+ m_RecvStage = 0;
+ m_RecvDataSize = -1;
+ if ( bFresh )
+ m_nSizeBytesReceived = 0;
+
+ DWORD dwNumBytesReceived = 0;
+ WSABUF buf = { sizeof( &m_RecvDataSize ) - m_nSizeBytesReceived, ((char*)&m_RecvDataSize) + m_nSizeBytesReceived };
+ DWORD dwFlags = 0;
+
+ int status = WSARecv(
+ m_Socket,
+ &buf,
+ 1,
+ &dwNumBytesReceived,
+ &dwFlags,
+ &m_RecvOverlapped,
+ NULL );
+
+ int err = -1;
+ if ( status == SOCKET_ERROR && (err = WSAGetLastError()) != ERROR_IO_PENDING )
+ {
+ SetConnectionLost( NULL, err );
+ return false;
+ }
+ else
+ {
+ return true;
+ }
+ }
+
+
+ bool PostNextDataPart()
+ {
+ DWORD dwNumBytesReceived = 0;
+ WSABUF buf = { m_RecvDataSize - m_AmountReceived, (char*)m_pIncomingData->m_Data + m_AmountReceived };
+ DWORD dwFlags = 0;
+
+ int status = WSARecv(
+ m_Socket,
+ &buf,
+ 1,
+ &dwNumBytesReceived,
+ &dwFlags,
+ &m_RecvOverlapped,
+ NULL );
+
+ int err = -1;
+ if ( status == SOCKET_ERROR && (err = WSAGetLastError()) != ERROR_IO_PENDING )
+ {
+ SetConnectionLost( NULL, err );
+ return false;
+ }
+ else
+ {
+ return true;
+ }
+ }
+
+
+ bool StartWaitingForData()
+ {
+ Assert( m_Socket != INVALID_SOCKET );
+ Assert( m_RecvStage == 0 );
+ Assert( m_bConnected );
+ Assert( m_RecvDataSize > 0 );
+
+ m_RecvStage = 1;
+
+ // Add a CRecvData element.
+ ParanoidMemoryCheck();
+ m_pIncomingData = (CRecvData*)malloc( sizeof( CRecvData ) - 1 + m_RecvDataSize );
+ if ( !m_pIncomingData )
+ {
+ char str[512];
+ _snprintf( str, sizeof( str ), "malloc() failed. m_RecvDataSize = %d\n", m_RecvDataSize );
+ SetConnectionLost( str );
+ return false;
+ }
+
+ m_pIncomingData->m_Count = m_RecvDataSize;
+
+ m_AmountReceived = 0;
+
+ return PostNextDataPart();
+ }
+
+ virtual bool Recv( CUtlVector<unsigned char> &data, double flTimeout )
+ {
+ if ( CheckConnectionLost() )
+ return false;
+
+ // Wait in 50ms chunks, checking for disconnections along the way.
+ bool bGotData = false;
+ DWORD msToWait = (DWORD)( flTimeout * 1000.0 );
+ do
+ {
+ DWORD curWaitTime = min( msToWait, 50 );
+ DWORD ret = WaitForSingleObject( m_hRecvSignal, curWaitTime );
+ if ( ret == WAIT_OBJECT_0 )
+ {
+ bGotData = true;
+ break;
+ }
+
+ // Did the connection timeout?
+ if ( CheckConnectionLost() )
+ return false;
+
+ msToWait -= curWaitTime;
+ } while ( msToWait );
+
+ // If we never got a WAIT_OBJECT_0, then we never received anything.
+ if ( !bGotData )
+ return false;
+
+
+ CCriticalSectionLock csLock( &m_RecvDataCS );
+ csLock.Lock();
+
+ // Pickup the head m_RecvDatas element.
+ CRecvData *pRecvData = m_RecvDatas[ m_RecvDatas.Head() ];
+ data.CopyArray( pRecvData->m_Data, pRecvData->m_Count );
+
+ // Now free it.
+ m_RecvDatas.Remove( m_RecvDatas.Head() );
+ ParanoidMemoryCheck( pRecvData );
+ free( pRecvData );
+
+ // Set the event again for the next time around, if there is more data waiting.
+ if ( m_RecvDatas.Count() > 0 )
+ SetEvent( m_hRecvSignal );
+
+ return true;
+ }
+
+ // INSIDE IO THREAD.
+ void HandleRecvCompletion( COverlappedPlus *pInfo, DWORD dwNumBytes )
+ {
+ if ( dwNumBytes == 0 )
+ {
+ SetConnectionLost( "Got 0 bytes in HandleRecvCompletion" );
+ return;
+ }
+
+ m_LastRecvTime = Plat_FloatTime();
+ if ( m_RecvStage == 0 )
+ {
+ m_nSizeBytesReceived += dwNumBytes;
+ if ( m_nSizeBytesReceived == sizeof( m_RecvDataSize ) )
+ {
+ // Size of -1 means the other size is breaking the connection.
+ if ( m_RecvDataSize == SENTINEL_DISCONNECT )
+ {
+ SetConnectionLost( "Got a graceful disconnect message." );
+ return;
+ }
+ else if ( m_RecvDataSize == SENTINEL_KEEPALIVE )
+ {
+ // No data follows this. Just let m_LastRecvTime get updated.
+ StartWaitingForSize( true );
+ return;
+ }
+
+ StartWaitingForData();
+ }
+ else if ( m_nSizeBytesReceived < sizeof( m_RecvDataSize ) )
+ {
+ // Handle the case where we only got some of the data (maybe one of the clients got disconnected).
+ StartWaitingForSize( false );
+ }
+ else
+ {
+ // This case should never ever happen!
+#if defined( _DEBUG )
+ __asm int 3;
+#endif
+
+ SetConnectionLost( "Received too much data in a packet!" );
+ return;
+ }
+ }
+ else if ( m_RecvStage == 1 )
+ {
+ // Got the data, make sure we got it all.
+ m_AmountReceived += dwNumBytes;
+
+ // Sanity check.
+#if defined( _DEBUG )
+ Assert( m_RecvDataSize == m_pIncomingData->m_Count );
+ Assert( m_AmountReceived <= m_RecvDataSize ); // TODO: make this threadsafe for multiple IO threads.
+#endif
+
+ if ( m_AmountReceived == m_RecvDataSize )
+ {
+ m_RecvStage = 2;
+
+ // Add the data to the list of packets waiting to be picked up.
+ CCriticalSectionLock csLock( &m_RecvDataCS );
+ csLock.Lock();
+
+ m_RecvDatas.AddToTail( m_pIncomingData );
+ m_pIncomingData = NULL;
+
+ if ( m_RecvDatas.Count() == 1 )
+ SetEvent( m_hRecvSignal ); // Notify the Recv() function.
+
+ StartWaitingForSize( true );
+ }
+ else
+ {
+ PostNextDataPart();
+ }
+ }
+ else
+ {
+ Assert( false );
+ }
+ }
+
+
+ // -------------------------------------------------------------------------------------------------- //
+ // The send code.
+ // -------------------------------------------------------------------------------------------------- //
+
+ virtual void WaitForSendsToComplete( double flTimeout )
+ {
+ CWaitTimer waitTimer( flTimeout );
+ while ( 1 )
+ {
+ CCriticalSectionLock sendBufLock( &m_SendCS );
+ sendBufLock.Lock();
+ if( m_SendBufs.Count() == 0 )
+ return;
+ sendBufLock.Unlock();
+
+ if ( waitTimer.ShouldKeepWaiting() )
+ Sleep( 10 );
+ else
+ break;
+ }
+ }
+
+
+ // This is called in the keepalive thread.
+ void SendKeepalive()
+ {
+ // Send a message saying we're exiting.
+ ParanoidMemoryCheck();
+ SendBuf_t *pBuf = (SendBuf_t*)malloc( sizeof( SendBuf_t ) - 1 + sizeof( int ) );
+ if ( !pBuf )
+ {
+ SetConnectionLost( "malloc() in SendKeepalive() failed." );
+ return;
+ }
+
+ pBuf->m_DataLength = sizeof( int );
+ *((int*)pBuf->m_Data) = SENTINEL_KEEPALIVE;
+ InternalSendDataBuf( pBuf );
+ }
+
+
+ void SendDisconnectSentinel()
+ {
+ // Send a message saying we're exiting.
+ ParanoidMemoryCheck();
+ SendBuf_t *pBuf = (SendBuf_t*)malloc( sizeof( SendBuf_t ) - 1 + sizeof( int ) );
+ if ( pBuf )
+ {
+ pBuf->m_DataLength = sizeof( int );
+ *((int*)pBuf->m_Data) = SENTINEL_DISCONNECT; // This signifies that we're exiting.
+ InternalSendDataBuf( pBuf );
+ }
+ }
+
+
+ virtual bool Send( const void *pData, int len )
+ {
+ const void *pChunks[1] = { pData };
+ int chunkLengths[1] = { len };
+ return SendChunks( pChunks, chunkLengths, 1 );
+ }
+
+
+ virtual bool SendChunks( void const * const *pChunks, const int *pChunkLengths, int nChunks )
+ {
+ if ( CheckConnectionLost() )
+ return false;
+
+ CChunkWalker walker( pChunks, pChunkLengths, nChunks );
+ int totalLength = walker.GetTotalLength();
+
+ if ( !totalLength )
+ return true;
+
+ // Create a buffer to hold the data and copy the data in.
+ ParanoidMemoryCheck();
+ SendBuf_t *pBuf = (SendBuf_t*)malloc( sizeof( SendBuf_t ) - 1 + totalLength + sizeof( int ) );
+ if ( !pBuf )
+ {
+ char str[512];
+ _snprintf( str, sizeof( str ), "malloc() in SendChunks() failed. totalLength = %d.", totalLength );
+ SetConnectionLost( str );
+ return false;
+ }
+
+ pBuf->m_DataLength = totalLength + sizeof( int );
+
+ int *pByteCountPos = (int*)pBuf->m_Data;
+ *pByteCountPos = totalLength;
+
+ char *pDataPos = &pBuf->m_Data[ sizeof( int ) ];
+ walker.CopyTo( pDataPos, totalLength );
+
+ int status = InternalSendDataBuf( pBuf );
+ int err = -1;
+ if ( status == SOCKET_ERROR && (err = WSAGetLastError()) != ERROR_IO_PENDING )
+ {
+ SetConnectionLost( NULL, err );
+ return false;
+ }
+ else
+ {
+ return true;
+ }
+ }
+
+
+ int InternalSendDataBuf( SendBuf_t *pBuf )
+ {
+ // Protect against interference from the keepalive thread.
+ CCriticalSectionLock csLock( &m_SendCS );
+ csLock.Lock();
+
+
+ pBuf->m_Overlapped.m_OPType = OP_SEND;
+ pBuf->m_Overlapped.hEvent = NULL;
+
+ // Add it to our list of buffers.
+ pBuf->m_Index = m_SendBufs.AddToTail( pBuf );
+
+ // Tell Winsock to send it.
+ WSABUF buf = { pBuf->m_DataLength, pBuf->m_Data };
+
+ DWORD dwNumBytesSent = 0;
+ return WSASend(
+ m_Socket,
+ &buf,
+ 1,
+ &dwNumBytesSent,
+ 0,
+ &pBuf->m_Overlapped,
+ NULL );
+ }
+
+
+ // INSIDE IO THREAD.
+ void HandleSendCompletion( COverlappedPlus *pInfo, DWORD dwNumBytes )
+ {
+ if ( dwNumBytes == 0 )
+ {
+ SetConnectionLost( "0 bytes in HandleSendCompletion." );
+ return;
+ }
+
+ // Just free the buffer.
+ SendBuf_t *pBuf = (SendBuf_t*)pInfo;
+ Assert( dwNumBytes == (DWORD)pBuf->m_DataLength );
+
+ CCriticalSectionLock sendBufLock( &m_SendCS );
+ sendBufLock.Lock();
+ m_SendBufs.Remove( pBuf->m_Index );
+ sendBufLock.Unlock();
+
+ ParanoidMemoryCheck( pBuf );
+ free( pBuf );
+ }
+
+
+ // -------------------------------------------------------------------------------------------------- //
+ // The connect code.
+ // -------------------------------------------------------------------------------------------------- //
+
+ virtual bool BeginConnect( const CIPAddr &inputAddr )
+ {
+ sockaddr_in addr;
+ IPAddrToSockAddr( &inputAddr, &addr );
+
+ m_bConnected = false;
+ int ret = connect( m_Socket, (struct sockaddr*)&addr, sizeof( addr ) );
+ ret=ret;
+
+ return true;
+ }
+
+
+ virtual bool UpdateConnect()
+ {
+ // We're still ok.. just wait until the socket becomes writable (is connected) or we timeout.
+ fd_set writeSet;
+ writeSet.fd_count = 1;
+ writeSet.fd_array[0] = m_Socket;
+ TIMEVAL timeVal = SetupTimeVal( 0 );
+
+ // See if it has a packet waiting.
+ int status = select( 0, NULL, &writeSet, NULL, &timeVal );
+ if ( status > 0 )
+ {
+ SetupConnected();
+ return true;
+ }
+
+ return false;
+ }
+
+
+ void SetupConnected()
+ {
+ m_bConnected = true;
+ m_bConnectionLost = false;
+ m_LastRecvTime = Plat_FloatTime();
+
+ CreateThreads();
+ StartWaitingForSize( true );
+ AddGlobalTCPSocket( this );
+ }
+
+
+ virtual bool IsConnected()
+ {
+ CheckConnectionLost();
+ return m_bConnected;
+ }
+
+
+ virtual void GetDisconnectReason( CUtlVector<char> &reason )
+ {
+ reason = m_ErrorString;
+ }
+
+
+ // -------------------------------------------------------------------------------------------------- //
+ // Threads code.
+ // -------------------------------------------------------------------------------------------------- //
+
+ // Create our IO Completion Port threads.
+ bool CreateThreads()
+ {
+ int nThreads = 1;
+ SetShouldExitThreads( false );
+
+ // Create our IO completion port and hook it to our socket.
+ m_hIOCP = CreateIoCompletionPort(
+ INVALID_HANDLE_VALUE, NULL, 0, 0);
+
+ m_hIOCP = CreateIoCompletionPort( (HANDLE)m_Socket, m_hIOCP, (unsigned long)this, nThreads );
+
+ for ( int i=0; i < nThreads; i++ )
+ {
+ DWORD dwThreadID = 0;
+ HANDLE hThread = CreateThread(
+ NULL,
+ 0,
+ &CTCPSocket::StaticThreadFn,
+ this,
+ 0,
+ &dwThreadID );
+
+ if ( hThread )
+ {
+ SetThreadPriority( hThread, THREAD_PRIORITY_ABOVE_NORMAL );
+ m_Threads.AddToTail( hThread );
+ }
+ else
+ {
+ StopThreads();
+ return false;
+ }
+ }
+
+ return true;
+ }
+
+
+ void StopThreads()
+ {
+ // Tell the threads to exit, then wait for them to do so.
+ SetShouldExitThreads( true );
+ WaitForMultipleObjects( m_Threads.Count(), m_Threads.Base(), TRUE, INFINITE );
+
+ for ( int i=0; i < m_Threads.Count(); i++ )
+ {
+ CloseHandle( m_Threads[i] );
+ }
+ m_Threads.Purge();
+ }
+
+
+ void SetShouldExitThreads( bool bShouldExit )
+ {
+ CCriticalSectionLock lock( &m_ThreadsCS );
+ lock.Lock();
+ m_bShouldExitThreads = bShouldExit;
+ }
+
+
+ bool ShouldExitThreads()
+ {
+ CCriticalSectionLock lock( &m_ThreadsCS );
+ lock.Lock();
+
+ bool bRet = m_bShouldExitThreads;
+ return bRet;
+ }
+
+
+ DWORD ThreadFn()
+ {
+ while ( 1 )
+ {
+ DWORD dwNumBytes = 0;
+ unsigned long pInputTCPSocket;
+ LPOVERLAPPED pOverlapped;
+
+ if ( GetQueuedCompletionStatus(
+ m_hIOCP, // the port we're listening on
+ &dwNumBytes, // # bytes received on the port
+ &pInputTCPSocket,// "completion key" = CTCPSocket*
+ &pOverlapped, // the overlapped info that was passed into AcceptEx, WSARecv, or WSASend.
+ 100 // listen for 100ms at a time so we can exit gracefully when the socket is deleted.
+ ) )
+ {
+ COverlappedPlus *pInfo = (COverlappedPlus*)pOverlapped;
+ ParanoidMemoryCheck( pInfo );
+
+ if ( pInfo->m_OPType == OP_RECV )
+ {
+ Assert( pInfo == &m_RecvOverlapped );
+ HandleRecvCompletion( pInfo, dwNumBytes );
+ }
+ else
+ {
+ Assert( pInfo->m_OPType == OP_SEND );
+ HandleSendCompletion( pInfo, dwNumBytes );
+ }
+ }
+
+ if ( ShouldExitThreads() )
+ break;
+ }
+
+ return 0;
+ }
+
+
+ static DWORD WINAPI StaticThreadFn( LPVOID pParameter )
+ {
+ return ((CTCPSocket*)pParameter)->ThreadFn();
+ }
+
+
+
+private:
+
+ SOCKET m_Socket;
+ bool m_bConnected;
+
+
+ // m_RecvOverlapped is setup to first wait for the size, then the data.
+ // Then it is not posted until the app grabs the data.
+ HANDLE m_hRecvSignal; // Tells Recv() when we have data.
+ COverlappedPlus m_RecvOverlapped;
+ int m_RecvStage; // -1 = not initialized
+ // 0 = waiting for size
+ // 1 = waiting for data
+ // 2 = waiting for app to pickup the data
+
+ CUtlLinkedList<CRecvData*,int> m_RecvDatas; // The head element is the next one to be picked up.
+ CRecvData *m_pIncomingData; // The packet we're currently receiving.
+ CCriticalSection m_RecvDataCS; // This protects adds and removes in the list.
+
+ // These reference the element at the tail of m_RecvData. It is the current one getting
+ volatile int m_nSizeBytesReceived; // How much of m_RecvDataSize have we received yet?
+ int m_RecvDataSize; // this is received over the network
+ int m_AmountReceived; // How much we've received so far.
+
+ // Last time we received anything from this connection. Used to determine if the connection is
+ // still active.
+ double m_LastRecvTime;
+
+
+ // Outgoing send buffers.
+ CUtlLinkedList<SendBuf_t*,int> m_SendBufs;
+ CCriticalSection m_SendCS;
+
+
+ // All the threads waiting for IO.
+ CUtlVector<HANDLE> m_Threads;
+ HANDLE m_hIOCP;
+
+ // Used during shutdown.
+ volatile bool m_bShouldExitThreads;
+ CCriticalSection m_ThreadsCS;
+
+ // For debugging.
+ DWORD m_MainThreadID;
+
+ // Set by the main thread or IO threads to signal connection lost.
+ bool m_bConnectionLost;
+ CCriticalSection m_ConnectionLostCS;
+
+ // This is set when we get disconnected.
+ CUtlVector<char> m_ErrorString;
+};
+
+
+// ------------------------------------------------------------------------------------------ //
+// ITCPListenSocket implementation.
+// ------------------------------------------------------------------------------------------ //
+
+class CTCPListenSocket : public ITCPListenSocket
+{
+public:
+
+ CTCPListenSocket()
+ {
+ m_Socket = INVALID_SOCKET;
+ }
+
+
+ virtual ~CTCPListenSocket()
+ {
+ if ( m_Socket != INVALID_SOCKET )
+ {
+ closesocket( m_Socket );
+ }
+ }
+
+
+ // The main function to create one of these suckers.
+ static ITCPListenSocket* Create( const unsigned short port, int nQueueLength )
+ {
+ CTCPListenSocket *pRet = new CTCPListenSocket;
+ if ( !pRet )
+ return NULL;
+
+ // Bind it to a socket and start listening.
+ CIPAddr addr( 0, 0, 0, 0, port ); // INADDR_ANY
+ pRet->m_Socket = TCPBind( &addr );
+ if ( pRet->m_Socket == INVALID_SOCKET ||
+ listen( pRet->m_Socket, nQueueLength == -1 ? SOMAXCONN : nQueueLength ) != 0 )
+ {
+ pRet->Release();
+ return false;
+ }
+
+ return pRet;
+ }
+
+
+ virtual void Release()
+ {
+ delete this;
+ }
+
+
+ virtual ITCPSocket* UpdateListen( CIPAddr *pAddr )
+ {
+ // We're still ok.. just wait until the socket becomes writable (is connected) or we timeout.
+ fd_set readSet;
+ readSet.fd_count = 1;
+ readSet.fd_array[0] = m_Socket;
+ TIMEVAL timeVal = SetupTimeVal( 0 );
+
+ // Wait until it connects.
+ int status = select( 0, &readSet, NULL, NULL, &timeVal );
+ if ( status > 0 )
+ {
+ sockaddr_in addr;
+ int addrSize = sizeof( addr );
+
+ // Now accept the final connection.
+ SOCKET newSock = accept( m_Socket, (struct sockaddr*)&addr, &addrSize );
+ if ( newSock == INVALID_SOCKET )
+ {
+ Assert( false );
+ }
+ else
+ {
+ CTCPSocket *pRet = new CTCPSocket;
+ if ( !pRet )
+ {
+ closesocket( newSock );
+ return NULL;
+ }
+
+ pRet->m_Socket = newSock;
+ pRet->SetInitialSocketOptions();
+ pRet->SetupConnected();
+
+ // Report the address..
+ SockAddrToIPAddr( &addr, pAddr );
+
+ return pRet;
+ }
+ }
+
+ return NULL;
+ }
+
+
+private:
+ SOCKET m_Socket;
+};
+
+
+
+ITCPListenSocket* CreateTCPListenSocket( const unsigned short port, int nQueueLength )
+{
+ return CTCPListenSocket::Create( port, nQueueLength );
+}
+
+
+ITCPSocket* CreateTCPSocket()
+{
+ return new CTCPSocket;
+}
+
+
+void TCPSocket_EnableTimeout( bool bEnable )
+{
+ g_bEnableTCPTimeout = bEnable;
+}
+
+
+// --------------------------------------------------------------------------------- //
+// This thread sends keepalives on all active TCP sockets.
+// --------------------------------------------------------------------------------- //
+
+HANDLE g_hKeepaliveThread;
+HANDLE g_hKeepaliveThreadSignal;
+HANDLE g_hKeepaliveThreadReply;
+CUtlLinkedList<CTCPSocket*,int> g_TCPSockets;
+CCriticalSection g_TCPSocketsCS;
+
+
+DWORD WINAPI TCPKeepaliveThread( LPVOID pParameter )
+{
+ while ( 1 )
+ {
+ if ( WaitForSingleObject( g_hKeepaliveThreadSignal, KEEPALIVE_INTERVAL_MS ) == WAIT_OBJECT_0 )
+ break;
+
+ // Tell all TCP sockets to send a keepalive.
+ CCriticalSectionLock csLock( &g_TCPSocketsCS );
+ csLock.Lock();
+
+ FOR_EACH_LL( g_TCPSockets, i )
+ {
+ g_TCPSockets[i]->SendKeepalive();
+ }
+ }
+
+ SetEvent( g_hKeepaliveThreadReply );
+ return 0;
+}
+
+
+void AddGlobalTCPSocket( CTCPSocket *pSocket )
+{
+ CCriticalSectionLock csLock( &g_TCPSocketsCS );
+ csLock.Lock();
+
+ Assert( g_TCPSockets.Find( pSocket ) == g_TCPSockets.InvalidIndex() );
+ g_TCPSockets.AddToTail( pSocket );
+
+ // If this is the first one, create the keepalive thread.
+ if ( g_TCPSockets.Count() == 1 )
+ {
+ g_hKeepaliveThreadSignal = CreateEvent( NULL, false, false, NULL );
+ g_hKeepaliveThreadReply = CreateEvent( NULL, false, false, NULL );
+
+ DWORD dwThreadID = 0;
+ g_hKeepaliveThread = CreateThread(
+ NULL,
+ 0,
+ TCPKeepaliveThread,
+ NULL,
+ 0,
+ &dwThreadID
+ );
+ }
+}
+
+
+void RemoveGlobalTCPSocket( CTCPSocket *pSocket )
+{
+ bool bThreadRunning = false;
+ DWORD dwExitCode = 0;
+ if ( GetExitCodeThread( g_hKeepaliveThread, &dwExitCode ) && dwExitCode == STILL_ACTIVE )
+ {
+ bThreadRunning = true;
+ }
+
+ CCriticalSectionLock csLock( &g_TCPSocketsCS );
+ csLock.Lock();
+
+ int index = g_TCPSockets.Find( pSocket );
+ if ( index != g_TCPSockets.InvalidIndex() )
+ {
+ g_TCPSockets.Remove( index );
+
+ // If this was the last one, delete the thread.
+ if ( g_TCPSockets.Count() == 0 )
+ {
+ csLock.Unlock();
+
+ if ( bThreadRunning )
+ {
+ SetEvent( g_hKeepaliveThreadSignal );
+ WaitForSingleObject( g_hKeepaliveThreadReply, INFINITE );
+ }
+
+ CloseHandle( g_hKeepaliveThreadSignal );
+ CloseHandle( g_hKeepaliveThreadReply );
+ CloseHandle( g_hKeepaliveThread );
+ return;
+ }
+ }
+
+ csLock.Unlock();
+}