diff --git a/CHANGELOG.md b/CHANGELOG.md index 8269f93b4..679ce6905 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,10 +2,10 @@ Full documentation for rocPRIM is available at [https://rocm.docs.amd.com/projects/rocPRIM/en/latest/](https://rocm.docs.amd.com/projects/rocPRIM/en/latest/). - ## (Unreleased) rocPRIM 3.4.0 for ROCm 6.4.0 ### Added + * Added extended tests to `rtest.py`. These tests are extra tests that did not fit the criteria of smoke and regression tests. These tests will take much longer to run relative to smoke and regression tests. * Use `python rtest.py [--emulation|-e|--test|-t]=extended` to run these tests. * Added regression tests to `rtest.py`. Regression tests are a subset of tests that caused hardware problems for past emulation environments. @@ -13,12 +13,17 @@ Full documentation for rocPRIM is available at [https://rocm.docs.amd.com/projec * Added the parallel `find_first_of` device function with autotuned configurations, this function is similar to `std::find_first_of`, it searches for the first occurrence of any of the provided elements. * Added `--emulation` option added for `rtest.py` * Unit tests can be run with `[--emulation|-e|--test|-t]=` +* Added tuned configurations for segmented radix sort for gfx942 to improve performance on this architecture. ### Changed + * Changed the subset of tests that are run for smoke tests such that the smoke test will complete with faster run-time and to never exceed 2GB of vram usage. Use `python rtest.py [--emulation|-e|--test|-t]=smoke` to run these tests. * The `rtest.py` options have changed. `rtest.py` is now run with at least either `--test|-t` or `--emulation|-e`, but not both options. +* Changed the internal algorithm of block radix sort to use rank match to improve performance of various radix sort related algorithms. +* Disabled padding in various cases where higher occupancy resulted in better performance despite more bank conflicts. ### Resolved issues + * Fixed an issue where `rmake.py` would generate wrong CMAKE commands while using Linux environment * Fixed an issue where `rocprim::partial_sort_copy` would yield a compile error if the input iterator is const. * Fixed incorrect 128-bit signed and unsigned integers type traits. diff --git a/benchmark/CMakeLists.txt b/benchmark/CMakeLists.txt index cda46e4a3..fae6d2cfc 100644 --- a/benchmark/CMakeLists.txt +++ b/benchmark/CMakeLists.txt @@ -24,6 +24,9 @@ option(BENCHMARK_CONFIG_TUNING "Benchmark device-level functions using various c include(../cmake/ConfigAutotune.cmake) include(ConfigAutotuneSettings.cmake) +option(BENCHMARK_TUNE_PARAM_NAMES "Tuning parameter names" "") +option(BENCHMARK_TUNE_PARAMS "Tuning parameters" "") + if(BENCHMARK_CONFIG_TUNING) add_custom_target("benchmark_config_tuning") endif() @@ -35,6 +38,12 @@ function(add_rocprim_benchmark BENCHMARK_SOURCE) if(EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/${BENCHMARK_TARGET}.parallel.cpp.in") message(STATUS "found ${BENCHMARK_TARGET}.parallel.cpp.in file, compiling in parallel.") read_config_autotune_settings(${BENCHMARK_TARGET} list_across_names list_across output_pattern_suffix) + + if(BENCHMARK_TUNE_PARAM_NAMES AND BENCHMARK_TUNE_PARAMS) + set(list_across_names "${BENCHMARK_TUNE_PARAM_NAMES}") + set(list_across "${BENCHMARK_TUNE_PARAMS}") + endif() + #make sure that variables are not empty, i.e. there actually is an entry for that benchmark in benchmark/ConfigAutotuneSettings.cmake if(list_across_names) add_executable(${BENCHMARK_TARGET} ${BENCHMARK_SOURCE}) diff --git a/benchmark/ConfigAutotuneSettings.cmake b/benchmark/ConfigAutotuneSettings.cmake index 8c18e334d..bdc34185b 100644 --- a/benchmark/ConfigAutotuneSettings.cmake +++ b/benchmark/ConfigAutotuneSettings.cmake @@ -83,16 +83,16 @@ binary_search upper_bound lower_bound;${TUNING_TYPES};${LIMITED_TUNING_TYPES};64 set(output_pattern_suffix "@SubAlgorithm@_@ValueType@_@OutputType@_@BlockSize@_@ItemsPerThread@" PARENT_SCOPE) elseif(file STREQUAL "benchmark_device_segmented_radix_sort_keys") set(list_across_names "\ -KeyType;BlockSize;ItemsPerThread;PartitionAllowed" PARENT_SCOPE) - set(list_across "${TUNING_TYPES};128 256;4 8 16;false" PARENT_SCOPE) +KeyType;LongBits;BlockSize;ItemsPerThread;WarpSmallLWS;WarpSmallIPT;WarpSmallBS;WarpPartition;WarpMediumLWS;WarpMediumIPT;WarpMediumBS" PARENT_SCOPE) + set(list_across "${TUNING_TYPES};8;256;4 8 16;8;4;256;64;16;8;256" PARENT_SCOPE) set(output_pattern_suffix "\ -@KeyType@_@BlockSize@_@ItemsPerThread@_@PartitionAllowed@" PARENT_SCOPE) +@KeyType@_@LongBits@_@BlockSize@_@ItemsPerThread@_@WarpSmallLWS@_@WarpSmallIPT@_@WarpSmallBS@_@WarpPartition@_@WarpMediumLWS@_@WarpMediumIPT@_@WarpMediumBS@" PARENT_SCOPE) elseif(file STREQUAL "benchmark_device_segmented_radix_sort_pairs") set(list_across_names "\ -KeyType;ValueType;BlockSize;ItemsPerThread;PartitionAllowed" PARENT_SCOPE) - set(list_across "${TUNING_TYPES};int8_t;64;4 8 16;true false" PARENT_SCOPE) +KeyType;ValueType;LongBits;BlockSize;ItemsPerThread;WarpSmallLWS;WarpSmallIPT;WarpSmallBS;WarpPartition;WarpMediumLWS;WarpMediumIPT;WarpMediumBS" PARENT_SCOPE) + set(list_across "${TUNING_TYPES};int8_t;8;256;4 8 16;8;4;256;64;16;8;256" PARENT_SCOPE) set(output_pattern_suffix "\ -@KeyType@_@ValueType@_@BlockSize@_@ItemsPerThread@_@PartitionAllowed@" PARENT_SCOPE) +@KeyType@_@ValueType@_@LongBits@_@BlockSize@_@ItemsPerThread@_@WarpSmallLWS@_@WarpSmallIPT@_@WarpSmallBS@_@WarpPartition@_@WarpMediumLWS@_@WarpMediumIPT@_@WarpMediumBS@" PARENT_SCOPE) elseif(file STREQUAL "benchmark_device_transform") set(list_across_names "\ DataType;BlockSize;" PARENT_SCOPE) diff --git a/benchmark/benchmark_device_segmented_radix_sort_keys.cpp b/benchmark/benchmark_device_segmented_radix_sort_keys.cpp index 157633cc8..7e2bb4477 100644 --- a/benchmark/benchmark_device_segmented_radix_sort_keys.cpp +++ b/benchmark/benchmark_device_segmented_radix_sort_keys.cpp @@ -291,7 +291,7 @@ int main(int argc, char* argv[]) config_autotune_register::register_benchmark_subset(benchmarks, parallel_instance, parallel_instances, - size, + min_size, seed, stream); #else diff --git a/benchmark/benchmark_device_segmented_radix_sort_keys.parallel.cpp.in b/benchmark/benchmark_device_segmented_radix_sort_keys.parallel.cpp.in index 4913fdff7..3e506148a 100644 --- a/benchmark/benchmark_device_segmented_radix_sort_keys.parallel.cpp.in +++ b/benchmark/benchmark_device_segmented_radix_sort_keys.parallel.cpp.in @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2022-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2022-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -27,8 +27,20 @@ namespace { -auto benchmarks = config_autotune_register::create_bulk(device_segmented_radix_sort_benchmark_generator<@BlockSize@, - @ItemsPerThread@, - @KeyType@, - @PartitionAllowed@>::create); +auto benchmarks = config_autotune_register::create_bulk( + device_segmented_radix_sort_benchmark_generator< + @LongBits@, + 0, + @BlockSize@, + @ItemsPerThread@, + @WarpSmallLWS@, + @WarpSmallIPT@, + @WarpSmallBS@, + @WarpPartition@, + @WarpMediumLWS@, + @WarpMediumIPT@, + @WarpMediumBS@, + @KeyType@, + true + >::create); } // namespace diff --git a/benchmark/benchmark_device_segmented_radix_sort_keys.parallel.hpp b/benchmark/benchmark_device_segmented_radix_sort_keys.parallel.hpp index 79534784a..adaad263b 100644 --- a/benchmark/benchmark_device_segmented_radix_sort_keys.parallel.hpp +++ b/benchmark/benchmark_device_segmented_radix_sort_keys.parallel.hpp @@ -261,92 +261,38 @@ struct device_segmented_radix_sort_benchmark : public config_autotune_interface template class T, bool enable, Tp... Idx> struct decider; -template + +template struct device_segmented_radix_sort_benchmark_generator { - template - struct create_lrb - { - template - struct create_srb - { - template - struct create_euws - { - template - struct create_lwss - { - template - struct create_pt - { - void operator()( - std::vector>& storage) - { - storage.emplace_back( - std::make_unique, - rocprim::WarpSortConfig, - EnableUnpartitionedWarpSort>>>()); - } - }; - - void - operator()(std::vector>& storage) - { - static_for_each, create_pt>(storage); - } - }; - - void operator()(std::vector>& storage) - { - if(PartitionAllowed) - { - - static_for_each, - create_lwss>(storage); - } - else - { - storage.emplace_back( - std::make_unique, - rocprim::DisabledWarpSortConfig, - EnableUnpartitionedWarpSort>>>()); - } - } - }; - - void operator()(std::vector>& storage) - { - decider::do_the_thing( - storage); - } - }; - - void operator()(std::vector>& storage) - { - decider::do_the_thing( - storage); - } - }; - static void create(std::vector>& storage) { - static_for_each, create_lrb>(storage); + storage.emplace_back(std::make_unique, + rocprim::WarpSortConfig, + UnpartitionWarpAllowed>>>()); } }; diff --git a/benchmark/benchmark_device_segmented_radix_sort_pairs.parallel.cpp.in b/benchmark/benchmark_device_segmented_radix_sort_pairs.parallel.cpp.in index 55fc0a849..37923ef01 100644 --- a/benchmark/benchmark_device_segmented_radix_sort_pairs.parallel.cpp.in +++ b/benchmark/benchmark_device_segmented_radix_sort_pairs.parallel.cpp.in @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2022-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2022-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -27,9 +27,21 @@ namespace { -auto benchmarks = config_autotune_register::create_bulk(device_segmented_radix_sort_benchmark_generator<@BlockSize@, - @ItemsPerThread@, - @KeyType@, - @ValueType@, - @PartitionAllowed@>::create); +auto benchmarks = config_autotune_register::create_bulk( + device_segmented_radix_sort_benchmark_generator< + @LongBits@, + 8, + @BlockSize@, + @ItemsPerThread@, + @WarpSmallLWS@, + @WarpSmallIPT@, + @WarpSmallBS@, + @WarpPartition@, + @WarpMediumLWS@, + @WarpMediumIPT@, + @WarpMediumBS@, + @KeyType@, + @ValueType@, + true + >::create); } // namespace diff --git a/benchmark/benchmark_device_segmented_radix_sort_pairs.parallel.hpp b/benchmark/benchmark_device_segmented_radix_sort_pairs.parallel.hpp index 3ed6e7a78..c3beaa863 100644 --- a/benchmark/benchmark_device_segmented_radix_sort_pairs.parallel.hpp +++ b/benchmark/benchmark_device_segmented_radix_sort_pairs.parallel.hpp @@ -131,10 +131,10 @@ struct device_segmented_radix_sort_benchmark : public config_autotune_interface seed.get_0()); std::vector values_input - = get_random_data(size, - generate_limits::min(), - generate_limits::max(), - seed.get_0()); + = get_random_data(size, + generate_limits::min(), + generate_limits::max(), + seed.get_0()); size_t batch_size = 1; if(size < target_size) @@ -286,97 +286,40 @@ struct device_segmented_radix_sort_benchmark : public config_autotune_interface template class T, bool enable, Tp... Idx> struct decider; -template + bool UnpartitionWarpAllowed = true> struct device_segmented_radix_sort_benchmark_generator { - template - struct create_lrb - { - template - struct create_srb - { - template - struct create_euws - { - template - struct create_lwss - { - template - struct create_pt - { - void operator()( - std::vector>& storage) - { - storage.emplace_back( - std::make_unique, - rocprim::WarpSortConfig, - EnableUnpartitionedWarpSort>>>()); - } - }; - - void - operator()(std::vector>& storage) - { - static_for_each, create_pt>(storage); - } - }; - - void operator()(std::vector>& storage) - { - if(PartitionAllowed) - { - - static_for_each, - create_lwss>(storage); - } - else - { - storage.emplace_back( - std::make_unique, - rocprim::DisabledWarpSortConfig, - EnableUnpartitionedWarpSort>>>()); - } - } - }; - - void operator()(std::vector>& storage) - { - static_for_each, create_euws>(storage); - } - }; - - void operator()(std::vector>& storage) - { - decider::do_the_thing( - storage); - } - }; - static void create(std::vector>& storage) { - static_for_each, create_lrb>(storage); + storage.emplace_back(std::make_unique, + rocprim::WarpSortConfig, + UnpartitionWarpAllowed>>>()); } }; diff --git a/benchmark/benchmark_utils.hpp b/benchmark/benchmark_utils.hpp index e244f11d6..8b2a39c52 100644 --- a/benchmark/benchmark_utils.hpp +++ b/benchmark/benchmark_utils.hpp @@ -208,7 +208,10 @@ inline auto generate_random_data_n(OutputIterator it, using T = typename std::iterator_traits::value_type; // Generate floats when T is half - using dis_type = std::conditional_t::value, float, T>; + using dis_type = std::conditional_t::value + || std::is_same::value, + float, + T>; std::uniform_real_distribution distribution((dis_type)min, (dis_type)max); std::generate_n(it, std::min(size, max_random_size), [&]() { return distribution(gen); }); for(size_t i = max_random_size; i < size; i += max_random_size) @@ -931,6 +934,11 @@ inline const char* Traits::name() return "rocprim::half"; } template<> +inline const char* Traits::name() +{ + return "rocprim::bfloat16"; +} +template<> inline const char* Traits::name() { return "int64_t"; diff --git a/rocprim/include/rocprim/block/block_exchange.hpp b/rocprim/include/rocprim/block/block_exchange.hpp index 0699a84e5..0962c77b0 100644 --- a/rocprim/include/rocprim/block/block_exchange.hpp +++ b/rocprim/include/rocprim/block/block_exchange.hpp @@ -27,6 +27,9 @@ #include "../functional.hpp" #include "../intrinsics.hpp" #include "../types.hpp" + +#include "config.hpp" + #include /// \addtogroup blockmodule @@ -40,6 +43,7 @@ BEGIN_ROCPRIM_NAMESPACE /// \tparam T - the input type. /// \tparam BlockSize - the number of threads in a block. /// \tparam ItemsPerThread - the number of items contributed by each thread. +/// \tparam PaddingHint - a hint that decides when to use padding. May not always be applicable. /// /// \par Overview /// * The \p block_exchange class supports the following rearrangement methods: @@ -77,7 +81,8 @@ template< unsigned int BlockSizeX, unsigned int ItemsPerThread, unsigned int BlockSizeY = 1, - unsigned int BlockSizeZ = 1 + unsigned int BlockSizeZ = 1, + block_padding_hint PaddingHint = block_padding_hint::avoid_conflicts > class block_exchange { @@ -86,21 +91,43 @@ class block_exchange static constexpr unsigned int warp_size = detail::get_min_warp_size(BlockSize, ::rocprim::device_warp_size()); // Number of warps in block - static constexpr unsigned int warps_no = (BlockSize + warp_size - 1) / warp_size; - - // Minimize LDS bank conflicts for power-of-two strides, i.e. when items accessed - // using `thread_id * ItemsPerThread` pattern where ItemsPerThread is power of two - // (all exchanges from/to blocked). - static constexpr bool has_bank_conflicts - = ItemsPerThread >= 2 && ::rocprim::detail::is_power_of_two(ItemsPerThread); + static constexpr unsigned int warps_no = ::rocprim::detail::ceiling_div(BlockSize, warp_size); static constexpr unsigned int banks_no = ::rocprim::detail::get_lds_banks_no(); static constexpr unsigned int buffer_size = static_cast(rocprim::max(size_t{1}, size_t{4} / sizeof(T))); - static constexpr unsigned int bank_conflicts_padding - = has_bank_conflicts ? (BlockSize * ItemsPerThread / banks_no) : 0; - static constexpr unsigned int storage_count - = BlockSize * ItemsPerThread + bank_conflicts_padding; + struct unpadded_config + { + static constexpr bool has_bank_conflicts = false; + static constexpr unsigned int padding = 0; + }; + + struct padded_config + { + // Minimize LDS bank conflicts for power-of-two strides, i.e. when items accessed + // using `thread_id * ItemsPerThread` pattern where ItemsPerThread is power of two + // (all exchanges from/to blocked). + static constexpr bool has_bank_conflicts + = ItemsPerThread >= 2 && ::rocprim::detail::is_power_of_two(ItemsPerThread); + static constexpr unsigned int padding + = has_bank_conflicts ? (BlockSize * ItemsPerThread / banks_no) : 0; + }; + + template + struct build_config : Config + { + static constexpr unsigned int storage_count = BlockSize * ItemsPerThread + Config::padding; + static constexpr unsigned int storage_size = sizeof(T) * storage_count; + static constexpr unsigned int occupancy = detail::get_min_lds_size() / storage_size; + }; + + using config = detail::select_block_padding_config, + build_config>; + + static constexpr bool has_bank_conflicts = config::has_bank_conflicts; + static constexpr unsigned int bank_conflicts_padding = config::padding; + static constexpr unsigned int storage_count = config::storage_count; struct storage_type_ { @@ -597,6 +624,73 @@ class block_exchange } } + /// \brief Scatters items to a *warp* striped arrangement based on their ranks + /// across the thread block, using temporary storage. + /// + /// \tparam U - [inferred] the output type. + /// \tparam Offset - [inferred] the rank type. + /// + /// \param [in] input - array that data is loaded from. + /// \param [out] output - array that data is loaded to. + /// \param [out] ranks - array that has rank of data. + /// \param [in] storage - reference to a temporary storage object of type storage_type. + /// + /// \par Storage reusage + /// Synchronization barrier should be placed before \p storage is reused + /// or repurposed: \p __syncthreads() or \p rocprim::syncthreads(). + /// + /// \par Example. + /// \code{.cpp} + /// __global__ void example_kernel(...) + /// { + /// // specialize block_exchange for int, block of 128 threads and 8 items per thread + /// using block_exchange_int = rocprim::block_exchange; + /// // allocate storage in shared memory + /// __shared__ block_exchange_int::storage_type storage; + /// + /// int items[8]; + /// int ranks[8]; + /// ... + /// block_exchange_int b_exchange; + /// b_exchange.scatter_to_warp_striped(items, items, ranks, storage); + /// ... + /// } + /// \endcode + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void scatter_to_warp_striped(const T (&input)[ItemsPerThread], + U (&output)[ItemsPerThread], + const Offset (&ranks)[ItemsPerThread], + storage_type& storage) + { + static_assert(detail::is_power_of_two(WarpSize) && WarpSize <= device_warp_size(), + "WarpSize must be a power of two and equal or less" + "than the size of hardware warp."); + const unsigned int flat_id + = ::rocprim::flat_block_thread_id(); + const unsigned int thread_id = detail::logical_lane_id(); + const unsigned int warp_id = flat_id / WarpSize; + const unsigned int warp_offset = warp_id * WarpSize * ItemsPerThread; + const unsigned int thread_offset = thread_id + warp_offset; + + ROCPRIM_UNROLL + for(unsigned int i = 0; i < ItemsPerThread; i++) + { + const Offset rank = ranks[i]; + storage.buffer.emplace(index(rank), input[i]); + } + + ::rocprim::syncthreads(); + + const auto& storage_buffer = storage.buffer.get_unsafe_array(); + + ROCPRIM_UNROLL + for(unsigned int i = 0; i < ItemsPerThread; i++) + { + output[i] = storage_buffer[index(thread_offset + i * WarpSize)]; + } + } + /// \brief Scatters items to a striped arrangement based on their ranks /// across the thread block, guarded by rank. /// diff --git a/rocprim/include/rocprim/block/block_radix_rank.hpp b/rocprim/include/rocprim/block/block_radix_rank.hpp index 7ffd89bcf..8229fe355 100644 --- a/rocprim/include/rocprim/block/block_radix_rank.hpp +++ b/rocprim/include/rocprim/block/block_radix_rank.hpp @@ -25,6 +25,7 @@ #include "../functional.hpp" #include "block_scan.hpp" +#include "config.hpp" #include "detail/block_radix_rank_basic.hpp" #include "detail/block_radix_rank_match.hpp" @@ -60,7 +61,8 @@ struct select_block_radix_rank_impl template + unsigned int BlockSizeZ, + block_padding_hint> using type = block_radix_rank; }; @@ -70,7 +72,8 @@ struct select_block_radix_rank_impl template + unsigned int BlockSizeZ, + block_padding_hint> using type = block_radix_rank; }; @@ -80,8 +83,9 @@ struct select_block_radix_rank_impl template - using type = block_radix_rank_match; + unsigned int BlockSizeZ, + block_padding_hint PaddingHint> + using type = block_radix_rank_match; }; } // namespace detail @@ -96,6 +100,7 @@ struct select_block_radix_rank_impl /// the same values from shared memory twice, at the expense of more register usage. /// \tparam BlockSizeY - the number of threads in a block's y dimension, defaults to 1. /// \tparam BlockSizeZ - the number of threads in a block's z dimension, defaults to 1. +/// \tparam PaddingHint - a hint that decides when to use padding. May not always be applicable. /// /// \par Overview /// * Key type must be an arithmetic type (that is, an integral type or a floating point type). @@ -135,17 +140,18 @@ struct select_block_radix_rank_impl /// \endcode template + block_radix_rank_algorithm Algorithm = block_radix_rank_algorithm::default_algorithm, + unsigned int BlockSizeY = 1, + unsigned int BlockSizeZ = 1, + block_padding_hint PaddingHint = block_padding_hint::avoid_conflicts> class block_radix_rank #ifndef DOXYGEN_SHOULD_SKIP_THIS : private detail::select_block_radix_rank_impl< - Algorithm>::template type + Algorithm>::template type #endif { using base_type = typename detail::select_block_radix_rank_impl< - Algorithm>::template type; + Algorithm>::template type; public: /// \brief The number of digits each thread will process. diff --git a/rocprim/include/rocprim/block/block_radix_sort.hpp b/rocprim/include/rocprim/block/block_radix_sort.hpp index 9928390db..c84ff1a95 100644 --- a/rocprim/include/rocprim/block/block_radix_sort.hpp +++ b/rocprim/include/rocprim/block/block_radix_sort.hpp @@ -25,15 +25,15 @@ #include "../config.hpp" #include "../detail/various.hpp" -#include "../thread/radix_key_codec.hpp" -#include "../warp/detail/warp_scan_crosslane.hpp" - -#include "../intrinsics.hpp" #include "../functional.hpp" +#include "../intrinsics/thread.hpp" +#include "../thread/radix_key_codec.hpp" #include "../types.hpp" +#include "../warp/warp_exchange.hpp" #include "block_exchange.hpp" #include "block_radix_rank.hpp" +#include "rocprim/block/config.hpp" /// \addtogroup blockmodule /// @{ @@ -50,6 +50,7 @@ BEGIN_ROCPRIM_NAMESPACE /// \tparam Value - the value type. Default type empty_type indicates /// a keys-only sort. /// \tparam RadixBitsPerPass - amount of bits to sort per pass. The Default is 4. +/// \tparam RadixRankAlgorithm the rank algorithm used. /// /// \par Overview /// * \p Key type must be an arithmetic type (that is, an integral type or a floating-point @@ -97,54 +98,67 @@ BEGIN_ROCPRIM_NAMESPACE template + class Value = empty_type, + unsigned int BlockSizeY = 1, + unsigned int BlockSizeZ = 1, + unsigned int RadixBitsPerPass + = (BlockSizeX * BlockSizeY * BlockSizeZ) % device_warp_size() == 0 ? 8 /* match */ + : 4 /* basic_memoize */, + block_radix_rank_algorithm RadixRankAlgorithm + = (BlockSizeX * BlockSizeY * BlockSizeZ) % device_warp_size() == 0 + ? block_radix_rank_algorithm::match + : block_radix_rank_algorithm::basic_memoize, + block_padding_hint PaddingHint = block_padding_hint::lds_occupancy_bound> class block_radix_sort { static_assert(RadixBitsPerPass > 0 && RadixBitsPerPass < 32, "The RadixBitsPerPass should be larger than 0 and smaller than the size " "of an unsigned int"); - static constexpr unsigned int BlockSize = BlockSizeX * BlockSizeY * BlockSizeZ; - static constexpr bool with_values = !std::is_same::value; + static constexpr unsigned int BlockSize = BlockSizeX * BlockSizeY * BlockSizeZ; + static constexpr bool with_values = !std::is_same::value; + static constexpr bool warp_striped = RadixRankAlgorithm == block_radix_rank_algorithm::match; + +#if __HIP_DEVICE_COMPILE__ + static_assert(!warp_striped || (BlockSize % device_warp_size()) == 0, + "When using 'block_radix_rank_algorithm::match', the block size should be a " + "multiple of the warp size"); +#endif + + static constexpr bool is_key_and_value_aligned + = alignof(Key) == alignof(Value) && sizeof(Key) == sizeof(Value); - using block_rank_type = ::rocprim::block_radix_rank; + using block_rank_type = ::rocprim:: + block_radix_rank; using keys_exchange_type - = ::rocprim::block_exchange; + = ::rocprim::block_exchange; using values_exchange_type - = ::rocprim::block_exchange; + = ::rocprim::block_exchange; // Struct used for creating a raw_storage object for this primitive's temporary storage. union storage_type_ { - typename keys_exchange_type::storage_type keys_exchange; - typename values_exchange_type::storage_type values_exchange; - typename block_rank_type::storage_type rank; + typename keys_exchange_type::storage_type keys_exchange; + typename values_exchange_type::storage_type values_exchange; + typename block_rank_type::storage_type rank; }; public: - - /// \brief Struct used to allocate a temporary memory that is required for thread - /// communication during operations provided by related parallel primitive. - /// - /// Depending on the implemention the operations exposed by parallel primitive may - /// require a temporary storage for thread communication. The storage should be allocated - /// using keywords __shared__. It can be aliased to - /// an externally allocated memory, or be a part of a union type with other storage types - /// to increase shared memory reusability. - #ifndef DOXYGEN_SHOULD_SKIP_THIS // hides storage_type implementation for Doxygen +/// \brief Struct used to allocate a temporary memory that is required for thread +/// communication during operations provided by related parallel primitive. +/// +/// Depending on the implemention the operations exposed by parallel primitive may +/// require a temporary storage for thread communication. The storage should be allocated +/// using keywords __shared__. It can be aliased to +/// an externally allocated memory, or be a part of a union type with other storage types +/// to increase shared memory reusability. +#ifndef DOXYGEN_SHOULD_SKIP_THIS // hides storage_type implementation for Doxygen ROCPRIM_DETAIL_SUPPRESS_DEPRECATION_WITH_PUSH using storage_type = detail::raw_storage; ROCPRIM_DETAIL_SUPPRESS_DEPRECATION_POP - #else +#else using storage_type = storage_type_; // only for Doxygen - #endif +#endif /// \brief Performs ascending radix sort over keys partitioned across threads in a block. /// @@ -193,11 +207,12 @@ class block_radix_sort /// then after sort they will be equal {[1, 2], [3, 4] ..., [255, 256]}. /// \endparblock template - ROCPRIM_DEVICE ROCPRIM_INLINE void sort(Key (&keys)[ItemsPerThread], - storage_type& storage, - unsigned int begin_bit = 0, - unsigned int end_bit = 8 * sizeof(Key), - Decomposer decomposer = {}) + ROCPRIM_DEVICE ROCPRIM_INLINE + void sort(Key (&keys)[ItemsPerThread], + storage_type& storage, + unsigned int begin_bit = 0, + unsigned int end_bit = 8 * sizeof(Key), + Decomposer decomposer = {}) { empty_type values[ItemsPerThread]; sort_impl(keys, values, storage, begin_bit, end_bit, decomposer); @@ -275,11 +290,12 @@ class block_radix_sort /// then after sort they will be equal {[256, 255], ..., [4, 3], [2, 1]}. /// \endparblock template - ROCPRIM_DEVICE ROCPRIM_INLINE void sort_desc(Key (&keys)[ItemsPerThread], - storage_type& storage, - unsigned int begin_bit = 0, - unsigned int end_bit = 8 * sizeof(Key), - Decomposer decomposer = {}) + ROCPRIM_DEVICE ROCPRIM_INLINE + void sort_desc(Key (&keys)[ItemsPerThread], + storage_type& storage, + unsigned int begin_bit = 0, + unsigned int end_bit = 8 * sizeof(Key), + Decomposer decomposer = {}) { empty_type values[ItemsPerThread]; sort_impl(keys, values, storage, begin_bit, end_bit, decomposer); @@ -367,13 +383,13 @@ class block_radix_sort /// equal {[128, 128], [127, 127] ..., [2, 2], [1, 1]}. /// \endparblock template - ROCPRIM_DEVICE ROCPRIM_INLINE void - sort(Key (&keys)[ItemsPerThread], - typename std::enable_if::type (&values)[ItemsPerThread], - storage_type& storage, - unsigned int begin_bit = 0, - unsigned int end_bit = 8 * sizeof(Key), - Decomposer decomposer = {}) + ROCPRIM_DEVICE ROCPRIM_INLINE + void sort(Key (&keys)[ItemsPerThread], + typename std::enable_if::type (&values)[ItemsPerThread], + storage_type& storage, + unsigned int begin_bit = 0, + unsigned int end_bit = 8 * sizeof(Key), + Decomposer decomposer = {}) { sort_impl(keys, values, storage, begin_bit, end_bit, decomposer); } @@ -466,13 +482,13 @@ class block_radix_sort /// will be equal {[1, 1], [2, 2] ..., [128, 128]}. /// \endparblock template - ROCPRIM_DEVICE ROCPRIM_INLINE void - sort_desc(Key (&keys)[ItemsPerThread], - typename std::enable_if::type (&values)[ItemsPerThread], - storage_type& storage, - unsigned int begin_bit = 0, - unsigned int end_bit = 8 * sizeof(Key), - Decomposer decomposer = {}) + ROCPRIM_DEVICE ROCPRIM_INLINE + void sort_desc(Key (&keys)[ItemsPerThread], + typename std::enable_if::type (&values)[ItemsPerThread], + storage_type& storage, + unsigned int begin_bit = 0, + unsigned int end_bit = 8 * sizeof(Key), + Decomposer decomposer = {}) { sort_impl(keys, values, storage, begin_bit, end_bit, decomposer); } @@ -558,11 +574,12 @@ class block_radix_sort /// then after sort they will be equal {[1, 129], [2, 130] ..., [128, 256]}. /// \endparblock template - ROCPRIM_DEVICE ROCPRIM_INLINE void sort_to_striped(Key (&keys)[ItemsPerThread], - storage_type& storage, - unsigned int begin_bit = 0, - unsigned int end_bit = 8 * sizeof(Key), - Decomposer decomposer = {}) + ROCPRIM_DEVICE ROCPRIM_INLINE + void sort_to_striped(Key (&keys)[ItemsPerThread], + storage_type& storage, + unsigned int begin_bit = 0, + unsigned int end_bit = 8 * sizeof(Key), + Decomposer decomposer = {}) { empty_type values[ItemsPerThread]; sort_impl(keys, values, storage, begin_bit, end_bit, decomposer); @@ -644,11 +661,12 @@ class block_radix_sort /// then after sort they will be equal {[256, 128], ..., [130, 2], [129, 1]}. /// \endparblock template - ROCPRIM_DEVICE ROCPRIM_INLINE void sort_desc_to_striped(Key (&keys)[ItemsPerThread], - storage_type& storage, - unsigned int begin_bit = 0, - unsigned int end_bit = 8 * sizeof(Key), - Decomposer decomposer = {}) + ROCPRIM_DEVICE ROCPRIM_INLINE + void sort_desc_to_striped(Key (&keys)[ItemsPerThread], + storage_type& storage, + unsigned int begin_bit = 0, + unsigned int end_bit = 8 * sizeof(Key), + Decomposer decomposer = {}) { empty_type values[ItemsPerThread]; sort_impl(keys, values, storage, begin_bit, end_bit, decomposer); @@ -738,13 +756,13 @@ class block_radix_sort /// equal {[-8, -4], [-7, -3], [-6, -2], [-5, -1]}. /// \endparblock template - ROCPRIM_DEVICE ROCPRIM_INLINE void - sort_to_striped(Key (&keys)[ItemsPerThread], - typename std::enable_if::type (&values)[ItemsPerThread], - storage_type& storage, - unsigned int begin_bit = 0, - unsigned int end_bit = 8 * sizeof(Key), - Decomposer decomposer = {}) + ROCPRIM_DEVICE ROCPRIM_INLINE + void sort_to_striped(Key (&keys)[ItemsPerThread], + typename std::enable_if::type (&values)[ItemsPerThread], + storage_type& storage, + unsigned int begin_bit = 0, + unsigned int end_bit = 8 * sizeof(Key), + Decomposer decomposer = {}) { sort_impl(keys, values, storage, begin_bit, end_bit, decomposer); } @@ -835,7 +853,8 @@ class block_radix_sort /// equal {[10, 50], [20, 60], [30, 70], [40, 80]}. /// \endparblock template - ROCPRIM_DEVICE ROCPRIM_INLINE void sort_desc_to_striped( + ROCPRIM_DEVICE ROCPRIM_INLINE + void sort_desc_to_striped( Key (&keys)[ItemsPerThread], typename std::enable_if::type (&values)[ItemsPerThread], storage_type& storage, @@ -877,28 +896,251 @@ class block_radix_sort sort_desc_to_striped(keys, values, storage, begin_bit, end_bit, decomposer); } + /// \brief Performs ascending radix sort over key-value pairs in a *warp-striped order* + /// partitioned across threads in a block, results are saved in a striped arrangement. + /// + /// \see block_radix_sort::sort_to_striped + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void sort_warp_striped_to_striped( + Key (&keys)[ItemsPerThread], + typename std::enable_if::type (&values)[ItemsPerThread], + storage_type& storage, + unsigned int begin_bit = 0, + unsigned int end_bit = 8 * sizeof(Key), + Decomposer decomposer = {}) + { + static_assert(warp_striped, + "'sort_warp_striped_to_striped' can only be used with " + "'block_radix_rank_algorithm::match'."); + + sort_impl(keys, values, storage, begin_bit, end_bit, decomposer); + } + + /// \brief Performs ascending radix sort over key-value pairs in a *warp-striped order* + /// + /// \see block_radix_sort::sort_to_striped + /// partitioned across threads in a block, results are saved in a striped arrangement. + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void sort_warp_striped_to_striped( + Key (&keys)[ItemsPerThread], + typename std::enable_if::type (&values)[ItemsPerThread], + unsigned int begin_bit = 0, + unsigned int end_bit = 8 * sizeof(Key), + Decomposer decomposer = {}) + { + static_assert(warp_striped, + "'sort_warp_striped_to_striped' can only be used with " + "'block_radix_rank_algorithm::match'."); + + ROCPRIM_SHARED_MEMORY storage_type storage; + sort_warp_striped_to_striped(keys, values, storage, begin_bit, end_bit, decomposer); + } + + /// \brief Performs ascending radix sort over key-value pairs in a *warp-striped order* + /// partitioned across threads in a block, results are saved in a striped arrangement. + /// + /// \see block_radix_sort::sort_to_striped + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void sort_warp_striped_to_striped(Key (&keys)[ItemsPerThread], + storage_type& storage, + unsigned int begin_bit = 0, + unsigned int end_bit = 8 * sizeof(Key), + Decomposer decomposer = {}) + { + static_assert(warp_striped, + "'sort_warp_striped_to_striped' can only be used with " + "'block_radix_rank_algorithm::match'."); + + empty_type values[ItemsPerThread]; + sort_impl(keys, values, storage, begin_bit, end_bit, decomposer); + } + + /// \brief Performs ascending radix sort over key-value pairs in a *warp-striped order* + /// partitioned across threads in a block, results are saved in a striped arrangement. + /// + /// \see block_radix_sort::sort_to_striped + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void sort_warp_striped_to_striped(Key (&keys)[ItemsPerThread], + unsigned int begin_bit = 0, + unsigned int end_bit = 8 * sizeof(Key), + Decomposer decomposer = {}) + { + static_assert(warp_striped, + "'sort_warp_striped_to_striped' can only be used with " + "'block_radix_rank_algorithm::match'."); + + ROCPRIM_SHARED_MEMORY storage_type storage; + sort_warp_striped_to_striped(keys, storage, begin_bit, end_bit, decomposer); + } + + /// \brief Performs descending radix sort over key-value pairs in a *warp-striped order* + /// partitioned across threads in a block, results are saved in a striped arrangement. + /// + /// \see block_radix_sort::sort_desc_to_striped + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void sort_desc_warp_striped_to_striped( + Key (&keys)[ItemsPerThread], + typename std::enable_if::type (&values)[ItemsPerThread], + storage_type& storage, + unsigned int begin_bit = 0, + unsigned int end_bit = 8 * sizeof(Key), + Decomposer decomposer = {}) + { + static_assert(warp_striped, + "'sort_warp_striped_to_striped' can only be used with " + "'block_radix_rank_algorithm::match'."); + + sort_impl(keys, values, storage, begin_bit, end_bit, decomposer); + } + + /// \brief Performs descending radix sort over key-value pairs in a *warp-striped order* + /// partitioned across threads in a block, results are saved in a striped arrangement. + /// + /// \see block_radix_sort::sort_desc_to_striped + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void sort_desc_warp_striped_to_striped( + Key (&keys)[ItemsPerThread], + typename std::enable_if::type (&values)[ItemsPerThread], + unsigned int begin_bit = 0, + unsigned int end_bit = 8 * sizeof(Key), + Decomposer decomposer = {}) + { + static_assert(warp_striped, + "'sort_warp_striped_to_striped' can only be used with " + "'block_radix_rank_algorithm::match'."); + + ROCPRIM_SHARED_MEMORY storage_type storage; + sort_desc_warp_striped_to_striped(keys, values, storage, begin_bit, end_bit, decomposer); + } + + /// \brief Performs descending radix sort over key-value pairs in a *warp-striped order* + /// partitioned across threads in a block, results are saved in a striped arrangement. + /// + /// \see block_radix_sort::sort_desc_to_striped + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void sort_desc_warp_striped_to_striped(Key (&keys)[ItemsPerThread], + storage_type& storage, + unsigned int begin_bit = 0, + unsigned int end_bit = 8 * sizeof(Key), + Decomposer decomposer = {}) + { + static_assert(warp_striped, + "'sort_warp_striped_to_striped' can only be used with " + "'block_radix_rank_algorithm::match'."); + + empty_type values[ItemsPerThread]; + sort_impl(keys, values, storage, begin_bit, end_bit, decomposer); + } + + /// \brief Performs descending radix sort over key-value pairs in a *warp-striped order* + /// partitioned across threads in a block, results are saved in a striped arrangement. + /// + /// \see block_radix_sort::sort_desc_to_striped + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void sort_desc_warp_striped_to_striped(Key (&keys)[ItemsPerThread], + unsigned int begin_bit = 0, + unsigned int end_bit = 8 * sizeof(Key), + Decomposer decomposer = {}) + { + static_assert(warp_striped, + "'sort_warp_striped_to_striped' can only be used with " + "'block_radix_rank_algorithm::match'."); + + ROCPRIM_SHARED_MEMORY storage_type storage; + sort_desc_warp_striped_to_striped(keys, storage, begin_bit, end_bit, decomposer); + } + private: - template - ROCPRIM_DEVICE ROCPRIM_INLINE void sort_impl(Key (&keys)[ItemsPerThread], - SortedValue (&values)[ItemsPerThread], - storage_type& storage, - unsigned int begin_bit, - unsigned int end_bit, - Decomposer decomposer) + static constexpr bool use_warp_exchange + = device_warp_size() % ItemsPerThread == 0 && ItemsPerThread <= 4; + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void blocked_to_warp_striped(Key (&keys)[ItemsPerThread], + SortedValue (&values)[ItemsPerThread], + storage_type& storage, + std::false_type) + { + keys_exchange_type().blocked_to_warp_striped(keys, keys, storage.get().keys_exchange); + if ROCPRIM_IF_CONSTEXPR(is_key_and_value_aligned) + { + // If keys and values are aligned, then the LDS for both exchanges is + // local per wave. We can relax the data dependency! + ::rocprim::wave_barrier(); + } + else + { + ::rocprim::syncthreads(); + } + values_exchange_type().blocked_to_warp_striped(values, + values, + storage.get().values_exchange); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void blocked_to_warp_striped(Key (&keys)[ItemsPerThread], + SortedValue (&values)[ItemsPerThread], + storage_type& /* storage */, + std::true_type) + { + ::rocprim::warp_exchange{}.blocked_to_striped_shuffle(keys, keys); + ::rocprim::warp_exchange{}.blocked_to_striped_shuffle(values, + values); + } + + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void sort_impl(Key (&keys)[ItemsPerThread], + SortedValue (&values)[ItemsPerThread], + storage_type& storage, + unsigned int begin_bit, + unsigned int end_bit, + Decomposer decomposer) { using key_codec = ::rocprim::radix_key_codec; + // 'rank_keys' may be invoked multiple times. We encode the key once and move the + // encoded during the majority of sort to save on some compute. ROCPRIM_UNROLL for(unsigned int i = 0; i < ItemsPerThread; i++) { key_codec::encode_inplace(keys[i], decomposer); } + // If we're using warp striped radix rank but our input is in a blocked layout, we + // can emulate the correct input through an exchange to a warp striped layout. + if ROCPRIM_IF_CONSTEXPR(TryEmulateWarpStriped && warp_striped && ItemsPerThread > 1) + { + // This appears to be slower with high large items per thread. + constexpr bool use_warp_exchange + = device_warp_size() % ItemsPerThread == 0 && ItemsPerThread <= 4; + blocked_to_warp_striped(keys, + values, + storage, + std::integral_constant{}); + // Storage has been dirtied. 'rank_keys' does not always align nicely with this + // so a full block synchronization is needed. + ::rocprim::syncthreads(); + } + + unsigned int ranks[ItemsPerThread]; while(true) { const int pass_bits = min(RadixBitsPerPass, end_bit - begin_bit); - unsigned int ranks[ItemsPerThread]; block_rank_type().rank_keys( keys, ranks, @@ -907,24 +1149,38 @@ class block_radix_sort { return key_codec::extract_digit(key, begin_bit, pass_bits, decomposer); }); begin_bit += RadixBitsPerPass; - exchange_keys(storage, keys, ranks); - exchange_values(storage, values, ranks); - if(begin_bit >= end_bit) { break; } + if ROCPRIM_IF_CONSTEXPR(warp_striped) + { + exchange_keys_warp_striped(storage, keys, ranks); + exchange_values_warp_striped(storage, values, ranks); + } + else + { + exchange_keys(storage, keys, ranks); + exchange_values(storage, values, ranks); + } + // Synchronization required to make block_rank wait on the next iteration. ::rocprim::syncthreads(); } if ROCPRIM_IF_CONSTEXPR(ToStriped) { - to_striped_keys(storage, keys); - to_striped_values(storage, values); + exchange_to_striped_keys(storage, keys, ranks); + exchange_to_striped_values(storage, values, ranks); + } + else + { + exchange_keys(storage, keys, ranks); + exchange_values(storage, values, ranks); } + // Done with 'rank_keys' so we can decode back to the original key. ROCPRIM_UNROLL for(unsigned int i = 0; i < ItemsPerThread; i++) { @@ -932,9 +1188,10 @@ class block_radix_sort } } - ROCPRIM_DEVICE ROCPRIM_INLINE void exchange_keys(storage_type& storage, - Key (&keys)[ItemsPerThread], - const unsigned int (&ranks)[ItemsPerThread]) + ROCPRIM_DEVICE ROCPRIM_INLINE + void exchange_keys(storage_type& storage, + Key (&keys)[ItemsPerThread], + const unsigned int (&ranks)[ItemsPerThread]) { storage_type_& storage_ = storage.get(); ::rocprim::syncthreads(); // Storage will be reused (union), synchronization is needed @@ -957,35 +1214,74 @@ class block_radix_sort empty_type (&values)[ItemsPerThread], const unsigned int (&ranks)[ItemsPerThread]) { - (void) storage; - (void) values; - (void) ranks; + (void)storage; + (void)values; + (void)ranks; + } + + ROCPRIM_DEVICE ROCPRIM_INLINE + void exchange_keys_warp_striped(storage_type& storage, + Key (&keys)[ItemsPerThread], + const unsigned int (&ranks)[ItemsPerThread]) + { + storage_type_& storage_ = storage.get(); + ::rocprim::syncthreads(); // Storage will be reused (union), synchronization is needed + keys_exchange_type().scatter_to_warp_striped(keys, keys, ranks, storage_.keys_exchange); } - ROCPRIM_DEVICE ROCPRIM_INLINE void to_striped_keys(storage_type& storage, - Key (&keys)[ItemsPerThread]) + template + ROCPRIM_DEVICE ROCPRIM_INLINE + void exchange_values_warp_striped(storage_type& storage, + SortedValue (&values)[ItemsPerThread], + const unsigned int (&ranks)[ItemsPerThread]) + { + storage_type_& storage_ = storage.get(); + ::rocprim::syncthreads(); // Storage will be reused (union), synchronization is needed + values_exchange_type().scatter_to_warp_striped(values, + values, + ranks, + storage_.values_exchange); + } + + ROCPRIM_DEVICE ROCPRIM_INLINE + void exchange_values_warp_striped(storage_type& storage, + empty_type (&values)[ItemsPerThread], + const unsigned int (&ranks)[ItemsPerThread]) + { + (void)storage; + (void)values; + (void)ranks; + } + + ROCPRIM_DEVICE ROCPRIM_INLINE + void exchange_to_striped_keys(storage_type& storage, + Key (&keys)[ItemsPerThread], + const unsigned int (&ranks)[ItemsPerThread]) { storage_type_& storage_ = storage.get(); ::rocprim::syncthreads(); // Storage will be reused (union), synchronization is needed - keys_exchange_type().blocked_to_striped(keys, keys, storage_.keys_exchange); + keys_exchange_type().scatter_to_striped(keys, keys, ranks, storage_.keys_exchange); } template ROCPRIM_DEVICE ROCPRIM_INLINE - void to_striped_values(storage_type& storage, - SortedValue (&values)[ItemsPerThread]) + void exchange_to_striped_values(storage_type& storage, + SortedValue (&values)[ItemsPerThread], + const unsigned int (&ranks)[ItemsPerThread]) { storage_type_& storage_ = storage.get(); ::rocprim::syncthreads(); // Storage will be reused (union), synchronization is needed - values_exchange_type().blocked_to_striped(values, values, storage_.values_exchange); + values_exchange_type().scatter_to_striped(values, values, ranks, storage_.values_exchange); } ROCPRIM_DEVICE ROCPRIM_INLINE - void to_striped_values(storage_type& storage, - empty_type * values) + void exchange_to_striped_values(storage_type& storage, + empty_type* values, + const unsigned int (&ranks)[ItemsPerThread]) { - (void) storage; - (void) values; + (void)ranks; + (void)storage; + (void)values; } }; diff --git a/rocprim/include/rocprim/block/config.hpp b/rocprim/include/rocprim/block/config.hpp new file mode 100644 index 000000000..3b0bb78c3 --- /dev/null +++ b/rocprim/include/rocprim/block/config.hpp @@ -0,0 +1,75 @@ +// Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_BLOCK_CONFIG_HELPER_HPP_ +#define ROCPRIM_BLOCK_CONFIG_HELPER_HPP_ + +#include "../config.hpp" +#include "../detail/various.hpp" + +BEGIN_ROCPRIM_NAMESPACE + +/// \brief Padding hints for algorithms. Padding can be used to reduce bank +/// conflicts at the cost of increasing LDS usage and potentially reducing +/// occupancy. +enum class block_padding_hint { + /// Use padding to avoid bank conflicts, if applicable. This allows an + /// algorithm to use more shared memory to reduce bank conflicts. + avoid_conflicts = 0, + + /// Never use padding. This is useful when occupancy needs to be + /// maximized, and bank conflicts are known to be not an issue. + never_pad = 1, + + /// Similar to \p block_padding_hint::avoid_conflicts , but only allows + /// padding when it does not affect theorethical occupancy limited by + /// shared memory. It's advised to use this when LDS usage is restricting + /// occupancy. + lds_occupancy_bound = 2, +}; + +namespace detail +{ +/// \brief Utility wrapper to expose a static constexpr member occupancy of +/// type T as a static constexpr value. +template +struct map_occupancy_to_value +{ + /// \brief The original type. + using type = T; + + /// \brief The value to order this by. + static constexpr auto value = T::occupancy; +}; + +/// \brief Selects the config depending on the padding hint. +template +using select_block_padding_config + = std::conditional_t, + map_occupancy_to_value>::type>>; +} // namespace detail + +END_ROCPRIM_NAMESPACE +#endif diff --git a/rocprim/include/rocprim/block/detail/block_radix_rank_match.hpp b/rocprim/include/rocprim/block/detail/block_radix_rank_match.hpp index b578009ce..d8d84f3e4 100644 --- a/rocprim/include/rocprim/block/detail/block_radix_rank_match.hpp +++ b/rocprim/include/rocprim/block/detail/block_radix_rank_match.hpp @@ -29,16 +29,18 @@ #include "../../thread/radix_key_codec.hpp" #include "../block_scan.hpp" +#include "../config.hpp" BEGIN_ROCPRIM_NAMESPACE namespace detail { -template +template class block_radix_rank_match { using digit_counter_type = unsigned int; @@ -52,18 +54,44 @@ class block_radix_rank_match static constexpr unsigned int block_size = BlockSizeX * BlockSizeY * BlockSizeZ; static constexpr unsigned int radix_digits = 1 << RadixBits; - // Force the number of warps to an uneven amount to reduce the number of lds bank conflicts. - static constexpr unsigned int warps - = ::rocprim::detail::ceiling_div(block_size, device_warp_size()) | 1u; + struct unpadded_config + { + static constexpr unsigned int warps + = ::rocprim::detail::ceiling_div(block_size, device_warp_size()); + }; + + struct padded_config + { + static constexpr unsigned int warps = unpadded_config::warps | 1u; + }; + + template + struct build_config : Config + { + static constexpr unsigned int active_counters = Config::warps * radix_digits; + static constexpr unsigned int counters_per_thread + = ::rocprim::detail::ceiling_div(active_counters, block_size); + static constexpr unsigned int counters = counters_per_thread * block_size; + + // Compute local data share and theorethical occupancy + static constexpr size_t lds_size = max(sizeof(digit_counter_type) * counters, + sizeof(typename block_scan_type::storage_type)); + static constexpr unsigned int occupancy = detail::get_min_lds_size() / lds_size; + }; + + using config = detail::select_block_padding_config, + build_config>; + + static constexpr unsigned int warps = config::warps; // The number of counters that are actively being used. - static constexpr unsigned int active_counters = warps * radix_digits; + static constexpr unsigned int active_counters = config::active_counters; // We want to use a regular block scan to scan the per-warp counters. This requires the // total number of counters to be divisible by the block size. To facilitate this, just add // a bunch of counters that are not otherwise used. - static constexpr unsigned int counters_per_thread - = ::rocprim::detail::ceiling_div(active_counters, block_size); + static constexpr unsigned int counters_per_thread = config::counters_per_thread; // The total number of counters, factoring in the unused ones for the block scan. - static constexpr unsigned int counters = counters_per_thread * block_size; + static constexpr unsigned int counters = config::counters; public: constexpr static unsigned int digits_per_thread @@ -109,7 +137,12 @@ class block_radix_rank_match // Get the digit counter for this key on the current warp. digit_counters[i] = &get_digit_counter(digit, warp_id, storage); - const digit_counter_type warp_digit_prefix = *digit_counters[i]; + + // Read the prefix sum of that digit. We already know it's 0 on the first iteration. So + // we can skip a read-after-write dependency. The conditional gets optimized out due to + // loop unrolling. + const digit_counter_type warp_digit_prefix + = i == 0 ? digit_counter_type(0) : *digit_counters[i]; // Construct a mask of threads in this wave which have the same digit. ::rocprim::lane_mask_type peer_mask = ::rocprim::match_any(digit); diff --git a/rocprim/include/rocprim/detail/various.hpp b/rocprim/include/rocprim/detail/various.hpp index 445c97c92..a8dbfbffd 100644 --- a/rocprim/include/rocprim/detail/various.hpp +++ b/rocprim/include/rocprim/detail/various.hpp @@ -154,6 +154,18 @@ constexpr unsigned int get_lds_banks_no() return 32; } +/// \brief Returns the minimum LDS size in bytes available on this device architecture. +ROCPRIM_HOST_DEVICE +constexpr unsigned int get_min_lds_size() +{ +#if defined(__GFX11__) || defined(__GFX10__) + return (1 << 17) /* 128 KiB*/; +#else + // On host the lowest should be returned! + return (1 << 16) /* 64 KiB */; +#endif +} + // Finds biggest fundamental type for type T that sizeof(T) is // a multiple of that type's size. template @@ -408,6 +420,25 @@ ROCPRIM_HOST_DEVICE auto bit_cast(const Source& source) #endif } +template +struct select_max_by_value; + +template +struct select_max_by_value +{ + using type = T; +}; + +template +struct select_max_by_value +{ + using tail = typename select_max_by_value::type; + using type = std::conditional_t<(T::value >= tail::value), T, tail>; +}; + +template +using select_max_by_value_t = typename select_max_by_value::type; + } // end namespace detail END_ROCPRIM_NAMESPACE diff --git a/rocprim/include/rocprim/device/detail/config/device_radix_sort_block_sort.hpp b/rocprim/include/rocprim/device/detail/config/device_radix_sort_block_sort.hpp index d9ea7f0ed..f3e77d427 100644 --- a/rocprim/include/rocprim/device/detail/config/device_radix_sort_block_sort.hpp +++ b/rocprim/include/rocprim/device/detail/config/device_radix_sort_block_sort.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2022-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2022-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -54,7 +54,7 @@ struct default_radix_sort_block_sort_config< value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<256, 31> + && (sizeof(value_type) > 4))>> : kernel_config<64, 25> {}; // Based on key_type = double, value_type = int @@ -65,7 +65,7 @@ struct default_radix_sort_block_sort_config< value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<256, 31> + && (sizeof(value_type) > 2))>> : kernel_config<64, 25> {}; // Based on key_type = double, value_type = short @@ -76,7 +76,7 @@ struct default_radix_sort_block_sort_config< value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<256, 31> + && (sizeof(value_type) > 1))>> : kernel_config<128, 26> {}; // Based on key_type = double, value_type = int8_t @@ -88,7 +88,7 @@ struct default_radix_sort_block_sort_config< std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) && (!std::is_same::value))>> - : kernel_config<256, 31> + : kernel_config<128, 28> {}; // Based on key_type = double, value_type = empty_type @@ -100,7 +100,7 @@ struct default_radix_sort_block_sort_config< std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) && (sizeof(key_type) > 4) && (std::is_same::value))>> - : kernel_config<256, 31> + : kernel_config<64, 32> {}; // Based on key_type = float, value_type = int64_t @@ -111,7 +111,7 @@ struct default_radix_sort_block_sort_config< value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<256, 31> + && (sizeof(value_type) > 4))>> : kernel_config<256, 25> {}; // Based on key_type = float, value_type = int @@ -122,7 +122,7 @@ struct default_radix_sort_block_sort_config< value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<256, 32> + && (sizeof(value_type) > 2))>> : kernel_config<128, 29> {}; // Based on key_type = float, value_type = short @@ -133,7 +133,7 @@ struct default_radix_sort_block_sort_config< value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<256, 32> + && (sizeof(value_type) > 1))>> : kernel_config<128, 31> {}; // Based on key_type = float, value_type = int8_t @@ -145,7 +145,7 @@ struct default_radix_sort_block_sort_config< std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) && (!std::is_same::value))>> - : kernel_config<256, 32> + : kernel_config<64, 31> {}; // Based on key_type = float, value_type = empty_type @@ -157,7 +157,7 @@ struct default_radix_sort_block_sort_config< std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) && (sizeof(key_type) > 2) && (std::is_same::value))>> - : kernel_config<128, 32> + : kernel_config<64, 32> {}; // Based on key_type = rocprim::half, value_type = int64_t @@ -168,7 +168,7 @@ struct default_radix_sort_block_sort_config< value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : kernel_config<512, 11> + : kernel_config<256, 25> {}; // Based on key_type = rocprim::half, value_type = int @@ -179,7 +179,7 @@ struct default_radix_sort_block_sort_config< value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : kernel_config<1024, 14> + : kernel_config<256, 32> {}; // Based on key_type = rocprim::half, value_type = short @@ -190,7 +190,7 @@ struct default_radix_sort_block_sort_config< value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : kernel_config<512, 8> + : kernel_config<128, 29> {}; // Based on key_type = rocprim::half, value_type = int8_t @@ -202,7 +202,7 @@ struct default_radix_sort_block_sort_config< std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) && (sizeof(value_type) <= 1) && (!std::is_same::value))>> - : kernel_config<1024, 18> + : kernel_config<128, 31> {}; // Based on key_type = rocprim::half, value_type = empty_type @@ -213,7 +213,7 @@ struct default_radix_sort_block_sort_config< value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) && (std::is_same::value))>> - : kernel_config<256, 14> + : kernel_config<128, 32> {}; // Based on key_type = int64_t, value_type = int64_t @@ -224,7 +224,7 @@ struct default_radix_sort_block_sort_config< value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<256, 31> + && (sizeof(value_type) > 4))>> : kernel_config<64, 25> {}; // Based on key_type = int64_t, value_type = int @@ -235,7 +235,7 @@ struct default_radix_sort_block_sort_config< value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<256, 31> + && (sizeof(value_type) > 2))>> : kernel_config<128, 30> {}; // Based on key_type = int64_t, value_type = short @@ -246,7 +246,7 @@ struct default_radix_sort_block_sort_config< value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<256, 31> + && (sizeof(value_type) > 1))>> : kernel_config<128, 29> {}; // Based on key_type = int64_t, value_type = int8_t @@ -258,7 +258,7 @@ struct default_radix_sort_block_sort_config< std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) && (!std::is_same::value))>> - : kernel_config<256, 31> + : kernel_config<64, 28> {}; // Based on key_type = int64_t, value_type = empty_type @@ -270,7 +270,7 @@ struct default_radix_sort_block_sort_config< std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) && (sizeof(key_type) > 4) && (std::is_same::value))>> - : kernel_config<256, 31> + : kernel_config<64, 28> {}; // Based on key_type = int, value_type = int64_t @@ -281,7 +281,7 @@ struct default_radix_sort_block_sort_config< value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<256, 31> + && (sizeof(value_type) > 4))>> : kernel_config<256, 25> {}; // Based on key_type = int, value_type = int @@ -292,7 +292,7 @@ struct default_radix_sort_block_sort_config< value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<256, 32> + && (sizeof(value_type) > 2))>> : kernel_config<128, 29> {}; // Based on key_type = int, value_type = short @@ -303,7 +303,7 @@ struct default_radix_sort_block_sort_config< value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<256, 32> + && (sizeof(value_type) > 1))>> : kernel_config<128, 31> {}; // Based on key_type = int, value_type = int8_t @@ -315,7 +315,7 @@ struct default_radix_sort_block_sort_config< std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) && (!std::is_same::value))>> - : kernel_config<256, 32> + : kernel_config<128, 31> {}; // Based on key_type = int, value_type = empty_type @@ -327,7 +327,7 @@ struct default_radix_sort_block_sort_config< std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) && (sizeof(key_type) > 2) && (std::is_same::value))>> - : kernel_config<128, 32> + : kernel_config<64, 32> {}; // Based on key_type = short, value_type = int64_t @@ -338,7 +338,7 @@ struct default_radix_sort_block_sort_config< value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<256, 16> + && (sizeof(value_type) > 4))>> : kernel_config<256, 26> {}; // Based on key_type = short, value_type = int @@ -360,7 +360,7 @@ struct default_radix_sort_block_sort_config< value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<256, 32> + && (sizeof(value_type) > 1))>> : kernel_config<128, 29> {}; // Based on key_type = short, value_type = int8_t @@ -372,7 +372,7 @@ struct default_radix_sort_block_sort_config< std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) && (!std::is_same::value))>> - : kernel_config<256, 32> + : kernel_config<64, 31> {}; // Based on key_type = short, value_type = empty_type @@ -384,7 +384,7 @@ struct default_radix_sort_block_sort_config< std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) && (sizeof(key_type) > 1) && (std::is_same::value))>> - : kernel_config<512, 32> + : kernel_config<64, 32> {}; // Based on key_type = int8_t, value_type = int64_t @@ -395,7 +395,7 @@ struct default_radix_sort_block_sort_config< value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : kernel_config<1024, 7> + : kernel_config<512, 15> {}; // Based on key_type = int8_t, value_type = int @@ -406,7 +406,7 @@ struct default_radix_sort_block_sort_config< value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : kernel_config<1024, 15> + : kernel_config<512, 31> {}; // Based on key_type = int8_t, value_type = short @@ -417,7 +417,7 @@ struct default_radix_sort_block_sort_config< value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : kernel_config<256, 32> + : kernel_config<256, 31> {}; // Based on key_type = int8_t, value_type = int8_t @@ -429,7 +429,7 @@ struct default_radix_sort_block_sort_config< std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) && (sizeof(value_type) <= 1) && (!std::is_same::value))>> - : kernel_config<256, 28> + : kernel_config<64, 31> {}; // Based on key_type = int8_t, value_type = empty_type @@ -440,875 +440,875 @@ struct default_radix_sort_block_sort_config< value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) && (std::is_same::value))>> - : kernel_config<512, 32> + : kernel_config<64, 32> {}; // Based on key_type = double, value_type = int64_t template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx900), + static_cast(target_arch::gfx1100), key_type, value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<128, 13> + && (sizeof(value_type) > 4))>> : kernel_config<128, 25> {}; // Based on key_type = double, value_type = int template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx900), + static_cast(target_arch::gfx1100), key_type, value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<256, 13> + && (sizeof(value_type) > 2))>> : kernel_config<128, 25> {}; // Based on key_type = double, value_type = short template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx900), + static_cast(target_arch::gfx1100), key_type, value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<256, 14> + && (sizeof(value_type) > 1))>> : kernel_config<128, 25> {}; // Based on key_type = double, value_type = int8_t template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx900), + static_cast(target_arch::gfx1100), key_type, value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) && (!std::is_same::value))>> - : kernel_config<256, 15> + : kernel_config<64, 28> {}; // Based on key_type = double, value_type = empty_type template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx900), + static_cast(target_arch::gfx1100), key_type, value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) && (sizeof(key_type) > 4) && (std::is_same::value))>> - : kernel_config<256, 15> + : kernel_config<128, 32> {}; // Based on key_type = float, value_type = int64_t template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx900), + static_cast(target_arch::gfx1100), key_type, value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<256, 13> + && (sizeof(value_type) > 4))>> : kernel_config<128, 25> {}; // Based on key_type = float, value_type = int template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx900), + static_cast(target_arch::gfx1100), key_type, value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<128, 17> + && (sizeof(value_type) > 2))>> : kernel_config<64, 31> {}; // Based on key_type = float, value_type = short template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx900), + static_cast(target_arch::gfx1100), key_type, value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<256, 15> + && (sizeof(value_type) > 1))>> : kernel_config<64, 30> {}; // Based on key_type = float, value_type = int8_t template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx900), + static_cast(target_arch::gfx1100), key_type, value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) && (!std::is_same::value))>> - : kernel_config<256, 17> + : kernel_config<64, 30> {}; // Based on key_type = float, value_type = empty_type template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx900), + static_cast(target_arch::gfx1100), key_type, value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) && (sizeof(key_type) > 2) && (std::is_same::value))>> - : kernel_config<512, 21> + : kernel_config<128, 32> {}; // Based on key_type = rocprim::half, value_type = int64_t template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx900), + static_cast(target_arch::gfx1100), key_type, value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : kernel_config<512, 5> + : kernel_config<256, 25> {}; // Based on key_type = rocprim::half, value_type = int template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx900), + static_cast(target_arch::gfx1100), key_type, value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : kernel_config<256, 9> + : kernel_config<128, 30> {}; // Based on key_type = rocprim::half, value_type = short template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx900), + static_cast(target_arch::gfx1100), key_type, value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : kernel_config<256, 7> + : kernel_config<128, 31> {}; // Based on key_type = rocprim::half, value_type = int8_t template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx900), + static_cast(target_arch::gfx1100), key_type, value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) && (sizeof(value_type) <= 1) && (!std::is_same::value))>> - : kernel_config<256, 10> + : kernel_config<64, 27> {}; // Based on key_type = rocprim::half, value_type = empty_type template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx900), + static_cast(target_arch::gfx1100), key_type, value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) && (std::is_same::value))>> - : kernel_config<256, 12> + : kernel_config<64, 32> {}; // Based on key_type = int64_t, value_type = int64_t template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx900), + static_cast(target_arch::gfx1100), key_type, value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<64, 29> + && (sizeof(value_type) > 4))>> : kernel_config<128, 25> {}; // Based on key_type = int64_t, value_type = int template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx900), + static_cast(target_arch::gfx1100), key_type, value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<256, 13> + && (sizeof(value_type) > 2))>> : kernel_config<128, 25> {}; // Based on key_type = int64_t, value_type = short template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx900), + static_cast(target_arch::gfx1100), key_type, value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<256, 14> + && (sizeof(value_type) > 1))>> : kernel_config<128, 25> {}; // Based on key_type = int64_t, value_type = int8_t template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx900), + static_cast(target_arch::gfx1100), key_type, value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) && (!std::is_same::value))>> - : kernel_config<256, 15> + : kernel_config<64, 28> {}; // Based on key_type = int64_t, value_type = empty_type template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx900), + static_cast(target_arch::gfx1100), key_type, value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) && (sizeof(key_type) > 4) && (std::is_same::value))>> - : kernel_config<256, 15> + : kernel_config<128, 32> {}; // Based on key_type = int, value_type = int64_t template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx900), + static_cast(target_arch::gfx1100), key_type, value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<256, 13> + && (sizeof(value_type) > 4))>> : kernel_config<128, 25> {}; // Based on key_type = int, value_type = int template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx900), + static_cast(target_arch::gfx1100), key_type, value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<128, 17> + && (sizeof(value_type) > 2))>> : kernel_config<64, 31> {}; // Based on key_type = int, value_type = short template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx900), + static_cast(target_arch::gfx1100), key_type, value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<256, 16> + && (sizeof(value_type) > 1))>> : kernel_config<64, 30> {}; // Based on key_type = int, value_type = int8_t template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx900), + static_cast(target_arch::gfx1100), key_type, value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) && (!std::is_same::value))>> - : kernel_config<256, 18> + : kernel_config<64, 30> {}; // Based on key_type = int, value_type = empty_type template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx900), + static_cast(target_arch::gfx1100), key_type, value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) && (sizeof(key_type) > 2) && (std::is_same::value))>> - : kernel_config<256, 21> + : kernel_config<64, 30> {}; // Based on key_type = short, value_type = int64_t template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx900), + static_cast(target_arch::gfx1100), key_type, value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<512, 5> + && (sizeof(value_type) > 4))>> : kernel_config<256, 25> {}; // Based on key_type = short, value_type = int template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx900), + static_cast(target_arch::gfx1100), key_type, value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<256, 9> + && (sizeof(value_type) > 2))>> : kernel_config<128, 30> {}; // Based on key_type = short, value_type = short template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx900), + static_cast(target_arch::gfx1100), key_type, value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<256, 7> + && (sizeof(value_type) > 1))>> : kernel_config<64, 31> {}; // Based on key_type = short, value_type = int8_t template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx900), + static_cast(target_arch::gfx1100), key_type, value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) && (!std::is_same::value))>> - : kernel_config<256, 10> + : kernel_config<128, 30> {}; // Based on key_type = short, value_type = empty_type template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx900), + static_cast(target_arch::gfx1100), key_type, value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) && (sizeof(key_type) > 1) && (std::is_same::value))>> - : kernel_config<256, 32> + : kernel_config<64, 32> {}; // Based on key_type = int8_t, value_type = int64_t template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx900), + static_cast(target_arch::gfx1100), key_type, value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : kernel_config<1024, 6> + : kernel_config<256, 16> {}; // Based on key_type = int8_t, value_type = int template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx900), + static_cast(target_arch::gfx1100), key_type, value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : kernel_config<512, 7> + : kernel_config<256, 30> {}; // Based on key_type = int8_t, value_type = short template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx900), + static_cast(target_arch::gfx1100), key_type, value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : kernel_config<512, 7> + : kernel_config<128, 30> {}; // Based on key_type = int8_t, value_type = int8_t template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx900), + static_cast(target_arch::gfx1100), key_type, value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) && (sizeof(value_type) <= 1) && (!std::is_same::value))>> - : kernel_config<256, 20> + : kernel_config<64, 31> {}; // Based on key_type = int8_t, value_type = empty_type template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx900), + static_cast(target_arch::gfx1100), key_type, value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) && (std::is_same::value))>> - : kernel_config<512, 32> + : kernel_config<64, 32> {}; // Based on key_type = double, value_type = int64_t template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), + static_cast(target_arch::gfx906), key_type, value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<256, 13> + && (sizeof(value_type) > 4))>> : kernel_config<256, 12> {}; // Based on key_type = double, value_type = int template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), + static_cast(target_arch::gfx906), key_type, value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<256, 13> + && (sizeof(value_type) > 2))>> : kernel_config<256, 12> {}; // Based on key_type = double, value_type = short template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), + static_cast(target_arch::gfx906), key_type, value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<256, 14> + && (sizeof(value_type) > 1))>> : kernel_config<256, 13> {}; // Based on key_type = double, value_type = int8_t template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), + static_cast(target_arch::gfx906), key_type, value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) && (!std::is_same::value))>> - : kernel_config<128, 15> + : kernel_config<256, 12> {}; // Based on key_type = double, value_type = empty_type template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), + static_cast(target_arch::gfx906), key_type, value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) && (sizeof(key_type) > 4) && (std::is_same::value))>> - : kernel_config<256, 15> + : kernel_config<256, 16> {}; // Based on key_type = float, value_type = int64_t template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), + static_cast(target_arch::gfx906), key_type, value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<256, 13> + && (sizeof(value_type) > 4))>> : kernel_config<256, 12> {}; // Based on key_type = float, value_type = int template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), + static_cast(target_arch::gfx906), key_type, value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<512, 31> + && (sizeof(value_type) > 2))>> : kernel_config<256, 18> {}; // Based on key_type = float, value_type = short template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), + static_cast(target_arch::gfx906), key_type, value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<1024, 15> + && (sizeof(value_type) > 1))>> : kernel_config<256, 16> {}; // Based on key_type = float, value_type = int8_t template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), + static_cast(target_arch::gfx906), key_type, value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) && (!std::is_same::value))>> - : kernel_config<1024, 15> + : kernel_config<256, 15> {}; // Based on key_type = float, value_type = empty_type template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), + static_cast(target_arch::gfx906), key_type, value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) && (sizeof(key_type) > 2) && (std::is_same::value))>> - : kernel_config<512, 31> + : kernel_config<512, 23> {}; // Based on key_type = rocprim::half, value_type = int64_t template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), + static_cast(target_arch::gfx906), key_type, value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : kernel_config<256, 7> + : kernel_config<256, 12> {}; // Based on key_type = rocprim::half, value_type = int template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), + static_cast(target_arch::gfx906), key_type, value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : kernel_config<256, 9> + : kernel_config<256, 15> {}; // Based on key_type = rocprim::half, value_type = short template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), + static_cast(target_arch::gfx906), key_type, value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : kernel_config<256, 7> + : kernel_config<256, 19> {}; // Based on key_type = rocprim::half, value_type = int8_t template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), + static_cast(target_arch::gfx906), key_type, value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) && (sizeof(value_type) <= 1) && (!std::is_same::value))>> - : kernel_config<256, 10> + : kernel_config<512, 17> {}; // Based on key_type = rocprim::half, value_type = empty_type template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), + static_cast(target_arch::gfx906), key_type, value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) && (std::is_same::value))>> - : kernel_config<256, 12> + : kernel_config<512, 23> {}; // Based on key_type = int64_t, value_type = int64_t template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), + static_cast(target_arch::gfx906), key_type, value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<256, 13> + && (sizeof(value_type) > 4))>> : kernel_config<256, 12> {}; // Based on key_type = int64_t, value_type = int template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), + static_cast(target_arch::gfx906), key_type, value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<256, 13> + && (sizeof(value_type) > 2))>> : kernel_config<256, 12> {}; // Based on key_type = int64_t, value_type = short template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), + static_cast(target_arch::gfx906), key_type, value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<256, 14> + && (sizeof(value_type) > 1))>> : kernel_config<256, 13> {}; // Based on key_type = int64_t, value_type = int8_t template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), + static_cast(target_arch::gfx906), key_type, value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) && (!std::is_same::value))>> - : kernel_config<256, 15> + : kernel_config<256, 12> {}; // Based on key_type = int64_t, value_type = empty_type template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), + static_cast(target_arch::gfx906), key_type, value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) && (sizeof(key_type) > 4) && (std::is_same::value))>> - : kernel_config<256, 15> + : kernel_config<256, 16> {}; // Based on key_type = int, value_type = int64_t template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), + static_cast(target_arch::gfx906), key_type, value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<256, 13> + && (sizeof(value_type) > 4))>> : kernel_config<256, 12> {}; // Based on key_type = int, value_type = int template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), + static_cast(target_arch::gfx906), key_type, value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<512, 31> + && (sizeof(value_type) > 2))>> : kernel_config<256, 19> {}; // Based on key_type = int, value_type = short template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), + static_cast(target_arch::gfx906), key_type, value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<1024, 15> + && (sizeof(value_type) > 1))>> : kernel_config<256, 16> {}; // Based on key_type = int, value_type = int8_t template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), + static_cast(target_arch::gfx906), key_type, value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) && (!std::is_same::value))>> - : kernel_config<1024, 15> + : kernel_config<256, 15> {}; // Based on key_type = int, value_type = empty_type template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), + static_cast(target_arch::gfx906), key_type, value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) && (sizeof(key_type) > 2) && (std::is_same::value))>> - : kernel_config<1024, 15> + : kernel_config<512, 23> {}; // Based on key_type = short, value_type = int64_t template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), + static_cast(target_arch::gfx906), key_type, value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<512, 13> + && (sizeof(value_type) > 4))>> : kernel_config<256, 12> {}; // Based on key_type = short, value_type = int template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), + static_cast(target_arch::gfx906), key_type, value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<1024, 13> + && (sizeof(value_type) > 2))>> : kernel_config<256, 14> {}; // Based on key_type = short, value_type = short template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), + static_cast(target_arch::gfx906), key_type, value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<64, 32> + && (sizeof(value_type) > 1))>> : kernel_config<256, 19> {}; // Based on key_type = short, value_type = int8_t template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), + static_cast(target_arch::gfx906), key_type, value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) && (!std::is_same::value))>> - : kernel_config<512, 32> + : kernel_config<256, 15> {}; // Based on key_type = short, value_type = empty_type template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), + static_cast(target_arch::gfx906), key_type, value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) && (sizeof(key_type) > 1) && (std::is_same::value))>> - : kernel_config<256, 32> + : kernel_config<256, 23> {}; // Based on key_type = int8_t, value_type = int64_t template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), + static_cast(target_arch::gfx906), key_type, value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : kernel_config<512, 14> + : kernel_config<256, 13> {}; // Based on key_type = int8_t, value_type = int template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), + static_cast(target_arch::gfx906), key_type, value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : kernel_config<512, 24> + : kernel_config<256, 10> {}; // Based on key_type = int8_t, value_type = short template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), + static_cast(target_arch::gfx906), key_type, value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : kernel_config<256, 7> + : kernel_config<256, 16> {}; // Based on key_type = int8_t, value_type = int8_t template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), + static_cast(target_arch::gfx906), key_type, value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) && (sizeof(value_type) <= 1) && (!std::is_same::value))>> - : kernel_config<1024, 20> + : kernel_config<256, 17> {}; // Based on key_type = int8_t, value_type = empty_type template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), + static_cast(target_arch::gfx906), key_type, value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) && (std::is_same::value))>> - : kernel_config<1024, 25> + : kernel_config<256, 22> {}; // Based on key_type = double, value_type = int64_t template struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), + static_cast(target_arch::gfx908), key_type, value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<256, 13> + && (sizeof(value_type) > 4))>> : kernel_config<256, 12> {}; // Based on key_type = double, value_type = int template struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), + static_cast(target_arch::gfx908), key_type, value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<256, 13> + && (sizeof(value_type) > 2))>> : kernel_config<256, 12> {}; // Based on key_type = double, value_type = short template struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), + static_cast(target_arch::gfx908), key_type, value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<256, 14> + && (sizeof(value_type) > 1))>> : kernel_config<256, 13> {}; // Based on key_type = double, value_type = int8_t template struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), + static_cast(target_arch::gfx908), key_type, value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) && (!std::is_same::value))>> - : kernel_config<128, 15> + : kernel_config<256, 12> {}; // Based on key_type = double, value_type = empty_type template struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), + static_cast(target_arch::gfx908), key_type, value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) && (sizeof(key_type) > 4) && (std::is_same::value))>> - : kernel_config<256, 15> + : kernel_config<256, 16> {}; // Based on key_type = float, value_type = int64_t template struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), + static_cast(target_arch::gfx908), key_type, value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<256, 13> + && (sizeof(value_type) > 4))>> : kernel_config<256, 12> {}; // Based on key_type = float, value_type = int template struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), + static_cast(target_arch::gfx908), key_type, value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) @@ -1319,30 +1319,30 @@ struct default_radix_sort_block_sort_config< // Based on key_type = float, value_type = short template struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), + static_cast(target_arch::gfx908), key_type, value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<1024, 15> + && (sizeof(value_type) > 1))>> : kernel_config<512, 31> {}; // Based on key_type = float, value_type = int8_t template struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), + static_cast(target_arch::gfx908), key_type, value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) && (!std::is_same::value))>> - : kernel_config<1024, 15> + : kernel_config<512, 31> {}; // Based on key_type = float, value_type = empty_type template struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), + static_cast(target_arch::gfx908), key_type, value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) @@ -1354,131 +1354,131 @@ struct default_radix_sort_block_sort_config< // Based on key_type = rocprim::half, value_type = int64_t template struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), + static_cast(target_arch::gfx908), key_type, value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : kernel_config<256, 7> + : kernel_config<256, 12> {}; // Based on key_type = rocprim::half, value_type = int template struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), + static_cast(target_arch::gfx908), key_type, value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : kernel_config<256, 9> + : kernel_config<512, 31> {}; // Based on key_type = rocprim::half, value_type = short template struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), + static_cast(target_arch::gfx908), key_type, value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : kernel_config<256, 7> + : kernel_config<512, 32> {}; // Based on key_type = rocprim::half, value_type = int8_t template struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), + static_cast(target_arch::gfx908), key_type, value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) && (sizeof(value_type) <= 1) && (!std::is_same::value))>> - : kernel_config<256, 10> + : kernel_config<512, 31> {}; // Based on key_type = rocprim::half, value_type = empty_type template struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), + static_cast(target_arch::gfx908), key_type, value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) && (std::is_same::value))>> - : kernel_config<256, 12> + : kernel_config<1024, 23> {}; // Based on key_type = int64_t, value_type = int64_t template struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), + static_cast(target_arch::gfx908), key_type, value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<256, 13> + && (sizeof(value_type) > 4))>> : kernel_config<256, 12> {}; // Based on key_type = int64_t, value_type = int template struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), + static_cast(target_arch::gfx908), key_type, value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<256, 13> + && (sizeof(value_type) > 2))>> : kernel_config<256, 12> {}; // Based on key_type = int64_t, value_type = short template struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), + static_cast(target_arch::gfx908), key_type, value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<256, 14> + && (sizeof(value_type) > 1))>> : kernel_config<256, 8> {}; // Based on key_type = int64_t, value_type = int8_t template struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), + static_cast(target_arch::gfx908), key_type, value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) && (!std::is_same::value))>> - : kernel_config<256, 15> + : kernel_config<256, 12> {}; // Based on key_type = int64_t, value_type = empty_type template struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), + static_cast(target_arch::gfx908), key_type, value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) && (sizeof(key_type) > 4) && (std::is_same::value))>> - : kernel_config<256, 15> + : kernel_config<256, 16> {}; // Based on key_type = int, value_type = int64_t template struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), + static_cast(target_arch::gfx908), key_type, value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<256, 13> + && (sizeof(value_type) > 4))>> : kernel_config<256, 12> {}; // Based on key_type = int, value_type = int template struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), + static_cast(target_arch::gfx908), key_type, value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) @@ -1489,75 +1489,75 @@ struct default_radix_sort_block_sort_config< // Based on key_type = int, value_type = short template struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), + static_cast(target_arch::gfx908), key_type, value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<1024, 15> + && (sizeof(value_type) > 1))>> : kernel_config<512, 31> {}; // Based on key_type = int, value_type = int8_t template struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), + static_cast(target_arch::gfx908), key_type, value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) && (!std::is_same::value))>> - : kernel_config<1024, 15> + : kernel_config<512, 31> {}; // Based on key_type = int, value_type = empty_type template struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), + static_cast(target_arch::gfx908), key_type, value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) && (sizeof(key_type) > 2) && (std::is_same::value))>> - : kernel_config<1024, 15> + : kernel_config<512, 31> {}; // Based on key_type = short, value_type = int64_t template struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), + static_cast(target_arch::gfx908), key_type, value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<512, 13> + && (sizeof(value_type) > 4))>> : kernel_config<256, 12> {}; // Based on key_type = short, value_type = int template struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), + static_cast(target_arch::gfx908), key_type, value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<1024, 13> + && (sizeof(value_type) > 2))>> : kernel_config<512, 31> {}; // Based on key_type = short, value_type = short template struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), + static_cast(target_arch::gfx908), key_type, value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<64, 32> + && (sizeof(value_type) > 1))>> : kernel_config<512, 32> {}; // Based on key_type = short, value_type = int8_t template struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), + static_cast(target_arch::gfx908), key_type, value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) @@ -1569,143 +1569,143 @@ struct default_radix_sort_block_sort_config< // Based on key_type = short, value_type = empty_type template struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), + static_cast(target_arch::gfx908), key_type, value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) && (sizeof(key_type) > 1) && (std::is_same::value))>> - : kernel_config<256, 32> + : kernel_config<1024, 23> {}; // Based on key_type = int8_t, value_type = int64_t template struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), + static_cast(target_arch::gfx908), key_type, value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : kernel_config<512, 14> + : kernel_config<256, 13> {}; // Based on key_type = int8_t, value_type = int template struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), + static_cast(target_arch::gfx908), key_type, value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : kernel_config<512, 24> + : kernel_config<512, 31> {}; // Based on key_type = int8_t, value_type = short template struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), + static_cast(target_arch::gfx908), key_type, value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : kernel_config<256, 7> + : kernel_config<512, 32> {}; // Based on key_type = int8_t, value_type = int8_t template struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), + static_cast(target_arch::gfx908), key_type, value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) && (sizeof(value_type) <= 1) && (!std::is_same::value))>> - : kernel_config<1024, 20> + : kernel_config<512, 32> {}; // Based on key_type = int8_t, value_type = empty_type template struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), + static_cast(target_arch::gfx908), key_type, value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) && (std::is_same::value))>> - : kernel_config<1024, 25> + : kernel_config<1024, 23> {}; // Based on key_type = double, value_type = int64_t template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), + static_cast(target_arch::unknown), key_type, value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<256, 13> + && (sizeof(value_type) > 4))>> : kernel_config<256, 12> {}; // Based on key_type = double, value_type = int template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), + static_cast(target_arch::unknown), key_type, value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<256, 13> + && (sizeof(value_type) > 2))>> : kernel_config<256, 12> {}; // Based on key_type = double, value_type = short template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), + static_cast(target_arch::unknown), key_type, value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<256, 14> + && (sizeof(value_type) > 1))>> : kernel_config<256, 13> {}; // Based on key_type = double, value_type = int8_t template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), + static_cast(target_arch::unknown), key_type, value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) && (!std::is_same::value))>> - : kernel_config<128, 15> + : kernel_config<256, 12> {}; // Based on key_type = double, value_type = empty_type template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), + static_cast(target_arch::unknown), key_type, value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) && (sizeof(key_type) > 4) && (std::is_same::value))>> - : kernel_config<256, 15> + : kernel_config<256, 16> {}; // Based on key_type = float, value_type = int64_t template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), + static_cast(target_arch::unknown), key_type, value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<256, 13> + && (sizeof(value_type) > 4))>> : kernel_config<256, 12> {}; // Based on key_type = float, value_type = int template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), + static_cast(target_arch::unknown), key_type, value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) @@ -1716,30 +1716,30 @@ struct default_radix_sort_block_sort_config< // Based on key_type = float, value_type = short template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), + static_cast(target_arch::unknown), key_type, value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<1024, 15> + && (sizeof(value_type) > 1))>> : kernel_config<512, 31> {}; // Based on key_type = float, value_type = int8_t template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), + static_cast(target_arch::unknown), key_type, value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) && (!std::is_same::value))>> - : kernel_config<1024, 15> + : kernel_config<512, 31> {}; // Based on key_type = float, value_type = empty_type template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), + static_cast(target_arch::unknown), key_type, value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) @@ -1751,131 +1751,131 @@ struct default_radix_sort_block_sort_config< // Based on key_type = rocprim::half, value_type = int64_t template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), + static_cast(target_arch::unknown), key_type, value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : kernel_config<256, 7> + : kernel_config<256, 12> {}; // Based on key_type = rocprim::half, value_type = int template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), + static_cast(target_arch::unknown), key_type, value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : kernel_config<256, 9> + : kernel_config<512, 31> {}; // Based on key_type = rocprim::half, value_type = short template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), + static_cast(target_arch::unknown), key_type, value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : kernel_config<256, 7> + : kernel_config<512, 32> {}; // Based on key_type = rocprim::half, value_type = int8_t template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), + static_cast(target_arch::unknown), key_type, value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) && (sizeof(value_type) <= 1) && (!std::is_same::value))>> - : kernel_config<256, 10> + : kernel_config<512, 31> {}; // Based on key_type = rocprim::half, value_type = empty_type template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), + static_cast(target_arch::unknown), key_type, value_type, std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) && (std::is_same::value))>> - : kernel_config<256, 12> + : kernel_config<1024, 23> {}; // Based on key_type = int64_t, value_type = int64_t template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), + static_cast(target_arch::unknown), key_type, value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<256, 13> + && (sizeof(value_type) > 4))>> : kernel_config<256, 12> {}; // Based on key_type = int64_t, value_type = int template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), + static_cast(target_arch::unknown), key_type, value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<256, 13> + && (sizeof(value_type) > 2))>> : kernel_config<256, 12> {}; // Based on key_type = int64_t, value_type = short template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), + static_cast(target_arch::unknown), key_type, value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<256, 14> + && (sizeof(value_type) > 1))>> : kernel_config<256, 8> {}; // Based on key_type = int64_t, value_type = int8_t template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), + static_cast(target_arch::unknown), key_type, value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) && (!std::is_same::value))>> - : kernel_config<256, 15> + : kernel_config<256, 12> {}; // Based on key_type = int64_t, value_type = empty_type template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), + static_cast(target_arch::unknown), key_type, value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) && (sizeof(key_type) > 4) && (std::is_same::value))>> - : kernel_config<256, 15> + : kernel_config<256, 16> {}; // Based on key_type = int, value_type = int64_t template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), + static_cast(target_arch::unknown), key_type, value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<256, 13> + && (sizeof(value_type) > 4))>> : kernel_config<256, 12> {}; // Based on key_type = int, value_type = int template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), + static_cast(target_arch::unknown), key_type, value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) @@ -1886,75 +1886,75 @@ struct default_radix_sort_block_sort_config< // Based on key_type = int, value_type = short template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), + static_cast(target_arch::unknown), key_type, value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<1024, 15> + && (sizeof(value_type) > 1))>> : kernel_config<512, 31> {}; // Based on key_type = int, value_type = int8_t template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), + static_cast(target_arch::unknown), key_type, value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) && (!std::is_same::value))>> - : kernel_config<1024, 15> + : kernel_config<512, 31> {}; // Based on key_type = int, value_type = empty_type template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), + static_cast(target_arch::unknown), key_type, value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) && (sizeof(key_type) > 2) && (std::is_same::value))>> - : kernel_config<1024, 15> + : kernel_config<512, 31> {}; // Based on key_type = short, value_type = int64_t template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), + static_cast(target_arch::unknown), key_type, value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<512, 13> + && (sizeof(value_type) > 4))>> : kernel_config<256, 12> {}; // Based on key_type = short, value_type = int template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), + static_cast(target_arch::unknown), key_type, value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<1024, 13> + && (sizeof(value_type) > 2))>> : kernel_config<512, 31> {}; // Based on key_type = short, value_type = short template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), + static_cast(target_arch::unknown), key_type, value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<64, 32> + && (sizeof(value_type) > 1))>> : kernel_config<512, 32> {}; // Based on key_type = short, value_type = int8_t template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), + static_cast(target_arch::unknown), key_type, value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) @@ -1966,234 +1966,863 @@ struct default_radix_sort_block_sort_config< // Based on key_type = short, value_type = empty_type template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), + static_cast(target_arch::unknown), key_type, value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) && (sizeof(key_type) > 1) && (std::is_same::value))>> - : kernel_config<256, 32> + : kernel_config<1024, 23> {}; // Based on key_type = int8_t, value_type = int64_t template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), + static_cast(target_arch::unknown), key_type, value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : kernel_config<512, 14> + : kernel_config<256, 13> {}; // Based on key_type = int8_t, value_type = int template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), + static_cast(target_arch::unknown), key_type, value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : kernel_config<512, 24> + : kernel_config<512, 31> {}; // Based on key_type = int8_t, value_type = short template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), + static_cast(target_arch::unknown), key_type, value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : kernel_config<256, 7> + : kernel_config<512, 32> {}; // Based on key_type = int8_t, value_type = int8_t template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), + static_cast(target_arch::unknown), key_type, value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) && (sizeof(value_type) <= 1) && (!std::is_same::value))>> - : kernel_config<1024, 20> + : kernel_config<512, 32> {}; // Based on key_type = int8_t, value_type = empty_type template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), + static_cast(target_arch::unknown), key_type, value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) && (std::is_same::value))>> - : kernel_config<1024, 25> + : kernel_config<1024, 23> {}; -template<> -struct default_radix_sort_block_sort_config(target_arch::gfx906), double> - : kernel_config<256, 15> +// Based on key_type = double, value_type = int64_t +template +struct default_radix_sort_block_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> : kernel_config<256, 12> {}; -template<> -struct default_radix_sort_block_sort_config(target_arch::gfx906), float> - : kernel_config<512, 25> +// Based on key_type = double, value_type = int +template +struct default_radix_sort_block_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> : kernel_config<256, 12> {}; -template<> -struct default_radix_sort_block_sort_config(target_arch::gfx906), - int, - double> : kernel_config<256, 13> +// Based on key_type = double, value_type = short +template +struct default_radix_sort_block_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> : kernel_config<256, 13> {}; -template<> -struct default_radix_sort_block_sort_config(target_arch::gfx906), - int, - double2> : kernel_config<256, 5> +// Based on key_type = double, value_type = int8_t +template +struct default_radix_sort_block_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : kernel_config<256, 12> {}; -template<> -struct default_radix_sort_block_sort_config(target_arch::gfx906), - int, - float> : kernel_config<64, 17> +// Based on key_type = double, value_type = empty_type +template +struct default_radix_sort_block_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))>> + : kernel_config<256, 16> {}; -template<> -struct default_radix_sort_block_sort_config(target_arch::gfx906), int> - : kernel_config<512, 25> +// Based on key_type = float, value_type = int64_t +template +struct default_radix_sort_block_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> : kernel_config<256, 12> {}; -template<> -struct default_radix_sort_block_sort_config(target_arch::gfx906), - int64_t, - double> : kernel_config<256, 13> +// Based on key_type = float, value_type = int +template +struct default_radix_sort_block_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> : kernel_config<512, 31> {}; -template<> -struct default_radix_sort_block_sort_config(target_arch::gfx906), - int64_t, - float> : kernel_config<256, 13> +// Based on key_type = float, value_type = short +template +struct default_radix_sort_block_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> : kernel_config<512, 31> {}; -template<> -struct default_radix_sort_block_sort_config(target_arch::gfx906), int64_t> - : kernel_config<256, 15> +// Based on key_type = float, value_type = int8_t +template +struct default_radix_sort_block_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : kernel_config<512, 31> {}; -template<> -struct default_radix_sort_block_sort_config(target_arch::gfx906), - int8_t, - int8_t> : kernel_config<256, 20> +// Based on key_type = float, value_type = empty_type +template +struct default_radix_sort_block_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))>> + : kernel_config<512, 31> {}; -template<> -struct default_radix_sort_block_sort_config(target_arch::gfx906), int8_t> - : kernel_config<256, 32> +// Based on key_type = rocprim::half, value_type = int64_t +template +struct default_radix_sort_block_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> + : kernel_config<256, 12> {}; -template<> -struct default_radix_sort_block_sort_config(target_arch::gfx906), - rocprim::half> : kernel_config<256, 12> +// Based on key_type = rocprim::half, value_type = int +template +struct default_radix_sort_block_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> + : kernel_config<512, 31> {}; -template<> -struct default_radix_sort_block_sort_config(target_arch::gfx906), - rocprim::half, - rocprim::half> : kernel_config<512, 6> +// Based on key_type = rocprim::half, value_type = short +template +struct default_radix_sort_block_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> + : kernel_config<512, 32> {}; -template<> -struct default_radix_sort_block_sort_config(target_arch::gfx906), short> - : kernel_config<256, 32> +// Based on key_type = rocprim::half, value_type = int8_t +template +struct default_radix_sort_block_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : kernel_config<512, 31> {}; -template<> -struct default_radix_sort_block_sort_config(target_arch::gfx906), - uint8_t, - uint8_t> : kernel_config<256, 20> +// Based on key_type = rocprim::half, value_type = empty_type +template +struct default_radix_sort_block_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))>> + : kernel_config<1024, 23> {}; -// Based on key_type = double, value_type = rocprim::empty_type +// Based on key_type = int64_t, value_type = int64_t template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), + static_cast(target_arch::gfx90a), key_type, value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : kernel_config<256, 15> + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> : kernel_config<256, 12> {}; -// Based on key_type = float, value_type = rocprim::empty_type +// Based on key_type = int64_t, value_type = int template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), + static_cast(target_arch::gfx90a), key_type, value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : kernel_config<512, 25> + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> : kernel_config<256, 12> {}; -// Based on key_type = rocprim::half, value_type = rocprim::empty_type +// Based on key_type = int64_t, value_type = short template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), + static_cast(target_arch::gfx90a), key_type, value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> : kernel_config<256, 8> +{}; + +// Based on key_type = int64_t, value_type = int8_t +template +struct default_radix_sort_block_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> : kernel_config<256, 12> {}; -// Based on key_type = int64_t, value_type = rocprim::empty_type +// Based on key_type = int64_t, value_type = empty_type template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), + static_cast(target_arch::gfx90a), key_type, value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) && (sizeof(key_type) > 4) && (std::is_same::value))>> - : kernel_config<256, 15> + : kernel_config<256, 16> {}; -// Based on key_type = int, value_type = rocprim::empty_type +// Based on key_type = int, value_type = int64_t template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), + static_cast(target_arch::gfx90a), key_type, value_type, std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : kernel_config<512, 25> + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> : kernel_config<256, 12> {}; -// Based on key_type = short, value_type = rocprim::empty_type +// Based on key_type = int, value_type = int template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), + static_cast(target_arch::gfx90a), key_type, value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : kernel_config<256, 32> + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> : kernel_config<512, 31> {}; -// Based on key_type = int8_t, value_type = rocprim::empty_type +// Based on key_type = int, value_type = short template struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), + static_cast(target_arch::gfx90a), key_type, value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : kernel_config<256, 32> + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> : kernel_config<512, 31> +{}; + +// Based on key_type = int, value_type = int8_t +template +struct default_radix_sort_block_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : kernel_config<512, 31> +{}; + +// Based on key_type = int, value_type = empty_type +template +struct default_radix_sort_block_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))>> + : kernel_config<512, 31> +{}; + +// Based on key_type = short, value_type = int64_t +template +struct default_radix_sort_block_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> : kernel_config<256, 12> +{}; + +// Based on key_type = short, value_type = int +template +struct default_radix_sort_block_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> : kernel_config<512, 31> +{}; + +// Based on key_type = short, value_type = short +template +struct default_radix_sort_block_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> : kernel_config<512, 32> +{}; + +// Based on key_type = short, value_type = int8_t +template +struct default_radix_sort_block_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : kernel_config<512, 32> +{}; + +// Based on key_type = short, value_type = empty_type +template +struct default_radix_sort_block_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))>> + : kernel_config<1024, 23> +{}; + +// Based on key_type = int8_t, value_type = int64_t +template +struct default_radix_sort_block_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> + : kernel_config<256, 13> +{}; + +// Based on key_type = int8_t, value_type = int +template +struct default_radix_sort_block_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> + : kernel_config<512, 31> +{}; + +// Based on key_type = int8_t, value_type = short +template +struct default_radix_sort_block_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> + : kernel_config<512, 32> +{}; + +// Based on key_type = int8_t, value_type = int8_t +template +struct default_radix_sort_block_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : kernel_config<512, 32> +{}; + +// Based on key_type = int8_t, value_type = empty_type +template +struct default_radix_sort_block_sort_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))>> + : kernel_config<1024, 23> +{}; + +// Based on key_type = double, value_type = int64_t +template +struct default_radix_sort_block_sort_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> : kernel_config<256, 16> +{}; + +// Based on key_type = double, value_type = int +template +struct default_radix_sort_block_sort_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> : kernel_config<256, 16> +{}; + +// Based on key_type = double, value_type = short +template +struct default_radix_sort_block_sort_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> : kernel_config<256, 16> +{}; + +// Based on key_type = double, value_type = int8_t +template +struct default_radix_sort_block_sort_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : kernel_config<256, 16> +{}; + +// Based on key_type = double, value_type = empty_type +template +struct default_radix_sort_block_sort_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))>> + : kernel_config<256, 16> +{}; + +// Based on key_type = float, value_type = int64_t +template +struct default_radix_sort_block_sort_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> : kernel_config<256, 8> +{}; + +// Based on key_type = float, value_type = int +template +struct default_radix_sort_block_sort_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> : kernel_config<512, 16> +{}; + +// Based on key_type = float, value_type = short +template +struct default_radix_sort_block_sort_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> : kernel_config<256, 32> +{}; + +// Based on key_type = float, value_type = int8_t +template +struct default_radix_sort_block_sort_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : kernel_config<256, 32> +{}; + +// Based on key_type = float, value_type = empty_type +template +struct default_radix_sort_block_sort_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))>> + : kernel_config<256, 32> +{}; + +// Based on key_type = rocprim::half, value_type = int64_t +template +struct default_radix_sort_block_sort_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> + : kernel_config<512, 8> +{}; + +// Based on key_type = rocprim::half, value_type = int +template +struct default_radix_sort_block_sort_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> + : kernel_config<256, 32> +{}; + +// Based on key_type = rocprim::half, value_type = short +template +struct default_radix_sort_block_sort_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> + : kernel_config<512, 18> +{}; + +// Based on key_type = rocprim::half, value_type = int8_t +template +struct default_radix_sort_block_sort_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : kernel_config<256, 32> +{}; + +// Based on key_type = rocprim::half, value_type = empty_type +template +struct default_radix_sort_block_sort_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))>> + : kernel_config<512, 21> +{}; + +// Based on key_type = int64_t, value_type = int64_t +template +struct default_radix_sort_block_sort_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> : kernel_config<256, 16> +{}; + +// Based on key_type = int64_t, value_type = int +template +struct default_radix_sort_block_sort_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> : kernel_config<256, 16> +{}; + +// Based on key_type = int64_t, value_type = short +template +struct default_radix_sort_block_sort_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> : kernel_config<256, 16> +{}; + +// Based on key_type = int64_t, value_type = int8_t +template +struct default_radix_sort_block_sort_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : kernel_config<256, 16> +{}; + +// Based on key_type = int64_t, value_type = empty_type +template +struct default_radix_sort_block_sort_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))>> + : kernel_config<256, 16> +{}; + +// Based on key_type = int, value_type = int64_t +template +struct default_radix_sort_block_sort_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> : kernel_config<256, 10> +{}; + +// Based on key_type = int, value_type = int +template +struct default_radix_sort_block_sort_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> : kernel_config<256, 32> +{}; + +// Based on key_type = int, value_type = short +template +struct default_radix_sort_block_sort_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> : kernel_config<256, 32> +{}; + +// Based on key_type = int, value_type = int8_t +template +struct default_radix_sort_block_sort_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : kernel_config<256, 32> +{}; + +// Based on key_type = int, value_type = empty_type +template +struct default_radix_sort_block_sort_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))>> + : kernel_config<256, 32> +{}; + +// Based on key_type = short, value_type = int64_t +template +struct default_radix_sort_block_sort_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> : kernel_config<256, 16> +{}; + +// Based on key_type = short, value_type = int +template +struct default_radix_sort_block_sort_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> : kernel_config<256, 11> +{}; + +// Based on key_type = short, value_type = short +template +struct default_radix_sort_block_sort_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> : kernel_config<512, 18> +{}; + +// Based on key_type = short, value_type = int8_t +template +struct default_radix_sort_block_sort_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : kernel_config<256, 32> +{}; + +// Based on key_type = short, value_type = empty_type +template +struct default_radix_sort_block_sort_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))>> + : kernel_config<512, 23> +{}; + +// Based on key_type = int8_t, value_type = int64_t +template +struct default_radix_sort_block_sort_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> + : kernel_config<256, 16> +{}; + +// Based on key_type = int8_t, value_type = int +template +struct default_radix_sort_block_sort_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> + : kernel_config<256, 21> +{}; + +// Based on key_type = int8_t, value_type = short +template +struct default_radix_sort_block_sort_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> + : kernel_config<256, 21> +{}; + +// Based on key_type = int8_t, value_type = int8_t +template +struct default_radix_sort_block_sort_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : kernel_config<256, 23> +{}; + +// Based on key_type = int8_t, value_type = empty_type +template +struct default_radix_sort_block_sort_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))>> + : kernel_config<512, 21> {}; } // end namespace detail diff --git a/rocprim/include/rocprim/device/detail/config/device_radix_sort_onesweep.hpp b/rocprim/include/rocprim/device/detail/config/device_radix_sort_onesweep.hpp index 617cb2999..64c6acb6d 100644 --- a/rocprim/include/rocprim/device/detail/config/device_radix_sort_onesweep.hpp +++ b/rocprim/include/rocprim/device/detail/config/device_radix_sort_onesweep.hpp @@ -3148,6 +3148,523 @@ struct default_radix_sort_onesweep_config< block_radix_rank_algorithm::match> {}; +// Based on key_type = double, value_type = int64_t +template +struct default_radix_sort_onesweep_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> + : radix_sort_onesweep_config, + kernel_config<1024, 6>, + 8, + block_radix_rank_algorithm::match> +{}; + +// Based on key_type = double, value_type = int +template +struct default_radix_sort_onesweep_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> + : radix_sort_onesweep_config, + kernel_config<1024, 6>, + 8, + block_radix_rank_algorithm::match> +{}; + +// Based on key_type = double, value_type = short +template +struct default_radix_sort_onesweep_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> + : radix_sort_onesweep_config, + kernel_config<512, 12>, + 8, + block_radix_rank_algorithm::match> +{}; + +// Based on key_type = double, value_type = int8_t +template +struct default_radix_sort_onesweep_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : radix_sort_onesweep_config, + kernel_config<1024, 6>, + 8, + block_radix_rank_algorithm::match> +{}; + +// Based on key_type = double, value_type = empty_type +template +struct default_radix_sort_onesweep_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))>> + : radix_sort_onesweep_config, + kernel_config<1024, 6>, + 8, + block_radix_rank_algorithm::match> +{}; + +// Based on key_type = float, value_type = int64_t +template +struct default_radix_sort_onesweep_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> + : radix_sort_onesweep_config, + kernel_config<1024, 6>, + 8, + block_radix_rank_algorithm::match> +{}; + +// Based on key_type = float, value_type = int +template +struct default_radix_sort_onesweep_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> + : radix_sort_onesweep_config, + kernel_config<1024, 12>, + 8, + block_radix_rank_algorithm::match> +{}; + +// Based on key_type = float, value_type = short +template +struct default_radix_sort_onesweep_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> + : radix_sort_onesweep_config, + kernel_config<1024, 12>, + 8, + block_radix_rank_algorithm::match> +{}; + +// Based on key_type = float, value_type = int8_t +template +struct default_radix_sort_onesweep_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : radix_sort_onesweep_config, + kernel_config<1024, 12>, + 8, + block_radix_rank_algorithm::match> +{}; + +// Based on key_type = float, value_type = empty_type +template +struct default_radix_sort_onesweep_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))>> + : radix_sort_onesweep_config, + kernel_config<1024, 12>, + 8, + block_radix_rank_algorithm::match> +{}; + +// Based on key_type = rocprim::half, value_type = int64_t +template +struct default_radix_sort_onesweep_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> + : radix_sort_onesweep_config, + kernel_config<1024, 6>, + 8, + block_radix_rank_algorithm::match> +{}; + +// Based on key_type = rocprim::half, value_type = int +template +struct default_radix_sort_onesweep_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> + : radix_sort_onesweep_config, + kernel_config<1024, 12>, + 8, + block_radix_rank_algorithm::match> +{}; + +// Based on key_type = rocprim::half, value_type = short +template +struct default_radix_sort_onesweep_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> + : radix_sort_onesweep_config, + kernel_config<512, 22>, + 8, + block_radix_rank_algorithm::match> +{}; + +// Based on key_type = rocprim::half, value_type = int8_t +template +struct default_radix_sort_onesweep_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : radix_sort_onesweep_config, + kernel_config<512, 22>, + 8, + block_radix_rank_algorithm::match> +{}; + +// Based on key_type = rocprim::half, value_type = empty_type +template +struct default_radix_sort_onesweep_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))>> + : radix_sort_onesweep_config, + kernel_config<512, 22>, + 8, + block_radix_rank_algorithm::match> +{}; + +// Based on key_type = int64_t, value_type = int64_t +template +struct default_radix_sort_onesweep_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> + : radix_sort_onesweep_config, + kernel_config<1024, 6>, + 8, + block_radix_rank_algorithm::match> +{}; + +// Based on key_type = int64_t, value_type = int +template +struct default_radix_sort_onesweep_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> + : radix_sort_onesweep_config, + kernel_config<1024, 6>, + 8, + block_radix_rank_algorithm::match> +{}; + +// Based on key_type = int64_t, value_type = short +template +struct default_radix_sort_onesweep_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> + : radix_sort_onesweep_config, + kernel_config<1024, 6>, + 8, + block_radix_rank_algorithm::match> +{}; + +// Based on key_type = int64_t, value_type = int8_t +template +struct default_radix_sort_onesweep_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : radix_sort_onesweep_config, + kernel_config<512, 12>, + 8, + block_radix_rank_algorithm::match> +{}; + +// Based on key_type = int64_t, value_type = empty_type +template +struct default_radix_sort_onesweep_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))>> + : radix_sort_onesweep_config, + kernel_config<1024, 6>, + 8, + block_radix_rank_algorithm::match> +{}; + +// Based on key_type = int, value_type = int64_t +template +struct default_radix_sort_onesweep_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> + : radix_sort_onesweep_config, + kernel_config<1024, 6>, + 8, + block_radix_rank_algorithm::match> +{}; + +// Based on key_type = int, value_type = int +template +struct default_radix_sort_onesweep_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> + : radix_sort_onesweep_config, + kernel_config<1024, 12>, + 8, + block_radix_rank_algorithm::match> +{}; + +// Based on key_type = int, value_type = short +template +struct default_radix_sort_onesweep_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> + : radix_sort_onesweep_config, + kernel_config<1024, 12>, + 8, + block_radix_rank_algorithm::match> +{}; + +// Based on key_type = int, value_type = int8_t +template +struct default_radix_sort_onesweep_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : radix_sort_onesweep_config, + kernel_config<1024, 12>, + 8, + block_radix_rank_algorithm::match> +{}; + +// Based on key_type = int, value_type = empty_type +template +struct default_radix_sort_onesweep_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))>> + : radix_sort_onesweep_config, + kernel_config<1024, 12>, + 8, + block_radix_rank_algorithm::match> +{}; + +// Based on key_type = short, value_type = int64_t +template +struct default_radix_sort_onesweep_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> + : radix_sort_onesweep_config, + kernel_config<1024, 6>, + 8, + block_radix_rank_algorithm::match> +{}; + +// Based on key_type = short, value_type = int +template +struct default_radix_sort_onesweep_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> + : radix_sort_onesweep_config, + kernel_config<1024, 12>, + 8, + block_radix_rank_algorithm::match> +{}; + +// Based on key_type = short, value_type = short +template +struct default_radix_sort_onesweep_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> + : radix_sort_onesweep_config, + kernel_config<512, 22>, + 8, + block_radix_rank_algorithm::match> +{}; + +// Based on key_type = short, value_type = int8_t +template +struct default_radix_sort_onesweep_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : radix_sort_onesweep_config, + kernel_config<1024, 12>, + 8, + block_radix_rank_algorithm::match> +{}; + +// Based on key_type = short, value_type = empty_type +template +struct default_radix_sort_onesweep_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))>> + : radix_sort_onesweep_config, + kernel_config<512, 22>, + 8, + block_radix_rank_algorithm::match> +{}; + +// Based on key_type = int8_t, value_type = int64_t +template +struct default_radix_sort_onesweep_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> + : radix_sort_onesweep_config, + kernel_config<1024, 6>, + 8, + block_radix_rank_algorithm::match> +{}; + +// Based on key_type = int8_t, value_type = int +template +struct default_radix_sort_onesweep_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> + : radix_sort_onesweep_config, + kernel_config<1024, 12>, + 8, + block_radix_rank_algorithm::match> +{}; + +// Based on key_type = int8_t, value_type = short +template +struct default_radix_sort_onesweep_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> + : radix_sort_onesweep_config, + kernel_config<512, 22>, + 8, + block_radix_rank_algorithm::match> +{}; + +// Based on key_type = int8_t, value_type = int8_t +template +struct default_radix_sort_onesweep_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : radix_sort_onesweep_config, + kernel_config<1024, 12>, + 8, + block_radix_rank_algorithm::match> +{}; + +// Based on key_type = int8_t, value_type = empty_type +template +struct default_radix_sort_onesweep_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))>> + : radix_sort_onesweep_config, + kernel_config<1024, 22>, + 8, + block_radix_rank_algorithm::match> +{}; + } // end namespace detail END_ROCPRIM_NAMESPACE diff --git a/rocprim/include/rocprim/device/detail/config/device_segmented_radix_sort.hpp b/rocprim/include/rocprim/device/detail/config/device_segmented_radix_sort.hpp index 3704b2f3c..57bbb0194 100644 --- a/rocprim/include/rocprim/device/detail/config/device_segmented_radix_sort.hpp +++ b/rocprim/include/rocprim/device/detail/config/device_segmented_radix_sort.hpp @@ -4631,6 +4631,663 @@ struct default_segmented_radix_sort_config< 1> {}; +// Based on key_type = double, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 8, + 8, + kernel_config<256, 16>, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 1024, 64, 8, 256>, + DisabledWarpSortConfig>::type, + 1> +{}; + +// Based on key_type = double, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 8, + 8, + kernel_config<256, 8>, + typename std::conditional<1, + WarpSortConfig<32, 2, 256, 2048, 64, 8, 256>, + DisabledWarpSortConfig>::type, + 1> +{}; + +// Based on key_type = double, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 8, + 8, + kernel_config<256, 17>, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 256, 64, 8, 256>, + DisabledWarpSortConfig>::type, + 1> +{}; + +// Based on key_type = double, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 8, + 8, + kernel_config<256, 13>, + typename std::conditional<1, + WarpSortConfig<16, 4, 256, 2048, 64, 8, 256>, + DisabledWarpSortConfig>::type, + 1> +{}; + +// Based on key_type = double, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 8, + 0, + kernel_config<256, 13>, + typename std::conditional<1, + WarpSortConfig<32, 8, 256, 4096, 32, 16, 256>, + DisabledWarpSortConfig>::type, + 1> +{}; + +// Based on key_type = float, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 8, + 8, + kernel_config<256, 16>, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 1024, 32, 16, 256>, + DisabledWarpSortConfig>::type, + 1> +{}; + +// Based on key_type = float, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 8, + 8, + kernel_config<256, 16>, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 2048, 64, 8, 256>, + DisabledWarpSortConfig>::type, + 1> +{}; + +// Based on key_type = float, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 8, + 8, + kernel_config<256, 17>, + typename std::conditional<1, + WarpSortConfig<32, 2, 256, 2048, 32, 16, 256>, + DisabledWarpSortConfig>::type, + 1> +{}; + +// Based on key_type = float, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 8, + 8, + kernel_config<256, 16>, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 128, 32, 16, 256>, + DisabledWarpSortConfig>::type, + 1> +{}; + +// Based on key_type = float, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 8, + 0, + kernel_config<256, 17>, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 4096, 64, 8, 256>, + DisabledWarpSortConfig>::type, + 1> +{}; + +// Based on key_type = rocprim::half, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 8, + 8, + kernel_config<256, 8>, + typename std::conditional<1, + WarpSortConfig<32, 2, 256, 2048, 64, 4, 256>, + DisabledWarpSortConfig>::type, + 1> +{}; + +// Based on key_type = rocprim::half, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 8, + 8, + kernel_config<256, 8>, + typename std::conditional<1, + WarpSortConfig<32, 2, 256, 4096, 64, 8, 256>, + DisabledWarpSortConfig>::type, + 1> +{}; + +// Based on key_type = rocprim::half, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 8, + 8, + kernel_config<256, 17>, + typename std::conditional<1, + WarpSortConfig<32, 2, 256, 2048, 64, 8, 256>, + DisabledWarpSortConfig>::type, + 1> +{}; + +// Based on key_type = rocprim::half, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 8, + 8, + kernel_config<256, 16>, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 1024, 64, 8, 256>, + DisabledWarpSortConfig>::type, + 1> +{}; + +// Based on key_type = rocprim::half, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 8, + 0, + kernel_config<256, 16>, + typename std::conditional<1, + WarpSortConfig<32, 8, 256, 4096, 32, 16, 256>, + DisabledWarpSortConfig>::type, + 1> +{}; + +// Based on key_type = int64_t, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 8, + 8, + kernel_config<256, 17>, + typename std::conditional<1, + WarpSortConfig<16, 4, 256, 1024, 64, 8, 256>, + DisabledWarpSortConfig>::type, + 1> +{}; + +// Based on key_type = int64_t, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 8, + 8, + kernel_config<256, 13>, + typename std::conditional<1, + WarpSortConfig<32, 2, 256, 1024, 32, 16, 256>, + DisabledWarpSortConfig>::type, + 1> +{}; + +// Based on key_type = int64_t, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 8, + 8, + kernel_config<256, 17>, + typename std::conditional<1, + WarpSortConfig<32, 2, 256, 64, 32, 16, 256>, + DisabledWarpSortConfig>::type, + 1> +{}; + +// Based on key_type = int64_t, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 8, + 8, + kernel_config<256, 13>, + typename std::conditional<1, + WarpSortConfig<32, 2, 256, 256, 64, 8, 256>, + DisabledWarpSortConfig>::type, + 1> +{}; + +// Based on key_type = int64_t, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 8, + 0, + kernel_config<256, 13>, + typename std::conditional<1, + WarpSortConfig<16, 4, 256, 1024, 64, 8, 256>, + DisabledWarpSortConfig>::type, + 1> +{}; + +// Based on key_type = int, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 8, + 8, + kernel_config<256, 16>, + typename std::conditional<1, + WarpSortConfig<16, 4, 256, 1024, 64, 8, 256>, + DisabledWarpSortConfig>::type, + 1> +{}; + +// Based on key_type = int, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 8, + 8, + kernel_config<256, 17>, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 2048, 64, 8, 256>, + DisabledWarpSortConfig>::type, + 1> +{}; + +// Based on key_type = int, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 8, + 8, + kernel_config<256, 17>, + typename std::conditional<1, + WarpSortConfig<16, 4, 256, 256, 32, 16, 256>, + DisabledWarpSortConfig>::type, + 1> +{}; + +// Based on key_type = int, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 8, + 8, + kernel_config<256, 16>, + typename std::conditional<1, + WarpSortConfig<16, 4, 256, 256, 32, 16, 256>, + DisabledWarpSortConfig>::type, + 1> +{}; + +// Based on key_type = int, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 8, + 0, + kernel_config<256, 16>, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 1024, 32, 16, 256>, + DisabledWarpSortConfig>::type, + 1> +{}; + +// Based on key_type = short, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 8, + 8, + kernel_config<256, 8>, + typename std::conditional<1, + WarpSortConfig<32, 2, 256, 4096, 32, 8, 256>, + DisabledWarpSortConfig>::type, + 1> +{}; + +// Based on key_type = short, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 8, + 8, + kernel_config<256, 17>, + typename std::conditional<1, + WarpSortConfig<16, 4, 256, 2048, 64, 8, 256>, + DisabledWarpSortConfig>::type, + 1> +{}; + +// Based on key_type = short, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 8, + 8, + kernel_config<256, 17>, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 2048, 64, 8, 256>, + DisabledWarpSortConfig>::type, + 1> +{}; + +// Based on key_type = short, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 8, + 8, + kernel_config<256, 17>, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 1024, 64, 8, 256>, + DisabledWarpSortConfig>::type, + 1> +{}; + +// Based on key_type = short, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 8, + 0, + kernel_config<256, 17>, + typename std::conditional<1, + WarpSortConfig<16, 4, 256, 4096, 32, 16, 256>, + DisabledWarpSortConfig>::type, + 1> +{}; + +// Based on key_type = int8_t, value_type = int64_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> + : segmented_radix_sort_config< + 8, + 8, + kernel_config<256, 8>, + typename std::conditional<1, + WarpSortConfig<32, 2, 256, 2048, 16, 4, 256>, + DisabledWarpSortConfig>::type, + 1> +{}; + +// Based on key_type = int8_t, value_type = int +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> + : segmented_radix_sort_config< + 8, + 8, + kernel_config<256, 8>, + typename std::conditional<1, + WarpSortConfig<16, 4, 256, 4096, 64, 4, 256>, + DisabledWarpSortConfig>::type, + 1> +{}; + +// Based on key_type = int8_t, value_type = short +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> + : segmented_radix_sort_config< + 8, + 8, + kernel_config<256, 8>, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 4096, 64, 8, 256>, + DisabledWarpSortConfig>::type, + 1> +{}; + +// Based on key_type = int8_t, value_type = int8_t +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : segmented_radix_sort_config< + 8, + 8, + kernel_config<256, 8>, + typename std::conditional<1, + WarpSortConfig<16, 4, 256, 4096, 64, 8, 256>, + DisabledWarpSortConfig>::type, + 1> +{}; + +// Based on key_type = int8_t, value_type = empty_type +template +struct default_segmented_radix_sort_config< + static_cast(target_arch::gfx942), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))>> + : segmented_radix_sort_config< + 8, + 0, + kernel_config<256, 17>, + typename std::conditional<1, + WarpSortConfig<8, 8, 256, 4096, 32, 16, 256>, + DisabledWarpSortConfig>::type, + 1> +{}; + } // end namespace detail END_ROCPRIM_NAMESPACE diff --git a/rocprim/include/rocprim/device/detail/device_config_helper.hpp b/rocprim/include/rocprim/device/detail/device_config_helper.hpp index 095f4f50f..cadaed26d 100644 --- a/rocprim/include/rocprim/device/detail/device_config_helper.hpp +++ b/rocprim/include/rocprim/device/detail/device_config_helper.hpp @@ -470,6 +470,7 @@ struct segmented_radix_sort_config_params /// \brief Number of bits in long iterations. unsigned int long_radix_bits = 0; /// \brief Number of bits in short iterations. + /// \deprecated The short radix bits parameter is no longer used and will be removed in a future version. unsigned int short_radix_bits = 0; /// \brief If set to \p true, warp sort can be used to sort the small segments, even if no partitioning happens. bool enable_unpartitioned_warp_sort = true; @@ -567,6 +568,7 @@ struct DisabledWarpSortConfig /// /// \tparam LongRadixBits - number of bits in long iterations. /// \tparam ShortRadixBits - number of bits in short iterations, must be equal to or less than `LongRadixBits`. +/// Deprecated and no longer used. /// \tparam SortConfig - configuration of radix sort kernel. Must be `kernel_config`. /// \tparam WarpSortConfig - configuration of the warp sort that is used on the short segments. template -ROCPRIM_DEVICE ROCPRIM_INLINE void sort_block(SortType sorter, - SortKey (&keys)[ItemsPerThread], - SortValue (&values)[ItemsPerThread], - typename SortType::storage_type& storage, - Decomposer decomposer, - unsigned int begin_bit, - unsigned int end_bit) +ROCPRIM_DEVICE ROCPRIM_INLINE +void sort_warp_striped_to_striped(SortType sorter, + SortKey (&keys)[ItemsPerThread], + SortValue (&values)[ItemsPerThread], + typename SortType::storage_type& storage, + Decomposer decomposer, + unsigned int begin_bit, + unsigned int end_bit) { - if(Descending) + if ROCPRIM_IF_CONSTEXPR(Descending) { - sorter.sort_desc(keys, values, storage, begin_bit, end_bit, decomposer); + sorter.sort_desc_warp_striped_to_striped(keys, + values, + storage, + begin_bit, + end_bit, + decomposer); } else { - sorter.sort(keys, values, storage, begin_bit, end_bit, decomposer); + sorter.sort_warp_striped_to_striped(keys, values, storage, begin_bit, end_bit, decomposer); } } @@ -77,22 +83,23 @@ template -ROCPRIM_DEVICE ROCPRIM_INLINE void sort_block(SortType sorter, - SortKey (&keys)[ItemsPerThread], - ::rocprim::empty_type (&values)[ItemsPerThread], - typename SortType::storage_type& storage, - Decomposer decomposer, - unsigned int begin_bit, - unsigned int end_bit) +ROCPRIM_DEVICE ROCPRIM_INLINE +void sort_warp_striped_to_striped(SortType sorter, + SortKey (&keys)[ItemsPerThread], + ::rocprim::empty_type (&values)[ItemsPerThread], + typename SortType::storage_type& storage, + Decomposer decomposer, + unsigned int begin_bit, + unsigned int end_bit) { (void) values; - if(Descending) + if ROCPRIM_IF_CONSTEXPR(Descending) { - sorter.sort_desc(keys, storage, begin_bit, end_bit, decomposer); + sorter.sort_desc_warp_striped_to_striped(keys, storage, begin_bit, end_bit, decomposer); } else { - sorter.sort(keys, storage, begin_bit, end_bit, decomposer); + sorter.sort_warp_striped_to_striped(keys, storage, begin_bit, end_bit, decomposer); } } @@ -105,19 +112,27 @@ template< > struct radix_digit_count_helper { - static constexpr unsigned int radix_size = 1 << RadixBits; - + static constexpr unsigned int radix_size = 1 << RadixBits; static constexpr unsigned int warp_size = WarpSize; - static constexpr unsigned int warps_no = BlockSize / warp_size; + static constexpr unsigned int atomic_stripes = 4; + static constexpr unsigned int counters = radix_size * atomic_stripes; + ROCPRIM_DETAIL_DEVICE_STATIC_ASSERT(BlockSize % ::rocprim::device_warp_size() == 0, "BlockSize must be divisible by warp size"); static_assert(radix_size <= BlockSize, "Radix size must not exceed BlockSize"); struct storage_type { - unsigned int digit_counts[warps_no][radix_size]; + unsigned int digit_counters[counters]; }; + ROCPRIM_DEVICE ROCPRIM_INLINE + unsigned int& + get_counter(const unsigned stripe, const unsigned int digit, storage_type& storage) + { + return storage.digit_counters[digit * atomic_stripes + stripe]; + } + template< bool IsFull = false, class KeysInputIterator, @@ -140,15 +155,20 @@ struct radix_digit_count_helper using bit_key_type = typename key_codec::bit_key_type; const unsigned int flat_id = ::rocprim::detail::block_thread_id<0>(); - const unsigned int warp_id = ::rocprim::warp_id<0, 1, 1>(); + const unsigned int stripe = flat_id % atomic_stripes; - if(flat_id < radix_size) + constexpr bool block_size_divides_counters_nicely = counters % BlockSize == 0; + + ROCPRIM_UNROLL + for(unsigned int i = 0; i < counters; i += BlockSize) { - for(unsigned int w = 0; w < warps_no; w++) + const unsigned int offset = i + flat_id; + if(block_size_divides_counters_nicely || offset < counters) { - storage.digit_counts[w][flat_id] = 0; + storage.digit_counters[offset] = 0; } } + ::rocprim::syncthreads(); for(Offset block_offset = begin_offset; block_offset < end_offset; block_offset += items_per_block) @@ -168,32 +188,30 @@ struct radix_digit_count_helper block_load_direct_striped(flat_id, keys_input + block_offset, keys, valid_count); } + ROCPRIM_UNROLL for(unsigned int i = 0; i < ItemsPerThread; i++) { const bit_key_type bit_key = key_codec::encode(keys[i]); const unsigned int digit = key_codec::extract_digit(bit_key, bit, current_radix_bits); const unsigned int pos = i * BlockSize + flat_id; - lane_mask_type same_digit_lanes_mask - = ::rocprim::match_any(digit, IsFull || (pos < valid_count)); - - if(::rocprim::group_elect(same_digit_lanes_mask)) + if(IsFull || pos < valid_count) { - // Write the number of lanes having this digit, - // if the current lane is the first (and maybe only) lane with this digit. - storage.digit_counts[warp_id][digit] - += ::rocprim::bit_count(same_digit_lanes_mask); + atomic_add(&get_counter(stripe, digit, storage), 1); } } } + ::rocprim::syncthreads(); digit_count = 0; if(flat_id < radix_size) { - for(unsigned int w = 0; w < warps_no; w++) + // Sum counters from all stripes + ROCPRIM_UNROLL + for(unsigned int stripe = 0; stripe < atomic_stripes; ++stripe) { - digit_count += storage.digit_counts[w][flat_id]; + digit_count += get_counter(stripe, flat_id, storage); } } } @@ -251,249 +269,306 @@ struct radix_sort_single_helper value_type values[ItemsPerThread]; if(!is_incomplete_block) { - block_load_direct_blocked(flat_id, keys_input + block_offset, keys); + block_load_direct_warp_striped(flat_id, keys_input + block_offset, keys); if ROCPRIM_IF_CONSTEXPR(with_values) { - block_load_direct_blocked(flat_id, values_input + block_offset, values); + block_load_direct_warp_striped(flat_id, values_input + block_offset, values); } } else { const key_type out_of_bounds = key_codec::get_out_of_bounds_key(decomposer); - block_load_direct_blocked(flat_id, - keys_input + block_offset, - keys, - valid_in_last_block, - out_of_bounds); + block_load_direct_warp_striped(flat_id, + keys_input + block_offset, + keys, + valid_in_last_block, + out_of_bounds); if ROCPRIM_IF_CONSTEXPR(with_values) { - block_load_direct_blocked(flat_id, - values_input + block_offset, - values, - valid_in_last_block); + block_load_direct_warp_striped(flat_id, + values_input + block_offset, + values, + valid_in_last_block); } } - sort_block(sort_type(), - keys, - values, - storage.sort, - decomposer, - bit, - bit + current_radix_bits); + sort_warp_striped_to_striped(sort_type(), + keys, + values, + storage.sort, + decomposer, + bit, + bit + current_radix_bits); // Store keys and values if(!is_incomplete_block) { - block_store_direct_blocked(flat_id, keys_output + block_offset, keys); + block_store_direct_striped(flat_id, keys_output + block_offset, keys); if ROCPRIM_IF_CONSTEXPR(with_values) { - block_store_direct_blocked(flat_id, values_output + block_offset, values); + block_store_direct_striped(flat_id, + values_output + block_offset, + values); } } else { - block_store_direct_blocked(flat_id, - keys_output + block_offset, - keys, - valid_in_last_block); + block_store_direct_striped(flat_id, + keys_output + block_offset, + keys, + valid_in_last_block); if ROCPRIM_IF_CONSTEXPR(with_values) { - block_store_direct_blocked(flat_id, - values_output + block_offset, - values, - valid_in_last_block); + block_store_direct_striped(flat_id, + values_output + block_offset, + values, + valid_in_last_block); } } } }; -template< - unsigned int BlockSize, - unsigned int ItemsPerThread, - unsigned int RadixBits, - bool Descending, - class Key, - class Value, - class Offset -> +template struct radix_sort_and_scatter_helper { static constexpr unsigned int items_per_block = BlockSize * ItemsPerThread; static constexpr unsigned int radix_size = 1 << RadixBits; + static constexpr unsigned int digits_per_thread = 1; + static constexpr bool with_values = !std::is_same::value; - using key_type = Key; - using value_type = Value; + using key_codec = radix_key_codec; + using radix_rank_type = ::rocprim::block_radix_rank; - using key_codec = ::rocprim::radix_key_codec; - using bit_key_type = typename key_codec::bit_key_type; - using keys_load_type = ::rocprim::block_load< - key_type, BlockSize, ItemsPerThread, - ::rocprim::block_load_method::block_load_transpose>; - using values_load_type = ::rocprim::block_load< - value_type, BlockSize, ItemsPerThread, - ::rocprim::block_load_method::block_load_transpose>; - using sort_type = ::rocprim::block_radix_sort; - using discontinuity_type = ::rocprim::block_discontinuity; - using bit_keys_exchange_type = ::rocprim::block_exchange; - using values_exchange_type = ::rocprim::block_exchange; + static constexpr bool load_warp_striped + = RadixRankAlgorithm == block_radix_rank_algorithm::match; - static constexpr bool with_values = !std::is_same::value; + static_assert(radix_size <= BlockSize, "Radix size must not exceed BlockSize"); - struct storage_type + struct storage_type_ { + Offset digit_offsets[radix_size]; union { - typename keys_load_type::storage_type keys_load; - typename values_load_type::storage_type values_load; - typename sort_type::storage_type sort; - typename discontinuity_type::storage_type discontinuity; - typename bit_keys_exchange_type::storage_type bit_keys_exchange; - typename values_exchange_type::storage_type values_exchange; - }; + typename radix_rank_type::storage_type rank; - unsigned short starts[radix_size]; - unsigned short ends[radix_size]; - - Offset digit_starts[radix_size]; + Key ordered_tile_keys[items_per_block]; + Value ordered_tile_values[items_per_block]; + }; }; - template< - bool IsFull = false, - class KeysInputIterator, - class KeysOutputIterator, - class ValuesInputIterator, - class ValuesOutputIterator - > + ROCPRIM_DETAIL_SUPPRESS_DEPRECATION_WITH_PUSH + using storage_type = detail::raw_storage; + ROCPRIM_DETAIL_SUPPRESS_DEPRECATION_POP + + template ROCPRIM_DEVICE ROCPRIM_INLINE - void sort_and_scatter(KeysInputIterator keys_input, - KeysOutputIterator keys_output, - ValuesInputIterator values_input, + void sort_and_scatter(KeysInputIterator keys_input, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, ValuesOutputIterator values_output, - Offset begin_offset, - Offset end_offset, - unsigned int bit, - unsigned int current_radix_bits, - Offset digit_start, // i-th thread must pass i-th digit's value - storage_type& storage) + Offset begin_offset, + Offset end_offset, + unsigned int bit, + unsigned int current_radix_bits, + Offset digit_start, // i-th thread must pass i-th digit's value + storage_type& storage_) { + auto& storage = storage_.get(); const unsigned int flat_id = ::rocprim::detail::block_thread_id<0>(); if(flat_id < radix_size) { - storage.digit_starts[flat_id] = digit_start; + storage.digit_offsets[flat_id] = digit_start; } for(Offset block_offset = begin_offset; block_offset < end_offset; block_offset += items_per_block) { - key_type keys[ItemsPerThread]; - value_type values[ItemsPerThread]; - unsigned int valid_count; + Key keys[ItemsPerThread]; + + unsigned int valid_items; if(IsFull || (block_offset + items_per_block <= end_offset)) { - valid_count = items_per_block; - keys_load_type().load(keys_input + block_offset, keys, storage.keys_load); - if(with_values) + valid_items = items_per_block; + if ROCPRIM_IF_CONSTEXPR(load_warp_striped) { - ::rocprim::syncthreads(); - values_load_type().load(values_input + block_offset, values, storage.values_load); + block_load_direct_warp_striped(flat_id, keys_input + block_offset, keys); + } + else + { + block_load_direct_blocked(flat_id, keys_input + block_offset, keys); } } else { - valid_count = end_offset - block_offset; - // Sort will leave "invalid" (out of size) items at the end of the sorted sequence - const key_type out_of_bounds = key_codec::decode(bit_key_type(-1)); - keys_load_type().load(keys_input + block_offset, keys, valid_count, out_of_bounds, storage.keys_load); - if(with_values) + valid_items = end_offset - block_offset; + // Fill the out-of-bounds elements of the key array with the key value with + // the largest digit. This will make sure they are sorted (ranked) last, and + // thus will be omitted when we compare the item offset against `valid_items` later. + // Note that this will lead to an incorrect digit count. Since this is the very last digit, + // it does not matter. It does cause the final digit offset to be increased past its end, + // but again this does not matter since this is the last iteration in which it will be used anyway. + const Key out_of_bounds = key_codec::get_out_of_bounds_key(); + if ROCPRIM_IF_CONSTEXPR(load_warp_striped) + { + block_load_direct_warp_striped(flat_id, + keys_input + block_offset, + keys, + valid_items, + out_of_bounds); + } + else { - ::rocprim::syncthreads(); - values_load_type().load(values_input + block_offset, values, valid_count, storage.values_load); + block_load_direct_blocked(flat_id, + keys_input + block_offset, + keys, + valid_items, + out_of_bounds); } } - if(flat_id < radix_size) + ROCPRIM_UNROLL + for(unsigned int i = 0; i < ItemsPerThread; ++i) { - storage.starts[flat_id] = valid_count; - storage.ends[flat_id] = valid_count; + key_codec::encode_inplace(keys[i]); } + unsigned int ranks[ItemsPerThread]; + unsigned int exclusive_digit_prefix[digits_per_thread]; + unsigned int digit_counts[digits_per_thread]; + + radix_rank_type{}.rank_keys( + keys, + ranks, + storage.rank, + [bit, current_radix_bits](const Key& key) + { return key_codec::extract_digit(key, bit, current_radix_bits); }, + exclusive_digit_prefix, + digit_counts); + ::rocprim::syncthreads(); - sort_block(sort_type(), - keys, - values, - storage.sort, - identity_decomposer{}, - bit, - bit + current_radix_bits); - - bit_key_type bit_keys[ItemsPerThread]; - unsigned int digits[ItemsPerThread]; - for(unsigned int i = 0; i < ItemsPerThread; i++) + + // Subtract the exclusive digit prefix from the digit offsets since we're ordering + // the keys in shared memory already. + if(flat_id < radix_size) { - bit_keys[i] = key_codec::encode(keys[i]); - digits[i] = key_codec::extract_digit(bit_keys[i], bit, current_radix_bits); + storage.digit_offsets[flat_id] -= exclusive_digit_prefix[0]; } - bool head_flags[ItemsPerThread]; - bool tail_flags[ItemsPerThread]; - ::rocprim::not_equal_to flag_op; + // Order keys in shared memory. + ROCPRIM_UNROLL + for(unsigned int i = 0; i < ItemsPerThread; ++i) + { + storage.ordered_tile_keys[ranks[i]] = keys[i]; + } ::rocprim::syncthreads(); - discontinuity_type().flag_heads_and_tails(head_flags, tail_flags, digits, flag_op, storage.discontinuity); - // Fill start and end position of subsequence for every digit - for(unsigned int i = 0; i < ItemsPerThread; i++) + ROCPRIM_UNROLL + for(unsigned int i = 0; i < ItemsPerThread; ++i) { - const unsigned int digit = digits[i]; - const unsigned int pos = flat_id * ItemsPerThread + i; - if(head_flags[i]) - { - storage.starts[digit] = pos; - } - if(tail_flags[i]) + const unsigned int rank = i * BlockSize + flat_id; + if(IsFull || rank < valid_items) { - storage.ends[digit] = pos; + Key key = storage.ordered_tile_keys[rank]; + const unsigned int digit + = key_codec::extract_digit(key, bit, current_radix_bits); + key_codec::decode_inplace(key); + const Offset global_offset = storage.digit_offsets[digit]; + keys_output[rank + global_offset] = key; } } - ::rocprim::syncthreads(); - // Rearrange to striped arrangement to have faster coalesced writes instead of - // scattering of blocked-arranged items - bit_keys_exchange_type().blocked_to_striped(bit_keys, bit_keys, storage.bit_keys_exchange); - if(with_values) + // Gather and scatter values if necessary + if ROCPRIM_IF_CONSTEXPR(with_values) { + Value values[ItemsPerThread]; + if ROCPRIM_IF_CONSTEXPR(IsFull) + { + if ROCPRIM_IF_CONSTEXPR(load_warp_striped) + { + block_load_direct_warp_striped(flat_id, + values_input + block_offset, + values); + } + else + { + block_load_direct_blocked(flat_id, values_input + block_offset, values); + } + } + else + { + if ROCPRIM_IF_CONSTEXPR(load_warp_striped) + { + block_load_direct_warp_striped(flat_id, + values_input + block_offset, + values, + valid_items); + } + else + { + block_load_direct_blocked(flat_id, + values_input + block_offset, + values, + valid_items); + } + } + + // Compute digits up-front so that we can re-use shared memory between ordered_tile_keys and + // ordered_tile_values. + unsigned int digits[ItemsPerThread]; + ROCPRIM_UNROLL + for(unsigned int i = 0; i < ItemsPerThread; ++i) + { + const unsigned int rank = i * BlockSize + flat_id; + if(IsFull || rank < valid_items) + { + const Key key = storage.ordered_tile_keys[rank]; + digits[i] = key_codec::extract_digit(key, bit, current_radix_bits); + } + } + ::rocprim::syncthreads(); - values_exchange_type().blocked_to_striped(values, values, storage.values_exchange); - } - for(unsigned int i = 0; i < ItemsPerThread; i++) - { - const unsigned int digit = key_codec::extract_digit(bit_keys[i], bit, current_radix_bits); - const unsigned int pos = i * BlockSize + flat_id; - if(IsFull || (pos < valid_count)) + ROCPRIM_UNROLL + for(unsigned int i = 0; i < ItemsPerThread; ++i) { - const Offset dst = pos - storage.starts[digit] + storage.digit_starts[digit]; - keys_output[dst] = key_codec::decode(bit_keys[i]); - if(with_values) + storage.ordered_tile_values[ranks[i]] = values[i]; + } + + ::rocprim::syncthreads(); + + // And scatter the values to global memory. + ROCPRIM_UNROLL + for(unsigned int i = 0; i < ItemsPerThread; ++i) + { + const unsigned int rank = i * BlockSize + flat_id; + if(IsFull || rank < valid_items) { - values_output[dst] = values[i]; + const Value value = storage.ordered_tile_values[rank]; + const Offset global_offset = storage.digit_offsets[digits[i]]; + values_output[rank + global_offset] = value; } } } ::rocprim::syncthreads(); - // Accumulate counts of the current block + // Update the digit offsets if(flat_id < radix_size) { - const unsigned int digit = flat_id; - const unsigned int start = storage.starts[digit]; - const unsigned int end = storage.ends[digit]; - if(start < valid_count) - { - storage.digit_starts[digit] += (::rocprim::min(valid_count - 1, end) - start + 1); - } + storage.digit_offsets[flat_id] + += exclusive_digit_prefix[flat_id] + digit_counts[flat_id]; } } } diff --git a/rocprim/include/rocprim/device/detail/device_segmented_radix_sort.hpp b/rocprim/include/rocprim/device/detail/device_segmented_radix_sort.hpp index ba15c954b..4ab54613a 100644 --- a/rocprim/include/rocprim/device/detail/device_segmented_radix_sort.hpp +++ b/rocprim/include/rocprim/device/detail/device_segmented_radix_sort.hpp @@ -36,8 +36,8 @@ #include "../../block/block_store.hpp" #include "../../block/block_scan.hpp" +#include "../../warp/detail/warp_sort_stable.hpp" #include "../../warp/warp_load.hpp" -#include "../../warp/warp_sort.hpp" #include "../../warp/warp_store.hpp" #include "../../thread/radix_key_codec.hpp" @@ -76,8 +76,8 @@ class segmented_radix_sort_helper union storage_type { - typename segmented_radix_sort_helper::count_helper_type::storage_type count_helper; - typename segmented_radix_sort_helper::sort_and_scatter_helper::storage_type sort_and_scatter_helper; + typename count_helper_type::storage_type count_helper; + typename sort_and_scatter_helper::storage_type sort_and_scatter_helper; }; template< @@ -274,57 +274,43 @@ template< > class segmented_radix_sort_single_block_helper { - using key_type = Key; - using value_type = Value; - - using key_codec = radix_key_codec; + using key_codec = radix_key_codec; using bit_key_type = typename key_codec::bit_key_type; - using keys_load_type = ::rocprim::block_load< - key_type, BlockSize, ItemsPerThread, - ::rocprim::block_load_method::block_load_transpose>; - using values_load_type = ::rocprim::block_load< - value_type, BlockSize, ItemsPerThread, - ::rocprim::block_load_method::block_load_transpose>; - using sort_type = ::rocprim::block_radix_sort; - using keys_store_type = ::rocprim::block_store< - key_type, BlockSize, ItemsPerThread, - ::rocprim::block_store_method::block_store_transpose>; - using values_store_type = ::rocprim::block_store< - value_type, BlockSize, ItemsPerThread, - ::rocprim::block_store_method::block_store_transpose>; - - static constexpr bool with_values = !std::is_same::value; + using sort_type = ::rocprim::block_radix_sort; + + static constexpr bool with_values = !std::is_same::value; public: union storage_type { - typename keys_load_type::storage_type keys_load; - typename values_load_type::storage_type values_load; typename sort_type::storage_type sort; - typename keys_store_type::storage_type keys_store; - typename values_store_type::storage_type values_store; }; - template< - class KeysInputIterator, - class KeysOutputIterator, - class ValuesInputIterator, - class ValuesOutputIterator - > + template ROCPRIM_DEVICE ROCPRIM_INLINE - void sort(KeysInputIterator keys_input, - key_type * keys_tmp, - KeysOutputIterator keys_output, - ValuesInputIterator values_input, - value_type * values_tmp, + void sort(KeysInputIterator keys_input, + Key* keys_tmp, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + Value* values_tmp, ValuesOutputIterator values_output, - bool to_output, - unsigned int begin_offset, - unsigned int end_offset, - unsigned int begin_bit, - unsigned int end_bit, - storage_type& storage) + bool to_output, + unsigned int begin_offset, + unsigned int end_offset, + unsigned int begin_bit, + unsigned int end_bit, + storage_type& storage) { if(to_output) { @@ -348,17 +334,17 @@ class segmented_radix_sort_single_block_helper // When all iterators are raw pointers, this overload is used to minimize code duplication in the kernel ROCPRIM_DEVICE ROCPRIM_INLINE - void sort(key_type * keys_input, - key_type * keys_tmp, - key_type * keys_output, - value_type * values_input, - value_type * values_tmp, - value_type * values_output, - bool to_output, - unsigned int begin_offset, - unsigned int end_offset, - unsigned int begin_bit, - unsigned int end_bit, + void sort(Key* keys_input, + Key* keys_tmp, + Key* keys_output, + Value* values_input, + Value* values_tmp, + Value* values_output, + bool to_output, + unsigned int begin_offset, + unsigned int end_offset, + unsigned int begin_bit, + unsigned int end_bit, storage_type& storage) { sort( @@ -388,10 +374,12 @@ class segmented_radix_sort_single_block_helper { constexpr unsigned int items_per_block = BlockSize * ItemsPerThread; - using shorter_single_block_helper = segmented_radix_sort_single_block_helper< - key_type, value_type, - BlockSize, ItemsPerThread / 2, Descending - >; + using shorter_single_block_helper + = segmented_radix_sort_single_block_helper; // Segment is longer than supported by this function if(end_offset - begin_offset > items_per_block) @@ -399,7 +387,7 @@ class segmented_radix_sort_single_block_helper return false; } - // Recursively chech if it is possible to sort the segment using fewer items per thread + // Recursively check if it is possible to sort the segment using fewer items per thread const bool processed_by_shorter = shorter_single_block_helper().sort( keys_input, keys_output, values_input, values_output, @@ -412,33 +400,46 @@ class segmented_radix_sort_single_block_helper return true; } - key_type keys[ItemsPerThread]; - value_type values[ItemsPerThread]; + const unsigned int flat_id = ::rocprim::flat_block_thread_id(); + + Key keys[ItemsPerThread]; + Value values[ItemsPerThread]; const unsigned int valid_count = end_offset - begin_offset; // Sort will leave "invalid" (out of size) items at the end of the sorted sequence - const key_type out_of_bounds = key_codec::decode(bit_key_type(-1)); - keys_load_type().load(keys_input + begin_offset, keys, valid_count, out_of_bounds, storage.keys_load); + const Key out_of_bounds = key_codec::get_out_of_bounds_key(); + block_load_direct_warp_striped(flat_id, + keys_input + begin_offset, + keys, + valid_count, + out_of_bounds); if(with_values) { - ::rocprim::syncthreads(); - values_load_type().load(values_input + begin_offset, values, valid_count, storage.values_load); + block_load_direct_warp_striped(flat_id, + values_input + begin_offset, + values, + valid_count); } ::rocprim::syncthreads(); - sort_block(sort_type(), - keys, - values, - storage.sort, - identity_decomposer{}, - begin_bit, - end_bit); + sort_warp_striped_to_striped(sort_type(), + keys, + values, + storage.sort, + identity_decomposer{}, + begin_bit, + end_bit); ::rocprim::syncthreads(); - keys_store_type().store(keys_output + begin_offset, keys, valid_count, storage.keys_store); - if(with_values) + block_store_direct_striped(flat_id, + keys_output + begin_offset, + keys, + valid_count); + if ROCPRIM_IF_CONSTEXPR(with_values) { - ::rocprim::syncthreads(); - values_store_type().store(values_output + begin_offset, values, valid_count, storage.values_store); + block_store_direct_striped(flat_id, + values_output + begin_offset, + values, + valid_count); } return true; @@ -504,13 +505,7 @@ using select_warp_sort_helper_config_t WarpSortHelperConfig, DisabledWarpSortHelperConfig>; -template< - class Config, - class Key, - class Value, - bool Descending, - class Enable = void -> +template struct segmented_warp_sort_helper { static constexpr unsigned int items_per_warp = 0; @@ -523,44 +518,32 @@ struct segmented_warp_sort_helper } }; -template +template class segmented_warp_sort_helper< Config, Key, Value, + BlockSize, Descending, std::enable_if_t::value>> { static constexpr unsigned int logical_warp_size = Config::logical_warp_size; static constexpr unsigned int items_per_thread = Config::items_per_thread; - using key_type = Key; - using value_type = Value; - using key_codec = ::rocprim::radix_key_codec; + using key_codec = ::rocprim::radix_key_codec; using bit_key_type = typename key_codec::bit_key_type; - using keys_load_type = ::rocprim::warp_load; - using values_load_type = ::rocprim::warp_load; - using keys_store_type = ::rocprim::warp_store; - using values_store_type = ::rocprim::warp_store; + using keys_load_type = ::rocprim::warp_load; + using values_load_type = ::rocprim::warp_load; + using keys_store_type = ::rocprim::warp_store; + using values_store_type = ::rocprim::warp_store; template - using radix_comparator_type = ::rocprim::detail::radix_merge_compare; - using stable_key_type = ::rocprim::tuple; - using sort_type = ::rocprim::warp_sort; + using radix_comparator_type + = ::rocprim::detail::radix_merge_compare; + using sort_type = ::rocprim::detail:: + warp_sort_stable; - static constexpr bool with_values = !std::is_same::value; - - template - ROCPRIM_DEVICE ROCPRIM_INLINE - decltype(auto) make_stable_comparator(ComparatorT comparator) - { - return [comparator](const stable_key_type& a, const stable_key_type& b) -> bool - { - const bool ab = comparator(rocprim::get<0>(a), rocprim::get<0>(b)); - const bool ba = comparator(rocprim::get<0>(b), rocprim::get<0>(a)); - return ab || (!ba && (rocprim::get<1>(a) < rocprim::get<1>(b))); - }; - } + static constexpr bool with_values = !std::is_same::value; public: static constexpr unsigned int items_per_warp = items_per_thread * logical_warp_size; @@ -575,41 +558,58 @@ class segmented_warp_sort_helper< }; private: + template + ROCPRIM_DEVICE + auto invoke_warp_sort(Key (&keys)[items_per_thread], + Value (&values)[items_per_thread], + storage_type& storage, + F comparator) + -> std::enable_if_t::value> + { + (void)values; + sort_type().sort(keys, storage.sort, comparator); + } + + template + ROCPRIM_DEVICE + auto invoke_warp_sort(Key (&keys)[items_per_thread], + Value (&values)[items_per_thread], + storage_type& storage, + F comparator) + -> std::enable_if_t::value> + { + sort_type().sort(keys, values, storage.sort, comparator); + } + template - ROCPRIM_DEVICE auto invoke_warp_sort(stable_key_type (&stable_keys)[items_per_thread], - value_type (&values)[items_per_thread], - storage_type& storage, - unsigned int begin_bit, - unsigned int end_bit) - -> std::enable_if_t::value> + ROCPRIM_DEVICE + auto invoke_warp_sort(Key (&keys)[items_per_thread], + Value (&values)[items_per_thread], + storage_type& storage, + unsigned int begin_bit, + unsigned int end_bit) -> std::enable_if_t::value> { (void)begin_bit; (void)end_bit; - sort_type().sort(stable_keys, - values, - storage.sort, - make_stable_comparator(radix_comparator_type{})); + invoke_warp_sort(keys, values, storage, radix_comparator_type{}); } template - ROCPRIM_DEVICE auto invoke_warp_sort(stable_key_type (&stable_keys)[items_per_thread], - value_type (&values)[items_per_thread], - storage_type& storage, - unsigned int begin_bit, - unsigned int end_bit) - -> std::enable_if_t::value> + ROCPRIM_DEVICE + auto invoke_warp_sort(Key (&keys)[items_per_thread], + Value (&values)[items_per_thread], + storage_type& storage, + unsigned int begin_bit, + unsigned int end_bit) -> std::enable_if_t::value> { - if(begin_bit == 0 && end_bit == 8 * sizeof(key_type)) + if(begin_bit == 0 && end_bit == 8 * sizeof(Key)) { - sort_type().sort(stable_keys, - values, - storage.sort, - make_stable_comparator(radix_comparator_type{})); + invoke_warp_sort(keys, values, storage, radix_comparator_type{}); } else { radix_comparator_type comparator(begin_bit, end_bit - begin_bit); - sort_type().sort(stable_keys, values, storage.sort, make_stable_comparator(comparator)); + invoke_warp_sort(keys, values, storage, comparator); } } @@ -632,64 +632,48 @@ class segmented_warp_sort_helper< storage_type& storage) { const unsigned int num_items = end_offset - begin_offset; - const key_type out_of_bounds = key_codec::decode(bit_key_type(-1)); + const Key out_of_bounds = key_codec::get_out_of_bounds_key(); - key_type keys[items_per_thread]; - stable_key_type stable_keys[items_per_thread]; - value_type values[items_per_thread]; + Key keys[items_per_thread]; + Value values[items_per_thread]; keys_load_type().load(keys_input + begin_offset, keys, num_items, out_of_bounds, storage.keys_load); - ROCPRIM_UNROLL - for(unsigned int i = 0; i < items_per_thread; i++) - { - ::rocprim::get<0>(stable_keys[i]) = keys[i]; - ::rocprim::get<1>(stable_keys[i]) = - ::rocprim::detail::logical_lane_id() + logical_warp_size * i; - } - - if(with_values) + if ROCPRIM_IF_CONSTEXPR(with_values) { ::rocprim::wave_barrier(); values_load_type().load(values_input + begin_offset, values, num_items, storage.values_load); } ::rocprim::wave_barrier(); - invoke_warp_sort(stable_keys, values, storage, begin_bit, end_bit); + invoke_warp_sort(keys, values, storage, begin_bit, end_bit); - ROCPRIM_UNROLL - for(unsigned int i = 0; i < items_per_thread; i++) - { - keys[i] = ::rocprim::get<0>(stable_keys[i]); - } ::rocprim::wave_barrier(); keys_store_type().store(keys_output + begin_offset, keys, num_items, storage.keys_store); - if(with_values) + if ROCPRIM_IF_CONSTEXPR(with_values) { ::rocprim::wave_barrier(); values_store_type().store(values_output + begin_offset, values, num_items, storage.values_store); } } - template< - class KeysInputIterator, - class KeysOutputIterator, - class ValuesInputIterator, - class ValuesOutputIterator - > + template ROCPRIM_DEVICE ROCPRIM_INLINE - void sort(KeysInputIterator keys_input, - key_type * keys_tmp, - KeysOutputIterator keys_output, - ValuesInputIterator values_input, - value_type * values_tmp, + void sort(KeysInputIterator keys_input, + Key* keys_tmp, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + Value* values_tmp, ValuesOutputIterator values_output, - bool to_output, - unsigned int begin_offset, - unsigned int end_offset, - unsigned int begin_bit, - unsigned int end_bit, - storage_type& storage) + bool to_output, + unsigned int begin_offset, + unsigned int end_offset, + unsigned int begin_bit, + unsigned int end_bit, + storage_type& storage) { if(to_output) { @@ -731,15 +715,13 @@ void segmented_sort(KeysInputIterator keys_input, bool to_output, OffsetIterator begin_offsets, OffsetIterator end_offsets, - unsigned int long_iterations, - unsigned int short_iterations, + unsigned int iterations, unsigned int begin_bit, unsigned int end_bit) { static constexpr segmented_radix_sort_config_params params = device_params(); static constexpr unsigned int long_radix_bits = params.long_radix_bits; - static constexpr unsigned int short_radix_bits = params.short_radix_bits; static constexpr unsigned int block_size = params.kernel_config.block_size; static constexpr unsigned int items_per_thread = params.kernel_config.items_per_thread; static constexpr unsigned int items_per_block = block_size * items_per_thread; @@ -753,31 +735,29 @@ void segmented_sort(KeysInputIterator keys_input, block_size, items_per_thread, Descending >; - using long_radix_helper_type = segmented_radix_sort_helper< - key_type, value_type, - ::rocprim::device_warp_size(), block_size, items_per_thread, - long_radix_bits, Descending - >; - using short_radix_helper_type = segmented_radix_sort_helper< - key_type, value_type, - ::rocprim::device_warp_size(), block_size, items_per_thread, - short_radix_bits, Descending - >; - using warp_sort_helper_type = segmented_warp_sort_helper< + using long_radix_helper_type = segmented_radix_sort_helper; + using warp_sort_helper_type = segmented_warp_sort_helper< select_warp_sort_helper_config_t, key_type, value_type, + block_size, Descending>; + static constexpr unsigned int items_per_warp = warp_sort_helper_type::items_per_warp; ROCPRIM_SHARED_MEMORY union { typename single_block_helper_type::storage_type single_block_helper; - typename long_radix_helper_type::storage_type long_radix_helper; - typename short_radix_helper_type::storage_type short_radix_helper; + typename long_radix_helper_type::storage_type long_radix_helper; typename warp_sort_helper_type::storage_type warp_sort_helper; } storage; @@ -795,8 +775,7 @@ void segmented_sort(KeysInputIterator keys_input, if(end_offset - begin_offset > items_per_block) { // Large segment - unsigned int bit = begin_bit; - for(unsigned int i = 0; i < long_iterations; i++) + for(unsigned int bit = begin_bit; bit < end_bit; bit += long_radix_bits) { long_radix_helper_type().sort( keys_input, keys_tmp, keys_output, values_input, values_tmp, values_output, @@ -807,43 +786,39 @@ void segmented_sort(KeysInputIterator keys_input, ); to_output = !to_output; - bit += long_radix_bits; - } - for(unsigned int i = 0; i < short_iterations; i++) - { - short_radix_helper_type().sort( - keys_input, keys_tmp, keys_output, values_input, values_tmp, values_output, - to_output, - begin_offset, end_offset, - bit, begin_bit, end_bit, - storage.short_radix_helper - ); - - to_output = !to_output; - bit += short_radix_bits; } } else if(!warp_sort_enabled || end_offset - begin_offset > items_per_warp) { // Small segment - single_block_helper_type().sort( - keys_input, keys_tmp, keys_output, values_input, values_tmp, values_output, - ((long_iterations + short_iterations) % 2 == 0) != to_output, - begin_offset, end_offset, - begin_bit, end_bit, - storage.single_block_helper - ); + single_block_helper_type().sort(keys_input, + keys_tmp, + keys_output, + values_input, + values_tmp, + values_output, + (iterations % 2 == 0) != to_output, + begin_offset, + end_offset, + begin_bit, + end_bit, + storage.single_block_helper); } else if(::rocprim::flat_block_thread_id() < params.warp_sort_config.logical_warp_size_small) { // Single warp segment - warp_sort_helper_type().sort( - keys_input, keys_tmp, keys_output, - values_input, values_tmp, values_output, - ((long_iterations + short_iterations) % 2 == 0) != to_output, - begin_offset, end_offset, - begin_bit, end_bit, storage.warp_sort_helper - ); + warp_sort_helper_type().sort(keys_input, + keys_tmp, + keys_output, + values_input, + values_tmp, + values_output, + (iterations % 2 == 0) != to_output, + begin_offset, + end_offset, + begin_bit, + end_bit, + storage.warp_sort_helper); } } @@ -868,15 +843,13 @@ void segmented_sort_large(KeysInputIterator keys_input, SegmentIndexIterator segment_indices, OffsetIterator begin_offsets, OffsetIterator end_offsets, - unsigned int long_iterations, - unsigned int short_iterations, + unsigned int iterations, unsigned int begin_bit, unsigned int end_bit) { static constexpr segmented_radix_sort_config_params params = device_params(); static constexpr unsigned int long_radix_bits = params.long_radix_bits; - static constexpr unsigned int short_radix_bits = params.short_radix_bits; static constexpr unsigned int block_size = params.kernel_config.block_size; static constexpr unsigned int items_per_thread = params.kernel_config.items_per_thread; static constexpr unsigned int items_per_block = block_size * items_per_thread; @@ -889,22 +862,18 @@ void segmented_sort_large(KeysInputIterator keys_input, block_size, items_per_thread, Descending >; - using long_radix_helper_type = segmented_radix_sort_helper< - key_type, value_type, - ::rocprim::device_warp_size(), block_size, items_per_thread, - long_radix_bits, Descending - >; - using short_radix_helper_type = segmented_radix_sort_helper< - key_type, value_type, - ::rocprim::device_warp_size(), block_size, items_per_thread, - short_radix_bits, Descending - >; + using long_radix_helper_type = segmented_radix_sort_helper; ROCPRIM_SHARED_MEMORY union { typename single_block_helper_type::storage_type single_block_helper; - typename long_radix_helper_type::storage_type long_radix_helper; - typename short_radix_helper_type::storage_type short_radix_helper; + typename long_radix_helper_type::storage_type long_radix_helper; } storage; const unsigned int block_id = ::rocprim::detail::block_id<0>(); @@ -919,8 +888,7 @@ void segmented_sort_large(KeysInputIterator keys_input, if(end_offset - begin_offset > items_per_block) { - unsigned int bit = begin_bit; - for(unsigned int i = 0; i < long_iterations; i++) + for(unsigned int bit = begin_bit; bit < end_bit; bit += long_radix_bits) { long_radix_helper_type().sort( keys_input, keys_tmp, keys_output, values_input, values_tmp, values_output, @@ -931,31 +899,22 @@ void segmented_sort_large(KeysInputIterator keys_input, ); to_output = !to_output; - bit += long_radix_bits; - } - for(unsigned int i = 0; i < short_iterations; i++) - { - short_radix_helper_type().sort( - keys_input, keys_tmp, keys_output, values_input, values_tmp, values_output, - to_output, - begin_offset, end_offset, - bit, begin_bit, end_bit, - storage.short_radix_helper - ); - - to_output = !to_output; - bit += short_radix_bits; } } else { - single_block_helper_type().sort( - keys_input, keys_tmp, keys_output, values_input, values_tmp, values_output, - ((long_iterations + short_iterations) % 2 == 0) != to_output, - begin_offset, end_offset, - begin_bit, end_bit, - storage.single_block_helper - ); + single_block_helper_type().sort(keys_input, + keys_tmp, + keys_output, + values_input, + values_tmp, + values_output, + (iterations % 2 == 0) != to_output, + begin_offset, + end_offset, + begin_bit, + end_bit, + storage.single_block_helper); } } @@ -1003,6 +962,7 @@ void segmented_sort_small(KeysInputIterator keys_input, params.warp_sort_config.block_size_small>, key_type, value_type, + block_size, Descending>; ROCPRIM_SHARED_MEMORY typename warp_sort_helper_type::storage_type storage; @@ -1018,10 +978,12 @@ void segmented_sort_small(KeysInputIterator keys_input, const unsigned int segment_id = segment_indices[segment_index]; const unsigned int begin_offset = begin_offsets[segment_id]; const unsigned int end_offset = end_offsets[segment_id]; + if(end_offset <= begin_offset) { return; } + warp_sort_helper_type().sort(keys_input, keys_tmp, keys_output, @@ -1078,6 +1040,7 @@ ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE void segmented_sort_medium( params.warp_sort_config.block_size_medium>, key_type, value_type, + block_size, Descending>; ROCPRIM_SHARED_MEMORY typename warp_sort_helper_type::storage_type storage; @@ -1097,6 +1060,7 @@ ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE void segmented_sort_medium( { return; } + warp_sort_helper_type().sort( keys_input, keys_tmp, keys_output, values_input, values_tmp, values_output, diff --git a/rocprim/include/rocprim/device/device_segmented_radix_sort.hpp b/rocprim/include/rocprim/device/device_segmented_radix_sort.hpp index 2dc2f06fc..1da25f003 100644 --- a/rocprim/include/rocprim/device/device_segmented_radix_sort.hpp +++ b/rocprim/include/rocprim/device/device_segmented_radix_sort.hpp @@ -59,28 +59,33 @@ template ROCPRIM_KERNEL - __launch_bounds__(device_params().kernel_config.block_size) void segmented_sort_kernel( - KeysInputIterator keys_input, - typename std::iterator_traits::value_type* keys_tmp, - KeysOutputIterator keys_output, - ValuesInputIterator values_input, - typename std::iterator_traits::value_type* values_tmp, - ValuesOutputIterator values_output, - bool to_output, - OffsetIterator begin_offsets, - OffsetIterator end_offsets, - unsigned int long_iterations, - unsigned int short_iterations, - unsigned int begin_bit, - unsigned int end_bit) + __launch_bounds__(device_params().kernel_config.block_size) +void segmented_sort_kernel( + KeysInputIterator keys_input, + typename std::iterator_traits::value_type* keys_tmp, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + typename std::iterator_traits::value_type* values_tmp, + ValuesOutputIterator values_output, + bool to_output, + OffsetIterator begin_offsets, + OffsetIterator end_offsets, + unsigned int iterations, + unsigned int begin_bit, + unsigned int end_bit) { - segmented_sort( - keys_input, keys_tmp, keys_output, values_input, values_tmp, values_output, - to_output, - begin_offsets, end_offsets, - long_iterations, short_iterations, - begin_bit, end_bit - ); + segmented_sort(keys_input, + keys_tmp, + keys_output, + values_input, + values_tmp, + values_output, + to_output, + begin_offsets, + end_offsets, + iterations, + begin_bit, + end_bit); } template() .kernel_config - .block_size) void segmented_sort_large_kernel(KeysInputIterator keys_input, - typename std::iterator_traits< - KeysInputIterator>::value_type* keys_tmp, - KeysOutputIterator keys_output, - ValuesInputIterator values_input, - typename std::iterator_traits< - ValuesInputIterator>::value_type* - values_tmp, - ValuesOutputIterator values_output, - bool to_output, - SegmentIndexIterator segment_indices, - OffsetIterator begin_offsets, - OffsetIterator end_offsets, - unsigned int long_iterations, - unsigned int short_iterations, - unsigned int begin_bit, - unsigned int end_bit) + .block_size) +void segmented_sort_large_kernel( + KeysInputIterator keys_input, + typename std::iterator_traits::value_type* keys_tmp, + KeysOutputIterator keys_output, + ValuesInputIterator values_input, + typename std::iterator_traits::value_type* values_tmp, + ValuesOutputIterator values_output, + bool to_output, + SegmentIndexIterator segment_indices, + OffsetIterator begin_offsets, + OffsetIterator end_offsets, + unsigned int iterations, + unsigned int begin_bit, + unsigned int end_bit) { - segmented_sort_large( - keys_input, keys_tmp, keys_output, values_input, values_tmp, values_output, - to_output, segment_indices, - begin_offsets, end_offsets, - long_iterations, short_iterations, - begin_bit, end_bit - ); + segmented_sort_large(keys_input, + keys_tmp, + keys_output, + values_input, + values_tmp, + values_output, + to_output, + segment_indices, + begin_offsets, + end_offsets, + iterations, + begin_bit, + end_bit); } template= params.warp_sort_config.partitioning_threshold; @@ -443,11 +445,8 @@ hipError_t segmented_radix_sort_impl(void * temporary_storage, std::cout << "end_bit " << end_bit << '\n'; std::cout << "bits " << bits << '\n'; std::cout << "segments " << segments << '\n'; - std::cout << "radix_bits_diff " << radix_bits_diff << '\n'; std::cout << "storage_size " << storage_size << '\n'; std::cout << "iterations " << iterations << '\n'; - std::cout << "long_iterations " << long_iterations << '\n'; - std::cout << "short_iterations " << short_iterations << '\n'; std::cout << "do_partitioning " << do_partitioning << '\n'; std::cout << "params.kernel_config.block_size: " << params.kernel_config.block_size << '\n'; std::cout << "params.kernel_config.items_per_thread: " @@ -523,8 +522,7 @@ hipError_t segmented_radix_sort_impl(void * temporary_storage, large_segment_indices_output, begin_offsets, end_offsets, - long_iterations, - short_iterations, + iterations, begin_bit, end_bit); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("segmented_sort:large_segments", @@ -607,8 +605,7 @@ hipError_t segmented_radix_sort_impl(void * temporary_storage, to_output, begin_offsets, end_offsets, - long_iterations, - short_iterations, + iterations, begin_bit, end_bit); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("segmented_sort", segments, start) diff --git a/rocprim/include/rocprim/rocprim.hpp b/rocprim/include/rocprim/rocprim.hpp index 52a6a0ae8..843a3d172 100644 --- a/rocprim/include/rocprim/rocprim.hpp +++ b/rocprim/include/rocprim/rocprim.hpp @@ -57,6 +57,7 @@ #include "block/block_scan.hpp" #include "block/block_sort.hpp" #include "block/block_store.hpp" +#include "block/config.hpp" #include "device/device_adjacent_difference.hpp" #include "device/device_binary_search.hpp" diff --git a/rocprim/include/rocprim/warp/detail/warp_sort_stable.hpp b/rocprim/include/rocprim/warp/detail/warp_sort_stable.hpp index 96ff48356..726fbcfe9 100644 --- a/rocprim/include/rocprim/warp/detail/warp_sort_stable.hpp +++ b/rocprim/include/rocprim/warp/detail/warp_sort_stable.hpp @@ -30,6 +30,8 @@ #include "../../functional.hpp" #include "../../intrinsics.hpp" +#include "../../detail/merge_path.hpp" + BEGIN_ROCPRIM_NAMESPACE namespace detail diff --git a/scripts/autotune-search/.gitignore b/scripts/autotune-search/.gitignore new file mode 100644 index 000000000..44898dfcf --- /dev/null +++ b/scripts/autotune-search/.gitignore @@ -0,0 +1 @@ +artifacts \ No newline at end of file diff --git a/scripts/autotune-search/main.py b/scripts/autotune-search/main.py new file mode 100755 index 000000000..9bbc4c19e --- /dev/null +++ b/scripts/autotune-search/main.py @@ -0,0 +1,387 @@ +#!/usr/bin/env python3 + +from typing import Union, List +import argparse +import glob +import itertools +import json +import logging +import multiprocessing as mp +import os +import pathos.multiprocessing as pamp +import rich.logging +import rich.progress +import scipy.optimize +import shutil +import subprocess + +parameter_spaces = { + "device_segmented_radix_sort_keys": { + "benchmark": "benchmark_device_segmented_radix_sort_keys", + "types": { + "KeyType": [ + "int64_t", + "int", + "short", + "int8_t", + "double", + "float", + "rocprim::half", + ], + }, + "params": { + "LongBits": [6, 7, 8], + "BlockSize": [256], + "ItemsPerThread": [7, 8, 13, 16, 17], + "WarpSmallLWS": [8, 16, 32, 64], + "WarpSmallIPT": [2, 4, 8], + "WarpSmallBS": [256], + "WarpPartition": [4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096], + "WarpMediumLWS": [16, 32, 64], + "WarpMediumIPT": [4, 8, 16], + "WarpMediumBS": [256], + }, + }, + "device_segmented_radix_sort_pairs": { + "benchmark": "benchmark_device_segmented_radix_sort_pairs", + "types": { + "KeyType": [ + "int64_t", + "int", + "short", + "int8_t", + "double", + "float", + "rocprim::half", + ], + "ValueType": [ + "int64_t", + "int", + "short", + "int8_t", + ], + }, + "params": { + "LongBits": [6, 7, 8], + "BlockSize": [256], + "ItemsPerThread": [7, 8, 13, 16, 17], + "WarpSmallLWS": [8, 16, 32, 64], + "WarpSmallIPT": [2, 4, 8], + "WarpSmallBS": [256], + "WarpPartition": [4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096], + "WarpMediumLWS": [16, 32, 64], + "WarpMediumIPT": [4, 8, 16], + "WarpMediumBS": [256], + }, + }, +} + +def get_result_from_json(filename: os.PathLike) -> Union[float, int]: + ''' + Get the result from the benchmark json. + ''' + with open(filename, 'r') as file: + data = json.load(file) + try: + return float(data['benchmarks'][0]['bytes_per_second']) + except Exception as e: + log.error('Could not extract \'bytes_per_second\' from JSON!') + raise e + +def merge_jsons(source_filenames: List[os.PathLike], target_filename: os.PathLike) -> None: + ''' + Merges benchmark JSONs. This is used to collect the singular results generated from the various + benchmark runs. + ''' + merged = { + 'context': {}, + 'benchmarks': [], + } + + # collect jsons + for filename in source_filenames: + with open(filename, 'r') as file: + try: + data = json.load(file) + except json.decoder.JSONDecodeError as e: + log.warning(f'Skipping file \'{filename}\' because of error: {e}') + + # HACK: we reuse the last context since we can only have one + merged['context'] = data['context'] + # append benchmark data + merged['benchmarks'].extend(data['benchmarks']) + + # write out file + with open(target_filename, 'w') as file: + json.dump(merged, file, indent=2) + +def combine(alg_name: str, arch: str): + ''' + combine() + ''' + script_dir = os.path.dirname(os.path.realpath(__file__)) + result_dir = os.path.join(script_dir, 'artifacts') + + alg_space = parameter_spaces[alg_name] + build_target = alg_space['benchmark'] + + merge_jsons( + glob.glob(os.path.join(result_dir, f'{arch}_{build_target}_*.json')), + os.path.join(result_dir, f'{arch}_{build_target}.json'), + ) + +def tune_alg(alg_name: str, arch: str, max_samples: int, num_workers: int, size: int, trials: int) -> None: + ''' + The core tuning procedure. This tunes a single algorithm for multiple types. + ''' + + # get the context of the tuning run + alg_space = parameter_spaces[alg_name] + build_target = alg_space['benchmark'] + + # types to tune, this can be a product of multiple types + types = [ + dict(zip(alg_space['types'], ts)) + for ts in itertools.product( + *[alg_space['types'][type] for type in alg_space['types']] + ) + ] + + # generate bounds by normalizing the parameter space from discrete to real numbers (relaxation) + bounds = dict(zip(alg_space['params'], ((0, 1) for _ in alg_space['params']))) + + # define a utility function to access parameters in 'alg_space' + def param_from_normalized(name: str, value: float) -> str: + ''' + Internal which maps a continious named parameter in [0; 1] to it's discrete value. + ''' + + # get the list of discrete values + params = alg_space['params'][name] + + # get the index, make sure we're not out-of-bounds when value is 1.0 + index = min(int(value * len(params)), len(params) - 1) + + try: + return str(params[index]) + except IndexError as e: + log.error( + f"Could not find parameter '{name}' at '{index}' derived from value '{value}' in {params}." + ) + raise e + + def tune_type(type: str) -> None: + cache = {} + + def sample(xs: List[float]) -> Union[float, int]: + # each worker should get their own build dir + build_dir = os.path.join(source_dir, f'build/tune-{worker_id}') + + # delete *.parallel folder + try: + # HACK: we just delete the benchmark folder because it's easier + shutil.rmtree(os.path.join(build_dir, 'benchmark')) + except FileNotFoundError: + # if the tree doesn't exist we don't have to remove it :) + pass + + tune_param_names = list(type.keys()) + list(alg_space['params']) + tune_param_vals = list(type.values()) + [ + param_from_normalized(name, value) + for (name, value) in zip(alg_space['params'].keys(), xs) + ] + + result_id = '_'.join(tune_param_vals) + if result_id in cache: + log.info(f'[{worker_id}] Skipped already computed result!') + return cache[result_id] + + result_filename = f'{arch}_{build_target}_{result_id}.json' + + tune_param_names = ';'.join(tune_param_names) + tune_param_vals = ';'.join(tune_param_vals) + + # CMake configure + log.info(f'[{worker_id}] Configuring: {result_id}') + configure = subprocess.call( + [ + 'cmake', + '-S', + '.', + '-B', + build_dir, + '-GNinja', + '-DCMAKE_CXX_COMPILER=/opt/rocm/llvm/bin/clang++', + '-DBUILD_BENCHMARK=ON', + '-DBENCHMARK_CONFIG_TUNING=ON', + f'-DAMDGPU_TARGETS={arch}', + f'-DBENCHMARK_TUNE_PARAM_NAMES={tune_param_names}', + f'-DBENCHMARK_TUNE_PARAMS={tune_param_vals}', + ], + cwd=source_dir, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + + if configure != 0: + cache[result_id] = 0.0 + return 0.0 + + # Build target + log.info(f'[{worker_id}] Building: {result_id}') + build = subprocess.call( + ['cmake', '--build', '.', '--target', build_target], + cwd=build_dir, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + + if build != 0: + cache[result_id] = 0.0 + result_context = { + 'config': dict( + ( + (name, param_from_normalized(name, val)) + for name, val in zip(alg_space['params'], xs) + ) + ) + } + log.debug(json.dumps(result_context, indent=2)) + return 0.0 + + # Run benchmark + gpu_lock.acquire() + try: + log.info(f'[{worker_id}] Benchmarking: {result_id}') + bench = subprocess.call( + [ + os.path.join(build_dir, 'benchmark', build_target), + '--name_format', + 'json', + '--seed', + 'random', # Random is better... I think? Otherwise we might overfit. + '--size', + f'{size}', + '--trials', + f'{trials}', + '--benchmark_out_format=json', + f'--benchmark_out={result_filename}', + ], + cwd=result_dir, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + timeout=120, + ) + + if bench != 0: + cache[result_id] = 0.0 + return 0.0 + result_value = get_result_from_json( + os.path.join(result_dir, result_filename) + ) + log.info(f'[{worker_id}] Completed: {result_id} @ {result_value / 1e9:.3f} GB/s') + finally: + gpu_lock.release() + + result_context = { + 'config': dict( + ( + (name, param_from_normalized(name, val)) + for name, val in zip(alg_space['params'], xs) + ) + ), + 'bytes_per_second': result_value, + } + log.debug(json.dumps(result_context, indent=2)) + + cache[result_id] = -result_value + + # scipy.optimize does minimization, negate result for maximize + return -result_value + + # Dual annealing is very good for tuning. See: + # - 'Benchmarking optimization algorithms for auto-tuning GPU kernels' by Schoonhoven et al, 2022. + # - 'A methodology for comparing optimization algorithms for auto-tuning' by Willemsen et al, 2024. + scipy.optimize.dual_annealing( + sample, bounds=bounds.values(), maxfun=max_samples + ) + + script_dir = os.path.dirname(os.path.realpath(__file__)) + source_dir = os.path.join(script_dir, '../..') + result_dir = os.path.join(script_dir, 'artifacts') + + os.makedirs(result_dir, exist_ok=True) + + def pool_init(worker_ids, lock): + global worker_id, gpu_lock + gpu_lock = lock + worker_id = worker_ids.get(False) + + man = mp.Manager() + + # create queue to distrubute worker ids + worker_ids = man.Queue(num_workers) + for i in range(num_workers): + worker_ids.put(i) + + # 'pathos' is needed to pickle local scopes to workers + with pamp.Pool( + processes=num_workers, + initializer=pool_init, + initargs=[worker_ids, mp.Lock()], + ) as pool: + pool.map(tune_type, types) + + # We're done with tuning this entire algorithm, collect them into a single file! + merge_jsons( + glob.glob(os.path.join(result_dir, f'{arch}_{build_target}_*.json')), + os.path.join(result_dir, f'{arch}_{build_target}.json'), + ) + +parser = argparse.ArgumentParser( + prog='autotune-search', + description='config tuning using local search', +) + +parser.add_argument('targets',metavar='TARGETS', nargs='*', help='target(s) to optimize, seperated by comma') +parser.add_argument('-a', '--arch', default='gfx942', help='architecture to target, e.g. gfx908') +parser.add_argument('-n', '--evals', default=200, help='maximum number of configs being evaluated per type per target') +parser.add_argument('-v', '--verbose', action='store_true', help='verbose output') +parser.add_argument('-w', '--workers', default=8, help='number of workers') +parser.add_argument('-s', '--size', default=33554432, help='input size to use for tuning') +parser.add_argument('-t', '--trials', default=3, help='number of trials per config to test') +parser.add_argument('-c', '--combine', action='store_true', help='skip tuning and combine the results of a previous run for the given targets and architecture') +parser.add_argument('-l', '--list', action='store_true', help='list available targets') + +args = parser.parse_args() + +if not args.targets: + args.targets = list(parameter_spaces.keys()) + +if args.list: + for target in parameter_spaces.keys(): + print(target) + quit() + +log_level = logging.INFO +if args.verbose: + log_level = logging.DEBUG + +logging.basicConfig(format='%(message)s', handlers=[rich.logging.RichHandler(rich_tracebacks=True, markup=True)], level=log_level) +log = logging.getLogger('rich') + +if args.combine: + for target in args.targets: + combine(alg_name=target, arch=args.arch) + quit() + +for target in args.targets: + log.info(f'Tuning {target} for {args.arch} with {int(args.evals)} max evaluations') + tune_alg( + alg_name=target, + arch=args.arch, + max_samples=int(args.evals), + num_workers=int(args.workers), + size=int(args.size), + trials=int(args.trials) + ) diff --git a/scripts/autotune-search/requirements.txt b/scripts/autotune-search/requirements.txt new file mode 100644 index 000000000..3cef92bbd --- /dev/null +++ b/scripts/autotune-search/requirements.txt @@ -0,0 +1,3 @@ +scipy +pathos +rich diff --git a/scripts/autotune/create_optimization.py b/scripts/autotune/create_optimization.py index 2bebf1fd5..933e8f08d 100755 --- a/scripts/autotune/create_optimization.py +++ b/scripts/autotune/create_optimization.py @@ -41,7 +41,7 @@ from typing import Dict, List, Callable, Optional, Tuple from jinja2 import Environment, PackageLoader, select_autoescape -TARGET_ARCHITECTURES = ['gfx803', 'gfx900', 'gfx906', 'gfx908', 'gfx90a', 'gfx1030', 'gfx1100', 'gfx1102'] +TARGET_ARCHITECTURES = ['gfx803', 'gfx900', 'gfx906', 'gfx908', 'gfx90a', 'gfx942', 'gfx1030', 'gfx1100', 'gfx1102'] # C++ typename used for optional types EMPTY_TYPENAME = "empty_type"