From b349f12b3aa5d63a3dcdab02595c8a0d293d9e71 Mon Sep 17 00:00:00 2001 From: Bernhard Manfred Gruber Date: Tue, 28 Jan 2025 17:34:25 +0100 Subject: [PATCH] Tune cub::DeviceTransform for Blackwell (#3543) --- .../dispatch/tuning/tuning_transform.cuh | 22 ++++++++++++++----- 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/cub/cub/device/dispatch/tuning/tuning_transform.cuh b/cub/cub/device/dispatch/tuning/tuning_transform.cuh index 9315fc630c0..888b00557ae 100644 --- a/cub/cub/device/dispatch/tuning/tuning_transform.cuh +++ b/cub/cub/device/dispatch/tuning/tuning_transform.cuh @@ -169,11 +169,11 @@ struct policy_hub + template + struct bulkcopy_policy { - static constexpr int min_bif = arch_to_min_bytes_in_flight(900); - using async_policy = async_copy_policy_t<256>; + static constexpr int min_bif = arch_to_min_bytes_in_flight(PtxVersion); + using async_policy = async_copy_policy_t; static constexpr bool exhaust_smem = bulk_copy_smem_for_tile_size( async_policy::block_threads * async_policy::min_items_per_thread) @@ -188,10 +188,20 @@ struct policy_hub, async_policy>; + using algo_policy = ::cuda::std::_If, async_policy>; }; - using max_policy = policy900; + struct policy900 + : bulkcopy_policy<256, 900> + , ChainedPolicy<900, policy900, policy300> + {}; + + struct policy1000 + : bulkcopy_policy<128, 1000> + , ChainedPolicy<1000, policy1000, policy900> + {}; + + using max_policy = policy1000; #else // _CUB_HAS_TRANSFORM_UBLKCP using max_policy = policy300; #endif // _CUB_HAS_TRANSFORM_UBLKCP