/****************************************************************************** * Copyright (c) 2011, Duane Merrill. All rights reserved. * Copyright (c) 2011-2014, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright * notice, this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Neither the name of the NVIDIA CORPORATION nor the * names of its contributors may be used to endorse or promote products * derived from this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * ******************************************************************************/ /** * \file * cub::WarpScanSmem provides smem-based variants of parallel prefix scan of items partitioned across a CUDA thread warp. */ #pragma once #include "../../thread/thread_operators.cuh" #include "../../thread/thread_load.cuh" #include "../../thread/thread_store.cuh" #include "../../util_type.cuh" #include "../../util_namespace.cuh" /// Optional outer namespace(s) CUB_NS_PREFIX /// CUB namespace namespace cub { /** * \brief WarpScanSmem provides smem-based variants of parallel prefix scan of items partitioned across a CUDA thread warp. */ template < typename T, ///< Data type being scanned int LOGICAL_WARP_THREADS, ///< Number of threads per logical warp int PTX_ARCH> ///< The PTX compute capability for which to to specialize this collective struct WarpScanSmem { /****************************************************************************** * Constants and type definitions ******************************************************************************/ enum { /// Whether the logical warp size and the PTX warp size coincide IS_ARCH_WARP = (LOGICAL_WARP_THREADS == CUB_WARP_THREADS(PTX_ARCH)), /// The number of warp scan steps STEPS = Log2::VALUE, /// The number of threads in half a warp HALF_WARP_THREADS = 1 << (STEPS - 1), /// The number of shared memory elements per warp WARP_SMEM_ELEMENTS = LOGICAL_WARP_THREADS + HALF_WARP_THREADS, }; /// Storage cell type (workaround for SM1x compiler bugs with custom-ops like Max() on signed chars) typedef typename If<((Equals::VALUE || Equals::VALUE) && (PTX_ARCH < 200)), int, T>::Type CellT; /// Shared memory storage layout type (1.5 warps-worth of elements for each warp) typedef CellT _TempStorage[WARP_SMEM_ELEMENTS]; // Alias wrapper allowing storage to be unioned struct TempStorage : Uninitialized<_TempStorage> {}; /****************************************************************************** * Thread fields ******************************************************************************/ _TempStorage &temp_storage; unsigned int lane_id; /****************************************************************************** * Construction ******************************************************************************/ /// Constructor __device__ __forceinline__ WarpScanSmem( TempStorage &temp_storage) : temp_storage(temp_storage.Alias()), lane_id(IS_ARCH_WARP ? LaneId() : LaneId() % LOGICAL_WARP_THREADS) {} /****************************************************************************** * Utility methods ******************************************************************************/ /// Basic inclusive scan iteration(template unrolled, base-case specialization) template < bool HAS_IDENTITY, typename ScanOp> __device__ __forceinline__ void ScanStep( T &partial, ScanOp scan_op, Int2Type step) {} /// Basic inclusive scan iteration (template unrolled, inductive-case specialization) template < bool HAS_IDENTITY, int STEP, typename ScanOp> __device__ __forceinline__ void ScanStep( T &partial, ScanOp scan_op, Int2Type step) { const int OFFSET = 1 << STEP; // Share partial into buffer ThreadStore(&temp_storage[HALF_WARP_THREADS + lane_id], (CellT) partial); // Update partial if addend is in range if (HAS_IDENTITY || (lane_id >= OFFSET)) { T addend = (T) ThreadLoad(&temp_storage[HALF_WARP_THREADS + lane_id - OFFSET]); partial = scan_op(addend, partial); } ScanStep(partial, scan_op, Int2Type()); } /// Inclusive prefix scan with identity template __device__ __forceinline__ void InclusiveScan( T input, ///< [in] Calling thread's input item. T &output, ///< [out] Calling thread's output item. May be aliased with \p input. T identity, ///< [in] Identity value ScanOp scan_op) ///< [in] Binary scan operator { ThreadStore(&temp_storage[lane_id], (CellT) identity); // Iterate scan steps output = input; ScanStep(output, scan_op, Int2Type<0>()); } /// Inclusive prefix scan (specialized for summation across primitive types) __device__ __forceinline__ void InclusiveScan( T input, ///< [in] Calling thread's input item. T &output, ///< [out] Calling thread's output item. May be aliased with \p input. Sum scan_op, ///< [in] Binary scan operator Int2Type is_primitive) ///< [in] Marker type indicating whether T is primitive type { T identity = ZeroInitialize(); InclusiveScan(input, output, identity, scan_op); } /// Inclusive prefix scan template __device__ __forceinline__ void InclusiveScan( T input, ///< [in] Calling thread's input item. T &output, ///< [out] Calling thread's output item. May be aliased with \p input. ScanOp scan_op, ///< [in] Binary scan operator Int2Type is_primitive) ///< [in] Marker type indicating whether T is primitive type { // Iterate scan steps output = input; ScanStep(output, scan_op, Int2Type<0>()); } /****************************************************************************** * Interface ******************************************************************************/ /// Broadcast __device__ __forceinline__ T Broadcast( T input, ///< [in] The value to broadcast unsigned int src_lane) ///< [in] Which warp lane is to do the broadcasting { if (lane_id == src_lane) { ThreadStore(temp_storage, (CellT) input); } return (T) ThreadLoad(temp_storage); } //--------------------------------------------------------------------- // Inclusive operations //--------------------------------------------------------------------- /// Inclusive scan template __device__ __forceinline__ void InclusiveScan( T input, ///< [in] Calling thread's input item. T &output, ///< [out] Calling thread's output item. May be aliased with \p input. ScanOp scan_op) ///< [in] Binary scan operator { InclusiveScan(input, output, scan_op, Int2Type::PRIMITIVE>()); } /// Inclusive scan with aggregate template __device__ __forceinline__ void InclusiveScan( T input, ///< [in] Calling thread's input item. T &output, ///< [out] Calling thread's output item. May be aliased with \p input. ScanOp scan_op, ///< [in] Binary scan operator T &warp_aggregate) ///< [out] Warp-wide aggregate reduction of input items. { InclusiveScan(input, output, scan_op); // Retrieve aggregate ThreadStore(&temp_storage[HALF_WARP_THREADS + lane_id], (CellT) output); warp_aggregate = (T) ThreadLoad(&temp_storage[WARP_SMEM_ELEMENTS - 1]); } //--------------------------------------------------------------------- // Combo (inclusive & exclusive) operations //--------------------------------------------------------------------- /// Combination scan without identity template __device__ __forceinline__ void Scan( T input, ///< [in] Calling thread's input item. T &inclusive_output, ///< [out] Calling thread's inclusive-scan output item. T &exclusive_output, ///< [out] Calling thread's exclusive-scan output item. ScanOp scan_op) ///< [in] Binary scan operator { // Compute inclusive scan InclusiveScan(input, inclusive_output, scan_op); // Grab result from predecessor ThreadStore(&temp_storage[HALF_WARP_THREADS + lane_id], (CellT) inclusive_output); exclusive_output = (T) ThreadLoad(&temp_storage[HALF_WARP_THREADS + lane_id - 1]); } /// Combination scan with identity template __device__ __forceinline__ void Scan( T input, ///< [in] Calling thread's input item. T &inclusive_output, ///< [out] Calling thread's inclusive-scan output item. T &exclusive_output, ///< [out] Calling thread's exclusive-scan output item. T identity, ///< [in] Identity value ScanOp scan_op) ///< [in] Binary scan operator { // Compute inclusive scan InclusiveScan(input, inclusive_output, identity, scan_op); // Grab result from predecessor ThreadStore(&temp_storage[HALF_WARP_THREADS + lane_id], (CellT) inclusive_output); exclusive_output = (T) ThreadLoad(&temp_storage[HALF_WARP_THREADS + lane_id - 1]); } //--------------------------------------------------------------------- // Exclusive operations //--------------------------------------------------------------------- /// Exclusive scan with aggregate template __device__ __forceinline__ void ExclusiveScan( T input, ///< [in] Calling thread's input item. T &output, ///< [out] Calling thread's output item. May be aliased with \p input. T identity, ///< [in] Identity value ScanOp scan_op, ///< [in] Binary scan operator T &warp_aggregate) ///< [out] Warp-wide aggregate reduction of input items. { T inclusive_output; Scan(input, inclusive_output, output, identity, scan_op); // Retrieve aggregate warp_aggregate = (T) ThreadLoad(&temp_storage[WARP_SMEM_ELEMENTS - 1]); } /// Exclusive scan with aggregate, without identity template __device__ __forceinline__ void ExclusiveScan( T input, ///< [in] Calling thread's input item. T &output, ///< [out] Calling thread's output item. May be aliased with \p input. ScanOp scan_op, ///< [in] Binary scan operator T &warp_aggregate) ///< [out] Warp-wide aggregate reduction of input items. { T inclusive_output; Scan(input, inclusive_output, output, scan_op); // Retrieve aggregate warp_aggregate = (T) ThreadLoad(&temp_storage[WARP_SMEM_ELEMENTS - 1]); } }; } // CUB namespace CUB_NS_POSTFIX // Optional outer namespace(s)