From 77a6a45c7ea6e46bf127cfe466ce973c5a675a87 Mon Sep 17 00:00:00 2001 From: Federico Busato <50413820+fbusato@users.noreply.github.com> Date: Thu, 30 Jan 2025 12:51:37 -0800 Subject: [PATCH] Remove `LEGACY_PTX_ARCH` (#3551) --- cub/cub/agent/agent_batch_memcpy.cuh | 2 -- cub/cub/agent/agent_histogram.cuh | 12 +++------- cub/cub/agent/agent_reduce_by_key.cuh | 2 +- cub/cub/agent/agent_rle.cuh | 2 +- cub/cub/agent/agent_scan.cuh | 2 +- cub/cub/agent/agent_scan_by_key.cuh | 2 +- cub/cub/agent/agent_select_if.cuh | 2 +- cub/cub/agent/agent_three_way_partition.cuh | 2 +- cub/cub/agent/agent_unique_by_key.cuh | 5 ++--- cub/cub/agent/single_pass_scan_operators.cuh | 1 - cub/cub/block/block_adjacent_difference.cuh | 2 +- cub/cub/block/block_discontinuity.cuh | 4 +--- cub/cub/block/block_exchange.cuh | 5 +---- cub/cub/block/block_histogram.cuh | 5 +---- cub/cub/block/block_load.cuh | 5 +---- cub/cub/block/block_radix_rank.cuh | 8 ++----- cub/cub/block/block_radix_sort.cuh | 5 +---- cub/cub/block/block_raking_layout.cuh | 4 +--- cub/cub/block/block_reduce.cuh | 5 +---- cub/cub/block/block_scan.cuh | 5 +---- cub/cub/block/block_shuffle.cuh | 4 +--- cub/cub/block/block_store.cuh | 5 +---- .../specializations/block_histogram_sort.cuh | 22 +++---------------- .../specializations/block_reduce_raking.cuh | 9 +++----- .../block_reduce_raking_commutative_only.cuh | 9 +++----- .../block_reduce_warp_reductions.cuh | 9 +++----- .../specializations/block_scan_raking.cuh | 9 +++----- .../specializations/block_scan_warp_scans.cuh | 9 +++----- cub/cub/util_ptx.cuh | 2 +- .../warp/specializations/warp_reduce_shfl.cuh | 9 +++----- .../warp/specializations/warp_reduce_smem.cuh | 9 +++----- .../warp/specializations/warp_scan_shfl.cuh | 9 +++----- .../warp/specializations/warp_scan_smem.cuh | 9 +++----- cub/cub/warp/warp_exchange.cuh | 4 ---- cub/cub/warp/warp_load.cuh | 5 +---- cub/cub/warp/warp_merge_sort.cuh | 9 +------- cub/cub/warp/warp_reduce.cuh | 8 +++---- cub/cub/warp/warp_scan.cuh | 4 +--- cub/cub/warp/warp_store.cuh | 5 +---- cub/test/catch2_test_warp_exchange.cuh | 12 +++++----- docs/cub/developer_overview.rst | 18 ++++++--------- thrust/thrust/system/cuda/detail/core/util.h | 9 +------- thrust/thrust/system/cuda/detail/reduce.h | 2 +- .../thrust/system/cuda/detail/reduce_by_key.h | 8 +++---- .../system/cuda/detail/set_operations.h | 4 ++-- thrust/thrust/system/cuda/detail/unique.h | 6 ++--- 46 files changed, 85 insertions(+), 203 deletions(-) diff --git a/cub/cub/agent/agent_batch_memcpy.cuh b/cub/cub/agent/agent_batch_memcpy.cuh index 2b926f582fe..c2cf936bd87 100644 --- a/cub/cub/agent/agent_batch_memcpy.cuh +++ b/cub/cub/agent/agent_batch_memcpy.cuh @@ -642,14 +642,12 @@ private: TilePrefixCallbackOp, BLevBufferOffsetTileState, - 0, typename AgentMemcpySmallBuffersPolicyT::buff_delay_constructor>; using BLevBlockScanPrefixCallbackOpT = TilePrefixCallbackOp, BLevBlockOffsetTileState, - 0, typename AgentMemcpySmallBuffersPolicyT::block_delay_constructor>; //----------------------------------------------------------------------------- diff --git a/cub/cub/agent/agent_histogram.cuh b/cub/cub/agent/agent_histogram.cuh index 2e98bf76771..400d1778b11 100644 --- a/cub/cub/agent/agent_histogram.cuh +++ b/cub/cub/agent/agent_histogram.cuh @@ -172,9 +172,6 @@ namespace histogram * * @tparam OffsetT * Signed integer type for global offsets - * - * @tparam LEGACY_PTX_ARCH - * PTX compute capability (unused) */ template + typename OffsetT> struct AgentHistogram { //--------------------------------------------------------------------- @@ -930,8 +926,7 @@ template + typename OffsetT> using AgentHistogram CCCL_DEPRECATED_BECAUSE("This class is considered an implementation detail and the public " "interface will be removed.") = detail::histogram::AgentHistogram< @@ -943,7 +938,6 @@ using AgentHistogram CCCL_DEPRECATED_BECAUSE("This class is considered an implem CounterT, PrivatizedDecodeOpT, OutputDecodeOpT, - OffsetT, - LEGACY_PTX_ARCH>; + OffsetT>; CUB_NAMESPACE_END diff --git a/cub/cub/agent/agent_reduce_by_key.cuh b/cub/cub/agent/agent_reduce_by_key.cuh index a90399f4325..fffa5a88e57 100644 --- a/cub/cub/agent/agent_reduce_by_key.cuh +++ b/cub/cub/agent/agent_reduce_by_key.cuh @@ -276,7 +276,7 @@ struct AgentReduceByKey // Callback type for obtaining tile prefix during block scan using DelayConstructorT = typename AgentReduceByKeyPolicyT::detail::delay_constructor_t; using TilePrefixCallbackOpT = - TilePrefixCallbackOp; + TilePrefixCallbackOp; // Key and value exchange types using KeyExchangeT = KeyOutputT[TILE_ITEMS + 1]; diff --git a/cub/cub/agent/agent_rle.cuh b/cub/cub/agent/agent_rle.cuh index 2ea0729db92..fabc0b721ae 100644 --- a/cub/cub/agent/agent_rle.cuh +++ b/cub/cub/agent/agent_rle.cuh @@ -258,7 +258,7 @@ struct AgentRle // Callback type for obtaining tile prefix during block scan using DelayConstructorT = typename AgentRlePolicyT::detail::delay_constructor_t; using TilePrefixCallbackOpT = - TilePrefixCallbackOp; + TilePrefixCallbackOp; // Warp exchange types using WarpExchangePairs = WarpExchange; diff --git a/cub/cub/agent/agent_scan.cuh b/cub/cub/agent/agent_scan.cuh index c3cc02b69a1..9f29615a5cd 100644 --- a/cub/cub/agent/agent_scan.cuh +++ b/cub/cub/agent/agent_scan.cuh @@ -201,7 +201,7 @@ struct AgentScan // Callback type for obtaining tile prefix during block scan using DelayConstructorT = typename AgentScanPolicyT::detail::delay_constructor_t; - using TilePrefixCallbackOpT = TilePrefixCallbackOp; + using TilePrefixCallbackOpT = TilePrefixCallbackOp; // Stateful BlockScan prefix callback type for managing a running total while // scanning consecutive tiles diff --git a/cub/cub/agent/agent_scan_by_key.cuh b/cub/cub/agent/agent_scan_by_key.cuh index 722a44ac074..161a8a5c237 100644 --- a/cub/cub/agent/agent_scan_by_key.cuh +++ b/cub/cub/agent/agent_scan_by_key.cuh @@ -179,7 +179,7 @@ struct AgentScanByKey using DelayConstructorT = typename AgentScanByKeyPolicyT::detail::delay_constructor_t; using TilePrefixCallbackT = - TilePrefixCallbackOp; + TilePrefixCallbackOp; using BlockScanT = BlockScan; diff --git a/cub/cub/agent/agent_select_if.cuh b/cub/cub/agent/agent_select_if.cuh index 37e7b838adf..b1785651f12 100644 --- a/cub/cub/agent/agent_select_if.cuh +++ b/cub/cub/agent/agent_select_if.cuh @@ -274,7 +274,7 @@ struct AgentSelectIf // Callback type for obtaining tile prefix during block scan using DelayConstructorT = typename AgentSelectIfPolicyT::detail::delay_constructor_t; using TilePrefixCallbackOpT = - TilePrefixCallbackOp, MemoryOrderedTileStateT, 0, DelayConstructorT>; + TilePrefixCallbackOp, MemoryOrderedTileStateT, DelayConstructorT>; // Item exchange type using ItemExchangeT = InputT[TILE_ITEMS]; diff --git a/cub/cub/agent/agent_three_way_partition.cuh b/cub/cub/agent/agent_three_way_partition.cuh index 047861254ac..f36151f916f 100644 --- a/cub/cub/agent/agent_three_way_partition.cuh +++ b/cub/cub/agent/agent_three_way_partition.cuh @@ -207,7 +207,7 @@ struct AgentThreeWayPartition // Callback type for obtaining tile prefix during block scan using DelayConstructorT = typename PolicyT::detail::delay_constructor_t; using TilePrefixCallbackOpT = - cub::TilePrefixCallbackOp, ScanTileStateT, 0, DelayConstructorT>; + cub::TilePrefixCallbackOp, ScanTileStateT, DelayConstructorT>; // Item exchange type using ItemExchangeT = InputT[TILE_ITEMS]; diff --git a/cub/cub/agent/agent_unique_by_key.cuh b/cub/cub/agent/agent_unique_by_key.cuh index a1a731f150f..52ca1d9b3a2 100644 --- a/cub/cub/agent/agent_unique_by_key.cuh +++ b/cub/cub/agent/agent_unique_by_key.cuh @@ -179,9 +179,8 @@ struct AgentUniqueByKey using BlockScanT = cub::BlockScan; // Parameterized BlockDiscontinuity type for items - using DelayConstructorT = typename AgentUniqueByKeyPolicyT::detail::delay_constructor_t; - using TilePrefixCallback = - cub::TilePrefixCallbackOp, ScanTileStateT, 0, DelayConstructorT>; + using DelayConstructorT = typename AgentUniqueByKeyPolicyT::detail::delay_constructor_t; + using TilePrefixCallback = cub::TilePrefixCallbackOp, ScanTileStateT, DelayConstructorT>; // Key exchange type using KeyExchangeT = KeyT[ITEMS_PER_TILE]; diff --git a/cub/cub/agent/single_pass_scan_operators.cuh b/cub/cub/agent/single_pass_scan_operators.cuh index bd6551b8f8d..98769aa7791 100644 --- a/cub/cub/agent/single_pass_scan_operators.cuh +++ b/cub/cub/agent/single_pass_scan_operators.cuh @@ -1170,7 +1170,6 @@ struct ReduceByKeyScanTileState template > struct TilePrefixCallbackOp { diff --git a/cub/cub/block/block_adjacent_difference.cuh b/cub/cub/block/block_adjacent_difference.cuh index 38636571e80..119ca4f328e 100644 --- a/cub/cub/block/block_adjacent_difference.cuh +++ b/cub/cub/block/block_adjacent_difference.cuh @@ -122,7 +122,7 @@ CUB_NAMESPACE_BEGIN //! ``{ [4,-2,-1,0], [0,0,0,0], [1,1,0,0], [0,1,-3,3], ... }``. //! //! @endrst -template +template class BlockAdjacentDifference { private: diff --git a/cub/cub/block/block_discontinuity.cuh b/cub/cub/block/block_discontinuity.cuh index e4998f32510..c175ac96458 100644 --- a/cub/cub/block/block_discontinuity.cuh +++ b/cub/cub/block/block_discontinuity.cuh @@ -122,9 +122,7 @@ CUB_NAMESPACE_BEGIN //! @tparam BLOCK_DIM_Z //! **[optional]** The thread block length in threads along the Z dimension (default: 1) //! -//! @tparam LEGACY_PTX_ARCH -//! **[optional]** Unused -template +template class BlockDiscontinuity { private: diff --git a/cub/cub/block/block_exchange.cuh b/cub/cub/block/block_exchange.cuh index d1ae91c223d..402c60fe5a4 100644 --- a/cub/cub/block/block_exchange.cuh +++ b/cub/cub/block/block_exchange.cuh @@ -137,15 +137,12 @@ CUB_NAMESPACE_BEGIN //! @tparam BLOCK_DIM_Z //! **[optional]** The thread block length in threads along the Z dimension (default: 1) //! -//! @tparam LEGACY_PTX_ARCH -//! [optional] Unused. template + int BLOCK_DIM_Z = 1> class BlockExchange { static constexpr int BLOCK_THREADS = BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z; ///< The thread block size in threads diff --git a/cub/cub/block/block_histogram.cuh b/cub/cub/block/block_histogram.cuh index 41abbd588b3..5ebd5c9371d 100644 --- a/cub/cub/block/block_histogram.cuh +++ b/cub/cub/block/block_histogram.cuh @@ -179,16 +179,13 @@ enum BlockHistogramAlgorithm //! @tparam BLOCK_DIM_Z //! **[optional]** The thread block length in threads along the Z dimension (default: 1) //! -//! @tparam LEGACY_PTX_ARCH -//! **[optional]** Unused. template + int BLOCK_DIM_Z = 1> class BlockHistogram { private: diff --git a/cub/cub/block/block_load.cuh b/cub/cub/block/block_load.cuh index c1e9b95ac56..f4a693f4750 100644 --- a/cub/cub/block/block_load.cuh +++ b/cub/cub/block/block_load.cuh @@ -790,15 +790,12 @@ enum BlockLoadAlgorithm //! @tparam BLOCK_DIM_Z //! **[optional]** The thread block length in threads along the Z dimension (default: 1) //! -//! @tparam LEGACY_PTX_ARCH -//! **[optional]** Unused. template + int BLOCK_DIM_Z = 1> class BlockLoad { static constexpr int BLOCK_THREADS = BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z; // total threads in the block diff --git a/cub/cub/block/block_radix_rank.cuh b/cub/cub/block/block_radix_rank.cuh index ad495e1db31..6a899b1440a 100644 --- a/cub/cub/block/block_radix_rank.cuh +++ b/cub/cub/block/block_radix_rank.cuh @@ -204,8 +204,6 @@ struct warp_in_block_matcher_t //! @tparam BLOCK_DIM_Z //! **[optional]** The thread block length in threads along the Z dimension (default: 1) //! -//! @tparam LEGACY_PTX_ARCH -//! **[optional]** Unused. template + int BLOCK_DIM_Z = 1> class BlockRadixRank { private: @@ -560,8 +557,7 @@ template + int BLOCK_DIM_Z = 1> class BlockRadixRankMatch { private: diff --git a/cub/cub/block/block_radix_sort.cuh b/cub/cub/block/block_radix_sort.cuh index 080053348d7..55dd8747ee4 100644 --- a/cub/cub/block/block_radix_sort.cuh +++ b/cub/cub/block/block_radix_sort.cuh @@ -238,8 +238,6 @@ CUB_NAMESPACE_BEGIN //! @tparam BLOCK_DIM_Z //! **[optional]** The thread block length in threads along the Z dimension (default: 1) //! -//! @tparam LEGACY_PTX_ARCH -//! **[optional]** Unused template + int BLOCK_DIM_Z = 1> class BlockRadixSort { private: diff --git a/cub/cub/block/block_raking_layout.cuh b/cub/cub/block/block_raking_layout.cuh index 4d675b626b8..8f7f8b138c0 100644 --- a/cub/cub/block/block_raking_layout.cuh +++ b/cub/cub/block/block_raking_layout.cuh @@ -68,9 +68,7 @@ CUB_NAMESPACE_BEGIN //! @tparam BLOCK_THREADS //! The thread block size in threads. //! -//! @tparam LEGACY_PTX_ARCH -//! **[optional]** Unused. -template +template struct BlockRakingLayout { //--------------------------------------------------------------------- diff --git a/cub/cub/block/block_reduce.cuh b/cub/cub/block/block_reduce.cuh index 6cf578963fc..356134d3b40 100644 --- a/cub/cub/block/block_reduce.cuh +++ b/cub/cub/block/block_reduce.cuh @@ -232,14 +232,11 @@ enum BlockReduceAlgorithm //! @tparam BLOCK_DIM_Z //! **[optional]** The thread block length in threads along the Z dimension (default: 1) //! -//! @tparam LEGACY_PTX_ARCH -//! **[optional]** Unused. template + int BLOCK_DIM_Z = 1> class BlockReduce { private: diff --git a/cub/cub/block/block_scan.cuh b/cub/cub/block/block_scan.cuh index c25bd2d258d..de019116956 100644 --- a/cub/cub/block/block_scan.cuh +++ b/cub/cub/block/block_scan.cuh @@ -221,14 +221,11 @@ enum BlockScanAlgorithm //! @tparam BLOCK_DIM_Z //! **[optional]** The thread block length in threads along the Z dimension (default: 1) //! -//! @tparam LEGACY_PTX_ARCH -//! **[optional]** Unused. template + int BLOCK_DIM_Z = 1> class BlockScan { private: diff --git a/cub/cub/block/block_shuffle.cuh b/cub/cub/block/block_shuffle.cuh index 93d8715c63b..0cb42eba3a0 100644 --- a/cub/cub/block/block_shuffle.cuh +++ b/cub/cub/block/block_shuffle.cuh @@ -73,9 +73,7 @@ CUB_NAMESPACE_BEGIN //! @tparam BLOCK_DIM_Z //! **[optional]** The thread block length in threads along the Z dimension (default: 1) //! -//! @tparam LEGACY_PTX_ARCH -//! **[optional]** Unused -template +template class BlockShuffle { private: diff --git a/cub/cub/block/block_store.cuh b/cub/cub/block/block_store.cuh index e207a1d76c1..a2cd74fcd90 100644 --- a/cub/cub/block/block_store.cuh +++ b/cub/cub/block/block_store.cuh @@ -639,15 +639,12 @@ enum BlockStoreAlgorithm //! @tparam BLOCK_DIM_Z //! **[optional]** The thread block length in threads along the Z dimension (default: 1) //! -//! @tparam LEGACY_PTX_ARCH -//! **[optional]** Unused. template + int BLOCK_DIM_Z = 1> class BlockStore { private: diff --git a/cub/cub/block/specializations/block_histogram_sort.cuh b/cub/cub/block/specializations/block_histogram_sort.cuh index 127f30953b2..b5e0f7beae2 100644 --- a/cub/cub/block/specializations/block_histogram_sort.cuh +++ b/cub/cub/block/specializations/block_histogram_sort.cuh @@ -72,17 +72,8 @@ namespace detail * * @tparam BLOCK_DIM_Z * The thread block length in threads along the Z dimension - * - * @tparam LEGACY_PTX_ARCH - * The PTX compute capability for which to to specialize this collective (unused) */ -template +template struct BlockHistogramSort { /// Constants @@ -246,16 +237,9 @@ struct BlockHistogramSort }; } // namespace detail -template +template using BlockHistogramSort CCCL_DEPRECATED_BECAUSE( "This class is considered an implementation detail and the public interface will be " - "removed.") = - detail::BlockHistogramSort; + "removed.") = detail::BlockHistogramSort; CUB_NAMESPACE_END diff --git a/cub/cub/block/specializations/block_reduce_raking.cuh b/cub/cub/block/specializations/block_reduce_raking.cuh index 90f8f12236f..a45a16f6e0d 100644 --- a/cub/cub/block/specializations/block_reduce_raking.cuh +++ b/cub/cub/block/specializations/block_reduce_raking.cuh @@ -77,11 +77,8 @@ namespace detail * * @tparam BLOCK_DIM_Z * The thread block length in threads along the Z dimension - * - * @tparam LEGACY_PTX_ARCH - * The PTX compute capability for which to to specialize this collective */ -template +template struct BlockReduceRaking { /// Constants @@ -260,9 +257,9 @@ struct BlockReduceRaking }; } // namespace detail -template +template using BlockReduceRaking CCCL_DEPRECATED_BECAUSE( "This class is considered an implementation detail and the public interface will be " - "removed.") = detail::BlockReduceRaking; + "removed.") = detail::BlockReduceRaking; CUB_NAMESPACE_END diff --git a/cub/cub/block/specializations/block_reduce_raking_commutative_only.cuh b/cub/cub/block/specializations/block_reduce_raking_commutative_only.cuh index 7841db5f18a..28ff55b5fe0 100644 --- a/cub/cub/block/specializations/block_reduce_raking_commutative_only.cuh +++ b/cub/cub/block/specializations/block_reduce_raking_commutative_only.cuh @@ -68,11 +68,8 @@ namespace detail * * @tparam BLOCK_DIM_Z * The thread block length in threads along the Z dimension - * - * @tparam LEGACY_PTX_ARCH - * The PTX compute capability for which to to specialize this collective */ -template +template struct BlockReduceRakingCommutativeOnly { /// Constants @@ -234,9 +231,9 @@ struct BlockReduceRakingCommutativeOnly }; } // namespace detail -template +template using BlockReduceRakingCommutativeOnly CCCL_DEPRECATED_BECAUSE( "This class is considered an implementation detail and the public interface will be " - "removed.") = detail::BlockReduceRakingCommutativeOnly; + "removed.") = detail::BlockReduceRakingCommutativeOnly; CUB_NAMESPACE_END diff --git a/cub/cub/block/specializations/block_reduce_warp_reductions.cuh b/cub/cub/block/specializations/block_reduce_warp_reductions.cuh index 2dfa526771f..b6e70248b1e 100644 --- a/cub/cub/block/specializations/block_reduce_warp_reductions.cuh +++ b/cub/cub/block/specializations/block_reduce_warp_reductions.cuh @@ -67,11 +67,8 @@ namespace detail * * @tparam BLOCK_DIM_Z * The thread block length in threads along the Z dimension - * - * @tparam LEGACY_PTX_ARCH - * The PTX compute capability for which to to specialize this collective */ -template +template struct BlockReduceWarpReductions { /// Constants @@ -259,9 +256,9 @@ struct BlockReduceWarpReductions }; } // namespace detail -template +template using BlockReduceWarpReductions CCCL_DEPRECATED_BECAUSE( "This class is considered an implementation detail and the public interface will be " - "removed.") = detail::BlockReduceWarpReductions; + "removed.") = detail::BlockReduceWarpReductions; CUB_NAMESPACE_END diff --git a/cub/cub/block/specializations/block_scan_raking.cuh b/cub/cub/block/specializations/block_scan_raking.cuh index 2af4b8693fc..26d9d949226 100644 --- a/cub/cub/block/specializations/block_scan_raking.cuh +++ b/cub/cub/block/specializations/block_scan_raking.cuh @@ -73,11 +73,8 @@ namespace detail * @tparam MEMOIZE * Whether or not to buffer outer raking scan partials to incur fewer shared memory reads at the * expense of higher register pressure - * - * @tparam LEGACY_PTX_ARCH - * The PTX compute capability for which to to specialize this collective */ -template +template struct BlockScanRaking { //--------------------------------------------------------------------- @@ -797,9 +794,9 @@ struct BlockScanRaking }; } // namespace detail -template +template using BlockScanRaking CCCL_DEPRECATED_BECAUSE( "This class is considered an implementation detail and the public interface will be " - "removed.") = detail::BlockScanRaking; + "removed.") = detail::BlockScanRaking; CUB_NAMESPACE_END diff --git a/cub/cub/block/specializations/block_scan_warp_scans.cuh b/cub/cub/block/specializations/block_scan_warp_scans.cuh index d034d2838ea..4fc74b423ce 100644 --- a/cub/cub/block/specializations/block_scan_warp_scans.cuh +++ b/cub/cub/block/specializations/block_scan_warp_scans.cuh @@ -64,11 +64,8 @@ namespace detail * * @tparam BLOCK_DIM_Z * The thread block length in threads along the Z dimension - * - * @tparam LEGACY_PTX_ARCH - * The PTX compute capability for which to to specialize this collective */ -template +template struct BlockScanWarpScans { //--------------------------------------------------------------------- @@ -539,9 +536,9 @@ struct BlockScanWarpScans } }; } // namespace detail -template +template using BlockScanWarpScans CCCL_DEPRECATED_BECAUSE( "This class is considered an implementation detail and the public interface will be " - "removed.") = detail::BlockScanWarpScans; + "removed.") = detail::BlockScanWarpScans; CUB_NAMESPACE_END diff --git a/cub/cub/util_ptx.cuh b/cub/cub/util_ptx.cuh index e6bb45c4a31..8e37c287109 100644 --- a/cub/cub/util_ptx.cuh +++ b/cub/cub/util_ptx.cuh @@ -384,7 +384,7 @@ _CCCL_DEVICE _CCCL_FORCEINLINE unsigned int WarpId() * hardware warp threads). * @param warp_id Id of virtual warp within architectural warp */ -template +template _CCCL_HOST_DEVICE _CCCL_FORCEINLINE unsigned int WarpMask(unsigned int warp_id) { constexpr bool is_pow_of_two = PowerOfTwo::VALUE; diff --git a/cub/cub/warp/specializations/warp_reduce_shfl.cuh b/cub/cub/warp/specializations/warp_reduce_shfl.cuh index 8c4ad78d1ad..3592e7c920a 100644 --- a/cub/cub/warp/specializations/warp_reduce_shfl.cuh +++ b/cub/cub/warp/specializations/warp_reduce_shfl.cuh @@ -92,11 +92,8 @@ struct reduce_max_exists : ::cu * * @tparam LOGICAL_WARP_THREADS * Number of threads per logical warp (must be a power-of-two) - * - * @tparam LEGACY_PTX_ARCH - * The PTX compute capability for which to to specialize this collective */ -template +template struct WarpReduceShfl { static_assert(PowerOfTwo::VALUE, "LOGICAL_WARP_THREADS must be a power of two"); @@ -739,9 +736,9 @@ struct WarpReduceShfl }; } // namespace detail -template +template using WarpReduceShfl CCCL_DEPRECATED_BECAUSE( "This class is considered an implementation detail and the public interface will be " - "removed.") = detail::WarpReduceShfl; + "removed.") = detail::WarpReduceShfl; CUB_NAMESPACE_END diff --git a/cub/cub/warp/specializations/warp_reduce_smem.cuh b/cub/cub/warp/specializations/warp_reduce_smem.cuh index ade195ee6cb..b4e509d6766 100644 --- a/cub/cub/warp/specializations/warp_reduce_smem.cuh +++ b/cub/cub/warp/specializations/warp_reduce_smem.cuh @@ -63,11 +63,8 @@ namespace detail * * @tparam LOGICAL_WARP_THREADS * Number of threads per logical warp - * - * @tparam LEGACY_PTX_ARCH - * The PTX compute capability for which to to specialize this collective */ -template +template struct WarpReduceSmem { /****************************************************************************** @@ -414,8 +411,8 @@ struct WarpReduceSmem }; } // namespace detail -template +template using WarpReduceSmem CCCL_DEPRECATED_BECAUSE( "This class is considered an implementation detail and the public interface will be " - "removed.") = detail::WarpReduceSmem; + "removed.") = detail::WarpReduceSmem; CUB_NAMESPACE_END diff --git a/cub/cub/warp/specializations/warp_scan_shfl.cuh b/cub/cub/warp/specializations/warp_scan_shfl.cuh index 402b476c4e4..4b3b115b266 100644 --- a/cub/cub/warp/specializations/warp_scan_shfl.cuh +++ b/cub/cub/warp/specializations/warp_scan_shfl.cuh @@ -62,11 +62,8 @@ namespace detail * * @tparam LOGICAL_WARP_THREADS * Number of threads per logical warp (must be a power-of-two) - * - * @tparam LEGACY_PTX_ARCH - * The PTX compute capability for which to to specialize this collective */ -template +template struct WarpScanShfl { //--------------------------------------------------------------------- @@ -677,9 +674,9 @@ struct WarpScanShfl }; } // namespace detail -template +template using WarpScanShfl CCCL_DEPRECATED_BECAUSE( "This class is considered an implementation detail and the public interface will be " - "removed.") = detail::WarpScanShfl; + "removed.") = detail::WarpScanShfl; CUB_NAMESPACE_END diff --git a/cub/cub/warp/specializations/warp_scan_smem.cuh b/cub/cub/warp/specializations/warp_scan_smem.cuh index 090f0f96cb5..e6d18fb561f 100644 --- a/cub/cub/warp/specializations/warp_scan_smem.cuh +++ b/cub/cub/warp/specializations/warp_scan_smem.cuh @@ -63,11 +63,8 @@ namespace detail * * @tparam LOGICAL_WARP_THREADS * Number of threads per logical warp - * - * @tparam LEGACY_PTX_ARCH - * The PTX compute capability for which to to specialize this collective */ -template +template struct WarpScanSmem { /****************************************************************************** @@ -435,9 +432,9 @@ struct WarpScanSmem }; } // namespace detail -template +template using WarpScanSmem CCCL_DEPRECATED_BECAUSE( "This class is considered an implementation detail and the public interface will be " - "removed.") = detail::WarpScanSmem; + "removed.") = detail::WarpScanSmem; CUB_NAMESPACE_END diff --git a/cub/cub/warp/warp_exchange.cuh b/cub/cub/warp/warp_exchange.cuh index 79f422f5abe..7ce5997a446 100644 --- a/cub/cub/warp/warp_exchange.cuh +++ b/cub/cub/warp/warp_exchange.cuh @@ -83,9 +83,6 @@ using InternalWarpExchangeImpl = * targeted CUDA compute-capability (e.g., 32 threads for SM86). Must be a * power of two. * - * @tparam LEGACY_PTX_ARCH - * Unused. - * * @par Overview * - It is commonplace for a warp of threads to rearrange data items between * threads. For example, the global memory accesses prefer patterns where @@ -139,7 +136,6 @@ using InternalWarpExchangeImpl = template class WarpExchange : private detail::InternalWarpExchangeImpl diff --git a/cub/cub/warp/warp_load.cuh b/cub/cub/warp/warp_load.cuh index 3f11129c35a..b945a5355b2 100644 --- a/cub/cub/warp/warp_load.cuh +++ b/cub/cub/warp/warp_load.cuh @@ -216,13 +216,10 @@ enum WarpLoadAlgorithm //! targeted CUDA compute-capability (e.g., 32 threads for SM86). Must be a //! power of two. //! -//! @tparam LEGACY_PTX_ARCH -//! Unused. template + int LOGICAL_WARP_THREADS = CUB_PTX_WARP_THREADS> class WarpLoad { static constexpr bool IS_ARCH_WARP = LOGICAL_WARP_THREADS == CUB_WARP_THREADS(0); diff --git a/cub/cub/warp/warp_merge_sort.cuh b/cub/cub/warp/warp_merge_sort.cuh index de3d311ae59..447dc4d00c2 100644 --- a/cub/cub/warp/warp_merge_sort.cuh +++ b/cub/cub/warp/warp_merge_sort.cuh @@ -122,14 +122,7 @@ CUB_NAMESPACE_BEGIN //! [optional] Value type (default: cub::NullType, which indicates a //! keys-only sort) //! -//! @tparam LEGACY_PTX_ARCH -//! Unused. -//! -template +template class WarpMergeSort : public BlockMergeSortStrategy[optional] Unused. -template +template class WarpReduce { private: @@ -663,8 +661,8 @@ public: }; #ifndef _CCCL_DOXYGEN_INVOKED // Do not document -template -class WarpReduce +template +class WarpReduce { private: using _TempStorage = cub::NullType; diff --git a/cub/cub/warp/warp_scan.cuh b/cub/cub/warp/warp_scan.cuh index 6eb6a35562b..e1c07c82691 100644 --- a/cub/cub/warp/warp_scan.cuh +++ b/cub/cub/warp/warp_scan.cuh @@ -156,9 +156,7 @@ CUB_NAMESPACE_BEGIN //! hardware warp threads). Default is the warp size associated with the CUDA Compute Capability //! targeted by the compiler (e.g., 32 threads for SM20). //! -//! @tparam LEGACY_PTX_ARCH -//! **[optional]** Unused. -template +template class WarpScan { private: diff --git a/cub/cub/warp/warp_store.cuh b/cub/cub/warp/warp_store.cuh index f0a9929e24f..a7ccb899607 100644 --- a/cub/cub/warp/warp_store.cuh +++ b/cub/cub/warp/warp_store.cuh @@ -223,13 +223,10 @@ enum WarpStoreAlgorithm //! targeted CUDA compute-capability (e.g., 32 threads for SM86). Must be a //! power of two. //! -//! @tparam LEGACY_PTX_ARCH -//! Unused. template + int LOGICAL_WARP_THREADS = CUB_PTX_WARP_THREADS> class WarpStore { static_assert(PowerOfTwo::VALUE, "LOGICAL_WARP_THREADS must be a power of two"); diff --git a/cub/test/catch2_test_warp_exchange.cuh b/cub/test/catch2_test_warp_exchange.cuh index 4b3b13563c0..e240abc7f48 100644 --- a/cub/test/catch2_test_warp_exchange.cuh +++ b/cub/test/catch2_test_warp_exchange.cuh @@ -53,7 +53,7 @@ struct exchange_data_t inline __device__ void - scatter(cub::WarpExchange& exchange, int (&ranks)[ItemsPerThread]) + scatter(cub::WarpExchange& exchange, int (&ranks)[ItemsPerThread]) { exchange.ScatterToStriped(input, ranks); } @@ -71,7 +71,7 @@ struct exchange_data_t inline __device__ void - scatter(cub::WarpExchange& exchange, int (&ranks)[ItemsPerThread]) + scatter(cub::WarpExchange& exchange, int (&ranks)[ItemsPerThread]) { exchange.ScatterToStriped(input, output, ranks); } @@ -85,7 +85,7 @@ template __global__ void scatter_kernel(const InputT* input_data, OutputT* output_data) { - using warp_exchange_t = cub::WarpExchange; + using warp_exchange_t = cub::WarpExchange; using storage_t = typename warp_exchange_t::TempStorage; constexpr int tile_size = ITEMS_PER_THREAD * LOGICAL_WARP_THREADS; @@ -147,7 +147,7 @@ template __global__ void kernel(const InputT* input_data, OutputT* output_data, ActionT action) { - using warp_exchange_t = cub::WarpExchange; + using warp_exchange_t = cub::WarpExchange; using storage_t = typename warp_exchange_t::TempStorage; constexpr int tile_size = ITEMS_PER_THREAD * LOGICAL_WARP_THREADS; @@ -205,7 +205,7 @@ struct blocked_to_striped cub::WarpExchangeAlgorithm Alg> __device__ void operator()(InputT (&input)[ITEMS_PER_THREAD], OutputT (&output)[ITEMS_PER_THREAD], - cub::WarpExchange& exchange) + cub::WarpExchange& exchange) { exchange.BlockedToStriped(input, output); } @@ -221,7 +221,7 @@ struct striped_to_blocked cub::WarpExchangeAlgorithm Alg> __device__ void operator()(InputT (&input)[ITEMS_PER_THREAD], OutputT (&output)[ITEMS_PER_THREAD], - cub::WarpExchange& exchange) + cub::WarpExchange& exchange) { exchange.StripedToBlocked(input, output); } diff --git a/docs/cub/developer_overview.rst b/docs/cub/developer_overview.rst index 8b31dab6283..29f02995ac4 100644 --- a/docs/cub/developer_overview.rst +++ b/docs/cub/developer_overview.rst @@ -157,8 +157,7 @@ For example, :cpp:struct:`cub::WarpReduce` is a class template: .. code-block:: c++ template + int LOGICAL_WARP_THREADS = 32> class WarpReduce { // ... // (1) define `_TempStorage` type @@ -193,10 +192,6 @@ There is a vital difference in the behavior of warp-level algorithms that depend .. TODO: Add diagram showing non-power of two logical warps. -It's important to note that ``LEGACY_PTX_ARCH`` has been recently deprecated. -This parameter used to affect specialization selection (see below). -It was conflicting with the PTX dispatch refactoring and limited NVHPC support. - Temporary storage usage ==================================== @@ -258,13 +253,15 @@ and algorithm implementation look like: .Reduce(input, valid_items, ::cuda::std::plus<>{}); } -Due to ``LEGACY_PTX_ARCH`` issues described above, -we can't specialize on the PTX version. + + +``__CUDA_ARCH__`` cannot be used because it is conflicting with the PTX dispatch refactoring and limited NVHPC support. +Due to this limitation, we can't specialize on the PTX version. ``NV_IF_TARGET`` shall be used by specializations instead: .. code-block:: c++ - template + template struct WarpReduceShfl { @@ -314,8 +311,7 @@ Block-scope algorithms are provided by structures as well: int BLOCK_DIM_X, BlockReduceAlgorithm ALGORITHM = BLOCK_REDUCE_WARP_REDUCTIONS, int BLOCK_DIM_Y = 1, - int BLOCK_DIM_Z = 1, - int LEGACY_PTX_ARCH = 0> + int BLOCK_DIM_Z = 1> class BlockReduce { public: struct TempStorage : Uninitialized<_TempStorage> {}; diff --git a/thrust/thrust/system/cuda/detail/core/util.h b/thrust/thrust/system/cuda/detail/core/util.h index b3bdcf1f086..1b11f459c71 100644 --- a/thrust/thrust/system/cuda/detail/core/util.h +++ b/thrust/thrust/system/cuda/detail/core/util.h @@ -488,14 +488,7 @@ struct get_arch> template ::value_type> struct BlockLoad { - using type = - cub::BlockLoad::type::ver>; + using type = cub::BlockLoad; }; // cuda_optional diff --git a/thrust/thrust/system/cuda/detail/reduce.h b/thrust/thrust/system/cuda/detail/reduce.h index 61ec2086adf..8ef245dc082 100644 --- a/thrust/thrust/system/cuda/detail/reduce.h +++ b/thrust/thrust/system/cuda/detail/reduce.h @@ -156,7 +156,7 @@ struct ReduceAgent using Vector = typename cub::CubVector; using LoadIt = typename core::detail::LoadIterator::type; - using BlockReduce = cub::BlockReduce; + using BlockReduce = cub::BlockReduce; using VectorLoadIt = cub::CacheModifiedInputIterator; diff --git a/thrust/thrust/system/cuda/detail/reduce_by_key.h b/thrust/thrust/system/cuda/detail/reduce_by_key.h index 8c1db436085..fc6ceefa21b 100644 --- a/thrust/thrust/system/cuda/detail/reduce_by_key.h +++ b/thrust/thrust/system/cuda/detail/reduce_by_key.h @@ -169,12 +169,10 @@ struct ReduceByKeyAgent using BlockLoadKeys = typename core::detail::BlockLoad::type; using BlockLoadValues = typename core::detail::BlockLoad::type; - using BlockDiscontinuityKeys = cub::BlockDiscontinuity; + using BlockDiscontinuityKeys = cub::BlockDiscontinuity; - using TilePrefixCallback = - cub::TilePrefixCallbackOp; - using BlockScan = - cub::BlockScan; + using TilePrefixCallback = cub::TilePrefixCallbackOp; + using BlockScan = cub::BlockScan; union TempStorage { diff --git a/thrust/thrust/system/cuda/detail/set_operations.h b/thrust/thrust/system/cuda/detail/set_operations.h index b336f8e55fa..85d03ae51cb 100644 --- a/thrust/thrust/system/cuda/detail/set_operations.h +++ b/thrust/thrust/system/cuda/detail/set_operations.h @@ -300,9 +300,9 @@ struct SetOpAgent using BlockLoadValues1 = typename core::detail::BlockLoad::type; using BlockLoadValues2 = typename core::detail::BlockLoad::type; - using TilePrefixCallback = cub::TilePrefixCallbackOp, ScanTileState, Arch::ver>; + using TilePrefixCallback = cub::TilePrefixCallbackOp, ScanTileState>; - using BlockScan = cub::BlockScan; + using BlockScan = cub::BlockScan; // gather required temporary storage in a union // diff --git a/thrust/thrust/system/cuda/detail/unique.h b/thrust/thrust/system/cuda/detail/unique.h index 1d39b161866..bb5092ba9ef 100644 --- a/thrust/thrust/system/cuda/detail/unique.h +++ b/thrust/thrust/system/cuda/detail/unique.h @@ -153,10 +153,10 @@ struct UniqueAgent using BlockLoadItems = typename core::detail::BlockLoad::type; - using BlockDiscontinuityItems = cub::BlockDiscontinuity; + using BlockDiscontinuityItems = cub::BlockDiscontinuity; - using TilePrefixCallback = cub::TilePrefixCallbackOp, ScanTileState, Arch::ver>; - using BlockScan = cub::BlockScan; + using TilePrefixCallback = cub::TilePrefixCallbackOp, ScanTileState>; + using BlockScan = cub::BlockScan; using shared_items_t = core::detail::uninitialized_array;