diff options
| author | FluorescentCIAAfricanAmerican <[email protected]> | 2020-04-22 12:56:21 -0400 |
|---|---|---|
| committer | FluorescentCIAAfricanAmerican <[email protected]> | 2020-04-22 12:56:21 -0400 |
| commit | 3bf9df6b2785fa6d951086978a3e66f49427166a (patch) | |
| tree | 2c0f1f0c63c4832882bc93814ebd2c2b1c6224e5 /utils/vmpi/tcpsocket.cpp | |
| download | archived-source-engine-2018-hl2-src-master.tar.xz archived-source-engine-2018-hl2-src-master.zip | |
Diffstat (limited to 'utils/vmpi/tcpsocket.cpp')
| -rw-r--r-- | utils/vmpi/tcpsocket.cpp | 1178 |
1 files changed, 1178 insertions, 0 deletions
diff --git a/utils/vmpi/tcpsocket.cpp b/utils/vmpi/tcpsocket.cpp new file mode 100644 index 0000000..8ab67cf --- /dev/null +++ b/utils/vmpi/tcpsocket.cpp @@ -0,0 +1,1178 @@ +//========= Copyright Valve Corporation, All rights reserved. ============// +// +// Purpose: +// +// $NoKeywords: $ +//=============================================================================// + + +//#define PARANOID + +#if defined( PARANOID ) + #include <stdlib.h> + #include <crtdbg.h> +#endif + +#include <winsock2.h> +#include <mswsock.h> +#include "tcpsocket.h" +#include "tier1/utllinkedlist.h" +#include <stdio.h> +#include "threadhelpers.h" +#include "tier0/dbg.h" + + + +#error "I am TCPSocket and I suck. Use IThreadedTCPSocket or ThreadedTCPSocketEmu instead." + + +extern TIMEVAL SetupTimeVal( double flTimeout ); +extern void IPAddrToSockAddr( const CIPAddr *pIn, sockaddr_in *pOut ); +extern void SockAddrToIPAddr( const sockaddr_in *pIn, CIPAddr *pOut ); + + +#define SENTINEL_DISCONNECT -1 +#define SENTINEL_KEEPALIVE -2 + + +#define KEEPALIVE_INTERVAL_MS 3000 // keepalives are sent every N MS +#define KEEPALIVE_TIMEOUT_SECONDS 15.0 // connections timeout after this long + + +static bool g_bEnableTCPTimeout = true; + + +class CRecvData +{ +public: + int m_Count; + unsigned char m_Data[1]; +}; + + + +SOCKET TCPBind( const CIPAddr *pAddr ) +{ + // Create a socket to send and receive through. + SOCKET sock = WSASocket( AF_INET, SOCK_STREAM, IPPROTO_TCP, NULL, 0, WSA_FLAG_OVERLAPPED ); + if ( sock == INVALID_SOCKET ) + { + Assert( false ); + return INVALID_SOCKET; + } + + // bind to it! + sockaddr_in addr; + IPAddrToSockAddr( pAddr, &addr ); + + int status = bind( sock, (sockaddr*)&addr, sizeof(addr) ); + if ( status == 0 ) + { + return sock; + } + else + { + closesocket( sock ); + return INVALID_SOCKET; + } +} + + + +// ---------------------------------------------------------------------------------------- // +// TCP sockets. +// ---------------------------------------------------------------------------------------- // + +enum +{ + OP_RECV=111, + OP_SEND +}; + +// We use this for all OVERLAPPED structures. +class COverlappedPlus : public WSAOVERLAPPED +{ +public: + COverlappedPlus() + { + memset( this, 0, sizeof( WSAOVERLAPPED ) ); + } + + int m_OPType; // One of the OP_ defines. +}; + +typedef struct SendBuf_t +{ + COverlappedPlus m_Overlapped; + int m_Index; // Index into m_SendBufs. + int m_DataLength; + char m_Data[1]; +} SendBuf_s; + + +// These manage a thread that calls SendKeepalive() on all TCPSockets. +// AddGlobalTCPSocket shouldn't be called until you're ready for SendKeepalive() to be called. +class CTCPSocket; +void AddGlobalTCPSocket( CTCPSocket *pSocket ); +void RemoveGlobalTCPSocket( CTCPSocket *pSocket ); + + + +// ------------------------------------------------------------------------------------------ // +// CTCPSocket implementation. +// ------------------------------------------------------------------------------------------ // + +class CTCPSocket : public ITCPSocket +{ +friend class CTCPListenSocket; + +public: + + CTCPSocket() + { + m_Socket = INVALID_SOCKET; + m_bConnected = false; + + m_hIOCP = NULL; + + m_bShouldExitThreads = false; + m_bConnectionLost = false; + m_nSizeBytesReceived = 0; + + m_pIncomingData = NULL; + + memset( &m_RecvOverlapped, 0, sizeof( m_RecvOverlapped ) ); + m_RecvOverlapped.m_OPType = OP_RECV; + + m_hRecvSignal = CreateEvent( NULL, FALSE, FALSE, NULL ); + m_RecvStage = -1; + + m_MainThreadID = GetCurrentThreadId(); + } + + virtual ~CTCPSocket() + { + Term(); + CloseHandle( m_hRecvSignal ); + } + + void Term() + { + Assert( GetCurrentThreadId() == m_MainThreadID ); + + RemoveGlobalTCPSocket( this ); + + if ( m_Socket != SOCKET_ERROR && !m_bConnectionLost ) + { + SendDisconnectSentinel(); + + // Give the sends a second to complete. SO_LINGER is having trouble for some reason. + WaitForSendsToComplete( 1 ); + } + + + StopThreads(); + + if ( m_Socket != INVALID_SOCKET ) + { + closesocket( m_Socket ); + m_Socket = INVALID_SOCKET; + } + + if ( m_hIOCP ) + { + CloseHandle( m_hIOCP ); + m_hIOCP = NULL; + } + + m_bConnected = false; + m_bConnectionLost = true; + m_RecvStage = -1; + + FOR_EACH_LL( m_SendBufs, i ) + { + SendBuf_t *pSendBuf = m_SendBufs[i]; + ParanoidMemoryCheck( pSendBuf ); + free( pSendBuf ); + } + m_SendBufs.Purge(); + + FOR_EACH_LL( m_RecvDatas, j ) + { + CRecvData *pRecvData = m_RecvDatas[j]; + ParanoidMemoryCheck( pRecvData ); + free( pRecvData ); + } + m_RecvDatas.Purge(); + + if ( m_pIncomingData ) + { + ParanoidMemoryCheck( m_pIncomingData ); + free( m_pIncomingData ); + m_pIncomingData = 0; + } + } + + virtual void Release() + { + delete this; + } + + + void ParanoidMemoryCheck( void *ptr = NULL ) + { +#if defined( PARANOID ) + Assert( _CrtIsValidHeapPointer( this ) ); + + if ( ptr ) + { + Assert( _CrtIsValidHeapPointer( ptr ) ); + } + + Assert( _CrtCheckMemory() == TRUE ); +#endif + } + + + virtual bool BindToAny( const unsigned short port ) + { + Term(); + + CIPAddr addr( 0, 0, 0, 0, port ); // INADDR_ANY + m_Socket = TCPBind( &addr ); + if ( m_Socket == INVALID_SOCKET ) + { + return false; + } + else + { + SetInitialSocketOptions(); + return true; + } + } + + + // Set the initial socket options that we want. + void SetInitialSocketOptions() + { + // Set nodelay to improve latency. + BOOL val = TRUE; + setsockopt( m_Socket, IPPROTO_TCP, TCP_NODELAY, (const char FAR *)&val, sizeof(BOOL) ); + + // Make it linger for 3 seconds when it exits. + LINGER linger; + linger.l_onoff = 1; + linger.l_linger = 3; + setsockopt( m_Socket, SOL_SOCKET, SO_LINGER, (char*)&linger, sizeof( linger ) ); + } + + + // Called only by main thread interface functions. + // Returns true if the connection is lost. + bool CheckConnectionLost() + { + Assert( GetCurrentThreadId() == m_MainThreadID ); + + if ( m_Socket == SOCKET_ERROR ) + return true; + + // Have we timed out? + if ( g_bEnableTCPTimeout && (Plat_FloatTime() - m_LastRecvTime > KEEPALIVE_TIMEOUT_SECONDS) ) + { + SetConnectionLost( "Connection timed out." ); + } + + // Has any thread posted that the connection has been lost? + CCriticalSectionLock postLock( &m_ConnectionLostCS ); + postLock.Lock(); + if ( m_bConnectionLost ) + { + Term(); + return true; + } + else + { + return false; + } + } + + // Called by any thread. All interface functions call CheckConnectionLost() and return errors if it's lost. + void SetConnectionLost( const char *pErrorString, int err = -1 ) + { + CCriticalSectionLock postLock( &m_ConnectionLostCS ); + postLock.Lock(); + m_bConnectionLost = true; + postLock.Unlock(); + + // Handle it right away if we're in the main thread. If we're in an IO thread, + // it has to wait until the next interface function calls CheckConnectionLost(). + if ( GetCurrentThreadId() == m_MainThreadID ) + { + Term(); + } + + if ( pErrorString ) + { + m_ErrorString.CopyArray( pErrorString, strlen( pErrorString ) + 1 ); + } + else + { + char *lpMsgBuf; + FormatMessage( + FORMAT_MESSAGE_ALLOCATE_BUFFER | + FORMAT_MESSAGE_FROM_SYSTEM | + FORMAT_MESSAGE_IGNORE_INSERTS, + NULL, + err, + MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), // Default language + (LPTSTR) &lpMsgBuf, + 0, + NULL + ); + + m_ErrorString.CopyArray( lpMsgBuf, strlen( lpMsgBuf ) + 1 ); + LocalFree( lpMsgBuf ); + } + } + + + // -------------------------------------------------------------------------------------------------- // + // The receive code. + // -------------------------------------------------------------------------------------------------- // + + virtual bool StartWaitingForSize( bool bFresh ) + { + Assert( m_Socket != INVALID_SOCKET ); + Assert( m_bConnected ); + + m_RecvStage = 0; + m_RecvDataSize = -1; + if ( bFresh ) + m_nSizeBytesReceived = 0; + + DWORD dwNumBytesReceived = 0; + WSABUF buf = { sizeof( &m_RecvDataSize ) - m_nSizeBytesReceived, ((char*)&m_RecvDataSize) + m_nSizeBytesReceived }; + DWORD dwFlags = 0; + + int status = WSARecv( + m_Socket, + &buf, + 1, + &dwNumBytesReceived, + &dwFlags, + &m_RecvOverlapped, + NULL ); + + int err = -1; + if ( status == SOCKET_ERROR && (err = WSAGetLastError()) != ERROR_IO_PENDING ) + { + SetConnectionLost( NULL, err ); + return false; + } + else + { + return true; + } + } + + + bool PostNextDataPart() + { + DWORD dwNumBytesReceived = 0; + WSABUF buf = { m_RecvDataSize - m_AmountReceived, (char*)m_pIncomingData->m_Data + m_AmountReceived }; + DWORD dwFlags = 0; + + int status = WSARecv( + m_Socket, + &buf, + 1, + &dwNumBytesReceived, + &dwFlags, + &m_RecvOverlapped, + NULL ); + + int err = -1; + if ( status == SOCKET_ERROR && (err = WSAGetLastError()) != ERROR_IO_PENDING ) + { + SetConnectionLost( NULL, err ); + return false; + } + else + { + return true; + } + } + + + bool StartWaitingForData() + { + Assert( m_Socket != INVALID_SOCKET ); + Assert( m_RecvStage == 0 ); + Assert( m_bConnected ); + Assert( m_RecvDataSize > 0 ); + + m_RecvStage = 1; + + // Add a CRecvData element. + ParanoidMemoryCheck(); + m_pIncomingData = (CRecvData*)malloc( sizeof( CRecvData ) - 1 + m_RecvDataSize ); + if ( !m_pIncomingData ) + { + char str[512]; + _snprintf( str, sizeof( str ), "malloc() failed. m_RecvDataSize = %d\n", m_RecvDataSize ); + SetConnectionLost( str ); + return false; + } + + m_pIncomingData->m_Count = m_RecvDataSize; + + m_AmountReceived = 0; + + return PostNextDataPart(); + } + + virtual bool Recv( CUtlVector<unsigned char> &data, double flTimeout ) + { + if ( CheckConnectionLost() ) + return false; + + // Wait in 50ms chunks, checking for disconnections along the way. + bool bGotData = false; + DWORD msToWait = (DWORD)( flTimeout * 1000.0 ); + do + { + DWORD curWaitTime = min( msToWait, 50 ); + DWORD ret = WaitForSingleObject( m_hRecvSignal, curWaitTime ); + if ( ret == WAIT_OBJECT_0 ) + { + bGotData = true; + break; + } + + // Did the connection timeout? + if ( CheckConnectionLost() ) + return false; + + msToWait -= curWaitTime; + } while ( msToWait ); + + // If we never got a WAIT_OBJECT_0, then we never received anything. + if ( !bGotData ) + return false; + + + CCriticalSectionLock csLock( &m_RecvDataCS ); + csLock.Lock(); + + // Pickup the head m_RecvDatas element. + CRecvData *pRecvData = m_RecvDatas[ m_RecvDatas.Head() ]; + data.CopyArray( pRecvData->m_Data, pRecvData->m_Count ); + + // Now free it. + m_RecvDatas.Remove( m_RecvDatas.Head() ); + ParanoidMemoryCheck( pRecvData ); + free( pRecvData ); + + // Set the event again for the next time around, if there is more data waiting. + if ( m_RecvDatas.Count() > 0 ) + SetEvent( m_hRecvSignal ); + + return true; + } + + // INSIDE IO THREAD. + void HandleRecvCompletion( COverlappedPlus *pInfo, DWORD dwNumBytes ) + { + if ( dwNumBytes == 0 ) + { + SetConnectionLost( "Got 0 bytes in HandleRecvCompletion" ); + return; + } + + m_LastRecvTime = Plat_FloatTime(); + if ( m_RecvStage == 0 ) + { + m_nSizeBytesReceived += dwNumBytes; + if ( m_nSizeBytesReceived == sizeof( m_RecvDataSize ) ) + { + // Size of -1 means the other size is breaking the connection. + if ( m_RecvDataSize == SENTINEL_DISCONNECT ) + { + SetConnectionLost( "Got a graceful disconnect message." ); + return; + } + else if ( m_RecvDataSize == SENTINEL_KEEPALIVE ) + { + // No data follows this. Just let m_LastRecvTime get updated. + StartWaitingForSize( true ); + return; + } + + StartWaitingForData(); + } + else if ( m_nSizeBytesReceived < sizeof( m_RecvDataSize ) ) + { + // Handle the case where we only got some of the data (maybe one of the clients got disconnected). + StartWaitingForSize( false ); + } + else + { + // This case should never ever happen! +#if defined( _DEBUG ) + __asm int 3; +#endif + + SetConnectionLost( "Received too much data in a packet!" ); + return; + } + } + else if ( m_RecvStage == 1 ) + { + // Got the data, make sure we got it all. + m_AmountReceived += dwNumBytes; + + // Sanity check. +#if defined( _DEBUG ) + Assert( m_RecvDataSize == m_pIncomingData->m_Count ); + Assert( m_AmountReceived <= m_RecvDataSize ); // TODO: make this threadsafe for multiple IO threads. +#endif + + if ( m_AmountReceived == m_RecvDataSize ) + { + m_RecvStage = 2; + + // Add the data to the list of packets waiting to be picked up. + CCriticalSectionLock csLock( &m_RecvDataCS ); + csLock.Lock(); + + m_RecvDatas.AddToTail( m_pIncomingData ); + m_pIncomingData = NULL; + + if ( m_RecvDatas.Count() == 1 ) + SetEvent( m_hRecvSignal ); // Notify the Recv() function. + + StartWaitingForSize( true ); + } + else + { + PostNextDataPart(); + } + } + else + { + Assert( false ); + } + } + + + // -------------------------------------------------------------------------------------------------- // + // The send code. + // -------------------------------------------------------------------------------------------------- // + + virtual void WaitForSendsToComplete( double flTimeout ) + { + CWaitTimer waitTimer( flTimeout ); + while ( 1 ) + { + CCriticalSectionLock sendBufLock( &m_SendCS ); + sendBufLock.Lock(); + if( m_SendBufs.Count() == 0 ) + return; + sendBufLock.Unlock(); + + if ( waitTimer.ShouldKeepWaiting() ) + Sleep( 10 ); + else + break; + } + } + + + // This is called in the keepalive thread. + void SendKeepalive() + { + // Send a message saying we're exiting. + ParanoidMemoryCheck(); + SendBuf_t *pBuf = (SendBuf_t*)malloc( sizeof( SendBuf_t ) - 1 + sizeof( int ) ); + if ( !pBuf ) + { + SetConnectionLost( "malloc() in SendKeepalive() failed." ); + return; + } + + pBuf->m_DataLength = sizeof( int ); + *((int*)pBuf->m_Data) = SENTINEL_KEEPALIVE; + InternalSendDataBuf( pBuf ); + } + + + void SendDisconnectSentinel() + { + // Send a message saying we're exiting. + ParanoidMemoryCheck(); + SendBuf_t *pBuf = (SendBuf_t*)malloc( sizeof( SendBuf_t ) - 1 + sizeof( int ) ); + if ( pBuf ) + { + pBuf->m_DataLength = sizeof( int ); + *((int*)pBuf->m_Data) = SENTINEL_DISCONNECT; // This signifies that we're exiting. + InternalSendDataBuf( pBuf ); + } + } + + + virtual bool Send( const void *pData, int len ) + { + const void *pChunks[1] = { pData }; + int chunkLengths[1] = { len }; + return SendChunks( pChunks, chunkLengths, 1 ); + } + + + virtual bool SendChunks( void const * const *pChunks, const int *pChunkLengths, int nChunks ) + { + if ( CheckConnectionLost() ) + return false; + + CChunkWalker walker( pChunks, pChunkLengths, nChunks ); + int totalLength = walker.GetTotalLength(); + + if ( !totalLength ) + return true; + + // Create a buffer to hold the data and copy the data in. + ParanoidMemoryCheck(); + SendBuf_t *pBuf = (SendBuf_t*)malloc( sizeof( SendBuf_t ) - 1 + totalLength + sizeof( int ) ); + if ( !pBuf ) + { + char str[512]; + _snprintf( str, sizeof( str ), "malloc() in SendChunks() failed. totalLength = %d.", totalLength ); + SetConnectionLost( str ); + return false; + } + + pBuf->m_DataLength = totalLength + sizeof( int ); + + int *pByteCountPos = (int*)pBuf->m_Data; + *pByteCountPos = totalLength; + + char *pDataPos = &pBuf->m_Data[ sizeof( int ) ]; + walker.CopyTo( pDataPos, totalLength ); + + int status = InternalSendDataBuf( pBuf ); + int err = -1; + if ( status == SOCKET_ERROR && (err = WSAGetLastError()) != ERROR_IO_PENDING ) + { + SetConnectionLost( NULL, err ); + return false; + } + else + { + return true; + } + } + + + int InternalSendDataBuf( SendBuf_t *pBuf ) + { + // Protect against interference from the keepalive thread. + CCriticalSectionLock csLock( &m_SendCS ); + csLock.Lock(); + + + pBuf->m_Overlapped.m_OPType = OP_SEND; + pBuf->m_Overlapped.hEvent = NULL; + + // Add it to our list of buffers. + pBuf->m_Index = m_SendBufs.AddToTail( pBuf ); + + // Tell Winsock to send it. + WSABUF buf = { pBuf->m_DataLength, pBuf->m_Data }; + + DWORD dwNumBytesSent = 0; + return WSASend( + m_Socket, + &buf, + 1, + &dwNumBytesSent, + 0, + &pBuf->m_Overlapped, + NULL ); + } + + + // INSIDE IO THREAD. + void HandleSendCompletion( COverlappedPlus *pInfo, DWORD dwNumBytes ) + { + if ( dwNumBytes == 0 ) + { + SetConnectionLost( "0 bytes in HandleSendCompletion." ); + return; + } + + // Just free the buffer. + SendBuf_t *pBuf = (SendBuf_t*)pInfo; + Assert( dwNumBytes == (DWORD)pBuf->m_DataLength ); + + CCriticalSectionLock sendBufLock( &m_SendCS ); + sendBufLock.Lock(); + m_SendBufs.Remove( pBuf->m_Index ); + sendBufLock.Unlock(); + + ParanoidMemoryCheck( pBuf ); + free( pBuf ); + } + + + // -------------------------------------------------------------------------------------------------- // + // The connect code. + // -------------------------------------------------------------------------------------------------- // + + virtual bool BeginConnect( const CIPAddr &inputAddr ) + { + sockaddr_in addr; + IPAddrToSockAddr( &inputAddr, &addr ); + + m_bConnected = false; + int ret = connect( m_Socket, (struct sockaddr*)&addr, sizeof( addr ) ); + ret=ret; + + return true; + } + + + virtual bool UpdateConnect() + { + // We're still ok.. just wait until the socket becomes writable (is connected) or we timeout. + fd_set writeSet; + writeSet.fd_count = 1; + writeSet.fd_array[0] = m_Socket; + TIMEVAL timeVal = SetupTimeVal( 0 ); + + // See if it has a packet waiting. + int status = select( 0, NULL, &writeSet, NULL, &timeVal ); + if ( status > 0 ) + { + SetupConnected(); + return true; + } + + return false; + } + + + void SetupConnected() + { + m_bConnected = true; + m_bConnectionLost = false; + m_LastRecvTime = Plat_FloatTime(); + + CreateThreads(); + StartWaitingForSize( true ); + AddGlobalTCPSocket( this ); + } + + + virtual bool IsConnected() + { + CheckConnectionLost(); + return m_bConnected; + } + + + virtual void GetDisconnectReason( CUtlVector<char> &reason ) + { + reason = m_ErrorString; + } + + + // -------------------------------------------------------------------------------------------------- // + // Threads code. + // -------------------------------------------------------------------------------------------------- // + + // Create our IO Completion Port threads. + bool CreateThreads() + { + int nThreads = 1; + SetShouldExitThreads( false ); + + // Create our IO completion port and hook it to our socket. + m_hIOCP = CreateIoCompletionPort( + INVALID_HANDLE_VALUE, NULL, 0, 0); + + m_hIOCP = CreateIoCompletionPort( (HANDLE)m_Socket, m_hIOCP, (unsigned long)this, nThreads ); + + for ( int i=0; i < nThreads; i++ ) + { + DWORD dwThreadID = 0; + HANDLE hThread = CreateThread( + NULL, + 0, + &CTCPSocket::StaticThreadFn, + this, + 0, + &dwThreadID ); + + if ( hThread ) + { + SetThreadPriority( hThread, THREAD_PRIORITY_ABOVE_NORMAL ); + m_Threads.AddToTail( hThread ); + } + else + { + StopThreads(); + return false; + } + } + + return true; + } + + + void StopThreads() + { + // Tell the threads to exit, then wait for them to do so. + SetShouldExitThreads( true ); + WaitForMultipleObjects( m_Threads.Count(), m_Threads.Base(), TRUE, INFINITE ); + + for ( int i=0; i < m_Threads.Count(); i++ ) + { + CloseHandle( m_Threads[i] ); + } + m_Threads.Purge(); + } + + + void SetShouldExitThreads( bool bShouldExit ) + { + CCriticalSectionLock lock( &m_ThreadsCS ); + lock.Lock(); + m_bShouldExitThreads = bShouldExit; + } + + + bool ShouldExitThreads() + { + CCriticalSectionLock lock( &m_ThreadsCS ); + lock.Lock(); + + bool bRet = m_bShouldExitThreads; + return bRet; + } + + + DWORD ThreadFn() + { + while ( 1 ) + { + DWORD dwNumBytes = 0; + unsigned long pInputTCPSocket; + LPOVERLAPPED pOverlapped; + + if ( GetQueuedCompletionStatus( + m_hIOCP, // the port we're listening on + &dwNumBytes, // # bytes received on the port + &pInputTCPSocket,// "completion key" = CTCPSocket* + &pOverlapped, // the overlapped info that was passed into AcceptEx, WSARecv, or WSASend. + 100 // listen for 100ms at a time so we can exit gracefully when the socket is deleted. + ) ) + { + COverlappedPlus *pInfo = (COverlappedPlus*)pOverlapped; + ParanoidMemoryCheck( pInfo ); + + if ( pInfo->m_OPType == OP_RECV ) + { + Assert( pInfo == &m_RecvOverlapped ); + HandleRecvCompletion( pInfo, dwNumBytes ); + } + else + { + Assert( pInfo->m_OPType == OP_SEND ); + HandleSendCompletion( pInfo, dwNumBytes ); + } + } + + if ( ShouldExitThreads() ) + break; + } + + return 0; + } + + + static DWORD WINAPI StaticThreadFn( LPVOID pParameter ) + { + return ((CTCPSocket*)pParameter)->ThreadFn(); + } + + + +private: + + SOCKET m_Socket; + bool m_bConnected; + + + // m_RecvOverlapped is setup to first wait for the size, then the data. + // Then it is not posted until the app grabs the data. + HANDLE m_hRecvSignal; // Tells Recv() when we have data. + COverlappedPlus m_RecvOverlapped; + int m_RecvStage; // -1 = not initialized + // 0 = waiting for size + // 1 = waiting for data + // 2 = waiting for app to pickup the data + + CUtlLinkedList<CRecvData*,int> m_RecvDatas; // The head element is the next one to be picked up. + CRecvData *m_pIncomingData; // The packet we're currently receiving. + CCriticalSection m_RecvDataCS; // This protects adds and removes in the list. + + // These reference the element at the tail of m_RecvData. It is the current one getting + volatile int m_nSizeBytesReceived; // How much of m_RecvDataSize have we received yet? + int m_RecvDataSize; // this is received over the network + int m_AmountReceived; // How much we've received so far. + + // Last time we received anything from this connection. Used to determine if the connection is + // still active. + double m_LastRecvTime; + + + // Outgoing send buffers. + CUtlLinkedList<SendBuf_t*,int> m_SendBufs; + CCriticalSection m_SendCS; + + + // All the threads waiting for IO. + CUtlVector<HANDLE> m_Threads; + HANDLE m_hIOCP; + + // Used during shutdown. + volatile bool m_bShouldExitThreads; + CCriticalSection m_ThreadsCS; + + // For debugging. + DWORD m_MainThreadID; + + // Set by the main thread or IO threads to signal connection lost. + bool m_bConnectionLost; + CCriticalSection m_ConnectionLostCS; + + // This is set when we get disconnected. + CUtlVector<char> m_ErrorString; +}; + + +// ------------------------------------------------------------------------------------------ // +// ITCPListenSocket implementation. +// ------------------------------------------------------------------------------------------ // + +class CTCPListenSocket : public ITCPListenSocket +{ +public: + + CTCPListenSocket() + { + m_Socket = INVALID_SOCKET; + } + + + virtual ~CTCPListenSocket() + { + if ( m_Socket != INVALID_SOCKET ) + { + closesocket( m_Socket ); + } + } + + + // The main function to create one of these suckers. + static ITCPListenSocket* Create( const unsigned short port, int nQueueLength ) + { + CTCPListenSocket *pRet = new CTCPListenSocket; + if ( !pRet ) + return NULL; + + // Bind it to a socket and start listening. + CIPAddr addr( 0, 0, 0, 0, port ); // INADDR_ANY + pRet->m_Socket = TCPBind( &addr ); + if ( pRet->m_Socket == INVALID_SOCKET || + listen( pRet->m_Socket, nQueueLength == -1 ? SOMAXCONN : nQueueLength ) != 0 ) + { + pRet->Release(); + return false; + } + + return pRet; + } + + + virtual void Release() + { + delete this; + } + + + virtual ITCPSocket* UpdateListen( CIPAddr *pAddr ) + { + // We're still ok.. just wait until the socket becomes writable (is connected) or we timeout. + fd_set readSet; + readSet.fd_count = 1; + readSet.fd_array[0] = m_Socket; + TIMEVAL timeVal = SetupTimeVal( 0 ); + + // Wait until it connects. + int status = select( 0, &readSet, NULL, NULL, &timeVal ); + if ( status > 0 ) + { + sockaddr_in addr; + int addrSize = sizeof( addr ); + + // Now accept the final connection. + SOCKET newSock = accept( m_Socket, (struct sockaddr*)&addr, &addrSize ); + if ( newSock == INVALID_SOCKET ) + { + Assert( false ); + } + else + { + CTCPSocket *pRet = new CTCPSocket; + if ( !pRet ) + { + closesocket( newSock ); + return NULL; + } + + pRet->m_Socket = newSock; + pRet->SetInitialSocketOptions(); + pRet->SetupConnected(); + + // Report the address.. + SockAddrToIPAddr( &addr, pAddr ); + + return pRet; + } + } + + return NULL; + } + + +private: + SOCKET m_Socket; +}; + + + +ITCPListenSocket* CreateTCPListenSocket( const unsigned short port, int nQueueLength ) +{ + return CTCPListenSocket::Create( port, nQueueLength ); +} + + +ITCPSocket* CreateTCPSocket() +{ + return new CTCPSocket; +} + + +void TCPSocket_EnableTimeout( bool bEnable ) +{ + g_bEnableTCPTimeout = bEnable; +} + + +// --------------------------------------------------------------------------------- // +// This thread sends keepalives on all active TCP sockets. +// --------------------------------------------------------------------------------- // + +HANDLE g_hKeepaliveThread; +HANDLE g_hKeepaliveThreadSignal; +HANDLE g_hKeepaliveThreadReply; +CUtlLinkedList<CTCPSocket*,int> g_TCPSockets; +CCriticalSection g_TCPSocketsCS; + + +DWORD WINAPI TCPKeepaliveThread( LPVOID pParameter ) +{ + while ( 1 ) + { + if ( WaitForSingleObject( g_hKeepaliveThreadSignal, KEEPALIVE_INTERVAL_MS ) == WAIT_OBJECT_0 ) + break; + + // Tell all TCP sockets to send a keepalive. + CCriticalSectionLock csLock( &g_TCPSocketsCS ); + csLock.Lock(); + + FOR_EACH_LL( g_TCPSockets, i ) + { + g_TCPSockets[i]->SendKeepalive(); + } + } + + SetEvent( g_hKeepaliveThreadReply ); + return 0; +} + + +void AddGlobalTCPSocket( CTCPSocket *pSocket ) +{ + CCriticalSectionLock csLock( &g_TCPSocketsCS ); + csLock.Lock(); + + Assert( g_TCPSockets.Find( pSocket ) == g_TCPSockets.InvalidIndex() ); + g_TCPSockets.AddToTail( pSocket ); + + // If this is the first one, create the keepalive thread. + if ( g_TCPSockets.Count() == 1 ) + { + g_hKeepaliveThreadSignal = CreateEvent( NULL, false, false, NULL ); + g_hKeepaliveThreadReply = CreateEvent( NULL, false, false, NULL ); + + DWORD dwThreadID = 0; + g_hKeepaliveThread = CreateThread( + NULL, + 0, + TCPKeepaliveThread, + NULL, + 0, + &dwThreadID + ); + } +} + + +void RemoveGlobalTCPSocket( CTCPSocket *pSocket ) +{ + bool bThreadRunning = false; + DWORD dwExitCode = 0; + if ( GetExitCodeThread( g_hKeepaliveThread, &dwExitCode ) && dwExitCode == STILL_ACTIVE ) + { + bThreadRunning = true; + } + + CCriticalSectionLock csLock( &g_TCPSocketsCS ); + csLock.Lock(); + + int index = g_TCPSockets.Find( pSocket ); + if ( index != g_TCPSockets.InvalidIndex() ) + { + g_TCPSockets.Remove( index ); + + // If this was the last one, delete the thread. + if ( g_TCPSockets.Count() == 0 ) + { + csLock.Unlock(); + + if ( bThreadRunning ) + { + SetEvent( g_hKeepaliveThreadSignal ); + WaitForSingleObject( g_hKeepaliveThreadReply, INFINITE ); + } + + CloseHandle( g_hKeepaliveThreadSignal ); + CloseHandle( g_hKeepaliveThreadReply ); + CloseHandle( g_hKeepaliveThread ); + return; + } + } + + csLock.Unlock(); +} |