diff options
Diffstat (limited to 'gcsdk/sqlaccess/sqlaccess.cpp')
| -rw-r--r-- | gcsdk/sqlaccess/sqlaccess.cpp | 953 |
1 files changed, 953 insertions, 0 deletions
diff --git a/gcsdk/sqlaccess/sqlaccess.cpp b/gcsdk/sqlaccess/sqlaccess.cpp new file mode 100644 index 0000000..130a446 --- /dev/null +++ b/gcsdk/sqlaccess/sqlaccess.cpp @@ -0,0 +1,953 @@ +//========= Copyright Valve Corporation, All rights reserved. ============// +// +// Purpose: Provides access to SQL at a high level +// +//============================================================================= + +#include "stdafx.h" +#include "gcsdk/sqlaccess/sqlaccess.h" +#include "gcsdk/gcsqlquery.h" + +// memdbgon must be the last include file in a .cpp file!!! +#include "tier0/memdbgon.h" + +template< typename LISTENER_FUNC > +static void RunAndClearListenerList( std::vector< LISTENER_FUNC > &vecListeners ) +{ + // Let us not underestimate the ability of random listeners to re-enter everything. + std::vector< LISTENER_FUNC > listenerCopy; + listenerCopy.swap( vecListeners ); + vecListeners.clear(); + + // Why would you consider such a thing + DO_NOT_YIELD_THIS_SCOPE(); + + for ( const auto &listener : listenerCopy ) + { + listener(); + } +} + + +namespace GCSDK +{ +//------------------------------------------------------------------------------------ +// Purpose: Constructor +//------------------------------------------------------------------------------------ +CSQLAccess::CSQLAccess( ESchemaCatalog eSchemaCatalog ) + : m_eSchemaCatalog( eSchemaCatalog) + , m_pCurrentQuery( NULL ) + , m_bInTransaction( false ) +{ + m_pQueryGroup = CGCSQLQueryGroup::Alloc(); +} + + +//------------------------------------------------------------------------------------ +// Purpose: Destructor +//------------------------------------------------------------------------------------ +CSQLAccess::~CSQLAccess( ) +{ + SAFE_RELEASE( m_pQueryGroup ); + Assert( !m_pCurrentQuery ); + SAFE_DELETE( m_pCurrentQuery ); + AssertMsg( !m_bInTransaction, "GCSDK::CSQLAccess object being destroyed with a transaction pending. Use BCommitTransaction or RollbackTransaction to match your BBeginTransaction call." ); +} + + +//------------------------------------------------------------------------------------ +// Purpose: Perform a query +//------------------------------------------------------------------------------------ +bool CSQLAccess::BYieldingExecute( const char *pchName, const char *pchSQLCommand, uint32 *pcRowsAffected, bool bSpewOnError ) +{ + if ( NULL == pchName ) + { + pchName = pchSQLCommand; + } + + bool bStandalone = !BInTransaction(); + if( bStandalone ) + { + BBeginTransaction( pchName ); + } + + CurrentQuery()->SetCommand( pchSQLCommand ); + m_pQueryGroup->AddQuery( m_pCurrentQuery ); + m_pCurrentQuery = NULL; + + bool bSuccess = true; + if( bStandalone ) + { + bSuccess = BCommitTransaction(); + if( bSuccess && pcRowsAffected ) + { + *pcRowsAffected = m_pQueryGroup->GetResults()->GetRowsAffected( 0 ); + } + } + return bSuccess; +} + + +//------------------------------------------------------------------------------------ +// Purpose: Starts a transaction +//------------------------------------------------------------------------------------ +bool CSQLAccess::BBeginTransaction( const char *pchName ) +{ + Assert( !m_bInTransaction ); + if( m_bInTransaction ) + return false; + m_pQueryGroup->Clear(); + m_pQueryGroup->SetName( pchName ); + m_bInTransaction = true; + return true; +} + +//------------------------------------------------------------------------------------ +// Purpose: Returns the string last passed to BBeginTransaction +//------------------------------------------------------------------------------------ +const char *CSQLAccess::PchTransactionName( ) const +{ + return m_pQueryGroup->PchName(); +} + + +//------------------------------------------------------------------------------------ +// Purpose: Commits a transaction to the database +//------------------------------------------------------------------------------------ +bool CSQLAccess::BCommitTransaction( bool bAllowEmpty ) +{ + Assert( BInTransaction() ); + if( !BInTransaction() ) + return false; + + if( !m_pCurrentQuery && !m_pQueryGroup->GetStatementCount() ) + { + if( bAllowEmpty ) + { + // No-op success + m_bInTransaction = false; + RunListeners_Commit(); + return true; + } + else + { + AssertMsg1( false, "BCommitTransaction with empty transaction at %s", m_pQueryGroup->PchName() ); + return false; + } + } + + AssertMsg1( !m_pCurrentQuery, "Unexecuted query present in BCommitTransaction: %s", m_pCurrentQuery->PchCommand() ); + if( m_pCurrentQuery ) + return false; + + m_bInTransaction = false; + + if( !GJobCur().BYieldingRunQuery( m_pQueryGroup, m_eSchemaCatalog ) ) + { + // Notify listeners that the transaction did not succeed + RunListeners_Rollback(); + return false; + } + + // The transaction presumably did make the database, so we do not notify rollback listeners beyond here. + RunListeners_Commit(); + + if( !m_pQueryGroup->GetResults() ) + return false; + + return true; +} + + +//------------------------------------------------------------------------------------ +// Purpose: Rolls back a transaction and clears any queries +//------------------------------------------------------------------------------------ +void CSQLAccess::RollbackTransaction() +{ + bool bWasTransaction = BInTransaction(); + + Assert( bWasTransaction ); + SAFE_DELETE( m_pCurrentQuery ); + m_bInTransaction = false; + + if ( bWasTransaction ) + { + RunListeners_Rollback(); + } + else + { + m_vecCommitListeners.clear(); + m_vecRollbackListeners.clear(); + } +} + +//------------------------------------------------------------------------------------ +// Purpose: Adds a listener to be called synchronously should the transaction successfully commit +//------------------------------------------------------------------------------------ +void CSQLAccess::AddCommitListener( std::function<void (void)> &&listener ) +{ + if ( !BInTransaction() ) + { + AssertMsg( BInTransaction(), "Adding a listener to a non-transaction access, will never fire" ); + return; + } + + m_vecCommitListeners.push_back( std::move( listener ) ); +} + +//------------------------------------------------------------------------------------ +// Purpose: Adds a listener to be called synchronously should the transaction fail or explicitly rollback +//------------------------------------------------------------------------------------ +void CSQLAccess::AddRollbackListener( std::function<void (void)> &&listener ) +{ + if ( !BInTransaction() ) + { + AssertMsg( BInTransaction(), "Adding a listener to a non-transaction access, will never fire" ); + return; + } + + m_vecRollbackListeners.push_back( std::move( listener ) ); +} + +//------------------------------------------------------------------------------------ +// Purpose: Notifies listeners of successful commit. +//------------------------------------------------------------------------------------ +void CSQLAccess::RunListeners_Commit() +{ + RunAndClearListenerList( m_vecCommitListeners ); + // Clear the unused set + m_vecRollbackListeners.clear(); +} + +//------------------------------------------------------------------------------------ +// Purpose: Notifies listeners of a implicitly or explicitly rolled back transactions and clears the listener list. +//------------------------------------------------------------------------------------ +void CSQLAccess::RunListeners_Rollback() +{ + RunAndClearListenerList( m_vecRollbackListeners ); + // Clear the unused set + m_vecCommitListeners.clear(); +} + +//------------------------------------------------------------------------------------ +// Purpose: Perform a query that returns a single string +//------------------------------------------------------------------------------------ +CSQLAccess::EReadSingleResultResult CSQLAccess::BYieldingExecuteSingleResultDataInternal( const char *pchName, const char *pchSQLCommand, EGCSQLType eType, uint8 **ppubData, uint32 *punSize, uint32 *pcRowsAffected, bool bHasDefaultValue ) +{ + AssertMsg( !BInTransaction(), "BYieldingExecuteSingleResultData is not supported in a transaction" ); + if( BInTransaction() ) + return eReadSingle_Error; + + bool bRet = BYieldingExecute( pchName, pchSQLCommand, pcRowsAffected ); + if ( !bRet ) + return eReadSingle_Error; + + if( m_pQueryGroup->GetResults()->GetResultSetCount() != 1 ) + { + AssertMsg1( false, "Expected single result set, found %d", m_pQueryGroup->GetResults()->GetResultSetCount() ); + return eReadSingle_Error; + } + + IGCSQLResultSet *pResultSet = m_pQueryGroup->GetResults()->GetResultSet( 0 ); + + // If we have a default value, getting back zero rows is acceptable. + if( pResultSet->GetRowCount() == 0 && bHasDefaultValue ) + { + return eReadSingle_UseDefault; + } + + // If we either have more than one row or no default value specified, that's an error. + if( pResultSet->GetRowCount() != 1 ) + { + AssertMsg1( false, "Expected single result, found %d", pResultSet->GetRowCount() ); + return eReadSingle_Error; + } + + if( pResultSet->GetColumnCount() != 1 ) + { + AssertMsg1( false, "Expected single column, found %d", pResultSet->GetColumnCount() ); + return eReadSingle_Error; + } + if( pResultSet->GetColumnType( 0 ) != eType ) + { + AssertMsg2( false, "Expected column of type %s, found %s", PchNameFromEGCSQLType( eType ), PchNameFromEGCSQLType( pResultSet->GetColumnType( 0 ) ) ); + return eReadSingle_Error; + } + + return pResultSet->GetData( 0, 0, ppubData, punSize ) + ? eReadSingle_ResultFound + : eReadSingle_Error; +} + + + + +//------------------------------------------------------------------------------------ +// Purpose: Perform a query that returns a single string +//------------------------------------------------------------------------------------ +bool CSQLAccess::BYieldingExecuteString( const char *pchName, const char *pchSQLCommand, CFmtStr1024 *psResult, uint32 *pcRowsAffected ) +{ + uint8 *pubData; + uint32 cubData; + if( CSQLAccess::BYieldingExecuteSingleResultDataInternal( pchName, pchSQLCommand, k_EGCSQLType_String, &pubData, &cubData, pcRowsAffected, false ) != eReadSingle_ResultFound ) + return false; + + *psResult = (char *)pubData; + + return true; +} + +//------------------------------------------------------------------------------------ +// Purpose: Perform a query that returns a single int +//------------------------------------------------------------------------------------ +bool CSQLAccess::BYieldingExecuteScalarInt( const char *pchName, const char *pchSQLCommand, int *pnResult, uint32 *pcRowsAffected ) +{ + return BYieldingExecuteSingleResult<int32, uint32>( pchName, pchSQLCommand, k_EGCSQLType_int32, pnResult, pcRowsAffected ); +} + +bool CSQLAccess::BYieldingExecuteScalarIntWithDefault( const char *pchName, const char *pchSQLCommand, int *pnResult, int iDefaultValue, uint32 *pcRowsAffected ) +{ + return BYieldingExecuteSingleResultWithDefault<int32, uint32>( pchName, pchSQLCommand, k_EGCSQLType_int32, pnResult, iDefaultValue, pcRowsAffected ); +} + +//------------------------------------------------------------------------------------ +// Purpose: Perform a query that returns a single uint32 +//------------------------------------------------------------------------------------ +bool CSQLAccess::BYieldingExecuteScalarUint32( const char *pchName, const char *pchSQLCommand, uint32 *punResult, uint32 *pcRowsAffected ) +{ + return BYieldingExecuteSingleResult<uint32, uint32>( pchName, pchSQLCommand, k_EGCSQLType_int32, punResult, pcRowsAffected ); +} + +bool CSQLAccess::BYieldingExecuteScalarUint32WithDefault( const char *pchName, const char *pchSQLCommand, uint32 *punResult, uint32 unDefaultValue, uint32 *pcRowsAffected ) +{ + return BYieldingExecuteSingleResultWithDefault<uint32, uint32>( pchName, pchSQLCommand, k_EGCSQLType_int32, punResult, unDefaultValue, pcRowsAffected ); +} + +//------------------------------------------------------------------------------------ +// Purpose: A bunch of pass throughs to the query itself +//------------------------------------------------------------------------------------ +void CSQLAccess::AddBindParam( const char *pchValue ) +{ + CurrentQuery()->AddBindParam( pchValue ); +} + +void CSQLAccess::AddBindParam( const int16 nValue ) +{ + CurrentQuery()->AddBindParam( nValue ); +} + +void CSQLAccess::AddBindParam( const uint16 uValue ) +{ + CurrentQuery()->AddBindParam( uValue ); +} + +void CSQLAccess::AddBindParam( const int32 nValue ) +{ + CurrentQuery()->AddBindParam( nValue ); +} + +void CSQLAccess::AddBindParam( const uint32 uValue ) +{ + CurrentQuery()->AddBindParam( uValue ); +} + +void CSQLAccess::AddBindParam( const uint64 ulValue ) +{ + CurrentQuery()->AddBindParam( ulValue ); +} + +void CSQLAccess::AddBindParam( const uint8 *ubValue, const int cubValue ) +{ + CurrentQuery()->AddBindParam( ubValue, cubValue ); +} + +void CSQLAccess::AddBindParam( const float fValue ) +{ + CurrentQuery()->AddBindParam( fValue ); +} + +void CSQLAccess::AddBindParam( const double dValue ) +{ + CurrentQuery()->AddBindParam( dValue ); +} + +void CSQLAccess::AddBindParamRaw( EGCSQLType eType, const byte *pubData, uint32 cubData ) +{ + CurrentQuery()->AddBindParamRaw( eType, pubData, cubData ); +} + +void CSQLAccess::ClearParams() +{ + if( m_pCurrentQuery ) + { + delete m_pCurrentQuery; + m_pCurrentQuery = NULL; + } +} + + +IGCSQLResultSetList *CSQLAccess::GetResults() +{ + return m_pQueryGroup->GetResults(); +} + + +//------------------------------------------------------------------------------------ +// Purpose: Returns the number of result sets +//------------------------------------------------------------------------------------ +uint32 CSQLAccess::GetResultSetCount() +{ + if( m_pQueryGroup->GetResults() ) + return m_pQueryGroup->GetResults()->GetResultSetCount(); + else + return 0; +} + + +//------------------------------------------------------------------------------------ +// Purpose: Returns the number of rows in a result set +//------------------------------------------------------------------------------------ +uint32 CSQLAccess::GetResultSetRowCount( uint32 unResultSet ) +{ + if( m_pQueryGroup->GetResults() && unResultSet < m_pQueryGroup->GetResults()->GetResultSetCount() ) + return m_pQueryGroup->GetResults()->GetResultSet( unResultSet )->GetRowCount(); + else + return 0; +} + + +//------------------------------------------------------------------------------------ +// Purpose: Returns a CSQLRecord object that represents a row in a result set +//------------------------------------------------------------------------------------ +CSQLRecord CSQLAccess::GetResultRecord( uint32 unResultSet, uint32 unRow ) +{ + if( m_pQueryGroup->GetResults() && unResultSet < m_pQueryGroup->GetResults()->GetResultSetCount() ) + { + IGCSQLResultSet *pResultSet = m_pQueryGroup->GetResults()->GetResultSet( unResultSet ); + if( unRow < pResultSet->GetRowCount() ) + return CSQLRecord( unRow, pResultSet ); + } + return CSQLRecord(); // if there was a problem return an empty record +} + +//----------------------------------------------------------------------------- +// Purpose: Inserts a new record into the DS +// Input: pRecordBase - record to insert +// Output: true if successful, false otherwise +//----------------------------------------------------------------------------- +bool CSQLAccess::BYieldingInsertRecord( const CRecordBase *pRecordBase ) +{ + ClearParams(); + + const CRecordInfo *pRecordInfo = pRecordBase->GetPSchema()->GetRecordInfo(); + int cColumns = pRecordInfo->GetNumColumns(); + for ( int nColumn = 0; nColumn < cColumns; nColumn++ ) + { + const CColumnInfo &columnInfo = pRecordInfo->GetColumnInfo( nColumn ); + if ( !columnInfo.BIsInsertable() ) + continue; + + uint8 *pubData; + uint32 cubData; + DbgVerify( pRecordBase->BGetField( nColumn, &pubData, &cubData ) ); + + CurrentQuery()->AddBindParamRaw( columnInfo.GetType(), pubData, cubData ); + } + + uint32 nRows; + const char *pchStatement = pRecordBase->GetPSchema()->GetInsertStatementText(); + + bool bRet = BYieldingExecute( pchStatement, pchStatement, &nRows ); + return ( nRows == 1 || BInTransaction() ) && bRet; +} + + +//----------------------------------------------------------------------------- +// Purpose: Inserts a new record into the DS if such row doesn't exist +// Input: pRecordBase - record to insert +// Output: true if successful, false otherwise +//----------------------------------------------------------------------------- +bool CSQLAccess::BYieldingInsertWhenNotMatchedOnPK( CRecordBase *pRecordBase ) +{ + ClearParams(); + + const CRecordInfo *pRecordInfo = pRecordBase->GetPSchema()->GetRecordInfo(); + int cColumns = pRecordInfo->GetNumColumns(); + for ( int nColumn = 0; nColumn < cColumns; nColumn++ ) + { + const CColumnInfo &columnInfo = pRecordInfo->GetColumnInfo( nColumn ); + if ( !columnInfo.BIsInsertable() ) + { + Assert( columnInfo.BIsInsertable() ); + return false; + } + + uint8 *pubData; + uint32 cubData; + DbgVerify( pRecordBase->BGetField( nColumn, &pubData, &cubData ) ); + + CurrentQuery()->AddBindParamRaw( columnInfo.GetType(), pubData, cubData ); + } + + uint32 nRows; + const char *pchStatement = pRecordBase->GetPSchema()->GetMergeStatementTextOnPKWhenNotMatchedInsert(); + + bool bRet = BYieldingExecute( pchStatement, pchStatement, &nRows ); + return ( nRows == 1 || nRows == 0 || BInTransaction() ) && bRet; +} + +//----------------------------------------------------------------------------- +// Purpose: Inserts a new record into the DS if such row doesn't exist +// updates an existing row if such row is matched by PK +// Input: pRecordBase - record to insert +// Output: true if successful, false otherwise +//----------------------------------------------------------------------------- +bool CSQLAccess::BYieldingInsertOrUpdateOnPK( CRecordBase *pRecordBase ) +{ + ClearParams(); + + const CRecordInfo *pRecordInfo = pRecordBase->GetPSchema()->GetRecordInfo(); + int cColumns = pRecordInfo->GetNumColumns(); + for ( int nColumn = 0; nColumn < cColumns; nColumn++ ) + { + const CColumnInfo &columnInfo = pRecordInfo->GetColumnInfo( nColumn ); + if ( !columnInfo.BIsInsertable() ) + { + Assert( columnInfo.BIsInsertable() ); + return false; + } + + uint8 *pubData; + uint32 cubData; + DbgVerify( pRecordBase->BGetField( nColumn, &pubData, &cubData ) ); + + CurrentQuery()->AddBindParamRaw( columnInfo.GetType(), pubData, cubData ); + } + + uint32 nRows; + const char *pchStatement = pRecordBase->GetPSchema()->GetMergeStatementTextOnPKWhenMatchedUpdateWhenNotMatchedInsert(); + + bool bRet = BYieldingExecute( pchStatement, pchStatement, &nRows ); + return ( nRows == 1 || BInTransaction() ) && bRet; +} + +//----------------------------------------------------------------------------- +// Purpose: Inserts a new record into the DB and reads non-insertable fields back +// into the record. +// Input: pRecordBase - record to insert +// Output: true if successful, false otherwise +//----------------------------------------------------------------------------- +bool CSQLAccess::BYieldingInsertWithIdentity( CRecordBase* pRecordBase ) +{ + AssertMsg( !BInTransaction(), "BYieldingInsertWithIdentity is not supported in a transaction" ); + if( BInTransaction() ) + return false; + ClearParams(); + + TSQLCmdStr sStatement; + CUtlVector<int> vecOutputFields; + CRecordInfo *pRecordInfo = pRecordBase->GetPSchema()->GetRecordInfo(); + BuildInsertAndReadStatementText( &sStatement, &vecOutputFields, pRecordInfo ); + + AssertMsg( vecOutputFields.Count() > 0, "BYieldingInsertAndReadRecord called for a record type with no non-insertable columns" ); + if ( vecOutputFields.Count() == 0 ) + return false; + + int cColumns = pRecordInfo->GetNumColumns(); + for ( int nColumn = 0; nColumn < cColumns; nColumn++ ) + { + const CColumnInfo &columnInfo = pRecordInfo->GetColumnInfo( nColumn ); + if ( !columnInfo.BIsInsertable() ) + { + continue; + } + + uint8 *pubData; + uint32 cubData; + DbgVerify( pRecordBase->BGetField( nColumn, &pubData, &cubData ) ); + + CurrentQuery()->AddBindParamRaw( columnInfo.GetType(), pubData, cubData ); + } + + bool bRet = BYieldingExecute( sStatement, sStatement ); + if( !bRet ) + return false; + + Assert( 1 == GetResultSetCount() ); + if ( 1 != GetResultSetCount() ) + return false; + + IGCSQLResultSet *pResultSet = m_pQueryGroup->GetResults()->GetResultSet( 0 ); + Assert( 1 == pResultSet->GetRowCount() ); + if ( 1 != pResultSet->GetRowCount() ) + return false; + + Assert( (uint32)vecOutputFields.Count() == pResultSet->GetColumnCount() ); + if ( (uint32)vecOutputFields.Count() != pResultSet->GetColumnCount() ) + return false; + + for( uint32 nColumn = 0; nColumn < pResultSet->GetColumnCount(); nColumn++ ) + { + uint8 *pubData; + uint32 cubData; + DbgVerify( pResultSet->GetData( 0, nColumn, &pubData, &cubData ) ); + + int nSchColumn = vecOutputFields[nColumn]; + Assert( pResultSet->GetColumnType( nColumn ) == pRecordInfo->GetColumnInfo( nSchColumn ).GetType() ); + DbgVerify( pRecordBase->BSetField( nSchColumn, pubData, cubData ) ); + } + + return true; +} + + +//----------------------------------------------------------------------------- +// Purpose: Reads a list of records from the DB according to the specified where +// clause +// Input: pRecordBase - record to read +// readSet - The set of columns to read +// whereSet - The set of columns to query on +// Output: true if successful, false otherwise +//----------------------------------------------------------------------------- +EResult CSQLAccess::YieldingReadRecordWithWhereColumns( CRecordBase *pRecord, const CColumnSet & readSet, const CColumnSet & whereSet, const char* pchOrderClause ) +{ + AssertMsg( !BInTransaction(), "BYieldingReadRecordWithWhereColumns is not supported in a transaction" ); + if( BInTransaction() ) + return k_EResultInvalidState; + + //if there is an order by clause, only take the top one, if there isn't, then validate that we have a single instance + const char* pszTopClause = ( pchOrderClause ) ? "TOP (1)" : "TOP (2)"; + + TSQLCmdStr sStatement; + BuildSelectStatementText( &sStatement, readSet, pszTopClause ); + + // if we actually have some columns for the where clause, + // append a where clause. + if( whereSet.GetColumnCount() ) + { + sStatement.Append( " WHERE " ); + AppendWhereClauseText( &sStatement, whereSet ); + AddRecordParameters( *pRecord, whereSet ); + } + //append the order by if they added one + if( pchOrderClause ) + { + sStatement.Append( " ORDER BY " ); + sStatement.Append( pchOrderClause ); + } + + Assert(!readSet.IsEmpty() ); + if( !BYieldingExecute( sStatement, sStatement ) ) + return k_EResultFail; + + if ( GetResultSetCount() != 1 ) + { + AssertMsg( GetResultSetCount() == 1, "Unexpected number of result sets returned from select statement" ); + return k_EResultFail; + } + + // make sure the types are the same + IGCSQLResultSet *pResultSet = m_pQueryGroup->GetResults()->GetResultSet( 0 ); + if ( pResultSet->GetRowCount() == 0 ) + return k_EResultNoMatch; + + //note that since we only take the top one when there is an order by clause, we don't need to handle that case down here, only if top 2 is selected + if( pResultSet->GetRowCount() != 1 ) + { + // Make sure we aren't failing because there are multiple matching records. + // That is probably a misuse of the API or some unexpected condition. + AssertMsg1( false, "BYieldingReadRecordWithWhereColumns from %s failing because multiple records match WHERE clause", readSet.GetRecordInfo()->GetName() ); + return k_EResultLimitExceeded; + } + FOR_EACH_COLUMN_IN_SET( readSet, nColumnIndex ) + { + EGCSQLType eRecordType = readSet.GetColumnInfo( nColumnIndex ).GetType(); + EGCSQLType eResultType = pResultSet->GetColumnType( nColumnIndex ); + + AssertMsg2( eResultType == eRecordType, "Column %d type mismatch in %s", nColumnIndex, readSet.GetRecordInfo()->GetName() ); + if( eRecordType != eResultType ) + return k_EResultInvalidParam; + } + + CSQLRecord sqlRecord = GetResultRecord( 0, 0 ); + + FOR_EACH_COLUMN_IN_SET( readSet, nColumnIndex ) + { + uint8 *pubData; + uint32 cubData; + + DbgVerify( sqlRecord.BGetColumnData( nColumnIndex, &pubData, (int*)&cubData ) ); + DbgVerify( pRecord->BSetField( readSet.GetColumn( nColumnIndex), pubData, cubData ) ); + } + + return k_EResultOK; +} + +//----------------------------------------------------------------------------- +// Purpose: Updates a record in the DB +// Input: record - data source for columns to match against (whereColumns) and +// columns to assign (updateColumns) +// whereColumns - columns to match against +// updateColumns - columns to update +// Output: true if successful, false otherwise +//----------------------------------------------------------------------------- +bool CSQLAccess::BYieldingUpdateRecord( const CRecordBase & record, const CColumnSet & whereColumns, const CColumnSet & updateColumns, const CSQLOutputParams *pOptionalOutputParams /* = NULL */ ) +{ + return BYieldingUpdateRecords( record, whereColumns, record, updateColumns, pOptionalOutputParams ); +} + +//----------------------------------------------------------------------------- +// Purpose: +//----------------------------------------------------------------------------- +bool CSQLAccess::BYieldingUpdateRecords( const CRecordBase & whereRecord, const CColumnSet & whereColumns, const CRecordBase & updateRecord, const CColumnSet & updateColumns, const CSQLOutputParams *pOptionalOutputParams /* = NULL */ ) +{ + ClearParams(); + + Assert( whereColumns.GetRecordInfo() == updateColumns.GetRecordInfo() ); + if ( whereColumns.GetRecordInfo() != updateColumns.GetRecordInfo() ) + return false; + Assert( whereColumns.GetRecordInfo() == whereRecord.GetPSchema()->GetRecordInfo() ); + if ( whereColumns.GetRecordInfo() != whereRecord.GetPSchema()->GetRecordInfo() ) + return false; + Assert( whereColumns.GetRecordInfo() == updateRecord.GetPSchema()->GetRecordInfo() ); + if ( whereColumns.GetRecordInfo() != updateRecord.GetPSchema()->GetRecordInfo() ) + return false; + + AssertMsg( !updateColumns.IsEmpty(), "Someone is calling BYieldingUpdateRecord with no columns to update." ); + if ( updateColumns.IsEmpty() ) + return false; + + // add the columns we're updating as bound params + TSQLCmdStr sStatement; + BuildUpdateStatementText( &sStatement, updateColumns ); + + AddRecordParameters( updateRecord, updateColumns ); + + // did the users specify an OUTPUT block? + if ( pOptionalOutputParams ) + { + TSQLCmdStr sOutput; + BuildOutputClauseText( &sOutput, pOptionalOutputParams->GetColumnSet() ); + sStatement.Append( sOutput ); + + AddRecordParameters( pOptionalOutputParams->GetRecord(), pOptionalOutputParams->GetColumnSet() ); + } + + if ( !whereColumns.IsEmpty() ) + { + sStatement.Append( " WHERE " ); + AppendWhereClauseText( &sStatement, whereColumns ); + + // add the columns we're querying on as bound params + AddRecordParameters( whereRecord, whereColumns ); + } + + return BYieldingExecute( sStatement, sStatement ); +} + +//----------------------------------------------------------------------------- +// Purpose: Deletes this record's row in the table +// Input: record - record to delete +// whereColumns - columns to use when searching for this record +//----------------------------------------------------------------------------- +bool CSQLAccess::BYieldingDeleteRecords( const CRecordBase & record, const CColumnSet & whereColumns ) +{ + Assert( whereColumns.GetRecordInfo() == record.GetPSchema()->GetRecordInfo() ); + if ( whereColumns.GetRecordInfo() != record.GetPSchema()->GetRecordInfo() ) + return false; + + ClearParams(); + AddRecordParameters( record, whereColumns ); + + TSQLCmdStr sStatement; + BuildDeleteStatementText( &sStatement, record.GetPRecordInfo() ); + sStatement.Append( " WHERE " ); + AppendWhereClauseText( &sStatement, whereColumns ); + + uint32 unRowsAffected; + if( !BYieldingExecute( sStatement, sStatement, &unRowsAffected ) ) + return false; + + return unRowsAffected > 0 || BInTransaction(); +} + +//-------------------------------------------------------------------------------------------------------------------------------- +// CSQLUpdateOrInsert +//-------------------------------------------------------------------------------------------------------------------------------- + +CSQLUpdateOrInsert::CSQLUpdateOrInsert( const char* pszName, int nTable, const CColumnSet & whereColumns, const CColumnSet & updateColumns, const char* pszWhereClause, const char* pszUpdateClause ) +{ + const CRecordInfo* pRecordInfo = GSchemaFull().GetSchema( nTable ).GetRecordInfo(); + + //how many columns do we have + const int nNumColumns = pRecordInfo->GetNumColumns(); + + TSQLCmdStr sStatement; + sStatement = "MERGE INTO "; + sStatement.Append( GSchemaFull().GetDefaultSchemaNameForCatalog( pRecordInfo->GetESchemaCatalog() ) ); + sStatement.Append( '.' ); + sStatement.Append( pRecordInfo->GetName() ); + sStatement.Append( " WITH(HOLDLOCK) AS D USING(VALUES(" ); + sStatement.AppendFormat( "%.*s", GetInsertArgStringChars( nNumColumns ), GetInsertArgString() ); + sStatement.Append( "))AS S(" ); + + //add each column that we are adding the values for, along with the parameter from the structure + for( int nCurrColumn = 0; nCurrColumn < nNumColumns; nCurrColumn++ ) + { + const CColumnInfo& colInfo = pRecordInfo->GetColumnInfo( nCurrColumn ); + if( nCurrColumn != 0 ) + sStatement.Append( ',' ); + sStatement.Append( colInfo.GetName() ); + } + + //our where clause + sStatement.Append( ")ON " ); + + if( pszWhereClause ) + { + sStatement.Append( pszWhereClause ); + } + else + { + FOR_EACH_COLUMN_IN_SET( whereColumns, nCurrColumn ) + { + const char* pszColName = pRecordInfo->GetColumnInfo( whereColumns.GetColumn( nCurrColumn ) ).GetName(); + if( nCurrColumn > 0 ) + sStatement.Append( " AND " ); + sStatement.AppendFormat( "D.%s=S.%s", pszColName, pszColName ); + } + } + + //our update clause (if they have provided fields that they want to update) + if( pszUpdateClause || !updateColumns.IsEmpty() ) + { + sStatement.Append( " WHEN MATCHED THEN UPDATE SET " ); + if( pszUpdateClause ) + { + sStatement.Append( pszUpdateClause ); + } + else + { + FOR_EACH_COLUMN_IN_SET( updateColumns, nCurrColumn ) + { + const char* pszColName = pRecordInfo->GetColumnInfo( updateColumns.GetColumn( nCurrColumn ) ).GetName(); + if( nCurrColumn > 0 ) + sStatement.Append( ',' ); + sStatement.AppendFormat( "%s=S.%s", pszColName, pszColName ); + } + } + } + + //our insert clause + sStatement.Append( " WHEN NOT MATCHED THEN INSERT(" ); + bool bFirstColumn = true; + for( int nCurrColumn = 0; nCurrColumn < nNumColumns; nCurrColumn++ ) + { + const CColumnInfo& colInfo = pRecordInfo->GetColumnInfo( nCurrColumn ); + if( !colInfo.BIsInsertable() ) + continue; + + if( !bFirstColumn ) + sStatement.Append( ',' ); + bFirstColumn = false; + sStatement.Append( colInfo.GetName() ); + } + + sStatement.Append( ")VALUES(" ); + bFirstColumn = true; + for( int nCurrColumn = 0; nCurrColumn < nNumColumns; nCurrColumn++ ) + { + const CColumnInfo& colInfo = pRecordInfo->GetColumnInfo( nCurrColumn ); + if( !colInfo.BIsInsertable() ) + continue; + + if( !bFirstColumn ) + sStatement.Append( ',' ); + bFirstColumn = false; + sStatement.AppendFormat( "S.%s", colInfo.GetName() ); + } + sStatement.Append( ");" ); + + //save our results so we can execute it in the future + m_nTable = nTable; + m_sName = pszName; + m_sQuery = sStatement; +} + +bool CSQLUpdateOrInsert::BYieldingExecute( CSQLAccess& sqlAccess, const CRecordBase& record, uint32 *out_punRowsAffected /* = NULL */ ) const +{ + AssertMsg2( record.GetITable() == m_nTable, "Error: Merge was compiled for table %s, but was attempted to be executed against %s", GSchemaFull().GetSchema( m_nTable ).GetRecordInfo()->GetName(), record.GetPRecordInfo()->GetName() ); + + const CRecordInfo* pRecordInfo = record.GetPRecordInfo(); + //how many columns do we have + const int nNumColumns = pRecordInfo->GetNumColumns(); + + sqlAccess.ClearParams(); + for( int nCurrColumn = 0; nCurrColumn < nNumColumns; nCurrColumn++ ) + { + const CColumnInfo& colInfo = pRecordInfo->GetColumnInfo( nCurrColumn ); + uint8 *pubData; + uint32 cubData; + DbgVerify( record.BGetField( nCurrColumn, &pubData, &cubData ) ); + sqlAccess.AddBindParamRaw( colInfo.GetType(), pubData, cubData ); + } + + return sqlAccess.BYieldingExecute( m_sName, m_sQuery, out_punRowsAffected ); +} + + +//----------------------------------------------------------------------------- +// Purpose: Adds bind parameters to the list based on a set of fields in a record +// Input: record - record to insert +// columnSet - The set of columns to add as params +//----------------------------------------------------------------------------- +void CSQLAccess::AddRecordParameters( const CRecordBase &record, const CColumnSet & columnSet ) +{ + Assert( record.GetPSchema()->GetRecordInfo() == columnSet.GetRecordInfo() ); + if ( record.GetPSchema()->GetRecordInfo() != columnSet.GetRecordInfo() ) + return; + + FOR_EACH_COLUMN_IN_SET( columnSet, nColumnIndex ) + { + const CColumnInfo &columnInfo = columnSet.GetColumnInfo( nColumnIndex ); + uint8 *pubData; + uint32 cubData; + DbgVerify( record.BGetField( columnSet.GetColumn( nColumnIndex ), &pubData, &cubData ) ); + EGCSQLType eType = columnInfo.GetType(); + CurrentQuery()->AddBindParamRaw( eType, pubData, cubData ); + } +} + +//----------------------------------------------------------------------------- +// Purpose: Deletes all records from a table +// Input: iTable - table to wipe +// Output: true if the operation was successful +// Note: PERFORMANCE WARNING: this is slow on big tables, not intended for use +// in production +//----------------------------------------------------------------------------- +bool CSQLAccess::BYieldingWipeTable( int iTable ) +{ + // make a wipe operation + CRecordInfo *pRecordInfo = GSchemaFull().GetSchema( iTable ).GetRecordInfo(); + + CUtlString buf; + buf.Format( "DELETE FROM %s", pRecordInfo->GetName() ); + return BYieldingExecute( buf.String(), buf.String() ); +} + + +//----------------------------------------------------------------------------- +// Purpose: Returns the current query to add stuff to, creating it if there isn't +// already a current query +//----------------------------------------------------------------------------- +CGCSQLQuery *CSQLAccess::CurrentQuery() +{ + if( m_pCurrentQuery ) + return m_pCurrentQuery; + + m_pCurrentQuery = new CGCSQLQuery(); + return m_pCurrentQuery; +} + + +} // namespace GCSDK |