From 400883412d56067a331b881a29e5f906d4cb4e70 Mon Sep 17 00:00:00 2001 From: Beatriz Navidad Vilches <61422851+Beanavil@users.noreply.github.com> Date: Wed, 20 Nov 2024 20:04:53 +0100 Subject: [PATCH] Develop stream 2024-10-29 (#631) * remove HIP-CPU support * Resolve: IssueMove ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR to seperate header file * rebase and add RETURN_ON_ERROR to the header * Added naive implementation for adjacent_find plus tests and benchmarks * Improved benchmark by only taking into account relevant processed elements * Use a faster reduction operation * Added block-reduction kernel with early exit * Improved test with random first pair * Get grid_size for maximum occupancy * Improved test coverage * Implement early exit with sequential blocks execution * Use a dynamic tile_id as in find_first_of for faster stable results * Added documentation for adjacent_find * Added tuning for adjacent_find * Modified tuning so that non-arithmetic types use default configs * Changed initialization mechanism of kernel's output element * Fixed tests from review comments - Simplified adjacent_find_impl functor definition - Added test for indirect_iterator * Simplified input transform logic * Added tuned configs * Removed duplicated ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR * Resolve "Refactor benchmarks to use a byte-based size" * Added a rocprim::numeric_limits to support uint128 and int128 and changed all std::numeric_limits to test_utils::numeric_limits * Create generate_limit to ensure floating point custom types are handled correctly. * Add rocprim::numeric_limits to numeric_limits_custom_test_type * Expected output fix block_radix_sort test for custom_test_type and custom_test_type * Docs fix numeric_limits * Added numeric_limits to changelog * Added a rocprim::uint128_t and rocprim::int128_t * Implemented find_end with tests and benchmark * Updated find_end benchmark with generate_limits * Added different input pattern for benchmark and added multiple items per thread * Added different key_size to tests for find_end * Added shared memory kernel for find_end * Changed find_end to search with reverse iterator * Added tests for different compare function * Change benchmark to no longer early exit and choosing shared mem kernel as config variable * Extra check search kernel to prevent unnessary global search * Documentation for find_end * Changed find_end to make it easier to create search * Fix docs errors find_end * Changes for reviews find_end * Fix rebasing issues find_end * Added find_end to rocprim header * Fix build error after adding headers * Use byte-based size in benchmark * Remove double defines * Added search function with tests and benchmark * Fix documentation find_end and search * Add device_search to rocprim.hpp header * add device_ptr usility Authored-By: Cenxuan Tian * replace high_resolution_clock with steady_clock Authorized-By: Cenxuan Tian * properly namespace ROCPRIM_RETURN_ON_ERROR * Set c++ version to 17 and create warning * Fix no_discard warning c++17 * Set CI tests to c++14 * Build for both c++ 14 and 17 * Add large sizes test to device_radix_sort * Added more test coverage segmented_radix_sort * fix not working with const_iterators * fix: use bytes instead of size for scan tuning benchmarks * Resolve "Partial sort optimization: make use of radix sort" * doc: address the upper bound restrictions on Channels for device_histogram * doc: explicitly state that ActiveChannels is bounded by Channels * batch memcpy tests with random seed * follow clang format * add newline at the end * make rocprim::reverse_iterator align with that of std * minor change * add constexpr * adjust format * add warnings * adjust format * change the way of triggering warnings * adjust format * minor change * adjust format * clear warnings * adjust format * correct warning behaviours * adjust format * adjust format * update changelog and fix warning issue * fix ambiguous issue * move a CHANGELOG entry to Deprecations section * feat: add support for predicated flagged device select * feat: add tests (with large indices) for predicated flagged device select * feat: add config tuning and benchmarks for predicated and flagged device select * fix: add missing template parameter to partition-based autotune templates * Add tuned configs * Fix clang-format hang * Fix ambiguous error make_reverse_iterator * Resolve "Config tuning and dynamic dispatch for device merge" * add search_n algo * add test * Add google test for search_n & tested the functionality * Add benchmark * Add Doc & add custom type for benchmark * Remove unused variables * Add NonBlockStream support * Remove unused type alias * Refactor search_n for loop, &dit comments & * Add More tests & Fixed some bugs * Add more benckmarks * Add document * Refactor benchmarks * Replace another DOXYGEN_DOCUMENTATION_BUILD and some minor modifications * Fix build debug error * Optimize algo with large input * add impl2 * Optimize * Move hipMalloc vars to temp_memory * Rewrite benchmarks * Resolve * Fix bugs -- several occurrences of consecutive full blocks * Many modifications, fixed the bugs and edited the tests and benchmarks * Optimised the block_search_n_kernel * 2nd version search_n implementation for large input * Add thread level search_n algorithm * Add optimizations * Edit benchmarks * remove unused variables * remove unused variables and remove __restrict__ * fix the bug on windows * fix bug and modify benchmakrs and tests * fix bugs in benchmarks and search_n_impl * Oh yes * Apply 1 suggestion(s) to 1 file(s) Co-authored-by: Beatriz Navidad Vilches * apply some suggestions * edit doc * replace search_n_min_kernal by rocprim:reduce * fixed some benchmarks bugs * remove graph support * resolve not compile on win * Add graph support and modified the design a little * resolve test fail on windws * fix gfx960 benchmark dead lock * Add device_search_n to rocprim.hpp * replace HIP_CHECK by ROCPRIM_RETURN_ON_ERROR * fix: fix doxygen error due to __launch_bounds__ macro * Implement 6.3 hotfixes for added/modified tests * Workaround CI memory usage limit * Reduce memory usage even more --------- Co-authored-by: Robin Voetter Co-authored-by: Cenxuan Tian Co-authored-by: Milo Lurati Co-authored-by: Nick Breed Co-authored-by: Bence Parajdi Co-authored-by: Yung-sheng Tu --- .gitlab-ci.yml | 11 + CHANGELOG.md | 11 + CMakeLists.txt | 77 +- README.md | 7 +- benchmark/CMakeLists.txt | 26 +- benchmark/ConfigAutotuneSettings.cmake | 8 + .../benchmark_block_adjacent_difference.cpp | 30 +- benchmark/benchmark_block_discontinuity.cpp | 23 +- benchmark/benchmark_block_exchange.cpp | 29 +- benchmark/benchmark_block_histogram.cpp | 20 +- benchmark/benchmark_block_radix_rank.cpp | 22 +- benchmark/benchmark_block_radix_sort.cpp | 21 +- benchmark/benchmark_block_reduce.cpp | 22 +- .../benchmark_block_run_length_decode.cpp | 18 +- benchmark/benchmark_block_scan.cpp | 24 +- benchmark/benchmark_block_sort.cpp | 10 +- benchmark/benchmark_block_sort.parallel.hpp | 5 +- benchmark/benchmark_config_dispatch.cpp | 6 +- ...rk_device_adjacent_difference.parallel.hpp | 6 +- benchmark/benchmark_device_adjacent_find.cpp | 145 + ...hmark_device_adjacent_find.parallel.cpp.in | 32 + ...enchmark_device_adjacent_find.parallel.hpp | 248 ++ benchmark/benchmark_device_binary_search.cpp | 20 +- benchmark/benchmark_device_find_end.cpp | 131 + benchmark/benchmark_device_find_end.hpp | 199 ++ benchmark/benchmark_device_histogram.cpp | 57 +- benchmark/benchmark_device_merge.cpp | 331 +-- .../benchmark_device_merge.parallel.cpp.in | 32 + benchmark/benchmark_device_merge.parallel.hpp | 422 +++ benchmark/benchmark_device_merge_sort.cpp | 12 +- benchmark/benchmark_device_merge_sort.hpp | 12 +- ...enchmark_device_merge_sort_block_merge.cpp | 12 +- ...device_merge_sort_block_merge.parallel.hpp | 48 +- ...benchmark_device_merge_sort_block_sort.cpp | 12 +- ..._device_merge_sort_block_sort.parallel.hpp | 12 +- benchmark/benchmark_device_nth_element.cpp | 12 +- benchmark/benchmark_device_nth_element.hpp | 4 +- benchmark/benchmark_device_partial_sort.cpp | 12 +- benchmark/benchmark_device_partial_sort.hpp | 6 +- .../benchmark_device_partial_sort_copy.cpp | 12 +- .../benchmark_device_partial_sort_copy.hpp | 4 +- benchmark/benchmark_device_partition.cpp | 22 +- .../benchmark_device_partition.parallel.hpp | 85 +- benchmark/benchmark_device_radix_sort.cpp | 14 +- benchmark/benchmark_device_radix_sort.hpp | 16 +- ...benchmark_device_radix_sort_block_sort.cpp | 14 +- ..._device_radix_sort_block_sort.parallel.hpp | 10 +- .../benchmark_device_radix_sort_onesweep.cpp | 16 +- ...rk_device_radix_sort_onesweep.parallel.hpp | 16 +- benchmark/benchmark_device_reduce.cpp | 14 +- .../benchmark_device_reduce.parallel.hpp | 5 +- .../benchmark_device_run_length_encode.cpp | 30 +- benchmark/benchmark_device_scan.cpp | 14 +- benchmark/benchmark_device_scan.parallel.hpp | 5 +- benchmark/benchmark_device_scan_by_key.cpp | 14 +- .../benchmark_device_scan_by_key.parallel.hpp | 5 +- ...hmark_device_scan_by_key_deterministic.cpp | 12 +- .../benchmark_device_scan_deterministic.cpp | 12 +- benchmark/benchmark_device_search.cpp | 131 + benchmark/benchmark_device_search.hpp | 199 ++ benchmark/benchmark_device_search_n.cpp | 77 + .../benchmark_device_search_n.parallel.cpp.in | 33 + .../benchmark_device_search_n.parallel.hpp | 431 +++ ...hmark_device_segmented_radix_sort_keys.cpp | 32 +- ...ice_segmented_radix_sort_keys.parallel.hpp | 5 +- ...mark_device_segmented_radix_sort_pairs.cpp | 44 +- ...ce_segmented_radix_sort_pairs.parallel.hpp | 5 +- .../benchmark_device_segmented_reduce.cpp | 21 +- benchmark/benchmark_device_select.cpp | 40 +- .../benchmark_device_select.parallel.hpp | 234 +- benchmark/benchmark_device_transform.cpp | 14 +- .../benchmark_device_transform.parallel.hpp | 9 +- benchmark/benchmark_predicate_iterator.cpp | 32 +- benchmark/benchmark_utils.hpp | 46 +- benchmark/benchmark_warp_exchange.cpp | 17 +- benchmark/benchmark_warp_reduce.cpp | 25 +- benchmark/benchmark_warp_scan.cpp | 23 +- benchmark/benchmark_warp_sort.cpp | 19 +- cmake/Dependencies.cmake | 42 +- cmake/Summary.cmake | 5 +- docs/device_ops/adjacent_difference.rst | 2 +- docs/device_ops/adjacent_find.rst | 20 + docs/device_ops/find_end.rst | 19 + docs/device_ops/index.rst | 4 + docs/device_ops/search.rst | 19 + docs/device_ops/search_n.rst | 19 + docs/reference/ops_summary.rst | 4 + docs/sphinx/_toc.yml.in | 4 + example/CMakeLists.txt | 27 +- .../include/rocprim/block/block_load_func.hpp | 3 +- .../detail/block_reduce_raking_reduce.hpp | 12 - rocprim/include/rocprim/common.hpp | 62 + rocprim/include/rocprim/config.hpp | 1 - .../detail/config/device_adjacent_find.hpp | 461 ++++ .../device/detail/config/device_merge.hpp | 2435 +++++++++++++++++ .../config/device_select_predicated_flag.hpp | 1949 +++++++++++++ .../device/detail/device_adjacent_find.hpp | 156 ++ .../device/detail/device_batch_memcpy.hpp | 7 +- .../device/detail/device_config_helper.hpp | 154 ++ .../device/detail/device_nth_element.hpp | 70 +- .../device/detail/device_partition.hpp | 70 +- .../rocprim/device/detail/device_search.hpp | 440 +++ .../rocprim/device/detail/device_search_n.hpp | 439 +++ .../device/device_adjacent_difference.hpp | 24 +- .../rocprim/device/device_adjacent_find.hpp | 299 ++ .../device/device_adjacent_find_config.hpp | 105 + .../rocprim/device/device_binary_search.hpp | 4 +- .../rocprim/device/device_find_end.hpp | 155 ++ .../rocprim/device/device_find_first_of.hpp | 24 +- .../rocprim/device/device_histogram.hpp | 56 +- .../include/rocprim/device/device_merge.hpp | 161 +- .../rocprim/device/device_merge_config.hpp | 146 +- .../rocprim/device/device_merge_sort.hpp | 32 +- .../rocprim/device/device_partial_sort.hpp | 386 ++- .../device/device_partial_sort_config.hpp | 11 +- .../rocprim/device/device_partition.hpp | 37 +- .../device/device_partition_config.hpp | 25 + .../rocprim/device/device_radix_sort.hpp | 32 +- .../include/rocprim/device/device_reduce.hpp | 29 +- .../rocprim/device/device_reduce_by_key.hpp | 26 +- .../device/device_run_length_encode.hpp | 32 +- .../include/rocprim/device/device_scan.hpp | 32 +- .../rocprim/device/device_scan_by_key.hpp | 26 +- .../include/rocprim/device/device_search.hpp | 155 ++ .../rocprim/device/device_search_config.hpp | 77 + .../rocprim/device/device_search_n.hpp | 111 + .../rocprim/device/device_search_n_config.hpp | 71 + .../device/device_segmented_radix_sort.hpp | 42 +- .../device/device_segmented_reduce.hpp | 22 +- .../rocprim/device/device_segmented_scan.hpp | 76 +- .../include/rocprim/device/device_select.hpp | 148 +- .../rocprim/device/device_transform.hpp | 22 +- .../device_radix_block_sort.hpp | 24 +- rocprim/include/rocprim/intrinsics/atomic.hpp | 47 +- rocprim/include/rocprim/intrinsics/thread.hpp | 5 - rocprim/include/rocprim/intrinsics/warp.hpp | 35 +- .../rocprim/intrinsics/warp_shuffle.hpp | 17 +- rocprim/include/rocprim/iterator.hpp | 2 - .../rocprim/iterator/reverse_iterator.hpp | 103 +- rocprim/include/rocprim/rocprim.hpp | 4 + .../rocprim/thread/radix_key_codec.hpp | 10 +- .../include/rocprim/thread/thread_load.hpp | 9 - .../include/rocprim/thread/thread_store.hpp | 7 - rocprim/include/rocprim/type_traits.hpp | 86 +- rocprim/include/rocprim/types.hpp | 22 +- .../warp/detail/warp_segment_bounds.hpp | 17 +- scripts/autotune/create_optimization.py | 37 +- .../templates/adjacent_find_config_template | 20 + .../autotune/templates/merge_config_template | 25 + .../templates/partition_flag_config_template | 2 +- .../partition_predicate_config_template | 2 +- .../partition_three_way_config_template | 2 +- .../partition_two_way_flag_config_template | 2 +- ...artition_two_way_predicate_config_template | 2 +- .../templates/select_flag_config_template | 2 +- .../select_predicate_config_template | 2 +- .../select_predicated_flag_config_template | 20 + .../select_unique_by_key_config_template | 2 +- .../templates/select_unique_config_template | 2 +- test/CMakeLists.txt | 24 +- test/common_test_header.hpp | 19 +- test/extra/CMakeLists.txt | 12 +- test/hip/test_hip_api.cpp | 58 +- test/hip/test_ordered_block_id.cpp | 16 +- test/rocprim/CMakeLists.txt | 29 +- test/rocprim/indirect_iterator.hpp | 2 +- test/rocprim/test_arg_index_iterator.cpp | 9 +- test/rocprim/test_block_radix_rank.hpp | 17 +- .../rocprim/test_block_radix_sort.kernels.hpp | 34 +- test/rocprim/test_block_run_length_decode.cpp | 6 +- test/rocprim/test_constant_iterator.cpp | 4 +- test/rocprim/test_counting_iterator.cpp | 4 +- .../test_device_adjacent_difference.cpp | 12 +- test/rocprim/test_device_adjacent_find.cpp | 237 ++ test/rocprim/test_device_batch_memcpy.cpp | 326 +-- test/rocprim/test_device_find_end.cpp | 445 +++ test/rocprim/test_device_find_first_of.cpp | 6 +- test/rocprim/test_device_histogram.cpp | 22 +- test/rocprim/test_device_merge.cpp | 179 +- test/rocprim/test_device_merge_sort.cpp | 2 +- test/rocprim/test_device_nth_element.cpp | 19 +- test/rocprim/test_device_partial_sort.cpp | 75 +- test/rocprim/test_device_partition.cpp | 108 +- test/rocprim/test_device_radix_sort.cpp.in | 5 +- test/rocprim/test_device_radix_sort.hpp | 92 +- test/rocprim/test_device_reduce.cpp | 32 +- .../rocprim/test_device_run_length_encode.cpp | 2 + test/rocprim/test_device_scan.cpp | 50 +- test/rocprim/test_device_search.cpp | 442 +++ test/rocprim/test_device_search_n.cpp | 1392 ++++++++++ .../test_device_segmented_radix_sort.cpp.in | 14 +- .../test_device_segmented_radix_sort.hpp | 709 ++++- test/rocprim/test_device_segmented_reduce.cpp | 6 +- test/rocprim/test_device_select.cpp | 473 +++- test/rocprim/test_device_transform.cpp | 47 +- test/rocprim/test_intrinsics.cpp | 34 +- test/rocprim/test_radix_key_codec.cpp | 4 +- .../test_temporary_storage_partitioning.cpp | 2 +- test/rocprim/test_texture_cache_iterator.cpp | 10 +- test/rocprim/test_transform_iterator.cpp | 8 +- test/rocprim/test_utils_assertions.hpp | 5 +- test/rocprim/test_utils_data_generation.hpp | 60 +- test/rocprim/test_utils_device_ptr.hpp | 242 ++ test/rocprim/test_utils_sort_comparator.hpp | 50 +- test/rocprim/test_utils_types.hpp | 4 +- test/rocprim/test_warp_load.cpp | 3 +- test/rocprim/test_zip_iterator.cpp | 20 +- 207 files changed, 15989 insertions(+), 2572 deletions(-) create mode 100644 benchmark/benchmark_device_adjacent_find.cpp create mode 100644 benchmark/benchmark_device_adjacent_find.parallel.cpp.in create mode 100644 benchmark/benchmark_device_adjacent_find.parallel.hpp create mode 100644 benchmark/benchmark_device_find_end.cpp create mode 100644 benchmark/benchmark_device_find_end.hpp create mode 100644 benchmark/benchmark_device_merge.parallel.cpp.in create mode 100644 benchmark/benchmark_device_merge.parallel.hpp create mode 100644 benchmark/benchmark_device_search.cpp create mode 100644 benchmark/benchmark_device_search.hpp create mode 100644 benchmark/benchmark_device_search_n.cpp create mode 100644 benchmark/benchmark_device_search_n.parallel.cpp.in create mode 100644 benchmark/benchmark_device_search_n.parallel.hpp create mode 100644 docs/device_ops/adjacent_find.rst create mode 100644 docs/device_ops/find_end.rst create mode 100644 docs/device_ops/search.rst create mode 100644 docs/device_ops/search_n.rst create mode 100644 rocprim/include/rocprim/common.hpp create mode 100644 rocprim/include/rocprim/device/detail/config/device_adjacent_find.hpp create mode 100644 rocprim/include/rocprim/device/detail/config/device_merge.hpp create mode 100644 rocprim/include/rocprim/device/detail/config/device_select_predicated_flag.hpp create mode 100644 rocprim/include/rocprim/device/detail/device_adjacent_find.hpp create mode 100644 rocprim/include/rocprim/device/detail/device_search.hpp create mode 100644 rocprim/include/rocprim/device/detail/device_search_n.hpp create mode 100644 rocprim/include/rocprim/device/device_adjacent_find.hpp create mode 100644 rocprim/include/rocprim/device/device_adjacent_find_config.hpp create mode 100644 rocprim/include/rocprim/device/device_find_end.hpp create mode 100644 rocprim/include/rocprim/device/device_search.hpp create mode 100644 rocprim/include/rocprim/device/device_search_config.hpp create mode 100644 rocprim/include/rocprim/device/device_search_n.hpp create mode 100644 rocprim/include/rocprim/device/device_search_n_config.hpp create mode 100644 scripts/autotune/templates/adjacent_find_config_template create mode 100644 scripts/autotune/templates/merge_config_template create mode 100644 scripts/autotune/templates/select_predicated_flag_config_template create mode 100644 test/rocprim/test_device_adjacent_find.cpp create mode 100644 test/rocprim/test_device_find_end.cpp create mode 100644 test/rocprim/test_device_search.cpp create mode 100644 test/rocprim/test_device_search_n.cpp create mode 100644 test/rocprim/test_utils_device_ptr.hpp diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 3e4b0205b..0f76d8c77 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -114,6 +114,7 @@ copyright-date: -D AMDGPU_TEST_TARGETS=$GPU_TARGETS -D CMAKE_C_COMPILER_LAUNCHER=phc_sccache_c -D CMAKE_CXX_COMPILER_LAUNCHER=phc_sccache_cxx + -D CMAKE_CXX_STANDARD=14 -S $CI_PROJECT_DIR -B $BUILD_DIR - cmake @@ -182,6 +183,7 @@ build:cmake-minimum-apt: -D BUILD_EXAMPLE=ON -D GPU_TARGETS=$GPU_TARGETS -D AMDGPU_TEST_TARGETS=$GPU_TARGETS + -D CMAKE_CXX_STANDARD="$BUILD_VERSION" -S $CI_PROJECT_DIR -B $BUILD_DIR - cmake --build $BUILD_DIR @@ -210,6 +212,7 @@ build:cmake-latest: matrix: - BUILD_TYPE: Release BUILD_TARGET: [BENCHMARK, TEST] + BUILD_VERSION: [14, 17] build:cmake-minimum: needs: [] @@ -220,6 +223,7 @@ build:cmake-minimum: matrix: - BUILD_TYPE: [Debug, Release] BUILD_TARGET: [BENCHMARK, TEST] + BUILD_VERSION: 14 build:package: stage: build @@ -236,6 +240,7 @@ build:package: -G Ninja -D CMAKE_CXX_COMPILER="$AMDCLANG" -D CMAKE_BUILD_TYPE=Release + -D CMAKE_CXX_STANDARD=14 -B $PACKAGE_DIR -S $CI_PROJECT_DIR - cd $PACKAGE_DIR @@ -268,6 +273,7 @@ build:windows: -D CMAKE_CXX_COMPILER:PATH="${env:HIP_PATH}\bin\clang++.exe" -D CMAKE_PREFIX_PATH:PATH="${env:HIP_PATH}" -D CMAKE_BUILD_TYPE="$BUILD_TYPE" + -D CMAKE_CXX_STANDARD=14 - cmake --build "$CI_PROJECT_DIR/build" artifacts: paths: @@ -314,6 +320,7 @@ autotune:build: -D GPU_TARGETS=$GPU_TARGETS -D CMAKE_C_COMPILER_LAUNCHER=phc_sccache_c -D CMAKE_CXX_COMPILER_LAUNCHER=phc_sccache_cxx + -D CMAKE_CXX_STANDARD=14 - cmake --build . --target $BENCHMARK_TARGETS - 'rm -rf $BUILD_DIR/benchmark/benchmark*.parallel' # The autotune benchmarks get very large, above GitLabs upload limit. Fortunately they compress well. @@ -339,6 +346,7 @@ test: matrix: - BUILD_TYPE: Release BUILD_TARGET: TEST + BUILD_VERSION: 14 script: - cd $BUILD_DIR - cmake @@ -395,6 +403,7 @@ test-windows-release: -D CMAKE_CXX_COMPILER="$AMDCLANG" -D CMAKE_BUILD_TYPE=Release -D GPU_TARGETS=$GPU_TARGETS + -D CMAKE_CXX_STANDARD=14 -S "$CI_PROJECT_DIR/test/extra" -B "$CI_PROJECT_DIR/package_test" - cmake --build "$CI_PROJECT_DIR/package_test" @@ -416,6 +425,7 @@ test:install: -G Ninja -D CMAKE_CXX_COMPILER="$AMDCLANG" -D CMAKE_BUILD_TYPE=Release + -D CMAKE_CXX_STANDARD=14 -B build -S $CI_PROJECT_DIR - $SUDO_CMD cmake --build build --target install @@ -458,6 +468,7 @@ benchmark: matrix: - BUILD_TYPE: Release BUILD_TARGET: BENCHMARK + BUILD_VERSION: 14 extends: - .cmake-minimum - .gpus:rocm diff --git a/CHANGELOG.md b/CHANGELOG.md index 679ce6905..8e9a7ad69 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,13 @@ Full documentation for rocPRIM is available at [https://rocm.docs.amd.com/projec * Added `--emulation` option added for `rtest.py` * Unit tests can be run with `[--emulation|-e|--test|-t]=` * Added tuned configurations for segmented radix sort for gfx942 to improve performance on this architecture. +* Added a parallel device-level function, `rocprim::adjacent_find`, similar to the C++ Standard Library `std::adjacent_find` algorithm. +* Added configuration autotuning to device adjacent find (`rocprim::adjacent_find`) for improved performance on selected architectures. +* Added rocprim::numeric_limits which is an extension of `std::numeric_limits`, which includes support for 128-bit integers. +* Added rocprim::int128_t and rocprim::uint128_t which are the __int128_t and __uint128_t types. +* Added the parallel `search` and `find_end` device functions similar to `std::search` and `std::find_end`, these functions search for the first and last occurrence of the sequence respectively. +* Added a parallel device-level function, `rocprim::search_n`, similar to the C++ Standard Library `std::search_n` algorithm. +* Added new constructors and a `base` function, and added `constexpr` specifier to all functions in `rocprim::reverse_iterator` to improve parity with the C++17 `std::reverse_iterator`. ### Changed @@ -22,6 +29,9 @@ Full documentation for rocPRIM is available at [https://rocm.docs.amd.com/projec * Changed the internal algorithm of block radix sort to use rank match to improve performance of various radix sort related algorithms. * Disabled padding in various cases where higher occupancy resulted in better performance despite more bank conflicts. +* Removed HIP-CPU support. HIP-CPU support was experimental and broken. +* Changed the C++ version from 14 to 17. C++14 will be deprecated in the next major release. + ### Resolved issues * Fixed an issue where `rmake.py` would generate wrong CMAKE commands while using Linux environment @@ -30,6 +40,7 @@ Full documentation for rocPRIM is available at [https://rocm.docs.amd.com/projec * Fixed compilation issue when `rocprim::radix_key_codec<...>` is specialized with a 128-bit integer. ### Upcoming changes +* Using the initialisation constructor of `rocprim::reverse_iterator` will throw a deprecation warning. It will be marked as explicit in the next major release. ## rocPRIM 3.3.0 for ROCm 6.3.0 diff --git a/CMakeLists.txt b/CMakeLists.txt index a5b9b1274..bedfa8542 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -48,7 +48,6 @@ option(BUILD_BENCHMARK "Build benchmarks" OFF) option(BUILD_NAIVE_BENCHMARK "Build naive benchmarks" OFF) option(BUILD_EXAMPLE "Build examples" OFF) option(BUILD_DOCS "Build documentation (requires sphinx)" OFF) -option(USE_HIP_CPU "Prefer HIP-CPU runtime instead of HW acceleration" OFF) # Disables building tests, benchmarks, examples option(ONLY_INSTALL "Only install" OFF) option(BUILD_CODE_COVERAGE "Build with code coverage enabled" OFF) @@ -70,50 +69,57 @@ endif() set(CMAKE_INSTALL_RPATH_USE_LINK_PATH TRUE CACHE BOOL "Add paths to linker search and installed rpath") # Set CXX flags -set(CMAKE_CXX_STANDARD 14) +if (NOT DEFINED CMAKE_CXX_STANDARD) + set(CMAKE_CXX_STANDARD 17) +endif() set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_EXTENSIONS OFF) +if (CMAKE_CXX_STANDARD EQUAL 14) + message(WARNING "C++14 will be deprecated in the next major release") +elseif(NOT CMAKE_CXX_STANDARD EQUAL 17) + message(FATAL_ERROR "Only C++14 and C++17 are supported") +endif() + if(DEFINED BUILD_SHARED_LIBS) set(PKG_BUILD_SHARED_LIBS ${BUILD_SHARED_LIBS}) -else() +else() set(PKG_BUILD_SHARED_LIBS ON) -endif() +endif() set(BUILD_SHARED_LIBS OFF) # don't build client dependencies as shared -if(NOT USE_HIP_CPU) - # Get dependencies (required here to get rocm-cmake) - include(cmake/Dependencies.cmake) - # Use target ID syntax if supported for GPU_TARGETS - if (NOT DEFINED AMDGPU_TARGETS) - set(GPU_TARGETS "all" CACHE STRING "GPU architectures to compile for") + +# Get dependencies (required here to get rocm-cmake) +include(cmake/Dependencies.cmake) +# Use target ID syntax if supported for GPU_TARGETS +if (NOT DEFINED AMDGPU_TARGETS) + set(GPU_TARGETS "all" CACHE STRING "GPU architectures to compile for") +else() + set(GPU_TARGETS "${AMDGPU_TARGETS}" CACHE STRING "GPU architectures to compile for") +endif() +set_property(CACHE GPU_TARGETS PROPERTY STRINGS "all") + +if(GPU_TARGETS STREQUAL "all") + if(BUILD_ADDRESS_SANITIZER) + # ASAN builds require xnack + rocm_check_target_ids(DEFAULT_AMDGPU_TARGETS + TARGETS "gfx908:xnack+;gfx90a:xnack+;gfx940:xnack+;gfx941:xnack+;gfx942:xnack+" + ) else() - set(GPU_TARGETS "${AMDGPU_TARGETS}" CACHE STRING "GPU architectures to compile for") - endif() - set_property(CACHE GPU_TARGETS PROPERTY STRINGS "all") - - if(GPU_TARGETS STREQUAL "all") - if(BUILD_ADDRESS_SANITIZER) - # ASAN builds require xnack - rocm_check_target_ids(DEFAULT_AMDGPU_TARGETS - TARGETS "gfx908:xnack+;gfx90a:xnack+;gfx940:xnack+;gfx941:xnack+;gfx942:xnack+" - ) - else() - rocm_check_target_ids(DEFAULT_AMDGPU_TARGETS - TARGETS "gfx803;gfx900:xnack-;gfx906:xnack-;gfx908:xnack-;gfx90a:xnack-;gfx90a:xnack+;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1151;gfx1200;gfx1201" - ) - endif() - - set(GPU_TARGETS "${DEFAULT_AMDGPU_TARGETS}" CACHE STRING "GPU architectures to compile for" FORCE) + rocm_check_target_ids(DEFAULT_AMDGPU_TARGETS + TARGETS "gfx803;gfx900:xnack-;gfx906:xnack-;gfx908:xnack-;gfx90a:xnack-;gfx90a:xnack+;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1151;gfx1200;gfx1201" + ) endif() - # TODO: Fix VerifyCompiler for HIP on Windows - if (NOT WIN32) - include(cmake/VerifyCompiler.cmake) - endif() - list(APPEND CMAKE_PREFIX_PATH ${ROCM_PATH} ${ROCM_PATH}/hip ${ROCM_PATH}/llvm ${ROCM_ROOT}/llvm ${ROCM_ROOT} ${ROCM_ROOT}/hip) - find_package(hip REQUIRED CONFIG PATHS ${HIP_DIR} ${ROCM_PATH} /opt/rocm) + set(GPU_TARGETS "${DEFAULT_AMDGPU_TARGETS}" CACHE STRING "GPU architectures to compile for" FORCE) endif() +# TODO: Fix VerifyCompiler for HIP on Windows +if (NOT WIN32) + include(cmake/VerifyCompiler.cmake) +endif() +list(APPEND CMAKE_PREFIX_PATH ${ROCM_PATH} ${ROCM_PATH}/hip ${ROCM_PATH}/llvm ${ROCM_ROOT}/llvm ${ROCM_ROOT} ${ROCM_ROOT}/hip) +find_package(hip REQUIRED CONFIG PATHS ${HIP_DIR} ${ROCM_PATH} /opt/rocm) + # FOR HANDLING ENABLE/DISABLE OPTIONAL BACKWARD COMPATIBILITY for FILE/FOLDER REORG option(BUILD_FILE_REORG_BACKWARD_COMPATIBILITY "Build with file/folder reorg with backward compatibility enabled" OFF) if(ROCPRIM_INSTALL AND BUILD_FILE_REORG_BACKWARD_COMPATIBILITY AND NOT WIN32) @@ -130,11 +136,6 @@ if(BUILD_CODE_COVERAGE) add_link_options(--coverage) endif() -if(USE_HIP_CPU) - # Get dependencies - include(cmake/Dependencies.cmake) -endif() - # Setup VERSION set(VERSION_STRING "3.3.0") rocm_setup_version(VERSION ${VERSION_STRING}) diff --git a/README.md b/README.md index fe87e326a..5517f9d63 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ develop performant GPU-accelerated code on AMD ROCm platforms. * Including [HIP-clang](https://github.com/ROCm/HIP/blob/master/INSTALL.md#hip-clang) compiler -* C++14 +* C++17 * Python 3.6 or higher (HIP on Windows only, required only for install script) * Visual Studio 2019 with Clang support (HIP on Windows only) * Strawberry Perl (HIP on Windows only) @@ -110,11 +110,6 @@ You can build and install rocPRIM on Linux or Windows. # before 'cmake' or setting cmake option 'CMAKE_CXX_COMPILER' to path to the compiler. # Using HIP-clang: [CXX=hipcc] cmake -DBUILD_BENCHMARK=ON ../. - # - # ! EXPERIMENTAL ! - # Alternatively one may build using the experimental (and highly incomplete) HIP-CPU back-end for host-side - # execution using any C++17 conforming compiler (supported by HIP-CPU). AMDGPU_* options are unavailable in this case. - # USE_HIP_CPU - OFF by default # Build make -j4 diff --git a/benchmark/CMakeLists.txt b/benchmark/CMakeLists.txt index fae6d2cfc..d8ef35f90 100644 --- a/benchmark/CMakeLists.txt +++ b/benchmark/CMakeLists.txt @@ -77,24 +77,10 @@ function(add_rocprim_benchmark BENCHMARK_SOURCE) rocprim benchmark::benchmark ) - if(NOT USE_HIP_CPU) - target_link_libraries(${BENCHMARK_TARGET} - PRIVATE - rocprim_hip - ) - else() - target_link_libraries(${BENCHMARK_TARGET} - PRIVATE - Threads::Threads - hip_cpu_rt::hip_cpu_rt - ) - if(STL_DEPENDS_ON_TBB) - target_link_libraries(${BENCHMARK_TARGET} - PRIVATE - TBB::tbb - ) - endif() - endif() + target_link_libraries(${BENCHMARK_TARGET} + PRIVATE + rocprim_hip + ) target_compile_options(${BENCHMARK_TARGET} PRIVATE @@ -143,9 +129,11 @@ add_rocprim_benchmark(benchmark_block_scan.cpp) add_rocprim_benchmark(benchmark_block_sort.cpp) add_rocprim_benchmark(benchmark_config_dispatch.cpp) add_rocprim_benchmark(benchmark_device_adjacent_difference.cpp) +add_rocprim_benchmark(benchmark_device_adjacent_find.cpp) add_rocprim_benchmark(benchmark_device_batch_memcpy.cpp) add_rocprim_benchmark(benchmark_device_binary_search.cpp) add_rocprim_benchmark(benchmark_device_find_first_of.cpp) +add_rocprim_benchmark(benchmark_device_find_end.cpp) add_rocprim_benchmark(benchmark_device_histogram.cpp) add_rocprim_benchmark(benchmark_device_merge.cpp) add_rocprim_benchmark(benchmark_device_merge_sort.cpp) @@ -165,7 +153,9 @@ add_rocprim_benchmark(benchmark_device_run_length_encode.cpp) add_rocprim_benchmark(benchmark_device_scan.cpp) add_rocprim_benchmark(benchmark_device_scan_deterministic.cpp) add_rocprim_benchmark(benchmark_device_scan_by_key.cpp) +add_rocprim_benchmark(benchmark_device_search.cpp) add_rocprim_benchmark(benchmark_device_scan_by_key_deterministic.cpp) +add_rocprim_benchmark(benchmark_device_search_n.cpp) add_rocprim_benchmark(benchmark_device_select.cpp) add_rocprim_benchmark(benchmark_device_segmented_radix_sort_keys.cpp) add_rocprim_benchmark(benchmark_device_segmented_radix_sort_pairs.cpp) diff --git a/benchmark/ConfigAutotuneSettings.cmake b/benchmark/ConfigAutotuneSettings.cmake index bdc34185b..d0d11c62b 100644 --- a/benchmark/ConfigAutotuneSettings.cmake +++ b/benchmark/ConfigAutotuneSettings.cmake @@ -33,6 +33,10 @@ function(read_config_autotune_settings file list_across_names list_across output set(list_across "${TUNING_TYPES};\ true;false true;32 64 128 256 512 1024" PARENT_SCOPE) set(output_pattern_suffix "@DataType@_@Left@_@InPlace@_@BlockSize@" PARENT_SCOPE) + elseif(file STREQUAL "benchmark_device_adjacent_find") + set(list_across_names "InputType;BlockSize" PARENT_SCOPE) + set(list_across "${TUNING_TYPES};64 128 256 512 1024" PARENT_SCOPE) + set(output_pattern_suffix "@InputType@_@BlockSize@" PARENT_SCOPE) elseif(file STREQUAL "benchmark_device_histogram") set(list_across_names "DataType;BlockSize" PARENT_SCOPE) set(list_across "${TUNING_TYPES};64 128 256" PARENT_SCOPE) @@ -115,5 +119,9 @@ DataType;BlockSize;" PARENT_SCOPE) set(list_across_names "DataType;BlockSize" PARENT_SCOPE) set(list_across "${LIMITED_TUNING_TYPES};32 64 128 256 512 1024" PARENT_SCOPE) set(output_pattern_suffix "@DataType@_@BlockSize@" PARENT_SCOPE) + elseif(file STREQUAL "benchmark_device_merge") + set(list_across_names "KeyType;ValueType;BlockSize" PARENT_SCOPE) + set(list_across "${TUNING_TYPES};rocprim::empty_type ${LIMITED_TUNING_TYPES};32 64 128 256 512 1024" PARENT_SCOPE) + set(output_pattern_suffix "@KeyType@_@ValueType@_@BlockSize@" PARENT_SCOPE) endif() endfunction() diff --git a/benchmark/benchmark_block_adjacent_difference.cpp b/benchmark/benchmark_block_adjacent_difference.cpp index c26689119..8d237f763 100644 --- a/benchmark/benchmark_block_adjacent_difference.cpp +++ b/benchmark/benchmark_block_adjacent_difference.cpp @@ -45,7 +45,7 @@ #include #ifndef DEFAULT_N -const size_t DEFAULT_N = 1024 * 1024 * 128; +const size_t DEFAULT_BYTES = 1024 * 1024 * 128 * 4; #endif namespace rp = rocprim; @@ -228,10 +228,13 @@ template -auto run_benchmark(benchmark::State& state, size_t N, const managed_seed& seed, hipStream_t stream) +auto run_benchmark(benchmark::State& state, size_t bytes, const managed_seed& seed, hipStream_t stream) -> std::enable_if_t::value && !std::is_same::value> { + // Calculate the number of elements N + size_t N = bytes / sizeof(T); + constexpr auto items_per_block = BlockSize * ItemsPerThread; const auto num_blocks = (N + items_per_block - 1) / items_per_block; // Round up size to the next multiple of items_per_block @@ -296,10 +299,13 @@ template -auto run_benchmark(benchmark::State& state, size_t N, const managed_seed& seed, hipStream_t stream) +auto run_benchmark(benchmark::State& state, size_t bytes, const managed_seed& seed, hipStream_t stream) -> std::enable_if_t::value || std::is_same::value> { + // Calculate the number of elements N + size_t N = bytes / sizeof(T); + static constexpr auto items_per_block = BlockSize * ItemsPerThread; const auto num_blocks = (N + items_per_block - 1) / items_per_block; // Round up size to the next multiple of items_per_block @@ -383,7 +389,7 @@ auto run_benchmark(benchmark::State& state, size_t N, const managed_seed& seed, ",with_tile:" #WITH_TILE "}}") \ .c_str(), \ run_benchmark, \ - size, \ + bytes, \ seed, \ stream) @@ -398,7 +404,7 @@ auto run_benchmark(benchmark::State& state, size_t N, const managed_seed& seed, template void add_benchmarks(const std::string& name, std::vector& benchmarks, - size_t size, + size_t bytes, const managed_seed& seed, hipStream_t stream) { @@ -429,7 +435,7 @@ void add_benchmarks(const std::string& name, int main(int argc, char *argv[]) { cli::Parser parser(argc, argv); - parser.set_optional("size", "size", DEFAULT_N, "number of values"); + parser.set_optional("size", "size", DEFAULT_BYTES, "number of bytes"); parser.set_optional("trials", "trials", -1, "number of iterations"); parser.set_optional("name_format", "name_format", @@ -440,7 +446,7 @@ int main(int argc, char *argv[]) // Parse argv benchmark::Initialize(&argc, argv); - const size_t size = parser.get("size"); + const size_t bytes = parser.get("size"); const int trials = parser.get("trials"); bench_naming::set_format(parser.get("name_format")); const std::string seed_type = parser.get("seed"); @@ -451,17 +457,17 @@ int main(int argc, char *argv[]) // Benchmark info add_common_benchmark_info(); - benchmark::AddCustomContext("size", std::to_string(size)); + benchmark::AddCustomContext("bytes", std::to_string(bytes)); benchmark::AddCustomContext("seed", seed_type); // Add benchmarks std::vector benchmarks; - add_benchmarks("subtract_left", benchmarks, size, seed, stream); - add_benchmarks("subtract_right", benchmarks, size, seed, stream); - add_benchmarks("subtract_left_partial", benchmarks, size, seed, stream); + add_benchmarks("subtract_left", benchmarks, bytes, seed, stream); + add_benchmarks("subtract_right", benchmarks, bytes, seed, stream); + add_benchmarks("subtract_left_partial", benchmarks, bytes, seed, stream); add_benchmarks("subtract_right_partial", benchmarks, - size, + bytes, seed, stream); diff --git a/benchmark/benchmark_block_discontinuity.cpp b/benchmark/benchmark_block_discontinuity.cpp index 78e817ee5..1d07cdb9b 100644 --- a/benchmark/benchmark_block_discontinuity.cpp +++ b/benchmark/benchmark_block_discontinuity.cpp @@ -44,7 +44,7 @@ #include #ifndef DEFAULT_N -const size_t DEFAULT_N = 1024 * 1024 * 128; +const size_t DEFAULT_BYTES = 1024 * 1024 * 128 * 4; #endif namespace rp = rocprim; @@ -201,8 +201,11 @@ template -void run_benchmark(benchmark::State& state, size_t N, const managed_seed& seed, hipStream_t stream) +void run_benchmark(benchmark::State& state, size_t bytes, const managed_seed& seed, hipStream_t stream) { + // Calculate the number of elements N + size_t N = bytes / sizeof(T); + constexpr auto items_per_block = BlockSize * ItemsPerThread; const auto size = items_per_block * ((N + items_per_block - 1)/items_per_block); @@ -266,7 +269,7 @@ void run_benchmark(benchmark::State& state, size_t N, const managed_seed& seed, ",with_tile:" #WITH_TILE "}}") \ .c_str(), \ run_benchmark, \ - size, \ + bytes, \ seed, \ stream) @@ -280,7 +283,7 @@ void run_benchmark(benchmark::State& state, size_t N, const managed_seed& seed, template void add_benchmarks(const std::string& name, std::vector& benchmarks, - size_t size, + size_t bytes, const managed_seed& seed, hipStream_t stream) { @@ -304,7 +307,7 @@ void add_benchmarks(const std::string& name, int main(int argc, char *argv[]) { cli::Parser parser(argc, argv); - parser.set_optional("size", "size", DEFAULT_N, "number of values"); + parser.set_optional("size", "size", DEFAULT_BYTES, "number of bytes"); parser.set_optional("trials", "trials", -1, "number of iterations"); parser.set_optional("name_format", "name_format", @@ -315,7 +318,7 @@ int main(int argc, char *argv[]) // Parse argv benchmark::Initialize(&argc, argv); - const size_t size = parser.get("size"); + const size_t bytes = parser.get("size"); const int trials = parser.get("trials"); bench_naming::set_format(parser.get("name_format")); const std::string seed_type = parser.get("seed"); @@ -326,14 +329,14 @@ int main(int argc, char *argv[]) // Benchmark info add_common_benchmark_info(); - benchmark::AddCustomContext("size", std::to_string(size)); + benchmark::AddCustomContext("bytes", std::to_string(bytes)); benchmark::AddCustomContext("seed", seed_type); // Add benchmarks std::vector benchmarks; - add_benchmarks("flag_heads", benchmarks, size, seed, stream); - add_benchmarks("flag_tails", benchmarks, size, seed, stream); - add_benchmarks("flag_heads_and_tails", benchmarks, size, seed, stream); + add_benchmarks("flag_heads", benchmarks, bytes, seed, stream); + add_benchmarks("flag_tails", benchmarks, bytes, seed, stream); + add_benchmarks("flag_heads_and_tails", benchmarks, bytes, seed, stream); // Use manual timing for(auto& b : benchmarks) diff --git a/benchmark/benchmark_block_exchange.cpp b/benchmark/benchmark_block_exchange.cpp index 102b75ffd..ec3c95f48 100644 --- a/benchmark/benchmark_block_exchange.cpp +++ b/benchmark/benchmark_block_exchange.cpp @@ -44,7 +44,7 @@ #include #ifndef DEFAULT_N -const size_t DEFAULT_N = 1024 * 1024 * 32; +const size_t DEFAULT_BYTES = 1024 * 1024 * 32 * 4; #endif namespace rp = rocprim; @@ -246,8 +246,11 @@ template -void run_benchmark(benchmark::State& state, size_t N, const managed_seed& seed, hipStream_t stream) +void run_benchmark(benchmark::State& state, size_t bytes, const managed_seed& seed, hipStream_t stream) { + // Calculate the number of elements N + size_t N = bytes / sizeof(T); + constexpr auto items_per_block = BlockSize * ItemsPerThread; const auto size = items_per_block * ((N + items_per_block - 1)/items_per_block); @@ -332,7 +335,7 @@ void run_benchmark(benchmark::State& state, size_t N, const managed_seed& seed, + ",key_type:" #T ",cfg:{bs:" #BS ",ipt:" #IPT "}}") \ .c_str(), \ run_benchmark, \ - size, \ + bytes, \ seed, \ stream) @@ -347,7 +350,7 @@ void run_benchmark(benchmark::State& state, size_t N, const managed_seed& seed, template void add_benchmarks(const std::string& name, std::vector& benchmarks, - size_t size, + size_t bytes, const managed_seed& seed, hipStream_t stream) { @@ -373,7 +376,7 @@ void add_benchmarks(const std::string& name, int main(int argc, char *argv[]) { cli::Parser parser(argc, argv); - parser.set_optional("size", "size", DEFAULT_N, "number of values"); + parser.set_optional("bytes", "bytes", DEFAULT_BYTES, "number of values"); parser.set_optional("trials", "trials", -1, "number of iterations"); parser.set_optional("name_format", "name_format", @@ -384,7 +387,7 @@ int main(int argc, char *argv[]) // Parse argv benchmark::Initialize(&argc, argv); - const size_t size = parser.get("size"); + const size_t bytes = parser.get("bytes"); const int trials = parser.get("trials"); bench_naming::set_format(parser.get("name_format")); const std::string seed_type = parser.get("seed"); @@ -395,25 +398,25 @@ int main(int argc, char *argv[]) // Benchmark info add_common_benchmark_info(); - benchmark::AddCustomContext("size", std::to_string(size)); + benchmark::AddCustomContext("bytes", std::to_string(bytes)); benchmark::AddCustomContext("seed", seed_type); // Add benchmarks std::vector benchmarks; - add_benchmarks("blocked_to_striped", benchmarks, size, seed, stream); - add_benchmarks("striped_to_blocked", benchmarks, size, seed, stream); + add_benchmarks("blocked_to_striped", benchmarks, bytes, seed, stream); + add_benchmarks("striped_to_blocked", benchmarks, bytes, seed, stream); add_benchmarks("blocked_to_warp_striped", benchmarks, - size, + bytes, seed, stream); add_benchmarks("warp_striped_to_blocked", benchmarks, - size, + bytes, seed, stream); - add_benchmarks("scatter_to_blocked", benchmarks, size, seed, stream); - add_benchmarks("scatter_to_striped", benchmarks, size, seed, stream); + add_benchmarks("scatter_to_blocked", benchmarks, bytes, seed, stream); + add_benchmarks("scatter_to_striped", benchmarks, bytes, seed, stream); // Use manual timing for(auto& b : benchmarks) diff --git a/benchmark/benchmark_block_histogram.cpp b/benchmark/benchmark_block_histogram.cpp index 676845f67..7dc282749 100644 --- a/benchmark/benchmark_block_histogram.cpp +++ b/benchmark/benchmark_block_histogram.cpp @@ -44,7 +44,7 @@ #include #ifndef DEFAULT_N -const size_t DEFAULT_N = 1024 * 1024 * 128; +const size_t DEFAULT_BYTES = 1024 * 1024 * 128 * 4; #endif namespace rp = rocprim; @@ -121,8 +121,10 @@ template< unsigned int BinSize = BlockSize, unsigned int Trials = 100 > -void run_benchmark(benchmark::State& state, hipStream_t stream, size_t N) +void run_benchmark(benchmark::State& state, hipStream_t stream, size_t bytes) { + // Calculate the number of elements N + size_t N = bytes / sizeof(T); // Make sure size is a multiple of BlockSize constexpr auto items_per_block = BlockSize * ItemsPerThread; const auto size = items_per_block * ((N + items_per_block - 1)/items_per_block); @@ -188,7 +190,7 @@ void run_benchmark(benchmark::State& state, hipStream_t stream, size_t N) .c_str(), \ run_benchmark, \ stream, \ - size) + bytes) #define BENCHMARK_TYPE(type, block) \ CREATE_BENCHMARK(type, block, 1), \ @@ -202,7 +204,7 @@ template void add_benchmarks(std::vector& benchmarks, const std::string& method_name, hipStream_t stream, - size_t size) + size_t bytes) { std::vector new_benchmarks = { @@ -219,7 +221,7 @@ void add_benchmarks(std::vector& benchmarks, int main(int argc, char *argv[]) { cli::Parser parser(argc, argv); - parser.set_optional("size", "size", DEFAULT_N, "number of values"); + parser.set_optional("size", "size", DEFAULT_BYTES, "number of bytes"); parser.set_optional("trials", "trials", -1, "number of iterations"); parser.set_optional("name_format", "name_format", @@ -229,7 +231,7 @@ int main(int argc, char *argv[]) // Parse argv benchmark::Initialize(&argc, argv); - const size_t size = parser.get("size"); + const size_t bytes = parser.get("size"); const int trials = parser.get("trials"); bench_naming::set_format(parser.get("name_format")); @@ -238,16 +240,16 @@ int main(int argc, char *argv[]) // Benchmark info add_common_benchmark_info(); - benchmark::AddCustomContext("size", std::to_string(size)); + benchmark::AddCustomContext("bytes", std::to_string(bytes)); // Add benchmarks std::vector benchmarks; // using_atomic using histogram_a_t = histogram; - add_benchmarks(benchmarks, "using_atomic", stream, size); + add_benchmarks(benchmarks, "using_atomic", stream, bytes); // using_sort using histogram_s_t = histogram; - add_benchmarks(benchmarks, "using_sort", stream, size); + add_benchmarks(benchmarks, "using_sort", stream, bytes); // Use manual timing for(auto& b : benchmarks) diff --git a/benchmark/benchmark_block_radix_rank.cpp b/benchmark/benchmark_block_radix_rank.cpp index f49097cb6..faee6d37c 100644 --- a/benchmark/benchmark_block_radix_rank.cpp +++ b/benchmark/benchmark_block_radix_rank.cpp @@ -40,7 +40,7 @@ #include #ifndef DEFAULT_N -const size_t DEFAULT_N = 1024 * 1024 * 128; +const size_t DEFAULT_BYTES = 1024 * 1024 * 128 * 4; #endif namespace rp = rocprim; @@ -97,8 +97,10 @@ template -void run_benchmark(benchmark::State& state, size_t N, const managed_seed& seed, hipStream_t stream) +void run_benchmark(benchmark::State& state, size_t bytes, const managed_seed& seed, hipStream_t stream) { + // Calculate the number of elements N + size_t N = bytes / sizeof(T); constexpr unsigned int items_per_block = BlockSize * ItemsPerThread; const unsigned int grid_size = ((N + items_per_block - 1) / items_per_block); const unsigned int size = items_per_block * grid_size; @@ -117,7 +119,7 @@ void run_benchmark(benchmark::State& state, size_t N, const managed_seed& seed, for(auto _ : state) { - auto start = std::chrono::high_resolution_clock::now(); + auto start = std::chrono::steady_clock::now(); hipLaunchKernelGGL(HIP_KERNEL_NAME(rank_kernel>(end - start); state.SetIterationTime(elapsed_seconds.count()); @@ -153,7 +155,7 @@ void run_benchmark(benchmark::State& state, size_t N, const managed_seed& seed, ",ipt:" #IPT ",method:" #KIND "}}") \ .c_str(), \ run_benchmark, \ - size, \ + bytes, \ seed, \ stream) @@ -173,7 +175,7 @@ void run_benchmark(benchmark::State& state, size_t N, const managed_seed& seed, // clang-format on void add_benchmarks(std::vector& benchmarks, - size_t size, + size_t bytes, const managed_seed& seed, hipStream_t stream) { @@ -197,7 +199,7 @@ void add_benchmarks(std::vector& benchmarks, int main(int argc, char* argv[]) { cli::Parser parser(argc, argv); - parser.set_optional("size", "size", DEFAULT_N, "number of values"); + parser.set_optional("size", "size", DEFAULT_BYTES, "number of bytes"); parser.set_optional("trials", "trials", -1, "number of iterations"); parser.set_optional("name_format", "name_format", @@ -208,7 +210,7 @@ int main(int argc, char* argv[]) // Parse argv benchmark::Initialize(&argc, argv); - const size_t size = parser.get("size"); + const size_t bytes = parser.get("size"); const int trials = parser.get("trials"); bench_naming::set_format(parser.get("name_format")); const std::string seed_type = parser.get("seed"); @@ -219,12 +221,12 @@ int main(int argc, char* argv[]) // Benchmark info add_common_benchmark_info(); - benchmark::AddCustomContext("size", std::to_string(size)); + benchmark::AddCustomContext("bytes", std::to_string(bytes)); benchmark::AddCustomContext("seed", seed_type); // Add benchmarks std::vector benchmarks; - add_benchmarks(benchmarks, size, seed, stream); + add_benchmarks(benchmarks, bytes, seed, stream); // Use manual timing for(auto& b : benchmarks) diff --git a/benchmark/benchmark_block_radix_sort.cpp b/benchmark/benchmark_block_radix_sort.cpp index d82253c58..ff144732f 100644 --- a/benchmark/benchmark_block_radix_sort.cpp +++ b/benchmark/benchmark_block_radix_sort.cpp @@ -45,7 +45,7 @@ #include #ifndef DEFAULT_N -const size_t DEFAULT_N = 1024 * 1024 * 128; +const size_t DEFAULT_BYTES = 1024 * 1024 * 128 * 4; #endif enum class benchmark_kinds @@ -129,10 +129,13 @@ template void run_benchmark(benchmark::State& state, benchmark_kinds benchmark_kind, - size_t N, + size_t bytes, const managed_seed& seed, hipStream_t stream) { + // Calculate the number of elements N + size_t N = bytes / sizeof(T); + constexpr auto items_per_block = BlockSize * ItemsPerThread; const auto size = items_per_block * ((N + items_per_block - 1)/items_per_block); @@ -217,7 +220,7 @@ void run_benchmark(benchmark::State& state, .c_str(), \ run_benchmark, \ benchmark_kind, \ - size, \ + bytes, \ seed, \ stream) @@ -229,7 +232,7 @@ void run_benchmark(benchmark::State& state, void add_benchmarks(benchmark_kinds benchmark_kind, const std::string& name, std::vector& benchmarks, - size_t size, + size_t bytes, const managed_seed& seed, hipStream_t stream) { @@ -303,7 +306,7 @@ void add_benchmarks(benchmark_kinds benchmark_kind int main(int argc, char *argv[]) { cli::Parser parser(argc, argv); - parser.set_optional("size", "size", DEFAULT_N, "number of values"); + parser.set_optional("size", "size", DEFAULT_BYTES, "number of bytes"); parser.set_optional("trials", "trials", -1, "number of iterations"); parser.set_optional("name_format", "name_format", @@ -314,7 +317,7 @@ int main(int argc, char *argv[]) // Parse argv benchmark::Initialize(&argc, argv); - const size_t size = parser.get("size"); + const size_t bytes = parser.get("size"); const int trials = parser.get("trials"); bench_naming::set_format(parser.get("name_format")); const std::string seed_type = parser.get("seed"); @@ -325,13 +328,13 @@ int main(int argc, char *argv[]) // Benchmark info add_common_benchmark_info(); - benchmark::AddCustomContext("size", std::to_string(size)); + benchmark::AddCustomContext("bytes", std::to_string(bytes)); benchmark::AddCustomContext("seed", seed_type); // Add benchmarks std::vector benchmarks; - add_benchmarks(benchmark_kinds::sort_keys, "keys", benchmarks, size, seed, stream); - add_benchmarks(benchmark_kinds::sort_pairs, "pairs", benchmarks, size, seed, stream); + add_benchmarks(benchmark_kinds::sort_keys, "keys", benchmarks, bytes, seed, stream); + add_benchmarks(benchmark_kinds::sort_pairs, "pairs", benchmarks, bytes, seed, stream); // Use manual timing for(auto& b : benchmarks) diff --git a/benchmark/benchmark_block_reduce.cpp b/benchmark/benchmark_block_reduce.cpp index 446649040..257165459 100644 --- a/benchmark/benchmark_block_reduce.cpp +++ b/benchmark/benchmark_block_reduce.cpp @@ -42,7 +42,7 @@ #include #ifndef DEFAULT_N -const size_t DEFAULT_N = 1024 * 1024 * 32; +const size_t DEFAULT_BYTES = 1024 * 1024 * 32 * 4; #endif namespace rp = rocprim; @@ -106,8 +106,10 @@ template< unsigned int ItemsPerThread, unsigned int Trials = 100 > -void run_benchmark(benchmark::State& state, hipStream_t stream, size_t N) +void run_benchmark(benchmark::State& state, hipStream_t stream, size_t bytes) { + // Calculate the number of elements N + size_t N = bytes / sizeof(T); // Make sure size is a multiple of BlockSize constexpr auto items_per_block = BlockSize * ItemsPerThread; const auto size = items_per_block * ((N + items_per_block - 1)/items_per_block); @@ -171,7 +173,7 @@ void run_benchmark(benchmark::State& state, hipStream_t stream, size_t N) .c_str(), \ run_benchmark, \ stream, \ - size) + bytes) #define BENCHMARK_TYPE(type, block) \ CREATE_BENCHMARK(type, block, 1), \ @@ -186,7 +188,7 @@ template void add_benchmarks(std::vector& benchmarks, const std::string& method_name, hipStream_t stream, - size_t size) + size_t bytes) { using custom_float2 = custom_type; using custom_double2 = custom_type; @@ -234,7 +236,7 @@ void add_benchmarks(std::vector& benchmarks, int main(int argc, char *argv[]) { cli::Parser parser(argc, argv); - parser.set_optional("size", "size", DEFAULT_N, "number of values"); + parser.set_optional("size", "size", DEFAULT_BYTES, "number of bytes"); parser.set_optional("trials", "trials", -1, "number of iterations"); parser.set_optional("name_format", "name_format", @@ -244,7 +246,7 @@ int main(int argc, char *argv[]) // Parse argv benchmark::Initialize(&argc, argv); - const size_t size = parser.get("size"); + const size_t bytes = parser.get("size"); const int trials = parser.get("trials"); bench_naming::set_format(parser.get("name_format")); @@ -253,19 +255,19 @@ int main(int argc, char *argv[]) // Benchmark info add_common_benchmark_info(); - benchmark::AddCustomContext("size", std::to_string(size)); + benchmark::AddCustomContext("bytes", std::to_string(bytes)); // Add benchmarks std::vector benchmarks; // using_warp_scan using reduce_uwr_t = reduce; - add_benchmarks(benchmarks, "using_warp_reduce", stream, size); + add_benchmarks(benchmarks, "using_warp_reduce", stream, bytes); // reduce then scan using reduce_rr_t = reduce; - add_benchmarks(benchmarks, "raking_reduce", stream, size); + add_benchmarks(benchmarks, "raking_reduce", stream, bytes); // reduce commutative only using reduce_rrco_t = reduce; - add_benchmarks(benchmarks, "raking_reduce_commutative_only", stream, size); + add_benchmarks(benchmarks, "raking_reduce_commutative_only", stream, bytes); // Use manual timing for(auto& b : benchmarks) diff --git a/benchmark/benchmark_block_run_length_decode.cpp b/benchmark/benchmark_block_run_length_decode.cpp index b3cd5f066..e56bf098e 100644 --- a/benchmark/benchmark_block_run_length_decode.cpp +++ b/benchmark/benchmark_block_run_length_decode.cpp @@ -33,7 +33,7 @@ #include #ifndef DEFAULT_N -const size_t DEFAULT_N = 1024 * 1024 * 32; +const size_t DEFAULT_BYTES = 1024 * 1024 * 32 * 4; #endif template -void run_benchmark(benchmark::State& state, size_t N, const managed_seed& seed, hipStream_t stream) +void run_benchmark(benchmark::State& state, size_t bytes, const managed_seed& seed, hipStream_t stream) { + // Calculate the number of elements N + size_t N = bytes / sizeof(ItemT); constexpr auto runs_per_block = BlockSize * RunsPerThread; const auto target_num_runs = 2 * N / (MinRunLength + MaxRunLength); const auto num_runs @@ -140,7 +142,7 @@ void run_benchmark(benchmark::State& state, size_t N, const managed_seed& seed, for(auto _ : state) { - auto start = std::chrono::high_resolution_clock::now(); + auto start = std::chrono::steady_clock::now(); hipLaunchKernelGGL(HIP_KERNEL_NAME(block_run_length_decode_kernel>(end - start); @@ -179,14 +181,14 @@ void run_benchmark(benchmark::State& state, size_t N, const managed_seed& seed, ",run_per_thread:" #RPT ",decoded_items_per_thread:" #DIPT "}}") \ .c_str(), \ &run_benchmark, \ - size, \ + bytes, \ seed, \ stream) int main(int argc, char* argv[]) { cli::Parser parser(argc, argv); - parser.set_optional("size", "size", DEFAULT_N, "number of values"); + parser.set_optional("size", "size", DEFAULT_BYTES, "number of bytes"); parser.set_optional("trials", "trials", -1, "number of iterations"); parser.set_optional("name_format", "name_format", @@ -197,7 +199,7 @@ int main(int argc, char* argv[]) // Parse argv benchmark::Initialize(&argc, argv); - const size_t size = parser.get("size"); + const size_t bytes = parser.get("size"); const int trials = parser.get("trials"); bench_naming::set_format(parser.get("name_format")); const std::string seed_type = parser.get("seed"); @@ -208,7 +210,7 @@ int main(int argc, char* argv[]) // Benchmark info add_common_benchmark_info(); - benchmark::AddCustomContext("size", std::to_string(size)); + benchmark::AddCustomContext("bytes", std::to_string(bytes)); benchmark::AddCustomContext("seed", seed_type); // Add benchmarks diff --git a/benchmark/benchmark_block_scan.cpp b/benchmark/benchmark_block_scan.cpp index 071312b62..9c49c47d2 100644 --- a/benchmark/benchmark_block_scan.cpp +++ b/benchmark/benchmark_block_scan.cpp @@ -42,7 +42,7 @@ #include #ifndef DEFAULT_N -const size_t DEFAULT_N = 1024 * 1024 * 32; +const size_t DEFAULT_BYTES = 1024 * 1024 * 128 * 4; #endif namespace rp = rocprim; @@ -145,8 +145,10 @@ template< unsigned int ItemsPerThread, unsigned int Trials = 100 > -void run_benchmark(benchmark::State& state, hipStream_t stream, size_t N) +void run_benchmark(benchmark::State& state, hipStream_t stream, size_t bytes) { + // Calculate the number of elements N + size_t N = bytes / sizeof(T); // Make sure size is a multiple of BlockSize constexpr auto items_per_block = BlockSize * ItemsPerThread; const auto size = items_per_block * ((N + items_per_block - 1)/items_per_block); @@ -211,7 +213,7 @@ void run_benchmark(benchmark::State& state, hipStream_t stream, size_t N) .c_str(), \ run_benchmark, \ stream, \ - size) + bytes) #define BENCHMARK_TYPE(type, block) \ CREATE_BENCHMARK(type, block, 1), \ @@ -227,7 +229,7 @@ void add_benchmarks(std::vector& benchmarks, const std::string& method_name, const std::string& algorithm_name, hipStream_t stream, - size_t size) + size_t bytes) { using custom_float2 = custom_type; using custom_double2 = custom_type; @@ -275,7 +277,7 @@ void add_benchmarks(std::vector& benchmarks, int main(int argc, char *argv[]) { cli::Parser parser(argc, argv); - parser.set_optional("size", "size", DEFAULT_N, "number of values"); + parser.set_optional("size", "size", DEFAULT_BYTES, "number of bytes"); parser.set_optional("trials", "trials", -1, "number of iterations"); parser.set_optional("name_format", "name_format", @@ -285,7 +287,7 @@ int main(int argc, char *argv[]) // Parse argv benchmark::Initialize(&argc, argv); - const size_t size = parser.get("size"); + const size_t bytes = parser.get("size"); const int trials = parser.get("trials"); bench_naming::set_format(parser.get("name_format")); @@ -294,29 +296,29 @@ int main(int argc, char *argv[]) // Benchmark info add_common_benchmark_info(); - benchmark::AddCustomContext("size", std::to_string(size)); + benchmark::AddCustomContext("bytes", std::to_string(bytes)); // Add benchmarks std::vector benchmarks; // inclusive_scan using_warp_scan using inclusive_scan_uws_t = inclusive_scan; add_benchmarks( - benchmarks, "inclusive_scan", "using_warp_scan", stream, size + benchmarks, "inclusive_scan", "using_warp_scan", stream, bytes ); // exclusive_scan using_warp_scan using exclusive_scan_uws_t = exclusive_scan; add_benchmarks( - benchmarks, "exclusive_scan", "using_warp_scan", stream, size + benchmarks, "exclusive_scan", "using_warp_scan", stream, bytes ); // inclusive_scan reduce then scan using inclusive_scan_rts_t = inclusive_scan; add_benchmarks( - benchmarks, "inclusive_scan", "reduce_then_scan", stream, size + benchmarks, "inclusive_scan", "reduce_then_scan", stream, bytes ); // exclusive_scan reduce then scan using exclusive_scan_rts_t = exclusive_scan; add_benchmarks( - benchmarks, "exclusive_scan", "reduce_then_scan", stream, size + benchmarks, "exclusive_scan", "reduce_then_scan", stream, bytes ); // Use manual timing diff --git a/benchmark/benchmark_block_sort.cpp b/benchmark/benchmark_block_sort.cpp index 1b69a60ec..9556f5f0d 100644 --- a/benchmark/benchmark_block_sort.cpp +++ b/benchmark/benchmark_block_sort.cpp @@ -41,7 +41,7 @@ #include #ifndef DEFAULT_N -const size_t DEFAULT_N = 1024 * 1024 * 128; +const size_t DEFAULT_BYTES = 1024 * 1024 * 128 * 4; #endif #define CREATE_BENCHMARK_IPT(K, V, BS, IPT) \ @@ -80,7 +80,7 @@ const size_t DEFAULT_N = 1024 * 1024 * 128; int main(int argc, char* argv[]) { cli::Parser parser(argc, argv); - parser.set_optional("size", "size", DEFAULT_N, "number of values"); + parser.set_optional("size", "size", DEFAULT_BYTES, "number of bytes"); parser.set_optional("trials", "trials", -1, "number of iterations"); parser.set_optional("name_format", "name_format", @@ -91,7 +91,7 @@ int main(int argc, char* argv[]) // Parse argv benchmark::Initialize(&argc, argv); - const size_t size = parser.get("size"); + const size_t bytes = parser.get("size"); const int trials = parser.get("trials"); bench_naming::set_format(parser.get("name_format")); const std::string seed_type = parser.get("seed"); @@ -102,7 +102,7 @@ int main(int argc, char* argv[]) // Benchmark info add_common_benchmark_info(); - benchmark::AddCustomContext("size", std::to_string(size)); + benchmark::AddCustomContext("bytes", std::to_string(bytes)); benchmark::AddCustomContext("seed", seed_type); // If we are NOT config tuning run a selection of benchmarks @@ -123,7 +123,7 @@ int main(int argc, char* argv[]) #endif std::vector benchmarks = {}; - config_autotune_register::register_benchmark_subset(benchmarks, 0, 1, size, seed, stream); + config_autotune_register::register_benchmark_subset(benchmarks, 0, 1, bytes, seed, stream); // Use manual timing for(auto& b : benchmarks) diff --git a/benchmark/benchmark_block_sort.parallel.hpp b/benchmark/benchmark_block_sort.parallel.hpp index 9a69205cf..d6e2fe15b 100644 --- a/benchmark/benchmark_block_sort.parallel.hpp +++ b/benchmark/benchmark_block_sort.parallel.hpp @@ -228,10 +228,13 @@ struct block_sort_benchmark : public config_autotune_interface } void run(benchmark::State& state, - size_t N, + size_t bytes, const managed_seed& seed, hipStream_t stream) const override { + // Calculate the number of elements N + size_t N = bytes / sizeof(KeyType); + const auto size = items_per_block * ((N + items_per_block - 1) / items_per_block); std::vector input = get_random_data(size, diff --git a/benchmark/benchmark_config_dispatch.cpp b/benchmark/benchmark_config_dispatch.cpp index fc47fabc0..0ff5c22c6 100644 --- a/benchmark/benchmark_config_dispatch.cpp +++ b/benchmark/benchmark_config_dispatch.cpp @@ -10,7 +10,7 @@ #include #ifndef DEFAULT_N -const size_t DEFAULT_N = 1024 * 1024 * 32; +const size_t DEFAULT_BYTES = 1024 * 1024 * 32 * 4; #endif @@ -63,7 +63,7 @@ static void BM_kernel_launch(benchmark::State& state) hipLaunchKernelGGL(empty_kernel, dim3(1), dim3(1), 0, stream); HIP_CHECK(hipGetLastError()); } - hipStreamSynchronize(stream); + HIP_CHECK(hipStreamSynchronize(stream)); } #define CREATE_BENCHMARK(ST, SK) \ @@ -81,7 +81,7 @@ static void BM_kernel_launch(benchmark::State& state) int main(int argc, char** argv) { cli::Parser parser(argc, argv); - parser.set_optional("size", "size", DEFAULT_N, "number of values"); + parser.set_optional("size", "size", DEFAULT_BYTES, "number of bytes"); parser.set_optional("trials", "trials", 100, "number of iterations"); parser.set_optional("name_format", "name_format", diff --git a/benchmark/benchmark_device_adjacent_difference.parallel.hpp b/benchmark/benchmark_device_adjacent_difference.parallel.hpp index e257cbeed..02133b913 100644 --- a/benchmark/benchmark_device_adjacent_difference.parallel.hpp +++ b/benchmark/benchmark_device_adjacent_difference.parallel.hpp @@ -227,12 +227,12 @@ struct device_adjacent_difference_benchmark : public config_autotune_interface state.SetBytesProcessed(state.iterations() * batch_size * size * sizeof(T)); state.SetItemsProcessed(state.iterations() * batch_size * size); - hipFree(d_input); + HIP_CHECK(hipFree(d_input)); if(!InPlace) { - hipFree(d_output); + HIP_CHECK(hipFree(d_output)); } - hipFree(d_temp_storage); + HIP_CHECK(hipFree(d_temp_storage)); } }; diff --git a/benchmark/benchmark_device_adjacent_find.cpp b/benchmark/benchmark_device_adjacent_find.cpp new file mode 100644 index 000000000..719061de1 --- /dev/null +++ b/benchmark/benchmark_device_adjacent_find.cpp @@ -0,0 +1,145 @@ +// MIT License +// +// Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#include "benchmark_device_adjacent_find.parallel.hpp" +#include "benchmark_utils.hpp" +#include "cmdparser.hpp" + +// gbench +#include + +// HIP +#include + +// C++ Standard Library +#include +#include + +#ifndef DEFAULT_N +const size_t DEFAULT_BYTES = size_t{2} << 30; // 2 GiB +#endif + +#define CREATE_BENCHMARK(T, P) \ + { \ + const device_adjacent_find_benchmark instance; \ + REGISTER_BENCHMARK(benchmarks, size, seed, stream, instance); \ + } + +#define CREATE_ADJACENT_FIND_BENCHMARKS(T) \ + CREATE_BENCHMARK(T, 1) \ + CREATE_BENCHMARK(T, 5) \ + CREATE_BENCHMARK(T, 9) + +int main(int argc, char* argv[]) +{ + cli::Parser parser(argc, argv); + parser.set_optional("size", "size", DEFAULT_BYTES, "number of input bytes"); + parser.set_optional("trials", "trials", -1, "number of iterations"); + parser.set_optional("name_format", + "name_format", + "human", + "either: json,human,txt"); + parser.set_optional("seed", "seed", "random", get_seed_message()); +#ifdef BENCHMARK_CONFIG_TUNING + // optionally run an evenly split subset of benchmarks, when making multiple program invocations + parser.set_optional("parallel_instance", + "parallel_instance", + 0, + "parallel instance index"); + parser.set_optional("parallel_instances", + "parallel_instances", + 1, + "total parallel instances"); +#endif + parser.run_and_exit_if_error(); + + // Parse argv + benchmark::Initialize(&argc, argv); + const size_t size = parser.get("size"); + const int trials = parser.get("trials"); + bench_naming::set_format(parser.get("name_format")); + const std::string seed_type = parser.get("seed"); + const managed_seed seed(seed_type); + + // HIP + hipStream_t stream = 0; // default + + // Benchmark info + add_common_benchmark_info(); + benchmark::AddCustomContext("size", std::to_string(size)); + benchmark::AddCustomContext("seed", seed_type); + + // Add benchmarks + std::vector benchmarks{}; +#ifdef BENCHMARK_CONFIG_TUNING + const int parallel_instance = parser.get("parallel_instance"); + const int parallel_instances = parser.get("parallel_instances"); + config_autotune_register::register_benchmark_subset(benchmarks, + parallel_instance, + parallel_instances, + size, + seed, + stream); +#else // BENCHMARK_CONFIG_TUNING \ + // add_adjacent_find_benchmarks(benchmarks, size, seed, stream); + using custom_float2 = custom_type; + using custom_double2 = custom_type; + using custom_int2 = custom_type; + using custom_char_double = custom_type; + using custom_longlong_double = custom_type; + + // Tuned types + CREATE_ADJACENT_FIND_BENCHMARKS(int8_t) + CREATE_ADJACENT_FIND_BENCHMARKS(int16_t) + CREATE_ADJACENT_FIND_BENCHMARKS(int32_t) + CREATE_ADJACENT_FIND_BENCHMARKS(int64_t) + CREATE_ADJACENT_FIND_BENCHMARKS(rocprim::half) + CREATE_ADJACENT_FIND_BENCHMARKS(float) + CREATE_ADJACENT_FIND_BENCHMARKS(double) + // Custom types + CREATE_ADJACENT_FIND_BENCHMARKS(custom_float2) + CREATE_ADJACENT_FIND_BENCHMARKS(custom_double2) + CREATE_ADJACENT_FIND_BENCHMARKS(custom_int2) + CREATE_ADJACENT_FIND_BENCHMARKS(custom_char_double) + CREATE_ADJACENT_FIND_BENCHMARKS(custom_longlong_double) +#endif // BENCHMARK_CONFIG_TUNING + + // Use manual timing + for(auto& b : benchmarks) + { + b->UseManualTime(); + b->Unit(benchmark::kMillisecond); + } + + // Force number of iterations + if(trials > 0) + { + for(auto& b : benchmarks) + { + b->Iterations(trials); + } + } + + // Run benchmarks + benchmark::RunSpecifiedBenchmarks(); + return 0; +} diff --git a/benchmark/benchmark_device_adjacent_find.parallel.cpp.in b/benchmark/benchmark_device_adjacent_find.parallel.cpp.in new file mode 100644 index 000000000..95f53bdf3 --- /dev/null +++ b/benchmark/benchmark_device_adjacent_find.parallel.cpp.in @@ -0,0 +1,32 @@ +// MIT License +// +// Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + + +#include "benchmark_utils.hpp" +#include "benchmark_device_adjacent_find.parallel.hpp" + +namespace { + auto benchmarks = config_autotune_register::create_bulk( + device_adjacent_find_benchmark_generator< + @InputType@, + @BlockSize@>::create); +} diff --git a/benchmark/benchmark_device_adjacent_find.parallel.hpp b/benchmark/benchmark_device_adjacent_find.parallel.hpp new file mode 100644 index 000000000..016ee3d1e --- /dev/null +++ b/benchmark/benchmark_device_adjacent_find.parallel.hpp @@ -0,0 +1,248 @@ +// MIT License +// +// Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#ifndef ROCPRIM_BENCHMARK_DEVICE_ADJACENT_FIND_PARALLEL_HPP_ +#define ROCPRIM_BENCHMARK_DEVICE_ADJACENT_FIND_PARALLEL_HPP_ + +#include "benchmark_utils.hpp" + +// gbench +#include + +// HIP +#include + +// rocPRIM +#include +#include +#include + +// C++ Standard Library +#include +#include +#include +#include +#include +#include + +template +std::string config_name() +{ + auto config = Config(); + return "{bs:" + std::to_string(config.kernel_config.block_size) + + ",ipt:" + std::to_string(config.kernel_config.items_per_thread) + "}"; +} + +template<> +inline std::string config_name() +{ + return "default_config"; +} + +template +struct device_adjacent_find_benchmark : public config_autotune_interface +{ + + std::string name() const override + { + + using namespace std::string_literals; + return bench_naming::format_name( + "{lvl:device,algo:adjacent_find,input_type:" + std::string(Traits::name()) + + ",first_adj_pos:" + std::to_string(FirstAdjPosDecimal * 0.1f) + + ",cfg:" + config_name() + "}"); + } + + static constexpr size_t warmup_size = 5; + static constexpr size_t batch_size = 10; + + void run(benchmark::State& state, + size_t bytes, + const managed_seed& seed, + hipStream_t stream) const override + { + using input_type = InputT; + using output_type = std::size_t; + + const size_t size = bytes / sizeof(input_type); + + // Get index of the first adjacent equal pair + std::size_t first_adj_index = static_cast(size * FirstAdjPosDecimal * 0.1f); + if(first_adj_index >= size - 1) + { + first_adj_index = size - 2; + } + + // Generate data ensuring there is no adjacent pair before first_adj_index + std::vector input(size); + if(std::is_same::value) + { + // For int8_t that has a very limited range of values, iota initialization + // seems to give a more reliable benchmark input + std::iota(input.begin(), input.end(), 0); + } + else + { + input = get_random_data(size, + generate_limits::min(), + generate_limits::max(), + seed.get_0()); + std::vector iota(size); + std::iota(iota.begin(), iota.end(), 0); + std::transform(iota.begin() + 1, + iota.begin() + first_adj_index + 1, + input.begin() + 1, + [&](std::size_t& idx) + { + while(input[idx] == input[idx - 1]) + { + input[idx] = get_random_value( + generate_limits::min(), + generate_limits::max(), + seed.get_0()); + } + return input[idx]; + }); + } + + // Insert first adjacent pair + input[first_adj_index] = input[first_adj_index + 1]; + + input_type* d_input; + output_type* d_output; + HIP_CHECK(hipMalloc(&d_input, size * sizeof(*d_input))); + HIP_CHECK(hipMalloc(&d_output, sizeof(*d_output))); + HIP_CHECK(hipMemcpy(d_input, + input.data(), + input.size() * sizeof(*d_input), + hipMemcpyHostToDevice)); + + std::size_t tmp_storage_size; + void* d_tmp_storage = nullptr; + auto launch_adjacent_find = [&]() + { + HIP_CHECK(::rocprim::adjacent_find(d_tmp_storage, + tmp_storage_size, + d_input, + d_output, + size, + rocprim::equal_to{}, + stream, + false)); + }; + + // Get size of tmporary storage + launch_adjacent_find(); + HIP_CHECK(hipMalloc(&d_tmp_storage, tmp_storage_size)); + + // Warm-up + for(size_t i = 0; i < warmup_size; i++) + { + launch_adjacent_find(); + } + HIP_CHECK(hipDeviceSynchronize()); + + // HIP events creation + hipEvent_t start, stop; + HIP_CHECK(hipEventCreate(&start)); + HIP_CHECK(hipEventCreate(&stop)); + + // Run + for(auto _ : state) + { + // Record start event + HIP_CHECK(hipEventRecord(start, stream)); + + for(size_t i = 0; i < batch_size; i++) + { + launch_adjacent_find(); + } + + // Record stop event and wait until it completes + HIP_CHECK(hipEventRecord(stop, stream)); + HIP_CHECK(hipEventSynchronize(stop)); + + HIP_CHECK(hipGetLastError()); + HIP_CHECK(hipDeviceSynchronize()); + + float elapsed_mseconds; + HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); + state.SetIterationTime(elapsed_mseconds / 1000); + } + + // Destroy HIP events + HIP_CHECK(hipEventDestroy(start)); + HIP_CHECK(hipEventDestroy(stop)); + + state.SetBytesProcessed(state.iterations() * batch_size * first_adj_index + * sizeof(*d_input)); + state.SetItemsProcessed(state.iterations() * batch_size * first_adj_index); + + HIP_CHECK(hipFree(d_input)); + HIP_CHECK(hipFree(d_output)); + HIP_CHECK(hipFree(d_tmp_storage)); + } +}; + +template +struct device_adjacent_find_benchmark_generator +{ + static constexpr unsigned int min_items_per_thread = 1; + static constexpr unsigned int max_items_per_thread_arg + = TUNING_SHARED_MEMORY_MAX / (BlockSize * sizeof(InputT) * 2); + + template + struct create_pos + { + template + struct create_ipt + { + static constexpr unsigned int items_per_thread = 1u << ItemsPerThreadExp; + using generated_config = rocprim::adjacent_find_config; + + void operator()(std::vector>& storage) + { + storage.emplace_back( + std::make_unique>()); + } + }; + void operator()(std::vector>& storage) + { + static constexpr unsigned int max_items_per_thread_exponent + = rocprim::Log2::VALUE - 1; + static_for_each< + make_index_range, + create_ipt>(storage); + } + }; + + static void create(std::vector>& storage) + { + static_for_each, create_pos>(storage); + } +}; + +#endif // ROCPRIM_BENCHMARK_DEVICE_ADJACENT_FIND_PARALLEL_HPP_ diff --git a/benchmark/benchmark_device_binary_search.cpp b/benchmark/benchmark_device_binary_search.cpp index 69913c82d..3b244795d 100644 --- a/benchmark/benchmark_device_binary_search.cpp +++ b/benchmark/benchmark_device_binary_search.cpp @@ -44,7 +44,7 @@ #include #ifndef DEFAULT_N -const size_t DEFAULT_N = 1024 * 1024 * 32; +const size_t DEFAULT_BYTES = 1024 * 1024 * 32 * 4; #endif const unsigned int batch_size = 10; @@ -52,10 +52,10 @@ const unsigned int warmup_size = 5; template void run_benchmark(benchmark::State& state, - size_t haystack_size, + size_t haystack_bytes, const managed_seed& seed, hipStream_t stream, - size_t needles_size, + size_t needles_bytes, bool sorted_needles) { using haystack_type = T; @@ -63,6 +63,10 @@ void run_benchmark(benchmark::State& state, using output_type = size_t; using compare_op_type = typename std::conditional::value, half_less, rocprim::less>::type; + // Calculate the number of elements from byte size + size_t haystack_size = haystack_bytes / sizeof(haystack_type); + size_t needles_size = needles_bytes / sizeof(needle_type); + compare_op_type compare_op; // Generate data std::vector haystack(haystack_size); @@ -185,7 +189,7 @@ void run_benchmark(benchmark::State& state, + std::string(SORTED ? "sorted" : "random") + "_needles,cfg:default_config}") \ .c_str(), \ [=](benchmark::State& state) \ - { run_benchmark(state, size, seed, stream, size * K / 100, SORTED); }) + { run_benchmark(state, bytes, seed, stream, bytes * K / 100, SORTED); }) #define BENCHMARK_ALGORITHMS(T, K, SORTED) \ CREATE_BENCHMARK(T, K, SORTED, binary_search_subalgorithm), \ @@ -198,7 +202,7 @@ void run_benchmark(benchmark::State& state, int main(int argc, char *argv[]) { cli::Parser parser(argc, argv); - parser.set_optional("size", "size", DEFAULT_N, "number of values"); + parser.set_optional("size", "size", DEFAULT_BYTES, "number of bytes"); parser.set_optional("trials", "trials", -1, "number of iterations"); parser.set_optional("name_format", "name_format", @@ -220,7 +224,7 @@ int main(int argc, char *argv[]) // Parse argv benchmark::Initialize(&argc, argv); - const size_t size = parser.get("size"); + const size_t bytes = parser.get("size"); const int trials = parser.get("trials"); bench_naming::set_format(parser.get("name_format")); const std::string seed_type = parser.get("seed"); @@ -231,7 +235,7 @@ int main(int argc, char *argv[]) // Benchmark info add_common_benchmark_info(); - benchmark::AddCustomContext("size", std::to_string(size)); + benchmark::AddCustomContext("bytes", std::to_string(bytes)); benchmark::AddCustomContext("seed", seed_type); using custom_float2 = custom_type; @@ -245,7 +249,7 @@ int main(int argc, char *argv[]) config_autotune_register::register_benchmark_subset(benchmarks, parallel_instance, parallel_instances, - size, + bytes, seed, stream); #else // BENCHMARK_CONFIG_TUNING diff --git a/benchmark/benchmark_device_find_end.cpp b/benchmark/benchmark_device_find_end.cpp new file mode 100644 index 000000000..91ecc9c01 --- /dev/null +++ b/benchmark/benchmark_device_find_end.cpp @@ -0,0 +1,131 @@ +// MIT License +// +// Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#include "benchmark_device_find_end.hpp" +#include "benchmark_utils.hpp" + +// CmdParser +#include "cmdparser.hpp" + +// Google Benchmark +#include + +// HIP API +#include + +#include +#include + +#ifndef DEFAULT_BYTES +const size_t DEFAULT_BYTES = 1024 * 1024 * 32 * 4; +#endif + +#define CREATE_BENCHMARK_FIND_END(TYPE, KEY_SIZE, REPEATING) \ + { \ + const device_find_end_benchmark instance(KEY_SIZE, REPEATING); \ + REGISTER_BENCHMARK(benchmarks, bytes, seed, stream, instance); \ + } + +#define CREATE_BENCHMARK_PATTERN(TYPE, REPEATING) \ + { \ + CREATE_BENCHMARK_FIND_END(TYPE, 10, REPEATING) \ + CREATE_BENCHMARK_FIND_END(TYPE, 100, REPEATING) \ + CREATE_BENCHMARK_FIND_END(TYPE, 1000, REPEATING) \ + CREATE_BENCHMARK_FIND_END(TYPE, 10000, REPEATING) \ + } + +#define CREATE_BENCHMARK(TYPE) \ + { \ + CREATE_BENCHMARK_PATTERN(TYPE, true) \ + CREATE_BENCHMARK_PATTERN(TYPE, false) \ + } + +int main(int argc, char* argv[]) +{ + cli::Parser parser(argc, argv); + parser.set_optional("bytes", "bytes", DEFAULT_BYTES, "number of values"); + parser.set_optional("trials", "trials", -1, "number of iterations"); + parser.set_optional("name_format", + "name_format", + "human", + "either: json,human,txt"); + parser.set_optional("seed", "seed", "random", get_seed_message()); + parser.run_and_exit_if_error(); + + // Parse argv + benchmark::Initialize(&argc, argv); + const size_t bytes = parser.get("bytes"); + const int trials = parser.get("trials"); + bench_naming::set_format(parser.get("name_format")); + const std::string seed_type = parser.get("seed"); + const managed_seed seed(seed_type); + + // HIP + hipStream_t stream = 0; // default + + // Benchmark info + add_common_benchmark_info(); + benchmark::AddCustomContext("bytes", std::to_string(bytes)); + benchmark::AddCustomContext("seed", seed_type); + + // Add benchmarks + std::vector benchmarks{}; + CREATE_BENCHMARK(int) + CREATE_BENCHMARK(long long) + CREATE_BENCHMARK(int8_t) + CREATE_BENCHMARK(uint8_t) + CREATE_BENCHMARK(rocprim::half) + CREATE_BENCHMARK(short) + CREATE_BENCHMARK(float) + + using custom_float2 = custom_type; + using custom_double2 = custom_type; + using custom_int2 = custom_type; + using custom_char_double = custom_type; + using custom_longlong_double = custom_type; + + CREATE_BENCHMARK(custom_float2) + CREATE_BENCHMARK(custom_double2) + CREATE_BENCHMARK(custom_int2) + CREATE_BENCHMARK(custom_char_double) + CREATE_BENCHMARK(custom_longlong_double) + + // Use manual timing + for(auto& b : benchmarks) + { + b->UseManualTime(); + b->Unit(benchmark::kMillisecond); + } + + // Force number of iterations + if(trials > 0) + { + for(auto& b : benchmarks) + { + b->Iterations(trials); + } + } + + // Run benchmarks + benchmark::RunSpecifiedBenchmarks(); + return 0; +} diff --git a/benchmark/benchmark_device_find_end.hpp b/benchmark/benchmark_device_find_end.hpp new file mode 100644 index 000000000..e8ef9008c --- /dev/null +++ b/benchmark/benchmark_device_find_end.hpp @@ -0,0 +1,199 @@ +// MIT License +// +// Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#ifndef ROCPRIM_BENCHMARK_DEVICE_FIND_END_PARALLEL_HPP_ +#define ROCPRIM_BENCHMARK_DEVICE_FIND_END_PARALLEL_HPP_ + +#include "benchmark_utils.hpp" + +// Google Benchmark +#include + +// HIP API +#include + +// rocPRIM +#include + +#include +#include + +#include + +template +struct device_find_end_benchmark : public config_autotune_interface +{ + size_t key_size_ = 10; + bool repeating_ = false; + + device_find_end_benchmark(size_t KeySize, bool repeating) + { + key_size_ = KeySize; + repeating_ = repeating; + } + + std::string name() const override + { + using namespace std::string_literals; + return bench_naming::format_name( + "{lvl:device,algo:find_end,value_pattern:" + (repeating_ ? "repeating"s : "random"s) + + ",key_size:" + std::to_string(key_size_) + + ",value_type:" + std::string(Traits::name()) + ",cfg:default_config}"); + } + + static constexpr unsigned int batch_size = 10; + static constexpr unsigned int warmup_size = 5; + + void run(benchmark::State& state, + size_t bytes, + const managed_seed& seed, + hipStream_t stream) const override + { + using key_type = Key; + using output_type = size_t; + + // Calculate the number of elements + size_t size = bytes / sizeof(key_type); + size_t key_size = std::min(size, key_size_); + + // Generate data + std::vector keys_input + = get_random_data(key_size, + generate_limits::min(), + generate_limits::max(), + seed.get_0()); + + std::vector input(size); + if(repeating_) + { + // Repeating similar pattern without early exits. + keys_input[0] = 0; + for(size_t i = 0; i < size; i++) + { + input[i] = keys_input[i % key_size]; + } + keys_input[0] = 1; + } + else + { + input = get_random_data(size, + generate_limits::min(), + generate_limits::max(), + seed.get_0() + 1); + } + + key_type* d_keys_input; + key_type* d_input; + output_type* d_output; + HIP_CHECK(hipMalloc(&d_keys_input, key_size * sizeof(*d_keys_input))); + HIP_CHECK(hipMalloc(&d_input, size * sizeof(*d_input))); + HIP_CHECK(hipMalloc(&d_output, sizeof(*d_output))); + + HIP_CHECK(hipMemcpy(d_input, input.data(), size * sizeof(*d_input), hipMemcpyHostToDevice)); + + HIP_CHECK(hipMemcpy(d_keys_input, + keys_input.data(), + key_size * sizeof(*d_keys_input), + hipMemcpyHostToDevice)); + + rocprim::equal_to compare_op; + + void* d_temporary_storage = nullptr; + size_t temporary_storage_bytes = 0; + + HIP_CHECK(rocprim::find_end(d_temporary_storage, + temporary_storage_bytes, + d_input, + d_keys_input, + d_output, + size, + key_size, + compare_op, + stream, + false)); + + HIP_CHECK(hipMalloc(&d_temporary_storage, temporary_storage_bytes)); + + // Warm-up + for(size_t i = 0; i < warmup_size; i++) + { + HIP_CHECK(rocprim::find_end(d_temporary_storage, + temporary_storage_bytes, + d_input, + d_keys_input, + d_output, + size, + key_size, + compare_op, + stream, + false)); + } + HIP_CHECK(hipDeviceSynchronize()); + + // HIP events creation + hipEvent_t start, stop; + HIP_CHECK(hipEventCreate(&start)); + HIP_CHECK(hipEventCreate(&stop)); + + for(auto _ : state) + { + // Record start event + HIP_CHECK(hipEventRecord(start, stream)); + + for(size_t i = 0; i < batch_size; i++) + { + HIP_CHECK(rocprim::find_end(d_temporary_storage, + temporary_storage_bytes, + d_input, + d_keys_input, + d_output, + size, + key_size, + compare_op, + stream, + false)); + } + + // Record stop event and wait until it completes + HIP_CHECK(hipEventRecord(stop, stream)); + HIP_CHECK(hipEventSynchronize(stop)); + + float elapsed_mseconds; + HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); + state.SetIterationTime(elapsed_mseconds / 1000); + } + + // Destroy HIP events + HIP_CHECK(hipEventDestroy(start)); + HIP_CHECK(hipEventDestroy(stop)); + + state.SetBytesProcessed(state.iterations() * batch_size * size * sizeof(*d_input)); + state.SetItemsProcessed(state.iterations() * batch_size * size); + + HIP_CHECK(hipFree(d_temporary_storage)); + HIP_CHECK(hipFree(d_keys_input)); + HIP_CHECK(hipFree(d_input)); + HIP_CHECK(hipFree(d_output)); + } +}; + +#endif // ROCPRIM_BENCHMARK_DEVICE_FIND_END_PARALLEL_HPP_ diff --git a/benchmark/benchmark_device_histogram.cpp b/benchmark/benchmark_device_histogram.cpp index 56b4b4b2b..4e7664ffb 100644 --- a/benchmark/benchmark_device_histogram.cpp +++ b/benchmark/benchmark_device_histogram.cpp @@ -30,6 +30,7 @@ #include // HIP API +#include #include // rocPRIM @@ -41,8 +42,8 @@ #include #include -#ifndef DEFAULT_N -const size_t DEFAULT_N = 1024 * 1024 * 32; +#ifndef DEFAULT_BYTES +const size_t DEFAULT_BYTES = 1024 * 1024 * 32 * 4; #endif namespace rp = rocprim; @@ -67,13 +68,16 @@ const int entropy_reductions[] = {0, 2, 4, 6}; template void run_even_benchmark(benchmark::State& state, - size_t size, + size_t bytes, const managed_seed&, hipStream_t stream, size_t bins, size_t scale, int entropy_reduction) { + // Calculate the number of elements + size_t size = bytes / sizeof(T); + using counter_type = unsigned int; using level_type = typename std::conditional_t::value && sizeof(T) < sizeof(int), int, T>; @@ -169,13 +173,16 @@ void run_even_benchmark(benchmark::State& state, template void run_multi_even_benchmark(benchmark::State& state, - size_t size, + size_t bytes, const managed_seed&, hipStream_t stream, size_t bins, size_t scale, int entropy_reduction) { + // Calculate the number of elements + size_t size = bytes / sizeof(T); + using counter_type = unsigned int; using level_type = typename std::conditional_t::value && sizeof(T) < sizeof(int), int, T>; @@ -285,8 +292,11 @@ void run_multi_even_benchmark(benchmark::State& state, template void run_range_benchmark( - benchmark::State& state, size_t size, const managed_seed& seed, hipStream_t stream, size_t bins) + benchmark::State& state, size_t bytes, const managed_seed& seed, hipStream_t stream, size_t bins) { + // Calculate the number of elements + size_t size = bytes / sizeof(T); + using counter_type = unsigned int; using level_type = typename std::conditional_t::value && sizeof(T) < sizeof(int), int, T>; @@ -389,8 +399,11 @@ void run_range_benchmark( template void run_multi_range_benchmark( - benchmark::State& state, size_t size, const managed_seed& seed, hipStream_t stream, size_t bins) + benchmark::State& state, size_t bytes, const managed_seed& seed, hipStream_t stream, size_t bins) { + // Calculate the number of elements + size_t size = bytes / sizeof(T); + using counter_type = unsigned int; using level_type = typename std::conditional_t::value && sizeof(T) < sizeof(int), int, T>; @@ -519,7 +532,7 @@ void run_multi_range_benchmark( + ",bins:" + std::to_string(BINS) + ",cfg:default_config}") \ .c_str(), \ [=](benchmark::State& state) \ - { run_even_benchmark(state, size, seed, stream, BINS, SCALE, entropy_reduction); })); + { run_even_benchmark(state, bytes, seed, stream, BINS, SCALE, entropy_reduction); })); #define BENCHMARK_EVEN_TYPE(VECTOR, T, S) \ CREATE_EVEN_BENCHMARK(VECTOR, T, 10, S); \ @@ -528,7 +541,7 @@ void run_multi_range_benchmark( CREATE_EVEN_BENCHMARK(VECTOR, T, 10000, S); void add_even_benchmarks(std::vector& benchmarks, - size_t size, + size_t bytes, const managed_seed& seed, hipStream_t stream) { @@ -556,7 +569,7 @@ void add_even_benchmarks(std::vector& benchmark [=](benchmark::State& state) \ { \ run_multi_even_benchmark(state, \ - size, \ + bytes, \ seed, \ stream, \ BINS, \ @@ -573,7 +586,7 @@ void add_even_benchmarks(std::vector& benchmark // clang-format on void add_multi_even_benchmarks(std::vector& benchmarks, - size_t size, + size_t bytes, const managed_seed& seed, hipStream_t stream) { @@ -595,7 +608,7 @@ void add_multi_even_benchmarks(std::vector& ben bench_naming::format_name("{lvl:device,algo:histogram_range,value_type:" #T ",bins:" \ + std::to_string(BINS) + ",cfg:default_config}") \ .c_str(), \ - [=](benchmark::State& state) { run_range_benchmark(state, size, seed, stream, BINS); }) + [=](benchmark::State& state) { run_range_benchmark(state, bytes, seed, stream, BINS); }) // clang-format off #define BENCHMARK_RANGE_TYPE(T) \ @@ -606,7 +619,7 @@ void add_multi_even_benchmarks(std::vector& ben // clang-format on void add_range_benchmarks(std::vector& benchmarks, - size_t size, + size_t bytes, const managed_seed& seed, hipStream_t stream) { @@ -632,7 +645,7 @@ void add_range_benchmarks(std::vector& benchmar .c_str(), \ [=](benchmark::State& state) { \ run_multi_range_benchmark(state, \ - size, \ + bytes, \ seed, \ stream, \ BINS); \ @@ -647,7 +660,7 @@ void add_range_benchmarks(std::vector& benchmar // clang-format on void add_multi_range_benchmarks(std::vector& benchmarks, - size_t size, + size_t bytes, const managed_seed& seed, hipStream_t stream) { @@ -665,7 +678,7 @@ void add_multi_range_benchmarks(std::vector& be int main(int argc, char* argv[]) { cli::Parser parser(argc, argv); - parser.set_optional("size", "size", DEFAULT_N, "number of values"); + parser.set_optional("size", "size", DEFAULT_BYTES, "number of bytes"); parser.set_optional("trials", "trials", -1, "number of iterations"); parser.set_optional("name_format", "name_format", @@ -687,7 +700,7 @@ int main(int argc, char* argv[]) // Parse argv benchmark::Initialize(&argc, argv); - const size_t size = parser.get("size"); + const size_t bytes = parser.get("size"); const int trials = parser.get("trials"); bench_naming::set_format(parser.get("name_format")); const std::string seed_type = parser.get("seed"); @@ -698,7 +711,7 @@ int main(int argc, char* argv[]) // Benchmark info add_common_benchmark_info(); - benchmark::AddCustomContext("size", std::to_string(size)); + benchmark::AddCustomContext("bytes", std::to_string(bytes)); benchmark::AddCustomContext("seed", seed_type); // Add benchmarks @@ -709,14 +722,14 @@ int main(int argc, char* argv[]) config_autotune_register::register_benchmark_subset(benchmarks, parallel_instance, parallel_instances, - size, + bytes, seed, stream); #else // BENCHMARK_CONFIG_TUNING - add_even_benchmarks(benchmarks, size, seed, stream); - add_multi_even_benchmarks(benchmarks, size, seed, stream); - add_range_benchmarks(benchmarks, size, seed, stream); - add_multi_range_benchmarks(benchmarks, size, seed, stream); + add_even_benchmarks(benchmarks, bytes, seed, stream); + add_multi_even_benchmarks(benchmarks, bytes, seed, stream); + add_range_benchmarks(benchmarks, bytes, seed, stream); + add_multi_range_benchmarks(benchmarks, bytes, seed, stream); #endif // BENCHMARK_CONFIG_TUNING // Use manual timing diff --git a/benchmark/benchmark_device_merge.cpp b/benchmark/benchmark_device_merge.cpp index 4f0b9b1bc..fe90dea97 100644 --- a/benchmark/benchmark_device_merge.cpp +++ b/benchmark/benchmark_device_merge.cpp @@ -20,6 +20,7 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. +#include "benchmark_device_merge.parallel.hpp" #include "benchmark_utils.hpp" // CmdParser #include "cmdparser.hpp" @@ -41,267 +42,22 @@ #include #include -#ifndef DEFAULT_N -const size_t DEFAULT_N = 1024 * 1024 * 32; +#ifndef DEFAULT_BYTES +const size_t DEFAULT_BYTES = 1024 * 1024 * 32 * 4; #endif -namespace rp = rocprim; - -const unsigned int batch_size = 10; -const unsigned int warmup_size = 5; - -template -void run_merge_keys_benchmark(benchmark::State& state, - size_t size, - const managed_seed& seed, - hipStream_t stream) -{ - using key_type = Key; - using compare_op_type = typename std::conditional::value, half_less, rocprim::less>::type; - - const size_t size1 = size / 2; - const size_t size2 = size - size1; - - compare_op_type compare_op; - - // Generate data - const auto random_range = limit_random_range(0, size); - - std::vector keys_input1 - = get_random_data(size1, random_range.first, random_range.second, seed.get_0()); - std::vector keys_input2 - = get_random_data(size2, random_range.first, random_range.second, seed.get_1()); - std::sort(keys_input1.begin(), keys_input1.end(), compare_op); - std::sort(keys_input2.begin(), keys_input2.end(), compare_op); - - key_type * d_keys_input1; - key_type * d_keys_input2; - key_type * d_keys_output; - HIP_CHECK(hipMalloc(reinterpret_cast(&d_keys_input1), size1 * sizeof(key_type))); - HIP_CHECK(hipMalloc(reinterpret_cast(&d_keys_input2), size2 * sizeof(key_type))); - HIP_CHECK(hipMalloc(reinterpret_cast(&d_keys_output), size * sizeof(key_type))); - HIP_CHECK( - hipMemcpy( - d_keys_input1, keys_input1.data(), - size1 * sizeof(key_type), - hipMemcpyHostToDevice - ) - ); - HIP_CHECK( - hipMemcpy( - d_keys_input2, keys_input2.data(), - size2 * sizeof(key_type), - hipMemcpyHostToDevice - ) - ); - - void * d_temporary_storage = nullptr; - size_t temporary_storage_bytes = 0; - HIP_CHECK( - rp::merge( - d_temporary_storage, temporary_storage_bytes, - d_keys_input1, d_keys_input2, d_keys_output, size1, size2, - compare_op, stream, false - ) - ); - - HIP_CHECK(hipMalloc(&d_temporary_storage, temporary_storage_bytes)); - - // Warm-up - for(size_t i = 0; i < warmup_size; i++) - { - HIP_CHECK( - rp::merge( - d_temporary_storage, temporary_storage_bytes, - d_keys_input1, d_keys_input2, d_keys_output, size1, size2, - compare_op, stream, false - ) - ); - } - HIP_CHECK(hipDeviceSynchronize()); - - // HIP events creation - hipEvent_t start, stop; - HIP_CHECK(hipEventCreate(&start)); - HIP_CHECK(hipEventCreate(&stop)); - - for (auto _ : state) - { - // Record start event - HIP_CHECK(hipEventRecord(start, stream)); - - for(size_t i = 0; i < batch_size; i++) - { - HIP_CHECK( - rp::merge( - d_temporary_storage, temporary_storage_bytes, - d_keys_input1, d_keys_input2, d_keys_output, size1, size2, - compare_op, stream, false - ) - ); - } - - // Record stop event and wait until it completes - HIP_CHECK(hipEventRecord(stop, stream)); - HIP_CHECK(hipEventSynchronize(stop)); - - float elapsed_mseconds; - HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); - state.SetIterationTime(elapsed_mseconds / 1000); - } - - // Destroy HIP events - HIP_CHECK(hipEventDestroy(start)); - HIP_CHECK(hipEventDestroy(stop)); - - state.SetBytesProcessed(state.iterations() * batch_size * size * sizeof(key_type)); - state.SetItemsProcessed(state.iterations() * batch_size * size); - - HIP_CHECK(hipFree(d_temporary_storage)); - HIP_CHECK(hipFree(d_keys_input1)); - HIP_CHECK(hipFree(d_keys_input2)); - HIP_CHECK(hipFree(d_keys_output)); -} - -template -void run_merge_pairs_benchmark(benchmark::State& state, - size_t size, - const managed_seed& seed, - hipStream_t stream) -{ - using key_type = Key; - using value_type = Value; - using compare_op_type = typename std::conditional::value, half_less, rocprim::less>::type; - - const size_t size1 = size / 2; - const size_t size2 = size - size1; - - compare_op_type compare_op; - - // Generate data - const auto random_range = limit_random_range(0, size); - std::vector keys_input1 - = get_random_data(size1, random_range.first, random_range.second, seed.get_0()); - std::vector keys_input2 - = get_random_data(size2, random_range.first, random_range.second, seed.get_1()); - std::sort(keys_input1.begin(), keys_input1.end(), compare_op); - std::sort(keys_input2.begin(), keys_input2.end(), compare_op); - std::vector values_input1(size1); - std::vector values_input2(size2); - std::iota(values_input1.begin(), values_input1.end(), 0); - std::iota(values_input2.begin(), values_input2.end(), size1); - - key_type * d_keys_input1; - key_type * d_keys_input2; - key_type * d_keys_output; - value_type * d_values_input1; - value_type * d_values_input2; - value_type * d_values_output; - HIP_CHECK(hipMalloc(reinterpret_cast(&d_keys_input1), size1 * sizeof(key_type))); - HIP_CHECK(hipMalloc(reinterpret_cast(&d_keys_input2), size2 * sizeof(key_type))); - HIP_CHECK(hipMalloc(reinterpret_cast(&d_keys_output), size * sizeof(key_type))); - HIP_CHECK(hipMalloc(reinterpret_cast(&d_values_input1), size1 * sizeof(value_type))); - HIP_CHECK(hipMalloc(reinterpret_cast(&d_values_input2), size2 * sizeof(value_type))); - HIP_CHECK(hipMalloc(reinterpret_cast(&d_values_output), size * sizeof(value_type))); - HIP_CHECK( - hipMemcpy( - d_keys_input1, keys_input1.data(), - size1 * sizeof(key_type), - hipMemcpyHostToDevice - ) - ); - HIP_CHECK( - hipMemcpy( - d_keys_input2, keys_input2.data(), - size2 * sizeof(key_type), - hipMemcpyHostToDevice - ) - ); - - void * d_temporary_storage = nullptr; - size_t temporary_storage_bytes = 0; - HIP_CHECK( - rp::merge( - d_temporary_storage, temporary_storage_bytes, - d_keys_input1, d_keys_input2, d_keys_output, - d_values_input1, d_values_input2, d_values_output, - size1, size2, - compare_op, stream, false - ) - ); - - HIP_CHECK(hipMalloc(&d_temporary_storage, temporary_storage_bytes)); - HIP_CHECK(hipDeviceSynchronize()); - - // Warm-up - for(size_t i = 0; i < warmup_size; i++) - { - HIP_CHECK( - rp::merge( - d_temporary_storage, temporary_storage_bytes, - d_keys_input1, d_keys_input2, d_keys_output, - d_values_input1, d_values_input2, d_values_output, - size1, size2, - compare_op, stream, false - ) - ); +#define CREATE_BENCHMARK(...) \ + { \ + const device_merge_benchmark<__VA_ARGS__> instance; \ + REGISTER_BENCHMARK(benchmarks, bytes, seed, stream, instance); \ } - HIP_CHECK(hipDeviceSynchronize()); - - // HIP events creation - hipEvent_t start, stop; - HIP_CHECK(hipEventCreate(&start)); - HIP_CHECK(hipEventCreate(&stop)); - - for (auto _ : state) - { - // Record start event - HIP_CHECK(hipEventRecord(start, stream)); - - for(size_t i = 0; i < batch_size; i++) - { - HIP_CHECK( - rp::merge( - d_temporary_storage, temporary_storage_bytes, - d_keys_input1, d_keys_input2, d_keys_output, - d_values_input1, d_values_input2, d_values_output, - size1, size2, - compare_op, stream, false - ) - ); - } - - // Record stop event and wait until it completes - HIP_CHECK(hipEventRecord(stop, stream)); - HIP_CHECK(hipEventSynchronize(stop)); - - float elapsed_mseconds; - HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); - state.SetIterationTime(elapsed_mseconds / 1000); - } - - // Destroy HIP events - HIP_CHECK(hipEventDestroy(start)); - HIP_CHECK(hipEventDestroy(stop)); - - state.SetBytesProcessed(state.iterations() * batch_size * size * (sizeof(key_type) + sizeof(value_type))); - state.SetItemsProcessed(state.iterations() * batch_size * size); - - HIP_CHECK(hipFree(d_temporary_storage)); - HIP_CHECK(hipFree(d_keys_input1)); - HIP_CHECK(hipFree(d_keys_input2)); - HIP_CHECK(hipFree(d_keys_output)); - HIP_CHECK(hipFree(d_values_input1)); - HIP_CHECK(hipFree(d_values_input2)); - HIP_CHECK(hipFree(d_values_output)); -} #define CREATE_MERGE_KEYS_BENCHMARK(Key) \ benchmark::RegisterBenchmark( \ bench_naming::format_name("{lvl:device,algo:merge,key_type:" #Key ",cfg:default_config}") \ .c_str(), \ [=](benchmark::State& state) \ - { run_merge_keys_benchmark(state, size, seed, stream); }) + { run_merge_keys_benchmark(state, bytes, seed, stream); }) #define CREATE_MERGE_PAIRS_BENCHMARK(Key, Value) \ benchmark::RegisterBenchmark( \ @@ -309,23 +65,34 @@ void run_merge_pairs_benchmark(benchmark::State& state, ",cfg:default_config}") \ .c_str(), \ [=](benchmark::State& state) \ - { run_merge_pairs_benchmark(state, size, seed, stream); }) + { run_merge_pairs_benchmark(state, bytes, seed, stream); }) int main(int argc, char *argv[]) { cli::Parser parser(argc, argv); - parser.set_optional("size", "size", DEFAULT_N, "number of values"); + parser.set_optional("size", "size", DEFAULT_BYTES, "number of bytes"); parser.set_optional("trials", "trials", -1, "number of iterations"); parser.set_optional("name_format", "name_format", "human", "either: json,human,txt"); parser.set_optional("seed", "seed", "random", get_seed_message()); +#ifdef BENCHMARK_CONFIG_TUNING + // optionally run an evenly split subset of benchmarks, when making multiple program invocations + parser.set_optional("parallel_instance", + "parallel_instance", + 0, + "parallel instance index"); + parser.set_optional("parallel_instances", + "parallel_instances", + 1, + "total parallel instances"); +#endif parser.run_and_exit_if_error(); // Parse argv benchmark::Initialize(&argc, argv); - const size_t size = parser.get("size"); + const size_t bytes = parser.get("size"); const int trials = parser.get("trials"); bench_naming::set_format(parser.get("name_format")); const std::string seed_type = parser.get("seed"); @@ -336,33 +103,43 @@ int main(int argc, char *argv[]) // Benchmark info add_common_benchmark_info(); - benchmark::AddCustomContext("size", std::to_string(size)); + benchmark::AddCustomContext("bytes", std::to_string(bytes)); benchmark::AddCustomContext("seed", seed_type); + // Add benchmarks + std::vector benchmarks = {}; +#ifdef BENCHMARK_CONFIG_TUNING + const int parallel_instance = parser.get("parallel_instance"); + const int parallel_instances = parser.get("parallel_instances"); + config_autotune_register::register_benchmark_subset(benchmarks, + parallel_instance, + parallel_instances, + bytes, + seed, + stream); +#else // BENCHMARK_CONFIG_TUNING using custom_int2 = custom_type; using custom_double2 = custom_type; - // Add benchmarks - std::vector benchmarks = - { - CREATE_MERGE_KEYS_BENCHMARK(int), - CREATE_MERGE_KEYS_BENCHMARK(long long), - CREATE_MERGE_KEYS_BENCHMARK(int8_t), - CREATE_MERGE_KEYS_BENCHMARK(uint8_t), - CREATE_MERGE_KEYS_BENCHMARK(rocprim::half), - CREATE_MERGE_KEYS_BENCHMARK(short), - CREATE_MERGE_KEYS_BENCHMARK(custom_int2), - CREATE_MERGE_KEYS_BENCHMARK(custom_double2), - - CREATE_MERGE_PAIRS_BENCHMARK(int, int), - CREATE_MERGE_PAIRS_BENCHMARK(long long, long long), - CREATE_MERGE_PAIRS_BENCHMARK(int8_t, int8_t), - CREATE_MERGE_PAIRS_BENCHMARK(uint8_t, uint8_t), - CREATE_MERGE_PAIRS_BENCHMARK(rocprim::half, rocprim::half), - CREATE_MERGE_PAIRS_BENCHMARK(short, short), - CREATE_MERGE_PAIRS_BENCHMARK(custom_int2, custom_int2), - CREATE_MERGE_PAIRS_BENCHMARK(custom_double2, custom_double2), - }; + CREATE_BENCHMARK(int) + CREATE_BENCHMARK(long long) + CREATE_BENCHMARK(int8_t) + CREATE_BENCHMARK(uint8_t) + CREATE_BENCHMARK(rocprim::half) + CREATE_BENCHMARK(short) + CREATE_BENCHMARK(custom_int2) + CREATE_BENCHMARK(custom_double2) + + CREATE_BENCHMARK(int, int) + CREATE_BENCHMARK(long long, long long) + CREATE_BENCHMARK(int8_t, int8_t) + CREATE_BENCHMARK(uint8_t, uint8_t) + CREATE_BENCHMARK(rocprim::half, rocprim::half) + CREATE_BENCHMARK(short, short) + CREATE_BENCHMARK(custom_int2, custom_int2) + CREATE_BENCHMARK(custom_double2, custom_double2) + +#endif // BENCHMARK_CONFIG_TUNING // Use manual timing for(auto& b : benchmarks) diff --git a/benchmark/benchmark_device_merge.parallel.cpp.in b/benchmark/benchmark_device_merge.parallel.cpp.in new file mode 100644 index 000000000..822f00c89 --- /dev/null +++ b/benchmark/benchmark_device_merge.parallel.cpp.in @@ -0,0 +1,32 @@ +// MIT License +// +// Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#include + +#include "benchmark_utils.hpp" +#include "benchmark_device_merge.parallel.hpp" + +namespace +{ +auto benchmarks = config_autotune_register::create_bulk( + device_merge_benchmark_generator<@KeyType@, @ValueType@, @BlockSize@>::create); +} diff --git a/benchmark/benchmark_device_merge.parallel.hpp b/benchmark/benchmark_device_merge.parallel.hpp new file mode 100644 index 000000000..6aff37388 --- /dev/null +++ b/benchmark/benchmark_device_merge.parallel.hpp @@ -0,0 +1,422 @@ +// MIT License +// +// Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#ifndef ROCPRIM_BENCHMARK_DEVICE_MERGE_PARALLEL_HPP_ +#define ROCPRIM_BENCHMARK_DEVICE_MERGE_PARALLEL_HPP_ + +#include "benchmark_utils.hpp" + +// Google Benchmark +#include + +// HIP API +#include +#include + +// rocPRIM HIP API +#include +#include + +#include +#include +#include +#include + +#include +#include + +namespace rp = rocprim; + +template +std::string config_name() +{ + const rocprim::detail::merge_config_params params = Config(); + return "{bs:" + std::to_string(params.kernel_config.block_size) + + ",ipt:" + std::to_string(params.kernel_config.items_per_thread) + "}"; +} + +template<> +inline std::string config_name() +{ + return "default_config"; +} + +template +struct device_merge_benchmark : public config_autotune_interface +{ + std::string name() const override + { + return bench_naming::format_name("{lvl:device,algo:merge,key_type:" + + std::string(Traits::name()) + + ",value_type:" + std::string(Traits::name()) + + ",cfg:" + config_name() + "}"); + } + + static constexpr unsigned int batch_size = 10; + static constexpr unsigned int warmup_size = 5; + + // keys benchmark + template + auto do_run(benchmark::State& state, + size_t bytes, + const managed_seed& seed, + hipStream_t stream) const -> + typename std::enable_if::value, void>::type + { + using key_type = KeyType; + using compare_op_type = + typename std::conditional::value, + half_less, + rocprim::less>::type; + + size_t size = bytes / sizeof(key_type); + + const size_t size1 = size / 2; + const size_t size2 = size - size1; + + compare_op_type compare_op; + + // Generate data + const auto random_range = limit_random_range(0, size); + + std::vector keys_input1 = get_random_data(size1, + random_range.first, + random_range.second, + seed.get_0()); + std::vector keys_input2 = get_random_data(size2, + random_range.first, + random_range.second, + seed.get_1()); + std::sort(keys_input1.begin(), keys_input1.end(), compare_op); + std::sort(keys_input2.begin(), keys_input2.end(), compare_op); + + key_type* d_keys_input1; + key_type* d_keys_input2; + key_type* d_keys_output; + HIP_CHECK(hipMalloc(reinterpret_cast(&d_keys_input1), size1 * sizeof(key_type))); + HIP_CHECK(hipMalloc(reinterpret_cast(&d_keys_input2), size2 * sizeof(key_type))); + HIP_CHECK(hipMalloc(reinterpret_cast(&d_keys_output), size * sizeof(key_type))); + HIP_CHECK(hipMemcpy(d_keys_input1, + keys_input1.data(), + size1 * sizeof(key_type), + hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_keys_input2, + keys_input2.data(), + size2 * sizeof(key_type), + hipMemcpyHostToDevice)); + + void* d_temporary_storage = nullptr; + size_t temporary_storage_bytes = 0; + HIP_CHECK(rp::merge(d_temporary_storage, + temporary_storage_bytes, + d_keys_input1, + d_keys_input2, + d_keys_output, + size1, + size2, + compare_op, + stream, + false)); + + HIP_CHECK(hipMalloc(&d_temporary_storage, temporary_storage_bytes)); + + // Warm-up + for(size_t i = 0; i < warmup_size; i++) + { + HIP_CHECK(rp::merge(d_temporary_storage, + temporary_storage_bytes, + d_keys_input1, + d_keys_input2, + d_keys_output, + size1, + size2, + compare_op, + stream, + false)); + } + HIP_CHECK(hipDeviceSynchronize()); + + // HIP events creation + hipEvent_t start, stop; + HIP_CHECK(hipEventCreate(&start)); + HIP_CHECK(hipEventCreate(&stop)); + + for(auto _ : state) + { + // Record start event + HIP_CHECK(hipEventRecord(start, stream)); + + for(size_t i = 0; i < batch_size; i++) + { + HIP_CHECK(rp::merge(d_temporary_storage, + temporary_storage_bytes, + d_keys_input1, + d_keys_input2, + d_keys_output, + size1, + size2, + compare_op, + stream, + false)); + } + + // Record stop event and wait until it completes + HIP_CHECK(hipEventRecord(stop, stream)); + HIP_CHECK(hipEventSynchronize(stop)); + + float elapsed_mseconds; + HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); + state.SetIterationTime(elapsed_mseconds / 1000); + } + + // Destroy HIP events + HIP_CHECK(hipEventDestroy(start)); + HIP_CHECK(hipEventDestroy(stop)); + + state.SetBytesProcessed(state.iterations() * batch_size * size * sizeof(key_type)); + state.SetItemsProcessed(state.iterations() * batch_size * size); + + HIP_CHECK(hipFree(d_temporary_storage)); + HIP_CHECK(hipFree(d_keys_input1)); + HIP_CHECK(hipFree(d_keys_input2)); + HIP_CHECK(hipFree(d_keys_output)); + } + + // pairs benchmark + template + auto do_run(benchmark::State& state, + size_t bytes, + const managed_seed& seed, + hipStream_t stream) const -> + typename std::enable_if::value, void>::type + { + using key_type = KeyType; + using value_type = ValueType; + using compare_op_type = + typename std::conditional::value, + half_less, + rocprim::less>::type; + + size_t size = bytes / sizeof(key_type); + + const size_t size1 = size / 2; + const size_t size2 = size - size1; + + compare_op_type compare_op; + + // Generate data + const auto random_range = limit_random_range(0, size); + std::vector keys_input1 = get_random_data(size1, + random_range.first, + random_range.second, + seed.get_0()); + std::vector keys_input2 = get_random_data(size2, + random_range.first, + random_range.second, + seed.get_1()); + std::sort(keys_input1.begin(), keys_input1.end(), compare_op); + std::sort(keys_input2.begin(), keys_input2.end(), compare_op); + std::vector values_input1(size1); + std::vector values_input2(size2); + std::iota(values_input1.begin(), values_input1.end(), 0); + std::iota(values_input2.begin(), values_input2.end(), size1); + + key_type* d_keys_input1; + key_type* d_keys_input2; + key_type* d_keys_output; + value_type* d_values_input1; + value_type* d_values_input2; + value_type* d_values_output; + HIP_CHECK(hipMalloc(reinterpret_cast(&d_keys_input1), size1 * sizeof(key_type))); + HIP_CHECK(hipMalloc(reinterpret_cast(&d_keys_input2), size2 * sizeof(key_type))); + HIP_CHECK(hipMalloc(reinterpret_cast(&d_keys_output), size * sizeof(key_type))); + HIP_CHECK( + hipMalloc(reinterpret_cast(&d_values_input1), size1 * sizeof(value_type))); + HIP_CHECK( + hipMalloc(reinterpret_cast(&d_values_input2), size2 * sizeof(value_type))); + HIP_CHECK(hipMalloc(reinterpret_cast(&d_values_output), size * sizeof(value_type))); + HIP_CHECK(hipMemcpy(d_keys_input1, + keys_input1.data(), + size1 * sizeof(key_type), + hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_keys_input2, + keys_input2.data(), + size2 * sizeof(key_type), + hipMemcpyHostToDevice)); + + void* d_temporary_storage = nullptr; + size_t temporary_storage_bytes = 0; + HIP_CHECK(rp::merge(d_temporary_storage, + temporary_storage_bytes, + d_keys_input1, + d_keys_input2, + d_keys_output, + d_values_input1, + d_values_input2, + d_values_output, + size1, + size2, + compare_op, + stream, + false)); + + HIP_CHECK(hipMalloc(&d_temporary_storage, temporary_storage_bytes)); + HIP_CHECK(hipDeviceSynchronize()); + + // Warm-up + for(size_t i = 0; i < warmup_size; i++) + { + HIP_CHECK(rp::merge(d_temporary_storage, + temporary_storage_bytes, + d_keys_input1, + d_keys_input2, + d_keys_output, + d_values_input1, + d_values_input2, + d_values_output, + size1, + size2, + compare_op, + stream, + false)); + } + HIP_CHECK(hipDeviceSynchronize()); + + // HIP events creation + hipEvent_t start, stop; + HIP_CHECK(hipEventCreate(&start)); + HIP_CHECK(hipEventCreate(&stop)); + + for(auto _ : state) + { + // Record start event + HIP_CHECK(hipEventRecord(start, stream)); + + for(size_t i = 0; i < batch_size; i++) + { + HIP_CHECK(rp::merge(d_temporary_storage, + temporary_storage_bytes, + d_keys_input1, + d_keys_input2, + d_keys_output, + d_values_input1, + d_values_input2, + d_values_output, + size1, + size2, + compare_op, + stream, + false)); + } + + // Record stop event and wait until it completes + HIP_CHECK(hipEventRecord(stop, stream)); + HIP_CHECK(hipEventSynchronize(stop)); + + float elapsed_mseconds; + HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); + state.SetIterationTime(elapsed_mseconds / 1000); + } + + // Destroy HIP events + HIP_CHECK(hipEventDestroy(start)); + HIP_CHECK(hipEventDestroy(stop)); + + state.SetBytesProcessed(state.iterations() * batch_size * size + * (sizeof(key_type) + sizeof(value_type))); + state.SetItemsProcessed(state.iterations() * batch_size * size); + + HIP_CHECK(hipFree(d_temporary_storage)); + HIP_CHECK(hipFree(d_keys_input1)); + HIP_CHECK(hipFree(d_keys_input2)); + HIP_CHECK(hipFree(d_keys_output)); + HIP_CHECK(hipFree(d_values_input1)); + HIP_CHECK(hipFree(d_values_input2)); + HIP_CHECK(hipFree(d_values_output)); + } + + void run(benchmark::State& state, + size_t bytes, + const managed_seed& seed, + hipStream_t stream) const override + { + do_run(state, bytes, seed, stream); + } +}; + +#ifdef BENCHMARK_CONFIG_TUNING +template +struct device_merge_benchmark_generator +{ + + template + struct create_ipt + { + static constexpr unsigned int items_per_thread = 1u << ItemsPerThreadExponent; + using generated_config = rocprim::merge_config; + using benchmark_struct = device_merge_benchmark; + + void operator()(std::vector>& storage) + { + storage.emplace_back(std::make_unique()); + } + }; + + struct create_default_config + { + using default_config = + typename rocprim::detail::default_merge_config_base::type; + using benchmark_struct = device_merge_benchmark; + + void operator()(std::vector>& storage) + { + storage.emplace_back(std::make_unique()); + } + }; + + static void create(std::vector>& storage) + { + static constexpr unsigned int min_items_per_thread_exponent = 0u; + + // Very large block sizes don't work with large items_per_thread since + // shared memory is limited + static constexpr unsigned int max_shared_memory = TUNING_SHARED_MEMORY_MAX; + static constexpr unsigned int max_size_per_element = sizeof(KeyType) + sizeof(ValueType); + static constexpr unsigned int max_items_per_thread + = max_shared_memory / (BlockSize * max_size_per_element); + static constexpr unsigned int max_items_per_thread_exponent + = rocprim::Log2::VALUE - 1; + + create_default_config()(storage); + + static_for_each, + create_ipt>(storage); + } +}; + +#endif // BENCHMARK_CONFIG_TUNING + +#endif // ROCPRIM_BENCHMARK_DEVICE_MERGE_PARALLEL_HPP_ \ No newline at end of file diff --git a/benchmark/benchmark_device_merge_sort.cpp b/benchmark/benchmark_device_merge_sort.cpp index df5f6a7b8..00a339caa 100644 --- a/benchmark/benchmark_device_merge_sort.cpp +++ b/benchmark/benchmark_device_merge_sort.cpp @@ -35,20 +35,20 @@ #include #include -#ifndef DEFAULT_N -const size_t DEFAULT_N = 1024 * 1024 * 32; +#ifndef DEFAULT_BYTES +const size_t DEFAULT_BYTES = 1024 * 1024 * 32 * 4; #endif #define CREATE_BENCHMARK(...) \ { \ const device_merge_sort_benchmark<__VA_ARGS__> instance; \ - REGISTER_BENCHMARK(benchmarks, size, seed, stream, instance); \ + REGISTER_BENCHMARK(benchmarks, bytes, seed, stream, instance); \ } int main(int argc, char *argv[]) { cli::Parser parser(argc, argv); - parser.set_optional("size", "size", DEFAULT_N, "number of values"); + parser.set_optional("size", "size", DEFAULT_BYTES, "number of bytes"); parser.set_optional("trials", "trials", -1, "number of iterations"); parser.set_optional("name_format", "name_format", @@ -59,7 +59,7 @@ int main(int argc, char *argv[]) // Parse argv benchmark::Initialize(&argc, argv); - const size_t size = parser.get("size"); + const size_t bytes = parser.get("size"); const int trials = parser.get("trials"); bench_naming::set_format(parser.get("name_format")); const std::string seed_type = parser.get("seed"); @@ -70,7 +70,7 @@ int main(int argc, char *argv[]) // Benchmark info add_common_benchmark_info(); - benchmark::AddCustomContext("size", std::to_string(size)); + benchmark::AddCustomContext("bytes", std::to_string(bytes)); benchmark::AddCustomContext("seed", seed_type); // Add benchmarks diff --git a/benchmark/benchmark_device_merge_sort.hpp b/benchmark/benchmark_device_merge_sort.hpp index c8f0e8dc3..c935730ae 100644 --- a/benchmark/benchmark_device_merge_sort.hpp +++ b/benchmark/benchmark_device_merge_sort.hpp @@ -59,13 +59,15 @@ struct device_merge_sort_benchmark : public config_autotune_interface // keys benchmark template auto do_run(benchmark::State& state, - size_t size, + size_t bytes, const managed_seed& seed, hipStream_t stream) const -> typename std::enable_if::value, void>::type { using key_type = Key; + // Calculate the number of elements + size_t size = bytes / sizeof(key_type); // Generate data std::vector keys_input = get_random_data(size, @@ -158,7 +160,7 @@ struct device_merge_sort_benchmark : public config_autotune_interface // pairs benchmark template auto do_run(benchmark::State& state, - size_t size, + size_t bytes, const managed_seed& seed, hipStream_t stream) const -> typename std::enable_if::value, void>::type @@ -166,6 +168,8 @@ struct device_merge_sort_benchmark : public config_autotune_interface using key_type = Key; using value_type = Value; + // Calculate the number of elements + size_t size = bytes / sizeof(key_type); // Generate data std::vector keys_input = get_random_data(size, @@ -277,11 +281,11 @@ struct device_merge_sort_benchmark : public config_autotune_interface } void run(benchmark::State& state, - size_t size, + size_t bytes, const managed_seed& seed, hipStream_t stream) const override { - do_run(state, size, seed, stream); + do_run(state, bytes, seed, stream); } }; diff --git a/benchmark/benchmark_device_merge_sort_block_merge.cpp b/benchmark/benchmark_device_merge_sort_block_merge.cpp index adc72773f..6138771da 100644 --- a/benchmark/benchmark_device_merge_sort_block_merge.cpp +++ b/benchmark/benchmark_device_merge_sort_block_merge.cpp @@ -37,19 +37,19 @@ #include #ifndef DEFAULT_N -const size_t DEFAULT_N = 1024 * 1024 * 32; +const size_t DEFAULT_BYTES = 1024 * 1024 * 32 * 4; #endif #define CREATE_BENCHMARK(...) \ { \ const device_merge_sort_block_merge_benchmark<__VA_ARGS__> instance; \ - REGISTER_BENCHMARK(benchmarks, size, seed, stream, instance); \ + REGISTER_BENCHMARK(benchmarks, bytes, seed, stream, instance); \ } int main(int argc, char* argv[]) { cli::Parser parser(argc, argv); - parser.set_optional("size", "size", DEFAULT_N, "number of values"); + parser.set_optional("size", "size", DEFAULT_BYTES, "number of bytes"); parser.set_optional("trials", "trials", -1, "number of iterations"); parser.set_optional("name_format", "name_format", @@ -71,7 +71,7 @@ int main(int argc, char* argv[]) // Parse argv benchmark::Initialize(&argc, argv); - const size_t size = parser.get("size"); + const size_t bytes = parser.get("size"); const int trials = parser.get("trials"); bench_naming::set_format(parser.get("name_format")); const std::string seed_type = parser.get("seed"); @@ -82,7 +82,7 @@ int main(int argc, char* argv[]) // Benchmark info add_common_benchmark_info(); - benchmark::AddCustomContext("size", std::to_string(size)); + benchmark::AddCustomContext("bytes", std::to_string(bytes)); benchmark::AddCustomContext("seed", seed_type); // Add benchmarks @@ -93,7 +93,7 @@ int main(int argc, char* argv[]) config_autotune_register::register_benchmark_subset(benchmarks, parallel_instance, parallel_instances, - size, + bytes, seed, stream); #else // BENCHMARK_CONFIG_TUNING diff --git a/benchmark/benchmark_device_merge_sort_block_merge.parallel.hpp b/benchmark/benchmark_device_merge_sort_block_merge.parallel.hpp index 12576f251..ff04119ff 100644 --- a/benchmark/benchmark_device_merge_sort_block_merge.parallel.hpp +++ b/benchmark/benchmark_device_merge_sort_block_merge.parallel.hpp @@ -80,13 +80,16 @@ struct device_merge_sort_block_merge_benchmark : public config_autotune_interfac // keys benchmark template auto do_run(benchmark::State& state, - size_t size, + size_t bytes, const managed_seed& seed, hipStream_t stream) const -> typename std::enable_if::value, void>::type { using key_type = Key; + // Calculate the number of elements + size_t size = bytes / sizeof(key_type); + // Generate data std::vector keys_input = get_random_data(size, @@ -102,7 +105,7 @@ struct device_merge_sort_block_merge_benchmark : public config_autotune_interfac keys_input.data(), size * sizeof(key_type), hipMemcpyHostToDevice)); - hipDeviceSynchronize(); + HIP_CHECK(hipDeviceSynchronize()); ::rocprim::less lesser_op; rocprim::empty_type* values_ptr = nullptr; @@ -172,11 +175,11 @@ struct device_merge_sort_block_merge_benchmark : public config_autotune_interfac for(auto _ : state) { // Record start event - hipMemcpyAsync(d_keys, - d_keys_input, - size * sizeof(key_type), - hipMemcpyDeviceToDevice, - stream); + HIP_CHECK(hipMemcpyAsync(d_keys, + d_keys_input, + size * sizeof(key_type), + hipMemcpyDeviceToDevice, + stream)); HIP_CHECK(hipEventRecord(start, stream)); HIP_CHECK(rp::detail::merge_sort_block_merge(d_temporary_storage, temporary_storage_bytes, @@ -212,7 +215,7 @@ struct device_merge_sort_block_merge_benchmark : public config_autotune_interfac // pairs benchmark template auto do_run(benchmark::State& state, - size_t size, + size_t bytes, const managed_seed& seed, hipStream_t stream) const -> typename std::enable_if::value, void>::type @@ -220,6 +223,9 @@ struct device_merge_sort_block_merge_benchmark : public config_autotune_interfac using key_type = Key; using value_type = Value; + // Calculate the number of elements + size_t size = bytes / sizeof(key_type); + // Generate data std::vector keys_input = get_random_data(size, @@ -248,7 +254,7 @@ struct device_merge_sort_block_merge_benchmark : public config_autotune_interfac size * sizeof(value_type), hipMemcpyHostToDevice)); - hipDeviceSynchronize(); + HIP_CHECK(hipDeviceSynchronize()); ::rocprim::less lesser_op; @@ -319,16 +325,16 @@ struct device_merge_sort_block_merge_benchmark : public config_autotune_interfac for(auto _ : state) { // Record start event - hipMemcpyAsync(d_keys, - d_keys_input, - size * sizeof(key_type), - hipMemcpyDeviceToDevice, - stream); - hipMemcpyAsync(d_values, - d_values_input, - size * sizeof(key_type), - hipMemcpyDeviceToDevice, - stream); + HIP_CHECK(hipMemcpyAsync(d_keys, + d_keys_input, + size * sizeof(key_type), + hipMemcpyDeviceToDevice, + stream)); + HIP_CHECK(hipMemcpyAsync(d_values, + d_values_input, + size * sizeof(key_type), + hipMemcpyDeviceToDevice, + stream)); HIP_CHECK(hipEventRecord(start, stream)); HIP_CHECK(rp::detail::merge_sort_block_merge(d_temporary_storage, temporary_storage_bytes, @@ -364,11 +370,11 @@ struct device_merge_sort_block_merge_benchmark : public config_autotune_interfac } void run(benchmark::State& state, - size_t size, + size_t bytes, const managed_seed& seed, hipStream_t stream) const override { - do_run(state, size, seed, stream); + do_run(state, bytes, seed, stream); } }; diff --git a/benchmark/benchmark_device_merge_sort_block_sort.cpp b/benchmark/benchmark_device_merge_sort_block_sort.cpp index 7bf0b9a34..46d6b2bdf 100644 --- a/benchmark/benchmark_device_merge_sort_block_sort.cpp +++ b/benchmark/benchmark_device_merge_sort_block_sort.cpp @@ -37,19 +37,19 @@ #include #ifndef DEFAULT_N -const size_t DEFAULT_N = 1024 * 1024 * 32; +const size_t DEFAULT_BYTES = 1024 * 1024 * 32 * 4; #endif #define CREATE_BENCHMARK(...) \ { \ const device_merge_sort_block_sort_benchmark<__VA_ARGS__> instance; \ - REGISTER_BENCHMARK(benchmarks, size, seed, stream, instance); \ + REGISTER_BENCHMARK(benchmarks, bytes, seed, stream, instance); \ } int main(int argc, char* argv[]) { cli::Parser parser(argc, argv); - parser.set_optional("size", "size", DEFAULT_N, "number of values"); + parser.set_optional("size", "size", DEFAULT_BYTES, "number of bytes"); parser.set_optional("trials", "trials", -1, "number of iterations"); parser.set_optional("name_format", "name_format", @@ -71,7 +71,7 @@ int main(int argc, char* argv[]) // Parse argv benchmark::Initialize(&argc, argv); - const size_t size = parser.get("size"); + const size_t bytes = parser.get("size"); const int trials = parser.get("trials"); bench_naming::set_format(parser.get("name_format")); const std::string seed_type = parser.get("seed"); @@ -82,7 +82,7 @@ int main(int argc, char* argv[]) // Benchmark info add_common_benchmark_info(); - benchmark::AddCustomContext("size", std::to_string(size)); + benchmark::AddCustomContext("bytes", std::to_string(bytes)); benchmark::AddCustomContext("seed", seed_type); // Add benchmarks @@ -93,7 +93,7 @@ int main(int argc, char* argv[]) config_autotune_register::register_benchmark_subset(benchmarks, parallel_instance, parallel_instances, - size, + bytes, seed, stream); #else // BENCHMARK_CONFIG_TUNING diff --git a/benchmark/benchmark_device_merge_sort_block_sort.parallel.hpp b/benchmark/benchmark_device_merge_sort_block_sort.parallel.hpp index 24db039bd..796c165e4 100644 --- a/benchmark/benchmark_device_merge_sort_block_sort.parallel.hpp +++ b/benchmark/benchmark_device_merge_sort_block_sort.parallel.hpp @@ -89,13 +89,15 @@ struct device_merge_sort_block_sort_benchmark : public config_autotune_interface // keys benchmark template auto do_run(benchmark::State& state, - size_t size, + size_t bytes, const managed_seed& seed, hipStream_t stream) const -> typename std::enable_if::value, void>::type { using key_type = Key; + // Calculate the number of elements + size_t size = bytes / sizeof(key_type); // Generate data std::vector keys_input = get_random_data(size, @@ -176,7 +178,7 @@ struct device_merge_sort_block_sort_benchmark : public config_autotune_interface // pairs benchmark template auto do_run(benchmark::State& state, - size_t size, + size_t bytes, const managed_seed& seed, hipStream_t stream) const -> typename std::enable_if::value, void>::type @@ -184,6 +186,8 @@ struct device_merge_sort_block_sort_benchmark : public config_autotune_interface using key_type = Key; using value_type = Value; + // Calculate the number of elements + size_t size = bytes / sizeof(key_type); // Generate data std::vector keys_input = get_random_data(size, @@ -280,11 +284,11 @@ struct device_merge_sort_block_sort_benchmark : public config_autotune_interface } void run(benchmark::State& state, - size_t size, + size_t bytes, const managed_seed& seed, hipStream_t stream) const override { - do_run(state, size, seed, stream); + do_run(state, bytes, seed, stream); } }; diff --git a/benchmark/benchmark_device_nth_element.cpp b/benchmark/benchmark_device_nth_element.cpp index 4cdb998d1..0c0bcb2a2 100644 --- a/benchmark/benchmark_device_nth_element.cpp +++ b/benchmark/benchmark_device_nth_element.cpp @@ -35,14 +35,14 @@ #include #include -#ifndef DEFAULT_N -const size_t DEFAULT_N = 1024 * 1024 * 32; +#ifndef DEFAULT_BYTES +const size_t DEFAULT_BYTES = 1024 * 1024 * 32 * 4; #endif #define CREATE_BENCHMARK_NTH_ELEMENT(TYPE, SMALL_N) \ { \ const device_nth_element_benchmark instance(SMALL_N); \ - REGISTER_BENCHMARK(benchmarks, size, seed, stream, instance); \ + REGISTER_BENCHMARK(benchmarks, bytes, seed, stream, instance); \ } #define CREATE_BENCHMARK(TYPE) \ @@ -54,7 +54,7 @@ const size_t DEFAULT_N = 1024 * 1024 * 32; int main(int argc, char* argv[]) { cli::Parser parser(argc, argv); - parser.set_optional("size", "size", DEFAULT_N, "number of values"); + parser.set_optional("size", "size", DEFAULT_BYTES, "number of bytes"); parser.set_optional("trials", "trials", -1, "number of iterations"); parser.set_optional("name_format", "name_format", @@ -65,7 +65,7 @@ int main(int argc, char* argv[]) // Parse argv benchmark::Initialize(&argc, argv); - const size_t size = parser.get("size"); + const size_t bytes = parser.get("size"); const int trials = parser.get("trials"); bench_naming::set_format(parser.get("name_format")); const std::string seed_type = parser.get("seed"); @@ -76,7 +76,7 @@ int main(int argc, char* argv[]) // Benchmark info add_common_benchmark_info(); - benchmark::AddCustomContext("size", std::to_string(size)); + benchmark::AddCustomContext("bytes", std::to_string(bytes)); benchmark::AddCustomContext("seed", seed_type); // Add benchmarks diff --git a/benchmark/benchmark_device_nth_element.hpp b/benchmark/benchmark_device_nth_element.hpp index e85c2d6b4..f517a1a77 100644 --- a/benchmark/benchmark_device_nth_element.hpp +++ b/benchmark/benchmark_device_nth_element.hpp @@ -61,12 +61,14 @@ struct device_nth_element_benchmark : public config_autotune_interface static constexpr unsigned int warmup_size = 5; void run(benchmark::State& state, - size_t size, + size_t bytes, const managed_seed& seed, hipStream_t stream) const override { using key_type = Key; + // Calculate the number of elements + size_t size = bytes / sizeof(key_type); size_t nth = 10; if(!small_n) diff --git a/benchmark/benchmark_device_partial_sort.cpp b/benchmark/benchmark_device_partial_sort.cpp index 49db25d25..ea0a7b2b1 100644 --- a/benchmark/benchmark_device_partial_sort.cpp +++ b/benchmark/benchmark_device_partial_sort.cpp @@ -35,14 +35,14 @@ #include #include -#ifndef DEFAULT_N -const size_t DEFAULT_N = 1024 * 1024 * 32; +#ifndef DEFAULT_BYTES +const size_t DEFAULT_BYTES = 1024 * 1024 * 32 * 4; #endif #define CREATE_BENCHMARK_PARTIAL_SORT(TYPE, SMALL_N) \ { \ const device_partial_sort_benchmark instance(SMALL_N); \ - REGISTER_BENCHMARK(benchmarks, size, seed, stream, instance); \ + REGISTER_BENCHMARK(benchmarks, bytes, seed, stream, instance); \ } #define CREATE_BENCHMARK(TYPE) \ @@ -54,7 +54,7 @@ const size_t DEFAULT_N = 1024 * 1024 * 32; int main(int argc, char* argv[]) { cli::Parser parser(argc, argv); - parser.set_optional("size", "size", DEFAULT_N, "number of values"); + parser.set_optional("size", "size", DEFAULT_BYTES, "number of bytes"); parser.set_optional("trials", "trials", -1, "number of iterations"); parser.set_optional("name_format", "name_format", @@ -65,7 +65,7 @@ int main(int argc, char* argv[]) // Parse argv benchmark::Initialize(&argc, argv); - const size_t size = parser.get("size"); + const size_t bytes = parser.get("size"); const int trials = parser.get("trials"); bench_naming::set_format(parser.get("name_format")); const std::string seed_type = parser.get("seed"); @@ -76,7 +76,7 @@ int main(int argc, char* argv[]) // Benchmark info add_common_benchmark_info(); - benchmark::AddCustomContext("size", std::to_string(size)); + benchmark::AddCustomContext("bytes", std::to_string(bytes)); benchmark::AddCustomContext("seed", seed_type); // Add benchmarks diff --git a/benchmark/benchmark_device_partial_sort.hpp b/benchmark/benchmark_device_partial_sort.hpp index c93ea8453..efcb1c5d5 100644 --- a/benchmark/benchmark_device_partial_sort.hpp +++ b/benchmark/benchmark_device_partial_sort.hpp @@ -61,12 +61,14 @@ struct device_partial_sort_benchmark : public config_autotune_interface static constexpr unsigned int warmup_size = 5; void run(benchmark::State& state, - size_t size, + size_t bytes, const managed_seed& seed, hipStream_t stream) const override { using key_type = Key; - + // Calculate the number of elements + size_t size = bytes / sizeof(key_type); + size_t middle = 10; if(!small_n) diff --git a/benchmark/benchmark_device_partial_sort_copy.cpp b/benchmark/benchmark_device_partial_sort_copy.cpp index e8097b635..2b5a1b840 100644 --- a/benchmark/benchmark_device_partial_sort_copy.cpp +++ b/benchmark/benchmark_device_partial_sort_copy.cpp @@ -35,14 +35,14 @@ #include #include -#ifndef DEFAULT_N -const size_t DEFAULT_N = 1024 * 1024 * 32; +#ifndef DEFAULT_BYTES +const size_t DEFAULT_BYTES = 1024 * 1024 * 32 * 4; #endif #define CREATE_BENCHMARK_PARTIAL_SORT_COPY(TYPE, SMALL_N) \ { \ const device_partial_sort_copy_benchmark instance(SMALL_N); \ - REGISTER_BENCHMARK(benchmarks, size, seed, stream, instance); \ + REGISTER_BENCHMARK(benchmarks, bytes, seed, stream, instance); \ } #define CREATE_BENCHMARK(TYPE) \ @@ -54,7 +54,7 @@ const size_t DEFAULT_N = 1024 * 1024 * 32; int main(int argc, char* argv[]) { cli::Parser parser(argc, argv); - parser.set_optional("size", "size", DEFAULT_N, "number of values"); + parser.set_optional("size", "size", DEFAULT_BYTES, "number of bytes"); parser.set_optional("trials", "trials", -1, "number of iterations"); parser.set_optional("name_format", "name_format", @@ -65,7 +65,7 @@ int main(int argc, char* argv[]) // Parse argv benchmark::Initialize(&argc, argv); - const size_t size = parser.get("size"); + const size_t bytes = parser.get("size"); const int trials = parser.get("trials"); bench_naming::set_format(parser.get("name_format")); const std::string seed_type = parser.get("seed"); @@ -76,7 +76,7 @@ int main(int argc, char* argv[]) // Benchmark info add_common_benchmark_info(); - benchmark::AddCustomContext("size", std::to_string(size)); + benchmark::AddCustomContext("bytes", std::to_string(bytes)); benchmark::AddCustomContext("seed", seed_type); // Add benchmarks diff --git a/benchmark/benchmark_device_partial_sort_copy.hpp b/benchmark/benchmark_device_partial_sort_copy.hpp index 12f8b0e48..833c51599 100644 --- a/benchmark/benchmark_device_partial_sort_copy.hpp +++ b/benchmark/benchmark_device_partial_sort_copy.hpp @@ -61,12 +61,14 @@ struct device_partial_sort_copy_benchmark : public config_autotune_interface static constexpr unsigned int warmup_size = 5; void run(benchmark::State& state, - size_t size, + size_t bytes, const managed_seed& seed, hipStream_t stream) const override { using key_type = Key; + // Calculate the number of elements + size_t size = bytes / sizeof(key_type); size_t middle = 10; if(!small_n) diff --git a/benchmark/benchmark_device_partition.cpp b/benchmark/benchmark_device_partition.cpp index de60fc392..46212e408 100644 --- a/benchmark/benchmark_device_partition.cpp +++ b/benchmark/benchmark_device_partition.cpp @@ -44,39 +44,39 @@ #include #include -#ifndef DEFAULT_N -const size_t DEFAULT_N = 1024 * 1024 * 32; +#ifndef DEFAULT_BYTES +const size_t DEFAULT_BYTES = 1024 * 1024 * 32 * 4; #endif #define CREATE_PARTITION_FLAG_BENCHMARK(T, F, p) \ { \ const device_partition_flag_benchmark instance; \ - REGISTER_BENCHMARK(benchmarks, size, seed, stream, instance); \ + REGISTER_BENCHMARK(benchmarks, bytes, seed, stream, instance); \ } #define CREATE_PARTITION_PREDICATE_BENCHMARK(T, p) \ { \ const device_partition_predicate_benchmark instance; \ - REGISTER_BENCHMARK(benchmarks, size, seed, stream, instance); \ + REGISTER_BENCHMARK(benchmarks, bytes, seed, stream, instance); \ } #define CREATE_PARTITION_TWO_WAY_FLAG_BENCHMARK(T, F, p) \ { \ const device_partition_two_way_flag_benchmark instance; \ - REGISTER_BENCHMARK(benchmarks, size, seed, stream, instance); \ + REGISTER_BENCHMARK(benchmarks, bytes, seed, stream, instance); \ } #define CREATE_PARTITION_TWO_WAY_PREDICATE_BENCHMARK(T, p) \ { \ const device_partition_two_way_predicate_benchmark \ instance; \ - REGISTER_BENCHMARK(benchmarks, size, seed, stream, instance); \ + REGISTER_BENCHMARK(benchmarks, bytes, seed, stream, instance); \ } #define CREATE_PARTITION_THREE_WAY_BENCHMARK(T, p) \ { \ const device_partition_three_way_benchmark instance; \ - REGISTER_BENCHMARK(benchmarks, size, seed, stream, instance); \ + REGISTER_BENCHMARK(benchmarks, bytes, seed, stream, instance); \ } #define BENCHMARK_FLAG_TYPE(type, flag_type) \ @@ -112,7 +112,7 @@ const size_t DEFAULT_N = 1024 * 1024 * 32; int main(int argc, char* argv[]) { cli::Parser parser(argc, argv); - parser.set_optional("size", "size", DEFAULT_N, "number of values"); + parser.set_optional("size", "size", DEFAULT_BYTES, "number of bytes"); parser.set_optional("trials", "trials", -1, "number of iterations"); parser.set_optional("name_format", "name_format", @@ -134,7 +134,7 @@ int main(int argc, char* argv[]) // Parse argv benchmark::Initialize(&argc, argv); - const size_t size = parser.get("size"); + const size_t bytes = parser.get("size"); const int trials = parser.get("trials"); bench_naming::set_format(parser.get("name_format")); const std::string seed_type = parser.get("seed"); @@ -145,7 +145,7 @@ int main(int argc, char* argv[]) // Benchmark info add_common_benchmark_info(); - benchmark::AddCustomContext("size", std::to_string(size)); + benchmark::AddCustomContext("bytes", std::to_string(bytes)); benchmark::AddCustomContext("seed", seed_type); // Add benchmarks @@ -156,7 +156,7 @@ int main(int argc, char* argv[]) config_autotune_register::register_benchmark_subset(benchmarks, parallel_instance, parallel_instances, - size, + bytes, seed, stream); #else diff --git a/benchmark/benchmark_device_partition.parallel.hpp b/benchmark/benchmark_device_partition.parallel.hpp index ab094fe01..afc39b669 100644 --- a/benchmark/benchmark_device_partition.parallel.hpp +++ b/benchmark/benchmark_device_partition.parallel.hpp @@ -131,10 +131,13 @@ struct device_partition_flag_benchmark : public config_autotune_interface } void run(benchmark::State& state, - size_t size, + size_t bytes, const managed_seed& seed, const hipStream_t stream) const override { + // Calculate the number of elements + size_t size = bytes / sizeof(DataType); + std::vector input = get_random_data(size, generate_limits::min(), generate_limits::max(), @@ -244,16 +247,16 @@ struct device_partition_flag_benchmark : public config_autotune_interface state.SetBytesProcessed(state.iterations() * batch_size * size * sizeof(DataType)); state.SetItemsProcessed(state.iterations() * batch_size * size); - hipFree(d_input); + HIP_CHECK(hipFree(d_input)); if(is_tuning) { - hipFree(d_flags_2); - hipFree(d_flags_1); + HIP_CHECK(hipFree(d_flags_2)); + HIP_CHECK(hipFree(d_flags_1)); } - hipFree(d_flags_0); - hipFree(d_output); - hipFree(d_selected_count_output); - hipFree(d_temp_storage); + HIP_CHECK(hipFree(d_flags_0)); + HIP_CHECK(hipFree(d_output)); + HIP_CHECK(hipFree(d_selected_count_output)); + HIP_CHECK(hipFree(d_temp_storage)); } static constexpr bool is_tuning = Probability == partition_probability::tuning; @@ -274,10 +277,13 @@ struct device_partition_predicate_benchmark : public config_autotune_interface } void run(benchmark::State& state, - size_t size, + size_t bytes, const managed_seed& seed, const hipStream_t stream) const override { + // Calculate the number of elements + size_t size = bytes / sizeof(DataType); + // all data types can represent [0, 127], -1 so a predicate can select all std::vector input = get_random_data(size, static_cast(0), @@ -360,10 +366,10 @@ struct device_partition_predicate_benchmark : public config_autotune_interface state.SetBytesProcessed(state.iterations() * batch_size * size * sizeof(DataType)); state.SetItemsProcessed(state.iterations() * batch_size * size); - hipFree(d_input); - hipFree(d_output); - hipFree(d_selected_count_output); - hipFree(d_temp_storage); + HIP_CHECK(hipFree(d_input)); + HIP_CHECK(hipFree(d_output)); + HIP_CHECK(hipFree(d_selected_count_output)); + HIP_CHECK(hipFree(d_temp_storage)); } static constexpr bool is_tuning = Probability == partition_probability::tuning; @@ -386,10 +392,13 @@ struct device_partition_two_way_flag_benchmark : public config_autotune_interfac } void run(benchmark::State& state, - size_t size, + size_t bytes, const managed_seed& seed, const hipStream_t stream) const override { + // Calculate the number of elements + size_t size = bytes / sizeof(DataType); + std::vector input = get_random_data(size, generate_limits::min(), generate_limits::max(), @@ -503,17 +512,17 @@ struct device_partition_two_way_flag_benchmark : public config_autotune_interfac state.SetBytesProcessed(state.iterations() * batch_size * size * sizeof(DataType)); state.SetItemsProcessed(state.iterations() * batch_size * size); - hipFree(d_input); + HIP_CHECK(hipFree(d_input)); if(is_tuning) { - hipFree(d_flags_2); - hipFree(d_flags_1); + HIP_CHECK(hipFree(d_flags_2)); + HIP_CHECK(hipFree(d_flags_1)); } - hipFree(d_flags_0); - hipFree(d_output_selected); - hipFree(d_output_rejected); - hipFree(d_selected_count_output); - hipFree(d_temp_storage); + HIP_CHECK(hipFree(d_flags_0)); + HIP_CHECK(hipFree(d_output_selected)); + HIP_CHECK(hipFree(d_output_rejected)); + HIP_CHECK(hipFree(d_selected_count_output)); + HIP_CHECK(hipFree(d_temp_storage)); } static constexpr bool is_tuning = Probability == partition_probability::tuning; @@ -534,10 +543,13 @@ struct device_partition_two_way_predicate_benchmark : public config_autotune_int } void run(benchmark::State& state, - size_t size, + size_t bytes, const managed_seed& seed, const hipStream_t stream) const override { + // Calculate the number of elements + size_t size = bytes / sizeof(DataType); + // all data types can represent [0, 127], -1 so a predicate can select all std::vector input = get_random_data(size, static_cast(0), @@ -623,11 +635,11 @@ struct device_partition_two_way_predicate_benchmark : public config_autotune_int state.SetBytesProcessed(state.iterations() * batch_size * size * sizeof(DataType)); state.SetItemsProcessed(state.iterations() * batch_size * size); - hipFree(d_input); - hipFree(d_output_selected); - hipFree(d_output_rejected); - hipFree(d_selected_count_output); - hipFree(d_temp_storage); + HIP_CHECK(hipFree(d_input)); + HIP_CHECK(hipFree(d_output_selected)); + HIP_CHECK(hipFree(d_output_rejected)); + HIP_CHECK(hipFree(d_selected_count_output)); + HIP_CHECK(hipFree(d_temp_storage)); } static constexpr bool is_tuning = Probability == partition_probability::tuning; @@ -648,10 +660,13 @@ struct device_partition_three_way_benchmark : public config_autotune_interface } void run(benchmark::State& state, - size_t size, + size_t bytes, const managed_seed& seed, const hipStream_t stream) const override { + // Calculate the number of elements + size_t size = bytes / sizeof(DataType); + // all data types can represent [0, 127], -1 so a predicate can select all std::vector input = get_random_data(size, static_cast(0), @@ -759,12 +774,12 @@ struct device_partition_three_way_benchmark : public config_autotune_interface state.SetBytesProcessed(state.iterations() * batch_size * size * sizeof(DataType)); state.SetItemsProcessed(state.iterations() * batch_size * size); - hipFree(d_input); - hipFree(d_output_first); - hipFree(d_output_second); - hipFree(d_output_unselected); - hipFree(d_selected_count_output); - hipFree(d_temp_storage); + HIP_CHECK(hipFree(d_input)); + HIP_CHECK(hipFree(d_output_first)); + HIP_CHECK(hipFree(d_output_second)); + HIP_CHECK(hipFree(d_output_unselected)); + HIP_CHECK(hipFree(d_selected_count_output)); + HIP_CHECK(hipFree(d_temp_storage)); } static constexpr bool is_tuning = Probability == partition_three_way_probability::tuning; diff --git a/benchmark/benchmark_device_radix_sort.cpp b/benchmark/benchmark_device_radix_sort.cpp index 221ce5eec..afc021c46 100644 --- a/benchmark/benchmark_device_radix_sort.cpp +++ b/benchmark/benchmark_device_radix_sort.cpp @@ -35,14 +35,14 @@ #include -#ifndef DEFAULT_N -const size_t DEFAULT_N = 1024 * 1024 * 32; +#ifndef DEFAULT_BYTES +const size_t DEFAULT_BYTES = 1024 * 1024 * 32 * 4; #endif int main(int argc, char *argv[]) { cli::Parser parser(argc, argv); - parser.set_optional("size", "size", DEFAULT_N, "number of values"); + parser.set_optional("size", "size", DEFAULT_BYTES, "number of bytes"); parser.set_optional("trials", "trials", -1, "number of iterations"); parser.set_optional("name_format", "name_format", @@ -53,7 +53,7 @@ int main(int argc, char *argv[]) // Parse argv benchmark::Initialize(&argc, argv); - const size_t size = parser.get("size"); + const size_t bytes = parser.get("size"); const int trials = parser.get("trials"); bench_naming::set_format(parser.get("name_format")); const std::string seed_type = parser.get("seed"); @@ -64,13 +64,13 @@ int main(int argc, char *argv[]) // Benchmark info add_common_benchmark_info(); - benchmark::AddCustomContext("size", std::to_string(size)); + benchmark::AddCustomContext("bytes", std::to_string(bytes)); benchmark::AddCustomContext("seed", seed_type); // Add benchmarks std::vector benchmarks = {}; - add_sort_keys_benchmarks(benchmarks, size, seed, stream); - add_sort_pairs_benchmarks(benchmarks, size, seed, stream); + add_sort_keys_benchmarks(benchmarks, bytes, seed, stream); + add_sort_pairs_benchmarks(benchmarks, bytes, seed, stream); // Use manual timing for(auto& b : benchmarks) diff --git a/benchmark/benchmark_device_radix_sort.hpp b/benchmark/benchmark_device_radix_sort.hpp index 5d0c0551e..3f38d0975 100644 --- a/benchmark/benchmark_device_radix_sort.hpp +++ b/benchmark/benchmark_device_radix_sort.hpp @@ -60,13 +60,16 @@ struct device_radix_sort_benchmark : public config_autotune_interface // keys benchmark template auto do_run(benchmark::State& state, - size_t size, + size_t bytes, const managed_seed& seed, hipStream_t stream) const -> std::enable_if_t::value, void> { using key_type = Key; + // Calculate the number of elements + size_t size = bytes / sizeof(key_type); + std::vector keys_input = get_random_data(size, generate_limits::min(), @@ -157,7 +160,7 @@ struct device_radix_sort_benchmark : public config_autotune_interface // pairs benchmark template auto do_run(benchmark::State& state, - size_t size, + size_t bytes, const managed_seed& seed, hipStream_t stream) const -> std::enable_if_t::value, void> @@ -165,6 +168,9 @@ struct device_radix_sort_benchmark : public config_autotune_interface using key_type = Key; using value_type = Value; + // Calculate the number of elements + size_t size = bytes / sizeof(key_type); + std::vector keys_input = get_random_data(size, generate_limits::min(), @@ -376,11 +382,11 @@ struct device_radix_sort_benchmark : public config_autotune_interface #define CREATE_RADIX_SORT_BENCHMARK(...) \ { \ const device_radix_sort_benchmark<__VA_ARGS__> instance; \ - REGISTER_BENCHMARK(benchmarks, size, seed, stream, instance); \ + REGISTER_BENCHMARK(benchmarks, bytes, seed, stream, instance); \ } inline void add_sort_keys_benchmarks(std::vector& benchmarks, - size_t size, + size_t bytes, const managed_seed& seed, hipStream_t stream) { @@ -396,7 +402,7 @@ inline void add_sort_keys_benchmarks(std::vector& benchmarks, - size_t size, + size_t bytes, const managed_seed& seed, hipStream_t stream) { diff --git a/benchmark/benchmark_device_radix_sort_block_sort.cpp b/benchmark/benchmark_device_radix_sort_block_sort.cpp index 5ce6c070f..151add31d 100644 --- a/benchmark/benchmark_device_radix_sort_block_sort.cpp +++ b/benchmark/benchmark_device_radix_sort_block_sort.cpp @@ -36,20 +36,20 @@ #include -#ifndef DEFAULT_N -const size_t DEFAULT_N = 1024 * 1024 * 32; +#ifndef DEFAULT_BYTES +const size_t DEFAULT_BYTES = 1024 * 1024 * 32 * 4; #endif #define CREATE_BENCHMARK(...) \ { \ const device_radix_sort_block_sort_benchmark<__VA_ARGS__> instance; \ - REGISTER_BENCHMARK(benchmarks, size, seed, stream, instance); \ + REGISTER_BENCHMARK(benchmarks, bytes, seed, stream, instance); \ } int main(int argc, char* argv[]) { cli::Parser parser(argc, argv); - parser.set_optional("size", "size", DEFAULT_N, "number of values"); + parser.set_optional("size", "size", DEFAULT_BYTES, "number of bytes"); parser.set_optional("trials", "trials", -1, "number of iterations"); parser.set_optional("name_format", "name_format", @@ -71,7 +71,7 @@ int main(int argc, char* argv[]) // Parse argv benchmark::Initialize(&argc, argv); - const size_t size = parser.get("size"); + const size_t bytes = parser.get("size"); const int trials = parser.get("trials"); bench_naming::set_format(parser.get("name_format")); const std::string seed_type = parser.get("seed"); @@ -82,7 +82,7 @@ int main(int argc, char* argv[]) // Benchmark info add_common_benchmark_info(); - benchmark::AddCustomContext("size", std::to_string(size)); + benchmark::AddCustomContext("bytes", std::to_string(bytes)); benchmark::AddCustomContext("seed", seed_type); // Add benchmarks @@ -93,7 +93,7 @@ int main(int argc, char* argv[]) config_autotune_register::register_benchmark_subset(benchmarks, parallel_instance, parallel_instances, - size, + bytes, seed, stream); #else // BENCHMARK_CONFIG_TUNING diff --git a/benchmark/benchmark_device_radix_sort_block_sort.parallel.hpp b/benchmark/benchmark_device_radix_sort_block_sort.parallel.hpp index 07f91db35..3c9c956e2 100644 --- a/benchmark/benchmark_device_radix_sort_block_sort.parallel.hpp +++ b/benchmark/benchmark_device_radix_sort_block_sort.parallel.hpp @@ -74,13 +74,16 @@ struct device_radix_sort_block_sort_benchmark : public config_autotune_interface // keys benchmark template auto do_run(benchmark::State& state, - size_t size, + size_t bytes, const managed_seed& seed, hipStream_t stream) const -> typename std::enable_if::value, void>::type { using key_type = Key; + // Calculate the number of elements + size_t size = bytes / sizeof(key_type); + // Generate data std::vector keys_input = get_random_data(size, @@ -165,7 +168,7 @@ struct device_radix_sort_block_sort_benchmark : public config_autotune_interface // pairs benchmark template auto do_run(benchmark::State& state, - size_t size, + size_t bytes, const managed_seed& seed, hipStream_t stream) const -> typename std::enable_if::value, void>::type @@ -173,6 +176,9 @@ struct device_radix_sort_block_sort_benchmark : public config_autotune_interface using key_type = Key; using value_type = Value; + // Calculate the number of elements + size_t size = bytes / sizeof(key_type); + // Generate data std::vector keys_input = get_random_data(size, diff --git a/benchmark/benchmark_device_radix_sort_onesweep.cpp b/benchmark/benchmark_device_radix_sort_onesweep.cpp index 2e7944f0e..bd2b81800 100644 --- a/benchmark/benchmark_device_radix_sort_onesweep.cpp +++ b/benchmark/benchmark_device_radix_sort_onesweep.cpp @@ -36,14 +36,14 @@ #include -#ifndef DEFAULT_N -const size_t DEFAULT_N = 1024 * 1024 * 32; +#ifndef DEFAULT_BYTES +const size_t DEFAULT_BYTES = 1024 * 1024 * 32 * 4; #endif int main(int argc, char* argv[]) { cli::Parser parser(argc, argv); - parser.set_optional("size", "size", DEFAULT_N, "number of values"); + parser.set_optional("size", "size", DEFAULT_BYTES, "number of bytes"); parser.set_optional("trials", "trials", -1, "number of iterations"); parser.set_optional("name_format", "name_format", @@ -65,7 +65,7 @@ int main(int argc, char* argv[]) // Parse argv benchmark::Initialize(&argc, argv); - const size_t size = parser.get("size"); + const size_t bytes = parser.get("size"); const int trials = parser.get("trials"); bench_naming::set_format(parser.get("name_format")); const std::string seed_type = parser.get("seed"); @@ -76,7 +76,7 @@ int main(int argc, char* argv[]) // Benchmark info add_common_benchmark_info(); - benchmark::AddCustomContext("size", std::to_string(size)); + benchmark::AddCustomContext("bytes", std::to_string(bytes)); benchmark::AddCustomContext("seed", seed_type); // Add benchmarks @@ -87,12 +87,12 @@ int main(int argc, char* argv[]) config_autotune_register::register_benchmark_subset(benchmarks, parallel_instance, parallel_instances, - size, + bytes, seed, stream); #else // BENCHMARK_CONFIG_TUNING - add_sort_keys_benchmarks(benchmarks, size, seed, stream); - add_sort_pairs_benchmarks(benchmarks, size, seed, stream); + add_sort_keys_benchmarks(benchmarks, bytes, seed, stream); + add_sort_pairs_benchmarks(benchmarks, bytes, seed, stream); #endif // BENCHMARK_CONFIG_TUNING // Use manual timing diff --git a/benchmark/benchmark_device_radix_sort_onesweep.parallel.hpp b/benchmark/benchmark_device_radix_sort_onesweep.parallel.hpp index b946244f3..40495dd63 100644 --- a/benchmark/benchmark_device_radix_sort_onesweep.parallel.hpp +++ b/benchmark/benchmark_device_radix_sort_onesweep.parallel.hpp @@ -91,13 +91,16 @@ struct device_radix_sort_onesweep_benchmark : public config_autotune_interface // keys benchmark template auto do_run(benchmark::State& state, - size_t size, + size_t bytes, const managed_seed& seed, hipStream_t stream) const -> typename std::enable_if::value, void>::type { using key_type = Key; + // Calculate the number of elements + size_t size = bytes / sizeof(key_type); + std::vector keys_input = get_random_data(size, generate_limits::min(), @@ -214,7 +217,7 @@ struct device_radix_sort_onesweep_benchmark : public config_autotune_interface // pairs benchmark template auto do_run(benchmark::State& state, - size_t size, + size_t bytes, const managed_seed& seed, hipStream_t stream) const -> typename std::enable_if::value, void>::type @@ -222,6 +225,9 @@ struct device_radix_sort_onesweep_benchmark : public config_autotune_interface using key_type = Key; using value_type = Value; + // Calculate the number of elements + size_t size = bytes / sizeof(key_type); + std::vector keys_input = get_random_data(size, generate_limits::min(), @@ -443,11 +449,11 @@ struct device_radix_sort_onesweep_benchmark_generator #define CREATE_RADIX_SORT_BENCHMARK(...) \ { \ const device_radix_sort_onesweep_benchmark<__VA_ARGS__> instance; \ - REGISTER_BENCHMARK(benchmarks, size, seed, stream, instance); \ + REGISTER_BENCHMARK(benchmarks, bytes, seed, stream, instance); \ } inline void add_sort_keys_benchmarks(std::vector& benchmarks, - size_t size, + size_t bytes, const managed_seed& seed, hipStream_t stream) { @@ -461,7 +467,7 @@ inline void add_sort_keys_benchmarks(std::vector& benchmarks, - size_t size, + size_t bytes, const managed_seed& seed, hipStream_t stream) { diff --git a/benchmark/benchmark_device_reduce.cpp b/benchmark/benchmark_device_reduce.cpp index d9bcb46a7..b4d9817e1 100644 --- a/benchmark/benchmark_device_reduce.cpp +++ b/benchmark/benchmark_device_reduce.cpp @@ -35,20 +35,20 @@ #include -#ifndef DEFAULT_N -const size_t DEFAULT_N = 1024 * 1024 * 128; +#ifndef DEFAULT_BYTES +const size_t DEFAULT_BYTES = 1024 * 1024 * 128 * 4; #endif #define CREATE_BENCHMARK(T, REDUCE_OP) \ { \ const device_reduce_benchmark instance; \ - REGISTER_BENCHMARK(benchmarks, size, seed, stream, instance); \ + REGISTER_BENCHMARK(benchmarks, bytes, seed, stream, instance); \ } int main(int argc, char *argv[]) { cli::Parser parser(argc, argv); - parser.set_optional("size", "size", DEFAULT_N, "number of values"); + parser.set_optional("size", "size", DEFAULT_BYTES, "number of bytes"); parser.set_optional("trials", "trials", -1, "number of iterations"); parser.set_optional("name_format", "name_format", @@ -70,7 +70,7 @@ int main(int argc, char *argv[]) // Parse argv benchmark::Initialize(&argc, argv); - const size_t size = parser.get("size"); + const size_t bytes = parser.get("size"); const int trials = parser.get("trials"); bench_naming::set_format(parser.get("name_format")); const std::string seed_type = parser.get("seed"); @@ -81,7 +81,7 @@ int main(int argc, char *argv[]) // Benchmark info add_common_benchmark_info(); - benchmark::AddCustomContext("size", std::to_string(size)); + benchmark::AddCustomContext("bytes", std::to_string(bytes)); benchmark::AddCustomContext("seed", seed_type); // Add benchmarks @@ -92,7 +92,7 @@ int main(int argc, char *argv[]) config_autotune_register::register_benchmark_subset(benchmarks, parallel_instance, parallel_instances, - size, + bytes, seed, stream); #else diff --git a/benchmark/benchmark_device_reduce.parallel.hpp b/benchmark/benchmark_device_reduce.parallel.hpp index 275d13015..32a867051 100644 --- a/benchmark/benchmark_device_reduce.parallel.hpp +++ b/benchmark/benchmark_device_reduce.parallel.hpp @@ -85,10 +85,13 @@ struct device_reduce_benchmark : public config_autotune_interface static constexpr unsigned int warmup_size = 5; void run(benchmark::State& state, - size_t size, + size_t bytes, const managed_seed& seed, hipStream_t stream) const override { + // Calculate the number of elements + size_t size = bytes / sizeof(T); + BinaryFunction reduce_op{}; const auto random_range = limit_random_range(0, 1000); std::vector input diff --git a/benchmark/benchmark_device_run_length_encode.cpp b/benchmark/benchmark_device_run_length_encode.cpp index 93f0e3e67..345708983 100644 --- a/benchmark/benchmark_device_run_length_encode.cpp +++ b/benchmark/benchmark_device_run_length_encode.cpp @@ -137,12 +137,15 @@ void run_encode_benchmark(benchmark::State& state, for(size_t i = 0; i < batch_size; i++) { - rp::run_length_encode( - d_temporary_storage, temporary_storage_bytes, - d_input, size, - d_unique_output, d_counts_output, d_runs_count_output, - stream, false - ); + HIP_CHECK(rp::run_length_encode(d_temporary_storage, + temporary_storage_bytes, + d_input, + size, + d_unique_output, + d_counts_output, + d_runs_count_output, + stream, + false)); } // Record stop event and wait until it completes @@ -261,12 +264,15 @@ void run_non_trivial_runs_benchmark(benchmark::State& state, for(size_t i = 0; i < batch_size; i++) { - rp::run_length_encode_non_trivial_runs( - d_temporary_storage, temporary_storage_bytes, - d_input, size, - d_offsets_output, d_counts_output, d_runs_count_output, - stream, false - ); + HIP_CHECK(rp::run_length_encode_non_trivial_runs(d_temporary_storage, + temporary_storage_bytes, + d_input, + size, + d_offsets_output, + d_counts_output, + d_runs_count_output, + stream, + false)); } // Record stop event and wait until it completes diff --git a/benchmark/benchmark_device_scan.cpp b/benchmark/benchmark_device_scan.cpp index 936fa4f1b..7a8fce15f 100644 --- a/benchmark/benchmark_device_scan.cpp +++ b/benchmark/benchmark_device_scan.cpp @@ -35,14 +35,14 @@ #include -#ifndef DEFAULT_N -const size_t DEFAULT_N = 1024 * 1024 * 32; +#ifndef DEFAULT_BYTES +const size_t DEFAULT_BYTES = 1024 * 1024 * 32 * 4; #endif #define CREATE_EXCL_INCL_BENCHMARK(EXCL, T, SCAN_OP) \ { \ const device_scan_benchmark instance; \ - REGISTER_BENCHMARK(benchmarks, size, seed, stream, instance); \ + REGISTER_BENCHMARK(benchmarks, bytes, seed, stream, instance); \ } #define CREATE_BENCHMARK(T, SCAN_OP) \ @@ -52,7 +52,7 @@ const size_t DEFAULT_N = 1024 * 1024 * 32; int main(int argc, char* argv[]) { cli::Parser parser(argc, argv); - parser.set_optional("size", "size", DEFAULT_N, "number of values"); + parser.set_optional("size", "size", DEFAULT_BYTES, "number of bytes"); parser.set_optional("trials", "trials", -1, "number of iterations"); parser.set_optional("name_format", "name_format", @@ -74,7 +74,7 @@ int main(int argc, char* argv[]) // Parse argv benchmark::Initialize(&argc, argv); - const size_t size = parser.get("size"); + const size_t bytes = parser.get("size"); const int trials = parser.get("trials"); bench_naming::set_format(parser.get("name_format")); const std::string seed_type = parser.get("seed"); @@ -85,7 +85,7 @@ int main(int argc, char* argv[]) // Benchmark info add_common_benchmark_info(); - benchmark::AddCustomContext("size", std::to_string(size)); + benchmark::AddCustomContext("bytes", std::to_string(bytes)); benchmark::AddCustomContext("seed", seed_type); // Add benchmarks @@ -96,7 +96,7 @@ int main(int argc, char* argv[]) config_autotune_register::register_benchmark_subset(benchmarks, parallel_instance, parallel_instances, - size, + bytes, seed, stream); #else diff --git a/benchmark/benchmark_device_scan.parallel.hpp b/benchmark/benchmark_device_scan.parallel.hpp index bb93dde93..8fb58b83e 100644 --- a/benchmark/benchmark_device_scan.parallel.hpp +++ b/benchmark/benchmark_device_scan.parallel.hpp @@ -146,10 +146,13 @@ struct device_scan_benchmark : public config_autotune_interface } void run(benchmark::State& state, - size_t size, + size_t bytes, const managed_seed& seed, hipStream_t stream) const override { + // Calculate the number of elements + size_t size = bytes / sizeof(T); + ScanOp scan_op{}; const auto random_range = limit_random_range(0, 1000); std::vector input diff --git a/benchmark/benchmark_device_scan_by_key.cpp b/benchmark/benchmark_device_scan_by_key.cpp index 54834e46f..5528afa0d 100644 --- a/benchmark/benchmark_device_scan_by_key.cpp +++ b/benchmark/benchmark_device_scan_by_key.cpp @@ -35,8 +35,8 @@ #include -#ifndef DEFAULT_N -const size_t DEFAULT_N = 1024 * 1024 * 32; +#ifndef DEFAULT_BYTES +const size_t DEFAULT_BYTES = 1024 * 1024 * 32 * 4; #endif #define CREATE_BY_KEY_BENCHMARK(EXCL, T, SCAN_OP, MAX_SEGMENT_LENGTH) \ @@ -48,7 +48,7 @@ const size_t DEFAULT_N = 1024 * 1024 * 32; rocprim::equal_to, \ MAX_SEGMENT_LENGTH> \ instance; \ - REGISTER_BENCHMARK(benchmarks, size, seed, stream, instance); \ + REGISTER_BENCHMARK(benchmarks, bytes, seed, stream, instance); \ } #define CREATE_EXCL_INCL_BENCHMARK(EXCL, T, SCAN_OP) \ @@ -65,7 +65,7 @@ const size_t DEFAULT_N = 1024 * 1024 * 32; int main(int argc, char* argv[]) { cli::Parser parser(argc, argv); - parser.set_optional("size", "size", DEFAULT_N, "number of values"); + parser.set_optional("size", "size", DEFAULT_BYTES, "number of bytes"); parser.set_optional("trials", "trials", -1, "number of iterations"); parser.set_optional("name_format", "name_format", @@ -87,7 +87,7 @@ int main(int argc, char* argv[]) // Parse argv benchmark::Initialize(&argc, argv); - const size_t size = parser.get("size"); + const size_t bytes = parser.get("size"); const int trials = parser.get("trials"); bench_naming::set_format(parser.get("name_format")); const std::string seed_type = parser.get("seed"); @@ -98,7 +98,7 @@ int main(int argc, char* argv[]) // Benchmark info add_common_benchmark_info(); - benchmark::AddCustomContext("size", std::to_string(size)); + benchmark::AddCustomContext("bytes", std::to_string(bytes)); benchmark::AddCustomContext("seed", seed_type); // Add benchmarks @@ -109,7 +109,7 @@ int main(int argc, char* argv[]) config_autotune_register::register_benchmark_subset(benchmarks, parallel_instance, parallel_instances, - size, + bytes, seed, stream); #else diff --git a/benchmark/benchmark_device_scan_by_key.parallel.hpp b/benchmark/benchmark_device_scan_by_key.parallel.hpp index e85611782..b0c80d842 100644 --- a/benchmark/benchmark_device_scan_by_key.parallel.hpp +++ b/benchmark/benchmark_device_scan_by_key.parallel.hpp @@ -162,10 +162,13 @@ struct device_scan_by_key_benchmark : public config_autotune_interface } void run(benchmark::State& state, - size_t size, + size_t bytes, const managed_seed& seed, hipStream_t stream) const override { + // Calculate the number of elements + size_t size = bytes / sizeof(Value); + constexpr bool debug = false; const std::vector keys diff --git a/benchmark/benchmark_device_scan_by_key_deterministic.cpp b/benchmark/benchmark_device_scan_by_key_deterministic.cpp index 76fa1078b..126f99a02 100644 --- a/benchmark/benchmark_device_scan_by_key_deterministic.cpp +++ b/benchmark/benchmark_device_scan_by_key_deterministic.cpp @@ -35,8 +35,8 @@ #include -#ifndef DEFAULT_N -const size_t DEFAULT_N = 1024 * 1024 * 32; +#ifndef DEFAULT_BYTES +const size_t DEFAULT_BYTES = 1024 * 1024 * 32 * 4; #endif #define CREATE_BY_KEY_BENCHMARK(EXCL, T, SCAN_OP, MAX_SEGMENT_LENGTH) \ @@ -49,7 +49,7 @@ const size_t DEFAULT_N = 1024 * 1024 * 32; MAX_SEGMENT_LENGTH, \ true> \ instance; \ - REGISTER_BENCHMARK(benchmarks, size, seed, stream, instance); \ + REGISTER_BENCHMARK(benchmarks, bytes, seed, stream, instance); \ } #define CREATE_EXCL_INCL_BENCHMARK(EXCL, T, SCAN_OP) \ @@ -66,7 +66,7 @@ const size_t DEFAULT_N = 1024 * 1024 * 32; int main(int argc, char* argv[]) { cli::Parser parser(argc, argv); - parser.set_optional("size", "size", DEFAULT_N, "number of values"); + parser.set_optional("size", "size", DEFAULT_BYTES, "number of bytes"); parser.set_optional("trials", "trials", -1, "number of iterations"); parser.set_optional("name_format", "name_format", @@ -77,7 +77,7 @@ int main(int argc, char* argv[]) // Parse argv benchmark::Initialize(&argc, argv); - const size_t size = parser.get("size"); + const size_t bytes = parser.get("size"); const int trials = parser.get("trials"); bench_naming::set_format(parser.get("name_format")); const std::string seed_type = parser.get("seed"); @@ -88,7 +88,7 @@ int main(int argc, char* argv[]) // Benchmark info add_common_benchmark_info(); - benchmark::AddCustomContext("size", std::to_string(size)); + benchmark::AddCustomContext("bytes", std::to_string(bytes)); benchmark::AddCustomContext("seed", seed_type); // Add benchmarks diff --git a/benchmark/benchmark_device_scan_deterministic.cpp b/benchmark/benchmark_device_scan_deterministic.cpp index 1bb62a481..b13c82bf9 100644 --- a/benchmark/benchmark_device_scan_deterministic.cpp +++ b/benchmark/benchmark_device_scan_deterministic.cpp @@ -35,14 +35,14 @@ #include -#ifndef DEFAULT_N -const size_t DEFAULT_N = 1024 * 1024 * 32; +#ifndef DEFAULT_BYTES +const size_t DEFAULT_BYTES = 1024 * 1024 * 32 * 4; #endif #define CREATE_EXCL_INCL_BENCHMARK(EXCL, T, SCAN_OP) \ { \ const device_scan_benchmark instance; \ - REGISTER_BENCHMARK(benchmarks, size, seed, stream, instance); \ + REGISTER_BENCHMARK(benchmarks, bytes, seed, stream, instance); \ } #define CREATE_BENCHMARK(T, SCAN_OP) \ @@ -52,7 +52,7 @@ const size_t DEFAULT_N = 1024 * 1024 * 32; int main(int argc, char* argv[]) { cli::Parser parser(argc, argv); - parser.set_optional("size", "size", DEFAULT_N, "number of values"); + parser.set_optional("size", "size", DEFAULT_BYTES, "number of bytes"); parser.set_optional("trials", "trials", -1, "number of iterations"); parser.set_optional("name_format", "name_format", @@ -63,7 +63,7 @@ int main(int argc, char* argv[]) // Parse argv benchmark::Initialize(&argc, argv); - const size_t size = parser.get("size"); + const size_t bytes = parser.get("size"); const int trials = parser.get("trials"); bench_naming::set_format(parser.get("name_format")); const std::string seed_type = parser.get("seed"); @@ -74,7 +74,7 @@ int main(int argc, char* argv[]) // Benchmark info add_common_benchmark_info(); - benchmark::AddCustomContext("size", std::to_string(size)); + benchmark::AddCustomContext("bytes", std::to_string(bytes)); benchmark::AddCustomContext("seed", seed_type); // Add benchmarks diff --git a/benchmark/benchmark_device_search.cpp b/benchmark/benchmark_device_search.cpp new file mode 100644 index 000000000..5f00d9416 --- /dev/null +++ b/benchmark/benchmark_device_search.cpp @@ -0,0 +1,131 @@ +// MIT License +// +// Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#include "benchmark_device_search.hpp" +#include "benchmark_utils.hpp" + +// CmdParser +#include "cmdparser.hpp" + +// Google Benchmark +#include + +// HIP API +#include + +#include +#include + +#ifndef DEFAULT_BYTES +const size_t DEFAULT_BYTES = 1024 * 1024 * 32 * 4; +#endif + +#define CREATE_BENCHMARK_SEARCH(TYPE, KEY_SIZE, REPEATING) \ + { \ + const device_search_benchmark instance(KEY_SIZE, REPEATING); \ + REGISTER_BENCHMARK(benchmarks, bytes, seed, stream, instance); \ + } + +#define CREATE_BENCHMARK_PATTERN(TYPE, REPEATING) \ + { \ + CREATE_BENCHMARK_SEARCH(TYPE, 10, REPEATING) \ + CREATE_BENCHMARK_SEARCH(TYPE, 100, REPEATING) \ + CREATE_BENCHMARK_SEARCH(TYPE, 1000, REPEATING) \ + CREATE_BENCHMARK_SEARCH(TYPE, 10000, REPEATING) \ + } + +#define CREATE_BENCHMARK(TYPE) \ + { \ + CREATE_BENCHMARK_PATTERN(TYPE, true) \ + CREATE_BENCHMARK_PATTERN(TYPE, false) \ + } + +int main(int argc, char* argv[]) +{ + cli::Parser parser(argc, argv); + parser.set_optional("bytes", "bytes", DEFAULT_BYTES, "number of values"); + parser.set_optional("trials", "trials", -1, "number of iterations"); + parser.set_optional("name_format", + "name_format", + "human", + "either: json,human,txt"); + parser.set_optional("seed", "seed", "random", get_seed_message()); + parser.run_and_exit_if_error(); + + // Parse argv + benchmark::Initialize(&argc, argv); + const size_t bytes = parser.get("bytes"); + const int trials = parser.get("trials"); + bench_naming::set_format(parser.get("name_format")); + const std::string seed_type = parser.get("seed"); + const managed_seed seed(seed_type); + + // HIP + hipStream_t stream = 0; // default + + // Benchmark info + add_common_benchmark_info(); + benchmark::AddCustomContext("bytes", std::to_string(bytes)); + benchmark::AddCustomContext("seed", seed_type); + + // Add benchmarks + std::vector benchmarks{}; + CREATE_BENCHMARK(int) + CREATE_BENCHMARK(long long) + CREATE_BENCHMARK(int8_t) + CREATE_BENCHMARK(uint8_t) + CREATE_BENCHMARK(rocprim::half) + CREATE_BENCHMARK(short) + CREATE_BENCHMARK(float) + + using custom_float2 = custom_type; + using custom_double2 = custom_type; + using custom_int2 = custom_type; + using custom_char_double = custom_type; + using custom_longlong_double = custom_type; + + CREATE_BENCHMARK(custom_float2) + CREATE_BENCHMARK(custom_double2) + CREATE_BENCHMARK(custom_int2) + CREATE_BENCHMARK(custom_char_double) + CREATE_BENCHMARK(custom_longlong_double) + + // Use manual timing + for(auto& b : benchmarks) + { + b->UseManualTime(); + b->Unit(benchmark::kMillisecond); + } + + // Force number of iterations + if(trials > 0) + { + for(auto& b : benchmarks) + { + b->Iterations(trials); + } + } + + // Run benchmarks + benchmark::RunSpecifiedBenchmarks(); + return 0; +} diff --git a/benchmark/benchmark_device_search.hpp b/benchmark/benchmark_device_search.hpp new file mode 100644 index 000000000..7b6afa4f9 --- /dev/null +++ b/benchmark/benchmark_device_search.hpp @@ -0,0 +1,199 @@ +// MIT License +// +// Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#ifndef ROCPRIM_BENCHMARK_DEVICE_SEARCH_PARALLEL_HPP_ +#define ROCPRIM_BENCHMARK_DEVICE_SEARCH_PARALLEL_HPP_ + +#include "benchmark_utils.hpp" + +// Google Benchmark +#include + +// HIP API +#include + +// rocPRIM +#include + +#include +#include + +#include + +template +struct device_search_benchmark : public config_autotune_interface +{ + size_t key_size_ = 10; + bool repeating_ = false; + + device_search_benchmark(size_t KeySize, bool repeating) + { + key_size_ = KeySize; + repeating_ = repeating; + } + + std::string name() const override + { + using namespace std::string_literals; + return bench_naming::format_name( + "{lvl:device,algo:search,value_pattern:" + (repeating_ ? "repeating"s : "random"s) + + ",key_size:" + std::to_string(key_size_) + + ",value_type:" + std::string(Traits::name()) + ",cfg:default_config}"); + } + + static constexpr unsigned int batch_size = 10; + static constexpr unsigned int warmup_size = 5; + + void run(benchmark::State& state, + size_t bytes, + const managed_seed& seed, + hipStream_t stream) const override + { + using key_type = Key; + using output_type = size_t; + + // Calculate the number of elements + size_t size = bytes / sizeof(key_type); + size_t key_size = std::min(size, key_size_); + + // Generate data + std::vector keys_input + = get_random_data(key_size, + generate_limits::min(), + generate_limits::max(), + seed.get_0()); + + std::vector input(size); + if(repeating_) + { + // Repeating similar pattern without early exits. + keys_input[key_size - 1] = 0; + for(size_t i = 0; i < size; i++) + { + input[i] = keys_input[i % key_size]; + } + keys_input[key_size - 1] = 1; + } + else + { + input = get_random_data(size, + generate_limits::min(), + generate_limits::max(), + seed.get_0() + 1); + } + + key_type* d_keys_input; + key_type* d_input; + output_type* d_output; + HIP_CHECK(hipMalloc(&d_keys_input, key_size * sizeof(*d_keys_input))); + HIP_CHECK(hipMalloc(&d_input, size * sizeof(*d_input))); + HIP_CHECK(hipMalloc(&d_output, sizeof(*d_output))); + + HIP_CHECK(hipMemcpy(d_input, input.data(), size * sizeof(*d_input), hipMemcpyHostToDevice)); + + HIP_CHECK(hipMemcpy(d_keys_input, + keys_input.data(), + key_size * sizeof(*d_keys_input), + hipMemcpyHostToDevice)); + + rocprim::equal_to compare_op; + + void* d_temporary_storage = nullptr; + size_t temporary_storage_bytes = 0; + + HIP_CHECK(rocprim::search(d_temporary_storage, + temporary_storage_bytes, + d_input, + d_keys_input, + d_output, + size, + key_size, + compare_op, + stream, + false)); + + HIP_CHECK(hipMalloc(&d_temporary_storage, temporary_storage_bytes)); + + // Warm-up + for(size_t i = 0; i < warmup_size; i++) + { + HIP_CHECK(rocprim::search(d_temporary_storage, + temporary_storage_bytes, + d_input, + d_keys_input, + d_output, + size, + key_size, + compare_op, + stream, + false)); + } + HIP_CHECK(hipDeviceSynchronize()); + + // HIP events creation + hipEvent_t start, stop; + HIP_CHECK(hipEventCreate(&start)); + HIP_CHECK(hipEventCreate(&stop)); + + for(auto _ : state) + { + // Record start event + HIP_CHECK(hipEventRecord(start, stream)); + + for(size_t i = 0; i < batch_size; i++) + { + HIP_CHECK(rocprim::search(d_temporary_storage, + temporary_storage_bytes, + d_input, + d_keys_input, + d_output, + size, + key_size, + compare_op, + stream, + false)); + } + + // Record stop event and wait until it completes + HIP_CHECK(hipEventRecord(stop, stream)); + HIP_CHECK(hipEventSynchronize(stop)); + + float elapsed_mseconds; + HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); + state.SetIterationTime(elapsed_mseconds / 1000); + } + + // Destroy HIP events + HIP_CHECK(hipEventDestroy(start)); + HIP_CHECK(hipEventDestroy(stop)); + + state.SetBytesProcessed(state.iterations() * batch_size * size * sizeof(*d_input)); + state.SetItemsProcessed(state.iterations() * batch_size * size); + + HIP_CHECK(hipFree(d_temporary_storage)); + HIP_CHECK(hipFree(d_keys_input)); + HIP_CHECK(hipFree(d_input)); + HIP_CHECK(hipFree(d_output)); + } +}; + +#endif // ROCPRIM_BENCHMARK_DEVICE_SEARCH_PARALLEL_HPP_ diff --git a/benchmark/benchmark_device_search_n.cpp b/benchmark/benchmark_device_search_n.cpp new file mode 100644 index 000000000..6705ca0d9 --- /dev/null +++ b/benchmark/benchmark_device_search_n.cpp @@ -0,0 +1,77 @@ +// MIT License +// +// Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#include "benchmark_device_search_n.parallel.hpp" + +int main(int argc, char* argv[]) +{ + cli::Parser parser(argc, argv); + parser.set_optional("size", "size", size_t{2} << 30, "number of input bytes"); + parser.set_optional("trials", "trials", -1, "number of iterations"); + parser.set_optional("name_format", + "name_format", + "human", + "either: json,human,txt"); + parser.set_optional("seed", "seed", "random", get_seed_message()); + parser.run_and_exit_if_error(); + + // Parse argv + benchmark::Initialize(&argc, argv); + const size_t size = parser.get("size"); + const int trials = parser.get("trials"); + bench_naming::set_format(parser.get("name_format")); + const std::string seed_type = parser.get("seed"); + const managed_seed seed(seed_type); + + // HIP + hipStream_t stream = 0; // default + + // Benchmark info + add_common_benchmark_info(); + benchmark::AddCustomContext("size", std::to_string(size)); + benchmark::AddCustomContext("seed", seed_type); + + // Add benchmarks + std::vector benchmarks{}; + add_benchmark_search_n(benchmarks, seed, stream, size); + + // Use manual timing + for(auto& b : benchmarks) + { + b->UseManualTime(); + b->Unit(benchmark::kMillisecond); + } + + // Force number of iterations + if(trials > 0) + { + for(auto& b : benchmarks) + { + b->Iterations(trials); + } + } + + // Run benchmarks + benchmark::RunSpecifiedBenchmarks(); + clean_up_benchmarks_search_n(); + return 0; +} diff --git a/benchmark/benchmark_device_search_n.parallel.cpp.in b/benchmark/benchmark_device_search_n.parallel.cpp.in new file mode 100644 index 000000000..6e3db47f0 --- /dev/null +++ b/benchmark/benchmark_device_search_n.parallel.cpp.in @@ -0,0 +1,33 @@ +// MIT License +// +// Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#include + +#include "benchmark_device_search_n.parallel.hpp" +#include "benchmark_utils.hpp" + +namespace { + auto benchmarks = config_autotune_register::create_bulk( + device_search_n_benchmark_generator< + @InputType@, + @BlockSize@>::create); +} \ No newline at end of file diff --git a/benchmark/benchmark_device_search_n.parallel.hpp b/benchmark/benchmark_device_search_n.parallel.hpp new file mode 100644 index 000000000..6bec544ed --- /dev/null +++ b/benchmark/benchmark_device_search_n.parallel.hpp @@ -0,0 +1,431 @@ +// MIT License +// +// Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#ifndef ROCPRIM_BENCHMARK_DEVICE_SEARCH_N_PARALLEL_HPP_ +#define ROCPRIM_BENCHMARK_DEVICE_SEARCH_N_PARALLEL_HPP_ + +#include "benchmark_utils.hpp" +#include "cmdparser.hpp" + +// gbench +#include + +// HIP +#include + +// rocPRIM +#include +#include + +// C++ Standard Library +#include +#include +#include +#include +#include +#include +#include +#include + +using custom_int2 = custom_type; +using custom_double2 = custom_type; +using custom_longlong_double = custom_type; + +namespace +{ +template +struct type_arr +{ + using type = First; + using next = type_arr; +}; +template +struct type_arr +{ + using type = First; +}; +template +using void_type = void; +template +constexpr bool is_type_arr_end = true; +template +constexpr bool is_type_arr_end> = false; + +template +inline unsigned int search_n_get_item_per_block() +{ + using input_type = InputType; + using config = Config; + using wrapped_config = rocprim::detail::wrapped_search_n_config; + + hipStream_t stream = 0; // default + rocprim::detail::target_arch target_arch; + HIP_CHECK(rocprim::detail::host_target_arch(stream, target_arch)); + const auto params = rocprim::detail::dispatch_target_arch(target_arch); + const unsigned int block_size = params.kernel_config.block_size; + const unsigned int items_per_thread = params.kernel_config.items_per_thread; + const unsigned int items_per_block = block_size * items_per_thread; + return items_per_block; +} + +enum class benchmark_search_n_mode +{ + NORMAL = 0, + NOISE = 1, +}; + +inline std::string to_string(benchmark_search_n_mode e) noexcept +{ + switch(e) + { + case benchmark_search_n_mode::NORMAL: return "NORMAL"; + case benchmark_search_n_mode::NOISE: return "NOISE"; + default: return "UNKNOWN"; + } +} + +} // namespace + +template +class benchmark_search_n +{ +public: + const managed_seed seed; + const hipStream_t stream; + size_t size_byte; + size_t count_byte; + size_t start_pos_byte; + InputType value; + std::vector input; + +private: + size_t size; + size_t count; + size_t start_pos; + const size_t warmup_size = 10; + const size_t batch_size = 10; + size_t temp_storage_size = 0; + size_t noise_sequence = 0; + bool create_noise = false; + + hipEvent_t start; + hipEvent_t stop; + + void* d_temp_storage = nullptr; + InputType* d_input; + OutputType* d_output; + InputType* d_value; + + void create() noexcept + { + switch(mode) + { + case benchmark_search_n_mode::NORMAL: + { + input.resize(size); + if(start_pos + count < size) + { + std::fill(input.begin(), input.begin() + start_pos, 0); + std::fill(input.begin() + start_pos, + input.begin() + count + start_pos, + value); + std::fill(input.begin() + count + start_pos, input.end(), 0); + } + else + { + std::fill(input.begin(), input.end(), 0); + } + break; + } + case benchmark_search_n_mode::NOISE: + { + InputType h_noise{0}; + input = std::vector(size, value); + + if(create_noise) + { + size_t cur_tile = 0; + size_t last_tile = size / count - 1; + while(cur_tile != last_tile) + { + input[cur_tile * count + count - 1] = h_noise; + ++cur_tile; + } + } + break; + } + default: + { + break; + } + } + + HIP_CHECK(hipMallocAsync(&d_value, sizeof(InputType), stream)); + HIP_CHECK(hipMallocAsync(&d_input, sizeof(InputType) * input.size(), stream)); + HIP_CHECK(hipMallocAsync(&d_output, sizeof(OutputType), stream)); + HIP_CHECK( + hipMemcpyAsync(d_value, &value, sizeof(InputType), hipMemcpyHostToDevice, stream)); + HIP_CHECK(hipMemcpyAsync(d_input, + input.data(), + sizeof(InputType) * input.size(), + hipMemcpyHostToDevice, + stream)); + + HIP_CHECK(hipEventCreate(&start)); + HIP_CHECK(hipEventCreate(&stop)); + } + + void release() noexcept + { + decltype(input) tmp; + input.swap(tmp); // clear input memspace + HIP_CHECK(hipEventDestroy(start)); + HIP_CHECK(hipEventDestroy(stop)); + HIP_CHECK(hipFree(d_value)); + + HIP_CHECK(hipFree(d_input)); + HIP_CHECK(hipFree(d_output)); + } + + void launch_search_n() + { + HIP_CHECK(::rocprim::search_n(d_temp_storage, + temp_storage_size, + d_input, + d_output, + size, + count, + d_value, + rocprim::equal_to{}, + stream, + false)); + } + + static void run(benchmark::State& state, benchmark_search_n const* _self) + { + auto& self = *const_cast(_self); + self.create(); + + // allocate memory + self.launch_search_n(); + HIP_CHECK(hipMallocAsync(&self.d_temp_storage, self.temp_storage_size, self.stream)); + // Warm-up + for(size_t i = 0; i < self.warmup_size; i++) + { + self.launch_search_n(); + } + HIP_CHECK(hipStreamSynchronize(self.stream)); + + // Run + for(auto _ : state) + { + // Record start event + HIP_CHECK(hipEventRecord(self.start, self.stream)); + + for(size_t i = 0; i < self.batch_size; i++) + { + self.launch_search_n(); + } + + // Record stop event and wait until it completes + HIP_CHECK(hipEventRecord(self.stop, self.stream)); + HIP_CHECK(hipEventSynchronize(self.stop)); + + float elapsed_mseconds; + HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, self.start, self.stop)); + state.SetIterationTime(elapsed_mseconds / 1000); + } + // Clean-up + HIP_CHECK(hipFree(self.d_temp_storage)); + self.d_temp_storage = nullptr; + self.temp_storage_size = 0; + state.SetBytesProcessed(state.iterations() * self.batch_size * self.size + * sizeof(*(self.d_input))); + state.SetItemsProcessed(state.iterations() * self.batch_size * self.size); + self.release(); + } + +public: + benchmark_search_n( + const managed_seed _seed, + const hipStream_t _stream, + const size_t _size_byte, + const size_t _count_byte, // for NOISE benchmarks, this is the multiple of count + const size_t _start_pos_byte) noexcept + : seed(_seed) + , stream(_stream) + , size_byte(_size_byte) + , count_byte(_count_byte) + , start_pos_byte(_start_pos_byte) + , value{1} + , input() + { + switch(mode) + { + case benchmark_search_n_mode::NORMAL: + { + size = size_byte / sizeof(InputType); + count = count_byte / sizeof(InputType); + start_pos = start_pos_byte / sizeof(InputType); + break; + } + case benchmark_search_n_mode::NOISE: + { + size = size_byte / sizeof(InputType); + count = count_byte; + noise_sequence + = _start_pos_byte == (size_t)-1 + ? search_n_get_item_per_block() + : _start_pos_byte; + + if(size > noise_sequence * count) + { + count = noise_sequence * count; + create_noise = true; + } + break; + } + } + } + + benchmark::internal::Benchmark* bench_register() const noexcept + { + return benchmark::RegisterBenchmark( + bench_naming::format_name( + "{lvl:device,algo:search_n,input_type:" + std::string(typeid(InputType).name()) + + ",size:" + std::to_string(size) + ",count:" + std::to_string(count) + + ",mode:" + to_string(mode) + ",cfg:default_config}") + .c_str(), + run, + this); + } +}; + +using destructor_t = std::function; +static std::vector destructors; + +static void clean_up_benchmarks_search_n() +{ + for(auto& i : destructors) + { + i(); + } + destructors = {}; +} + +template +inline void add_one_benchmark_search_n(std::vector& benchmarks, + const managed_seed _seed, + const hipStream_t _stream, + const size_t _size_byte) +{ + // normal + auto start_from_0 + = new benchmark_search_n(_seed, + _stream, + _size_byte, + _size_byte, + 0); + auto start_from_mid + = new benchmark_search_n(_seed, + _stream, + _size_byte, + _size_byte / 2, + _size_byte / 2); + // small count test + auto small_count6 + = new benchmark_search_n(_seed, + _stream, + _size_byte, + 1, // count times + 6); + // mid count test + auto mid_count4095 + = new benchmark_search_n(_seed, + _stream, + _size_byte, + 1, // count times + 4095); + // big input + auto big_count6 + = new benchmark_search_n(_seed, + _stream, + _size_byte, + 6, // count times + (size_t)-1); + std::vector bs = {start_from_0->bench_register(), + start_from_mid->bench_register(), + small_count6->bench_register(), + mid_count4095->bench_register(), + big_count6->bench_register()}; + destructors.emplace_back( + [=]() + { + delete start_from_0; + delete start_from_mid; + delete small_count6; + delete mid_count4095; + delete big_count6; + }); + benchmarks.insert(benchmarks.end(), bs.begin(), bs.end()); +} + +template, bool> = true> +inline void add_benchmark_search_n(std::vector& benchmarks, + const managed_seed _seed, + const hipStream_t _stream, + const size_t _size_byte) +{ + add_one_benchmark_search_n(benchmarks, _seed, _stream, _size_byte); + add_benchmark_search_n(benchmarks, _seed, _stream, _size_byte); +} +template, bool> = true> +inline void add_benchmark_search_n(std::vector& benchmarks, + const managed_seed _seed, + const hipStream_t _stream, + const size_t _size_byte) +{ + add_one_benchmark_search_n(benchmarks, _seed, _stream, _size_byte); +} + +typedef type_arr + benchmark_search_n_types; + +template +struct device_search_n_benchmark_generator +{ + // TODO: add implementation + struct create_search_n_algorithm + {}; + // TODO: add implementation + static void create(std::vector>&) {} +}; + +#endif // ROCPRIM_BENCHMARK_DEVICE_SEARCH_N_PARALLEL_HPP_ diff --git a/benchmark/benchmark_device_segmented_radix_sort_keys.cpp b/benchmark/benchmark_device_segmented_radix_sort_keys.cpp index 7e2bb4477..0ec242d0b 100644 --- a/benchmark/benchmark_device_segmented_radix_sort_keys.cpp +++ b/benchmark/benchmark_device_segmented_radix_sort_keys.cpp @@ -39,8 +39,8 @@ #include #include -#ifndef DEFAULT_N -const size_t DEFAULT_N = 1024 * 1024 * 32; +#ifndef DEFAULT_BYTES +const size_t DEFAULT_BYTES = 1024 * 1024 * 32 * 4; #endif namespace rp = rocprim; @@ -62,13 +62,16 @@ template void run_sort_keys_benchmark(benchmark::State& state, size_t num_segments, size_t mean_segment_length, - size_t target_size, + size_t target_bytes, const managed_seed& seed, hipStream_t stream) { using offset_type = int; using key_type = Key; + // Calculate the number of elements + size_t target_size = target_bytes / sizeof(key_type); + std::vector offsets; offsets.push_back(0); @@ -205,12 +208,15 @@ void run_sort_keys_benchmark(benchmark::State& state, template void add_sort_keys_benchmarks(std::vector& benchmarks, - size_t max_size, + size_t max_bytes, size_t min_size, size_t target_size, const managed_seed& seed, hipStream_t stream) { + // Calculate the number of elements + size_t max_size = max_bytes / sizeof(KeyT); + std::string key_name = Traits::name(); std::string value_name = Traits::name(); for(const auto segment_count : segment_counts) @@ -244,7 +250,7 @@ void add_sort_keys_benchmarks(std::vector& benc int main(int argc, char* argv[]) { cli::Parser parser(argc, argv); - parser.set_optional("size", "size", DEFAULT_N, "number of values"); + parser.set_optional("size", "size", DEFAULT_BYTES, "number of bytes"); parser.set_optional("trials", "trials", -1, "number of iterations"); parser.set_optional("name_format", "name_format", @@ -268,7 +274,7 @@ int main(int argc, char* argv[]) // Parse argv benchmark::Initialize(&argc, argv); - const size_t size = parser.get("size"); + const size_t bytes = parser.get("size"); const int trials = parser.get("trials"); bench_naming::set_format(parser.get("name_format")); const std::string seed_type = parser.get("seed"); @@ -279,7 +285,7 @@ int main(int argc, char* argv[]) // Benchmark info add_common_benchmark_info(); - benchmark::AddCustomContext("size", std::to_string(size)); + benchmark::AddCustomContext("bytes", std::to_string(bytes)); benchmark::AddCustomContext("seed", seed_type); // Add benchmarks @@ -295,12 +301,12 @@ int main(int argc, char* argv[]) seed, stream); #else - add_sort_keys_benchmarks(benchmarks, size, min_size, size / 2, seed, stream); - add_sort_keys_benchmarks(benchmarks, size, min_size, size / 2, seed, stream); - add_sort_keys_benchmarks(benchmarks, size, min_size, size / 2, seed, stream); - add_sort_keys_benchmarks(benchmarks, size, min_size, size / 2, seed, stream); - add_sort_keys_benchmarks(benchmarks, size, min_size, size / 2, seed, stream); - add_sort_keys_benchmarks(benchmarks, size, min_size, size / 2, seed, stream); + add_sort_keys_benchmarks(benchmarks, bytes, min_size, bytes / 2, seed, stream); + add_sort_keys_benchmarks(benchmarks, bytes, min_size, bytes / 2, seed, stream); + add_sort_keys_benchmarks(benchmarks, bytes, min_size, bytes / 2, seed, stream); + add_sort_keys_benchmarks(benchmarks, bytes, min_size, bytes / 2, seed, stream); + add_sort_keys_benchmarks(benchmarks, bytes, min_size, bytes / 2, seed, stream); + add_sort_keys_benchmarks(benchmarks, bytes, min_size, bytes / 2, seed, stream); #endif // Use manual timing diff --git a/benchmark/benchmark_device_segmented_radix_sort_keys.parallel.hpp b/benchmark/benchmark_device_segmented_radix_sort_keys.parallel.hpp index adaad263b..ad798af9f 100644 --- a/benchmark/benchmark_device_segmented_radix_sort_keys.parallel.hpp +++ b/benchmark/benchmark_device_segmented_radix_sort_keys.parallel.hpp @@ -235,10 +235,13 @@ struct device_segmented_radix_sort_benchmark : public config_autotune_interface } void run(benchmark::State& state, - size_t size, + size_t bytes, const managed_seed& seed, hipStream_t stream) const override { + // Calculate the number of elements + size_t size = bytes / sizeof(Key); + constexpr std::array segment_counts{10, 100, 1000, 2500, 5000, 7500, 10000, 100000}; constexpr std::array segment_lengths{30, 256, 3000, 300000}; diff --git a/benchmark/benchmark_device_segmented_radix_sort_pairs.cpp b/benchmark/benchmark_device_segmented_radix_sort_pairs.cpp index fa8831d71..d0666deeb 100644 --- a/benchmark/benchmark_device_segmented_radix_sort_pairs.cpp +++ b/benchmark/benchmark_device_segmented_radix_sort_pairs.cpp @@ -39,8 +39,8 @@ #include #include -#ifndef DEFAULT_N -const size_t DEFAULT_N = 1024 * 1024 * 32; +#ifndef DEFAULT_BYTES +const size_t DEFAULT_BYTES = 1024 * 1024 * 32 * 4; #endif namespace rp = rocprim; @@ -62,7 +62,7 @@ template void run_sort_pairs_benchmark(benchmark::State& state, size_t num_segments, size_t mean_segment_length, - size_t target_size, + size_t target_bytes, const managed_seed& seed, hipStream_t stream) { @@ -70,6 +70,9 @@ void run_sort_pairs_benchmark(benchmark::State& state, using key_type = Key; using value_type = Value; + // Calculate the number of elements + size_t target_size = target_bytes / sizeof(key_type); + // Generate data std::vector offsets; offsets.push_back(0); @@ -228,12 +231,15 @@ void run_sort_pairs_benchmark(benchmark::State& state, template void add_sort_pairs_benchmarks(std::vector& benchmarks, - size_t max_size, + size_t max_bytes, size_t min_size, size_t target_size, const managed_seed& seed, hipStream_t stream) { + // Calculate the number of elements + size_t max_size = max_bytes / sizeof(KeyT); + std::string key_name = Traits::name(); std::string value_name = Traits::name(); for(const auto segment_count : segment_counts) @@ -267,7 +273,7 @@ void add_sort_pairs_benchmarks(std::vector& ben int main(int argc, char* argv[]) { cli::Parser parser(argc, argv); - parser.set_optional("size", "size", DEFAULT_N, "number of values"); + parser.set_optional("size", "size", DEFAULT_BYTES, "number of bytes"); parser.set_optional("trials", "trials", -1, "number of iterations"); parser.set_optional("name_format", "name_format", @@ -290,7 +296,7 @@ int main(int argc, char* argv[]) // Parse argv benchmark::Initialize(&argc, argv); - const size_t size = parser.get("size"); + const size_t bytes = parser.get("size"); const int trials = parser.get("trials"); bench_naming::set_format(parser.get("name_format")); const std::string seed_type = parser.get("seed"); @@ -301,7 +307,7 @@ int main(int argc, char* argv[]) // Benchmark info add_common_benchmark_info(); - benchmark::AddCustomContext("size", std::to_string(size)); + benchmark::AddCustomContext("bytes", std::to_string(bytes)); benchmark::AddCustomContext("seed", seed_type); // Add benchmarks @@ -313,37 +319,37 @@ int main(int argc, char* argv[]) config_autotune_register::register_benchmark_subset(benchmarks, parallel_instance, parallel_instances, - size, + bytes, seed, stream); #else using custom_float2 = custom_type; using custom_double2 = custom_type; - add_sort_pairs_benchmarks(benchmarks, size, min_size, size / 2, seed, stream); + add_sort_pairs_benchmarks(benchmarks, bytes, min_size, bytes / 2, seed, stream); add_sort_pairs_benchmarks(benchmarks, - size, + bytes, min_size, - size / 2, + bytes / 2, seed, stream); - add_sort_pairs_benchmarks(benchmarks, size, min_size, size / 2, seed, stream); - add_sort_pairs_benchmarks(benchmarks, size, min_size, size / 2, seed, stream); + add_sort_pairs_benchmarks(benchmarks, bytes, min_size, bytes / 2, seed, stream); + add_sort_pairs_benchmarks(benchmarks, bytes, min_size, bytes / 2, seed, stream); add_sort_pairs_benchmarks(benchmarks, - size, + bytes, min_size, - size / 2, + bytes / 2, seed, stream); add_sort_pairs_benchmarks(benchmarks, - size, + bytes, min_size, - size / 2, + bytes / 2, seed, stream); add_sort_pairs_benchmarks(benchmarks, - size, + bytes, min_size, - size / 2, + bytes / 2, seed, stream); #endif diff --git a/benchmark/benchmark_device_segmented_radix_sort_pairs.parallel.hpp b/benchmark/benchmark_device_segmented_radix_sort_pairs.parallel.hpp index c3beaa863..3a5e3ecfd 100644 --- a/benchmark/benchmark_device_segmented_radix_sort_pairs.parallel.hpp +++ b/benchmark/benchmark_device_segmented_radix_sort_pairs.parallel.hpp @@ -260,10 +260,13 @@ struct device_segmented_radix_sort_benchmark : public config_autotune_interface } void run(benchmark::State& state, - size_t size, + size_t bytes, const managed_seed& seed, hipStream_t stream) const override { + // Calculate the number of elements + size_t size = bytes / sizeof(Key); + constexpr std::array segment_counts{10, 100, 1000, 2500, 5000, 7500, 10000, 100000}; constexpr std::array segment_lengths{30, 256, 3000, 300000}; diff --git a/benchmark/benchmark_device_segmented_reduce.cpp b/benchmark/benchmark_device_segmented_reduce.cpp index 1d8495861..ddc5fff8f 100644 --- a/benchmark/benchmark_device_segmented_reduce.cpp +++ b/benchmark/benchmark_device_segmented_reduce.cpp @@ -39,8 +39,8 @@ #include #include -#ifndef DEFAULT_N -const size_t DEFAULT_N = 1024 * 1024 * 32; +#ifndef DEFAULT_BYTES +const size_t DEFAULT_BYTES = 1024 * 1024 * 32 * 4; #endif namespace rp = rocprim; @@ -51,13 +51,16 @@ const unsigned int warmup_size = 5; template void run_benchmark(benchmark::State& state, size_t desired_segments, - size_t size, + size_t bytes, const managed_seed& seed, hipStream_t stream) { using offset_type = int; using value_type = T; + // Calculate the number of elements + size_t size = bytes / sizeof(T); + // Generate data engine_type gen(seed.get_0()); @@ -192,7 +195,7 @@ void run_benchmark(benchmark::State& state, .c_str(), \ run_benchmark, \ SEGMENTS, \ - size, \ + bytes, \ seed, \ stream) @@ -204,7 +207,7 @@ void run_benchmark(benchmark::State& state, CREATE_BENCHMARK(type, 10000) void add_benchmarks(std::vector& benchmarks, - size_t size, + size_t bytes, const managed_seed& seed, hipStream_t stream) { @@ -229,7 +232,7 @@ void add_benchmarks(std::vector& benchmarks, int main(int argc, char *argv[]) { cli::Parser parser(argc, argv); - parser.set_optional("size", "size", DEFAULT_N, "number of values"); + parser.set_optional("size", "size", DEFAULT_BYTES, "number of bytes"); parser.set_optional("trials", "trials", -1, "number of iterations"); parser.set_optional("name_format", "name_format", @@ -241,7 +244,7 @@ int main(int argc, char *argv[]) // Parse argv benchmark::Initialize(&argc, argv); - const size_t size = parser.get("size"); + const size_t bytes = parser.get("size"); const int trials = parser.get("trials"); bench_naming::set_format(parser.get("name_format")); const std::string seed_type = parser.get("seed"); @@ -252,12 +255,12 @@ int main(int argc, char *argv[]) // Benchmark info add_common_benchmark_info(); - benchmark::AddCustomContext("size", std::to_string(size)); + benchmark::AddCustomContext("bytes", std::to_string(bytes)); benchmark::AddCustomContext("seed", seed_type); // Add benchmarks std::vector benchmarks; - add_benchmarks(benchmarks, size, seed, stream); + add_benchmarks(benchmarks, bytes, seed, stream); // Use manual timing for(auto& b : benchmarks) diff --git a/benchmark/benchmark_device_select.cpp b/benchmark/benchmark_device_select.cpp index 8cbf12c04..c19e5835b 100644 --- a/benchmark/benchmark_device_select.cpp +++ b/benchmark/benchmark_device_select.cpp @@ -43,34 +43,46 @@ #include #include -#ifndef DEFAULT_N -const size_t DEFAULT_N = 1024 * 1024 * 32; +#ifndef DEFAULT_BYTES +const size_t DEFAULT_BYTES = 1024 * 1024 * 32 * 4; #endif +#define CREATE_SELECT_PREDICATED_FLAG_BENCHMARK(T, F, p) \ + { \ + const device_select_predicated_flag_benchmark instance; \ + REGISTER_BENCHMARK(benchmarks, bytes, seed, stream, instance); \ + } + #define CREATE_SELECT_FLAG_BENCHMARK(T, F, p) \ { \ const device_select_flag_benchmark instance; \ - REGISTER_BENCHMARK(benchmarks, size, seed, stream, instance); \ + REGISTER_BENCHMARK(benchmarks, bytes, seed, stream, instance); \ } #define CREATE_SELECT_PREDICATE_BENCHMARK(T, p) \ { \ const device_select_predicate_benchmark instance; \ - REGISTER_BENCHMARK(benchmarks, size, seed, stream, instance); \ + REGISTER_BENCHMARK(benchmarks, bytes, seed, stream, instance); \ } #define CREATE_UNIQUE_BENCHMARK(T, p) \ { \ const device_select_unique_benchmark instance; \ - REGISTER_BENCHMARK(benchmarks, size, seed, stream, instance); \ + REGISTER_BENCHMARK(benchmarks, bytes, seed, stream, instance); \ } #define CREATE_UNIQUE_BY_KEY_BENCHMARK(K, V, p) \ { \ const device_select_unique_by_key_benchmark instance; \ - REGISTER_BENCHMARK(benchmarks, size, seed, stream, instance); \ + REGISTER_BENCHMARK(benchmarks, bytes, seed, stream, instance); \ } +#define BENCHMARK_SELECT_PREDICATED_FLAG_TYPE(type, value) \ + CREATE_SELECT_PREDICATED_FLAG_BENCHMARK(type, value, select_probability::p005); \ + CREATE_SELECT_PREDICATED_FLAG_BENCHMARK(type, value, select_probability::p025); \ + CREATE_SELECT_PREDICATED_FLAG_BENCHMARK(type, value, select_probability::p050); \ + CREATE_SELECT_PREDICATED_FLAG_BENCHMARK(type, value, select_probability::p075) + #define BENCHMARK_SELECT_FLAG_TYPE(type, value) \ CREATE_SELECT_FLAG_BENCHMARK(type, value, select_probability::p005); \ CREATE_SELECT_FLAG_BENCHMARK(type, value, select_probability::p025); \ @@ -98,7 +110,7 @@ const size_t DEFAULT_N = 1024 * 1024 * 32; int main(int argc, char* argv[]) { cli::Parser parser(argc, argv); - parser.set_optional("size", "size", DEFAULT_N, "number of values"); + parser.set_optional("size", "size", DEFAULT_BYTES, "number of bytes"); parser.set_optional("trials", "trials", -1, "number of iterations"); parser.set_optional("name_format", "name_format", @@ -120,7 +132,7 @@ int main(int argc, char* argv[]) // Parse argv benchmark::Initialize(&argc, argv); - const size_t size = parser.get("size"); + const size_t bytes = parser.get("size"); const int trials = parser.get("trials"); bench_naming::set_format(parser.get("name_format")); const std::string seed_type = parser.get("seed"); @@ -131,7 +143,7 @@ int main(int argc, char* argv[]) // Benchmark info add_common_benchmark_info(); - benchmark::AddCustomContext("size", std::to_string(size)); + benchmark::AddCustomContext("bytes", std::to_string(bytes)); benchmark::AddCustomContext("seed", seed_type); // Add benchmarks @@ -142,7 +154,7 @@ int main(int argc, char* argv[]) config_autotune_register::register_benchmark_subset(benchmarks, parallel_instance, parallel_instances, - size, + bytes, seed, stream); #else @@ -165,6 +177,14 @@ int main(int argc, char* argv[]) BENCHMARK_SELECT_PREDICATE_TYPE(rocprim::half); BENCHMARK_SELECT_PREDICATE_TYPE(custom_int_double); + BENCHMARK_SELECT_PREDICATED_FLAG_TYPE(int, unsigned char); + BENCHMARK_SELECT_PREDICATED_FLAG_TYPE(float, unsigned char); + BENCHMARK_SELECT_PREDICATED_FLAG_TYPE(double, unsigned char); + BENCHMARK_SELECT_PREDICATED_FLAG_TYPE(uint8_t, uint8_t); + BENCHMARK_SELECT_PREDICATED_FLAG_TYPE(int8_t, int8_t); + BENCHMARK_SELECT_PREDICATED_FLAG_TYPE(rocprim::half, int8_t); + BENCHMARK_SELECT_PREDICATED_FLAG_TYPE(custom_double2, unsigned char); + BENCHMARK_UNIQUE_TYPE(int); BENCHMARK_UNIQUE_TYPE(float); BENCHMARK_UNIQUE_TYPE(double); diff --git a/benchmark/benchmark_device_select.parallel.hpp b/benchmark/benchmark_device_select.parallel.hpp index 9ac19a684..793f5e130 100644 --- a/benchmark/benchmark_device_select.parallel.hpp +++ b/benchmark/benchmark_device_select.parallel.hpp @@ -94,10 +94,13 @@ struct device_select_flag_benchmark : public config_autotune_interface } void run(benchmark::State& state, - size_t size, + size_t bytes, const managed_seed& seed, hipStream_t stream) const override { + // Calculate the number of elements + size_t size = bytes / sizeof(DataType); + std::vector input = get_random_data(size, generate_limits::min(), generate_limits::max(), @@ -207,16 +210,16 @@ struct device_select_flag_benchmark : public config_autotune_interface state.SetBytesProcessed(state.iterations() * batch_size * size * sizeof(DataType)); state.SetItemsProcessed(state.iterations() * batch_size * size); - hipFree(d_input); + HIP_CHECK(hipFree(d_input)); if(is_tuning) { - hipFree(d_flags_2); - hipFree(d_flags_1); + HIP_CHECK(hipFree(d_flags_2)); + HIP_CHECK(hipFree(d_flags_1)); } - hipFree(d_flags_0); - hipFree(d_output); - hipFree(d_selected_count_output); - hipFree(d_temp_storage); + HIP_CHECK(hipFree(d_flags_0)); + HIP_CHECK(hipFree(d_output)); + HIP_CHECK(hipFree(d_selected_count_output)); + HIP_CHECK(hipFree(d_temp_storage)); } static constexpr bool is_tuning = Probability == select_probability::tuning; @@ -237,10 +240,13 @@ struct device_select_predicate_benchmark : public config_autotune_interface } void run(benchmark::State& state, - size_t size, + size_t bytes, const managed_seed& seed, hipStream_t stream) const override { + // Calculate the number of elements + size_t size = bytes / sizeof(DataType); + // all data types can represent [0, 127], -1 so a predicate can select all std::vector input = get_random_data(size, static_cast(0), @@ -321,10 +327,160 @@ struct device_select_predicate_benchmark : public config_autotune_interface state.SetBytesProcessed(state.iterations() * batch_size * size * sizeof(DataType)); state.SetItemsProcessed(state.iterations() * batch_size * size); - hipFree(d_input); - hipFree(d_output); - hipFree(d_selected_count_output); - hipFree(d_temp_storage); + HIP_CHECK(hipFree(d_input)); + HIP_CHECK(hipFree(d_output)); + HIP_CHECK(hipFree(d_selected_count_output)); + HIP_CHECK(hipFree(d_temp_storage)); + } + + static constexpr bool is_tuning = Probability == select_probability::tuning; +}; + +template +struct device_select_predicated_flag_benchmark : public config_autotune_interface +{ + std::string name() const override + { + using namespace std::string_literals; + return bench_naming::format_name( + "{lvl:device,algo:select,subalgo:predicated_flag,data_type:" + + std::string(Traits::name()) + + ",flag_type:" + std::string(Traits::name()) + ",probability:" + + get_probability_name(Probability) + ",cfg:" + partition_config_name() + "}"); + } + + void run(benchmark::State& state, + size_t bytes, + const managed_seed& seed, + hipStream_t stream) const override + { + // Calculate the number of elements + size_t size = bytes / sizeof(DataType); + + std::vector input = get_random_data(size, + generate_limits::min(), + generate_limits::max(), + seed.get_0()); + + std::vector flags_0; + std::vector flags_1; + std::vector flags_2; + + if(is_tuning) + { + flags_0 = get_random_data01(size, 0.0f, seed.get_1()); + flags_1 = get_random_data01(size, 0.5f, seed.get_1()); + flags_2 = get_random_data01(size, 1.0f, seed.get_1()); + } + else + { + flags_0 = get_random_data01(size, get_probability(Probability), seed.get_1()); + } + + DataType* d_input{}; + HIP_CHECK(hipMalloc(&d_input, size * sizeof(*d_input))); + HIP_CHECK(hipMemcpy(d_input, input.data(), size * sizeof(*d_input), hipMemcpyHostToDevice)); + + FlagType* d_flags_0{}; + FlagType* d_flags_1{}; + FlagType* d_flags_2{}; + HIP_CHECK(hipMalloc(&d_flags_0, size * sizeof(*d_flags_0))); + HIP_CHECK( + hipMemcpy(d_flags_0, flags_0.data(), size * sizeof(*d_flags_0), hipMemcpyHostToDevice)); + if(is_tuning) + { + HIP_CHECK(hipMalloc(&d_flags_1, size * sizeof(*d_flags_1))); + HIP_CHECK(hipMemcpy(d_flags_1, + flags_1.data(), + size * sizeof(*d_flags_1), + hipMemcpyHostToDevice)); + HIP_CHECK(hipMalloc(&d_flags_2, size * sizeof(*d_flags_2))); + HIP_CHECK(hipMemcpy(d_flags_2, + flags_2.data(), + size * sizeof(*d_flags_2), + hipMemcpyHostToDevice)); + } + + DataType* d_output{}; + HIP_CHECK(hipMalloc(&d_output, size * sizeof(*d_output))); + + unsigned int* d_selected_count_output{}; + HIP_CHECK(hipMalloc(&d_selected_count_output, sizeof(*d_selected_count_output))); + + const auto dispatch = [&](void* d_temp_storage, size_t& temp_storage_size_bytes) + { + const auto dispatch_predicated_flags = [&](FlagType* d_flags) + { + auto predicate = [](const FlagType& value) -> bool { return value; }; + HIP_CHECK(rocprim::select(d_temp_storage, + temp_storage_size_bytes, + d_input, + d_flags, + d_output, + d_selected_count_output, + size, + predicate, + stream)); + }; + + dispatch_predicated_flags(d_flags_0); + if(is_tuning) + { + dispatch_predicated_flags(d_flags_1); + dispatch_predicated_flags(d_flags_2); + } + }; + + // Allocate temporary storage memory + size_t temp_storage_size_bytes{}; + dispatch(nullptr, temp_storage_size_bytes); + void* d_temp_storage{}; + HIP_CHECK(hipMalloc(&d_temp_storage, temp_storage_size_bytes)); + + for(int i = 0; i < warmup_iter; i++) + { + dispatch(d_temp_storage, temp_storage_size_bytes); + } + HIP_CHECK(hipDeviceSynchronize()); + + hipEvent_t start, stop; + HIP_CHECK(hipEventCreate(&start)); + HIP_CHECK(hipEventCreate(&stop)); + + for(auto _ : state) + { + HIP_CHECK(hipEventRecord(start, stream)); + for(int i = 0; i < batch_size; ++i) + { + dispatch(d_temp_storage, temp_storage_size_bytes); + } + HIP_CHECK(hipEventRecord(stop, stream)); + HIP_CHECK(hipEventSynchronize(stop)); + + float elapsed_mseconds{}; + HIP_CHECK(hipEventElapsedTime(&elapsed_mseconds, start, stop)); + state.SetIterationTime(elapsed_mseconds / 1000); + } + + HIP_CHECK(hipEventDestroy(start)); + HIP_CHECK(hipEventDestroy(stop)); + + state.SetBytesProcessed(state.iterations() * batch_size * size * sizeof(DataType)); + state.SetItemsProcessed(state.iterations() * batch_size * size); + + HIP_CHECK(hipFree(d_input)); + if(is_tuning) + { + HIP_CHECK(hipFree(d_flags_2)); + HIP_CHECK(hipFree(d_flags_1)); + } + HIP_CHECK(hipFree(d_flags_0)); + HIP_CHECK(hipFree(d_output)); + HIP_CHECK(hipFree(d_selected_count_output)); + HIP_CHECK(hipFree(d_temp_storage)); } static constexpr bool is_tuning = Probability == select_probability::tuning; @@ -364,10 +520,13 @@ struct device_select_unique_benchmark : public config_autotune_interface } void run(benchmark::State& state, - size_t size, + size_t bytes, const managed_seed& seed, hipStream_t stream) const override { + // Calculate the number of elements + size_t size = bytes / sizeof(DataType); + std::vector input_0; std::vector input_1; std::vector input_2; @@ -470,13 +629,13 @@ struct device_select_unique_benchmark : public config_autotune_interface if(is_tuning) { - hipFree(d_input_2); - hipFree(d_input_1); + HIP_CHECK(hipFree(d_input_2)); + HIP_CHECK(hipFree(d_input_1)); } - hipFree(d_input_0); - hipFree(d_output); - hipFree(d_selected_count_output); - hipFree(d_temp_storage); + HIP_CHECK(hipFree(d_input_0)); + HIP_CHECK(hipFree(d_output)); + HIP_CHECK(hipFree(d_selected_count_output)); + HIP_CHECK(hipFree(d_temp_storage)); } static constexpr bool is_tuning = Probability == select_probability::tuning; @@ -499,10 +658,13 @@ struct device_select_unique_by_key_benchmark : public config_autotune_interface } void run(benchmark::State& state, - size_t size, + size_t bytes, const managed_seed& seed, hipStream_t stream) const override { + // Calculate the number of elements + size_t size = bytes / sizeof(KeyType); + std::vector input_keys_0; std::vector input_keys_1; std::vector input_keys_2; @@ -628,15 +790,15 @@ struct device_select_unique_by_key_benchmark : public config_autotune_interface if(is_tuning) { - hipFree(d_keys_input_2); - hipFree(d_keys_input_1); + HIP_CHECK(hipFree(d_keys_input_2)); + HIP_CHECK(hipFree(d_keys_input_1)); } - hipFree(d_keys_input_0); - hipFree(d_values_input); - hipFree(d_keys_output); - hipFree(d_values_output); - hipFree(d_selected_count_output); - hipFree(d_temp_storage); + HIP_CHECK(hipFree(d_keys_input_0)); + HIP_CHECK(hipFree(d_values_input)); + HIP_CHECK(hipFree(d_keys_output)); + HIP_CHECK(hipFree(d_values_output)); + HIP_CHECK(hipFree(d_selected_count_output)); + HIP_CHECK(hipFree(d_temp_storage)); } static constexpr bool is_tuning = Probability == select_probability::tuning; @@ -647,10 +809,24 @@ struct device_select_unique_by_key_benchmark : public config_autotune_interface template struct create_benchmark { + static constexpr unsigned int block_size = Config().kernel_config.block_size; + static constexpr unsigned int items_per_thread = Config().kernel_config.items_per_thread; + static constexpr unsigned int max_shared_memory = TUNING_SHARED_MEMORY_MAX; + static constexpr unsigned int max_size_per_element = sizeof(KeyType) + sizeof(ValueType); + static constexpr unsigned int max_items_per_thread + = max_shared_memory / (block_size * max_size_per_element); + void operator()(std::vector>& storage) { storage.emplace_back( std::make_unique>()); + + if(items_per_thread <= max_items_per_thread) + { + storage.emplace_back( + std::make_unique< + device_select_predicated_flag_benchmark>()); + } } }; diff --git a/benchmark/benchmark_device_transform.cpp b/benchmark/benchmark_device_transform.cpp index c21efc5c7..f55715346 100644 --- a/benchmark/benchmark_device_transform.cpp +++ b/benchmark/benchmark_device_transform.cpp @@ -43,20 +43,20 @@ #include #include -#ifndef DEFAULT_N -const size_t DEFAULT_N = 1024 * 1024 * 128; +#ifndef DEFAULT_BYTES +const size_t DEFAULT_BYTES = 1024 * 1024 * 128 * 4; #endif #define CREATE_BENCHMARK(T) \ { \ const device_transform_benchmark instance{}; \ - REGISTER_BENCHMARK(benchmarks, size, seed, stream, instance); \ + REGISTER_BENCHMARK(benchmarks, bytes, seed, stream, instance); \ } int main(int argc, char *argv[]) { cli::Parser parser(argc, argv); - parser.set_optional("size", "size", DEFAULT_N, "number of values"); + parser.set_optional("size", "size", DEFAULT_BYTES, "number of bytes"); parser.set_optional("trials", "trials", -1, "number of iterations"); parser.set_optional("name_format", "name_format", @@ -78,7 +78,7 @@ int main(int argc, char *argv[]) // Parse argv benchmark::Initialize(&argc, argv); - const size_t size = parser.get("size"); + const size_t bytes = parser.get("size"); const int trials = parser.get("trials"); bench_naming::set_format(parser.get("name_format")); const std::string seed_type = parser.get("seed"); @@ -89,7 +89,7 @@ int main(int argc, char *argv[]) // Benchmark info add_common_benchmark_info(); - benchmark::AddCustomContext("size", std::to_string(size)); + benchmark::AddCustomContext("bytes", std::to_string(bytes)); benchmark::AddCustomContext("seed", seed_type); // Add benchmarks @@ -100,7 +100,7 @@ int main(int argc, char *argv[]) config_autotune_register::register_benchmark_subset(benchmarks, parallel_instance, parallel_instances, - size, + bytes, seed, stream); #else // BENCHMARK_CONFIG_TUNING diff --git a/benchmark/benchmark_device_transform.parallel.hpp b/benchmark/benchmark_device_transform.parallel.hpp index 3a8a956e9..6bad9a392 100644 --- a/benchmark/benchmark_device_transform.parallel.hpp +++ b/benchmark/benchmark_device_transform.parallel.hpp @@ -70,12 +70,15 @@ struct device_transform_benchmark : public config_autotune_interface static constexpr unsigned int warmup_size = 5; void run(benchmark::State& state, - size_t size, + size_t bytes, const managed_seed& seed, hipStream_t stream) const override { using output_type = T; + // Calculate the number of elements + size_t size = bytes / sizeof(T); + static constexpr bool debug_synchronous = false; // Generate data @@ -143,8 +146,8 @@ struct device_transform_benchmark : public config_autotune_interface state.SetBytesProcessed(state.iterations() * batch_size * size * sizeof(T)); state.SetItemsProcessed(state.iterations() * batch_size * size); - hipFree(d_input); - hipFree(d_output); + HIP_CHECK(hipFree(d_input)); + HIP_CHECK(hipFree(d_output)); } }; diff --git a/benchmark/benchmark_predicate_iterator.cpp b/benchmark/benchmark_predicate_iterator.cpp index 87389bd95..d38cf3a1d 100644 --- a/benchmark/benchmark_predicate_iterator.cpp +++ b/benchmark/benchmark_predicate_iterator.cpp @@ -40,8 +40,8 @@ #include #include -#ifndef DEFAULT_N -const size_t DEFAULT_N = 1024 * 1024 * 128; +#ifndef DEFAULT_BYTES +const size_t DEFAULT_BYTES = 1024 * 1024 * 128 * 4; #endif const unsigned int batch_size = 10; @@ -74,6 +74,16 @@ struct increment } }; +template +struct transform_op +{ + __device__ + auto operator()(T v) const + { + return Predicate{}(v) ? Transform{}(v) : v; + } +}; + template struct transform_it { @@ -81,9 +91,8 @@ struct transform_it void operator()(T* d_input, T* d_output, const size_t size, const hipStream_t stream) { - auto t_it = rocprim::make_transform_iterator( - d_input, - [&] __device__(T v) { return Predicate{}(v) ? Transform{}(v) : v; }); + auto t_it + = rocprim::make_transform_iterator(d_input, transform_op{}); HIP_CHECK(rocprim::transform(t_it, d_output, size, identity{}, stream)); } }; @@ -116,12 +125,15 @@ struct write_predicate_it template void run_benchmark(benchmark::State& state, - size_t size, + size_t bytes, const managed_seed& seed, hipStream_t stream) { using T = typename IteratorBenchmark::value_type; + // Calculate the number of elements + size_t size = bytes / sizeof(T); + const auto random_range = limit_random_range(0, 99); std::vector input = get_random_data(size, random_range.first, random_range.second, seed.get_0()); @@ -179,7 +191,7 @@ void run_benchmark(benchmark::State& state, ",key_type:" #T ",cfg:default_config}") \ .c_str(), \ run_benchmark, increment>>, \ - size, \ + bytes, \ seed, \ stream) @@ -205,7 +217,7 @@ void run_benchmark(benchmark::State& state, int main(int argc, char* argv[]) { cli::Parser parser(argc, argv); - parser.set_optional("size", "size", DEFAULT_N, "number of values"); + parser.set_optional("size", "size", DEFAULT_BYTES, "number of bytes"); parser.set_optional("trials", "trials", -1, "number of iterations"); parser.set_optional("name_format", "name_format", @@ -216,7 +228,7 @@ int main(int argc, char* argv[]) // Parse argv benchmark::Initialize(&argc, argv); - const size_t size = parser.get("size"); + const size_t bytes = parser.get("size"); const int trials = parser.get("trials"); bench_naming::set_format(parser.get("name_format")); const std::string seed_type = parser.get("seed"); @@ -227,7 +239,7 @@ int main(int argc, char* argv[]) // Benchmark info add_common_benchmark_info(); - benchmark::AddCustomContext("size", std::to_string(size)); + benchmark::AddCustomContext("bytes", std::to_string(bytes)); benchmark::AddCustomContext("seed", seed_type); using custom_128 = custom_type; diff --git a/benchmark/benchmark_utils.hpp b/benchmark/benchmark_utils.hpp index 8b2a39c52..6b5c11cd4 100644 --- a/benchmark/benchmark_utils.hpp +++ b/benchmark/benchmark_utils.hpp @@ -340,11 +340,11 @@ struct generate_limits::value>> { static inline T min() { - return std::numeric_limits::min(); + return rocprim::numeric_limits::min(); } static inline T max() { - return std::numeric_limits::max(); + return rocprim::numeric_limits::max(); } }; @@ -447,19 +447,19 @@ auto limit_cast(U value) -> T { if(value < 0) { - return std::numeric_limits::min(); + return rocprim::numeric_limits::min(); } if(static_cast(value) - > static_cast(std::numeric_limits::max())) + > static_cast(rocprim::numeric_limits::max())) { - return std::numeric_limits::max(); + return rocprim::numeric_limits::max(); } } else if(rocprim::is_signed::value && rocprim::is_unsigned::value) { - if(value > std::numeric_limits::max()) + if(value > rocprim::numeric_limits::max()) { - return std::numeric_limits::max(); + return rocprim::numeric_limits::max(); } } else if(rocprim::is_floating_point::value) @@ -468,13 +468,13 @@ auto limit_cast(U value) -> T } else // Both T and U are signed { - if(value < static_cast(std::numeric_limits::min())) + if(value < static_cast(rocprim::numeric_limits::min())) { - return std::numeric_limits::min(); + return rocprim::numeric_limits::min(); } - else if(value > static_cast(std::numeric_limits::max())) + else if(value > static_cast(rocprim::numeric_limits::max())) { - return std::numeric_limits::max(); + return rocprim::numeric_limits::max(); } } return static_cast(value); @@ -559,7 +559,7 @@ std::vector using key_distribution_type = std::conditional_t::value, std::uniform_int_distribution, std::uniform_real_distribution>; - key_distribution_type key_distribution(std::numeric_limits::max()); + key_distribution_type key_distribution(rocprim::numeric_limits::max()); std::vector keys(size); size_t keys_start_index = 0; @@ -597,6 +597,28 @@ std::vector return keys; } +template +inline auto get_random_value(U min, V max, size_t seed_value) + -> std::enable_if_t::value, T> +{ + T result; + engine_type gen(seed_value); + generate_random_data_n(&result, 1, min, max, gen); + return result; +} + +template +inline auto get_random_value(T min, T max, size_t seed_value) + -> std::enable_if_t::value, T> +{ + typename T::first_type result_first; + typename T::second_type result_second; + engine_type gen(seed_value); + generate_random_data_n(&result_first, 1, min.x, max.x, gen); + generate_random_data_n(&result_second, 1, min.y, max.y, gen); + return T{result_first, result_second}; +} + template struct make_index_range_impl; diff --git a/benchmark/benchmark_warp_exchange.cpp b/benchmark/benchmark_warp_exchange.cpp index 7ddb48728..5b7ec6013 100644 --- a/benchmark/benchmark_warp_exchange.cpp +++ b/benchmark/benchmark_warp_exchange.cpp @@ -40,8 +40,8 @@ #include #include -#ifndef DEFAULT_N -const size_t DEFAULT_N = 1024 * 1024 * 32; +#ifndef DEFAULT_BYTES +const size_t DEFAULT_BYTES = 1024 * 1024 * 32 * 4; #endif struct BlockedToStripedOp @@ -233,8 +233,11 @@ template< unsigned int LogicalWarpSize, class Op > -void run_benchmark(benchmark::State& state, hipStream_t stream, size_t N) +void run_benchmark(benchmark::State& state, hipStream_t stream, size_t bytes) { + // Calculate the number of elements + size_t N = bytes / sizeof(T); + constexpr unsigned int trials = 200; constexpr unsigned int items_per_block = BlockSize * ItemsPerThread; const unsigned int size = items_per_block * ((N + items_per_block - 1) / items_per_block); @@ -283,12 +286,12 @@ void run_benchmark(benchmark::State& state, hipStream_t stream, size_t N) .c_str(), \ &run_benchmark, \ stream, \ - size) + bytes) int main(int argc, char *argv[]) { cli::Parser parser(argc, argv); - parser.set_optional("size", "size", DEFAULT_N, "number of values"); + parser.set_optional("size", "size", DEFAULT_BYTES, "number of bytes"); parser.set_optional("trials", "trials", -1, "number of iterations"); parser.set_optional("name_format", "name_format", @@ -298,7 +301,7 @@ int main(int argc, char *argv[]) // Parse argv benchmark::Initialize(&argc, argv); - const size_t size = parser.get("size"); + const size_t bytes = parser.get("size"); const int trials = parser.get("trials"); bench_naming::set_format(parser.get("name_format")); @@ -307,7 +310,7 @@ int main(int argc, char *argv[]) // Benchmark info add_common_benchmark_info(); - benchmark::AddCustomContext("size", std::to_string(size)); + benchmark::AddCustomContext("bytes", std::to_string(bytes)); // Add benchmarks std::vector benchmarks{ diff --git a/benchmark/benchmark_warp_reduce.cpp b/benchmark/benchmark_warp_reduce.cpp index 0241234c2..1cddacbfd 100644 --- a/benchmark/benchmark_warp_reduce.cpp +++ b/benchmark/benchmark_warp_reduce.cpp @@ -41,8 +41,8 @@ #include #include -#ifndef DEFAULT_N -const size_t DEFAULT_N = 1024 * 1024 * 32; +#ifndef DEFAULT_BYTES +const size_t DEFAULT_BYTES = 1024 * 1024 * 32 * 4; #endif template< @@ -146,10 +146,13 @@ template -void run_benchmark(benchmark::State& state, size_t N, const managed_seed& seed, hipStream_t stream) +void run_benchmark(benchmark::State& state, size_t bytes, const managed_seed& seed, hipStream_t stream) { using flag_type = unsigned char; + // Calculate the number of elements + size_t N = bytes / sizeof(T); + const auto size = BlockSize * ((N + BlockSize - 1)/BlockSize); const auto random_range = limit_random_range(0, 10); @@ -223,7 +226,7 @@ void run_benchmark(benchmark::State& state, size_t N, const managed_seed& seed, + ",ws:" #WS ",cfg:{bs:" #BS "}}") \ .c_str(), \ run_benchmark, \ - size, \ + bytes, \ seed, \ stream) @@ -235,7 +238,7 @@ void run_benchmark(benchmark::State& state, size_t N, const managed_seed& seed, template void add_benchmarks(std::vector& benchmarks, - size_t size, + size_t bytes, const managed_seed& seed, hipStream_t stream) { @@ -255,7 +258,7 @@ void add_benchmarks(std::vector& benchmarks, int main(int argc, char *argv[]) { cli::Parser parser(argc, argv); - parser.set_optional("size", "size", DEFAULT_N, "number of values"); + parser.set_optional("size", "size", DEFAULT_BYTES, "number of bytes"); parser.set_optional("trials", "trials", -1, "number of iterations"); parser.set_optional("name_format", "name_format", @@ -266,7 +269,7 @@ int main(int argc, char *argv[]) // Parse argv benchmark::Initialize(&argc, argv); - const size_t size = parser.get("size"); + const size_t bytes = parser.get("size"); const int trials = parser.get("trials"); bench_naming::set_format(parser.get("name_format")); const std::string seed_type = parser.get("seed"); @@ -277,14 +280,14 @@ int main(int argc, char *argv[]) // Benchmark info add_common_benchmark_info(); - benchmark::AddCustomContext("size", std::to_string(size)); + benchmark::AddCustomContext("bytes", std::to_string(bytes)); benchmark::AddCustomContext("seed", seed_type); // Add benchmarks std::vector benchmarks; - add_benchmarks(benchmarks, size, seed, stream); - add_benchmarks(benchmarks, size, seed, stream); - add_benchmarks(benchmarks, size, seed, stream); + add_benchmarks(benchmarks, bytes, seed, stream); + add_benchmarks(benchmarks, bytes, seed, stream); + add_benchmarks(benchmarks, bytes, seed, stream); // Use manual timing for(auto& b : benchmarks) diff --git a/benchmark/benchmark_warp_scan.cpp b/benchmark/benchmark_warp_scan.cpp index c12066fd5..f7983761a 100644 --- a/benchmark/benchmark_warp_scan.cpp +++ b/benchmark/benchmark_warp_scan.cpp @@ -39,8 +39,8 @@ #include #include -#ifndef DEFAULT_N -const size_t DEFAULT_N = 1024 * 1024 * 32; +#ifndef DEFAULT_BYTES +const size_t DEFAULT_BYTES = 1024 * 1024 * 32 * 4; #endif namespace rp = rocprim; @@ -90,8 +90,11 @@ template< bool Inclusive = true, unsigned int Trials = 100 > -void run_benchmark(benchmark::State& state, hipStream_t stream, size_t size) +void run_benchmark(benchmark::State& state, hipStream_t stream, size_t bytes) { + // Calculate the number of elements + size_t size = bytes / sizeof(T); + // Make sure size is a multiple of BlockSize size = BlockSize * ((size + BlockSize - 1)/BlockSize); // Allocate and fill memory @@ -165,7 +168,7 @@ void run_benchmark(benchmark::State& state, hipStream_t stream, size_t size) .c_str(), \ run_benchmark, \ stream, \ - size) + bytes) #define BENCHMARK_TYPE(type) \ CREATE_BENCHMARK(type, 64, 64, Inclusive), \ @@ -180,7 +183,7 @@ void run_benchmark(benchmark::State& state, hipStream_t stream, size_t size) template void add_benchmarks(std::vector& benchmarks, hipStream_t stream, - size_t size) + size_t bytes) { using custom_double2 = custom_type; using custom_int_double = custom_type; @@ -202,7 +205,7 @@ void add_benchmarks(std::vector& benchmarks, int main(int argc, char *argv[]) { cli::Parser parser(argc, argv); - parser.set_optional("size", "size", DEFAULT_N, "number of values"); + parser.set_optional("size", "size", DEFAULT_BYTES, "number of bytes"); parser.set_optional("trials", "trials", -1, "number of iterations"); parser.set_optional("name_format", "name_format", @@ -212,7 +215,7 @@ int main(int argc, char *argv[]) // Parse argv benchmark::Initialize(&argc, argv); - const size_t size = parser.get("size"); + const size_t bytes = parser.get("size"); const int trials = parser.get("trials"); bench_naming::set_format(parser.get("name_format")); @@ -221,12 +224,12 @@ int main(int argc, char *argv[]) // Benchmark info add_common_benchmark_info(); - benchmark::AddCustomContext("size", std::to_string(size)); + benchmark::AddCustomContext("bytes", std::to_string(bytes)); // Add benchmarks std::vector benchmarks; - add_benchmarks(benchmarks, stream, size); //inclusive - add_benchmarks(benchmarks, stream, size); //exclusive + add_benchmarks(benchmarks, stream, bytes); //inclusive + add_benchmarks(benchmarks, stream, bytes); //exclusive // Use manual timing for(auto& b : benchmarks) diff --git a/benchmark/benchmark_warp_sort.cpp b/benchmark/benchmark_warp_sort.cpp index 5d6789f68..a12f1e7b1 100644 --- a/benchmark/benchmark_warp_sort.cpp +++ b/benchmark/benchmark_warp_sort.cpp @@ -41,8 +41,8 @@ #include #include -#ifndef DEFAULT_N -const size_t DEFAULT_N = 1024 * 1024 * 32; +#ifndef DEFAULT_BYTES +const size_t DEFAULT_BYTES = 1024 * 1024 * 32 *4; #endif namespace rp = rocprim; @@ -95,10 +95,13 @@ template void run_benchmark(benchmark::State& state, - size_t size, + size_t bytes, const managed_seed& seed, hipStream_t stream) { + // Calculate the number of elements + size_t size = bytes / sizeof(Key); + // Make sure size is a multiple of items_per_block constexpr auto items_per_block = BlockSize * ItemsPerThread; size = BlockSize * ((size + items_per_block - 1) / items_per_block); @@ -205,7 +208,7 @@ void run_benchmark(benchmark::State& state, + ",ws:" #WS ",cfg:{bs:" #BS ",ipt:" #IPT "}}") \ .c_str(), \ run_benchmark, \ - size, \ + bytes, \ seed, \ stream) @@ -215,7 +218,7 @@ void run_benchmark(benchmark::State& state, ",cfg:{bs:" #BS ",ipt:" #IPT "}}") \ .c_str(), \ run_benchmark, \ - size, \ + bytes, \ seed, \ stream) @@ -246,7 +249,7 @@ void run_benchmark(benchmark::State& state, int main(int argc, char *argv[]) { cli::Parser parser(argc, argv); - parser.set_optional("size", "size", DEFAULT_N, "number of values"); + parser.set_optional("size", "size", DEFAULT_BYTES, "number of bytes"); parser.set_optional("trials", "trials", -1, "number of iterations"); parser.set_optional("name_format", "name_format", @@ -257,7 +260,7 @@ int main(int argc, char *argv[]) // Parse argv benchmark::Initialize(&argc, argv); - const size_t size = parser.get("size"); + const size_t bytes = parser.get("size"); const int trials = parser.get("trials"); bench_naming::set_format(parser.get("name_format")); const std::string seed_type = parser.get("seed"); @@ -268,7 +271,7 @@ int main(int argc, char *argv[]) // Benchmark info add_common_benchmark_info(); - benchmark::AddCustomContext("size", std::to_string(size)); + benchmark::AddCustomContext("bytes", std::to_string(bytes)); benchmark::AddCustomContext("seed", seed_type); using custom_double2 = custom_type; diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index e00f67290..5199b02ad 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -57,50 +57,10 @@ list(JOIN CXX_FLAGS_LIST " " CMAKE_CXX_FLAGS) set(BUILD_SHARED_LIBS OFF CACHE BOOL "Global flag to cause add_library() to create shared libraries if on." FORCE) # HIP dependency is handled earlier in the project cmake file -# when VerifyCompiler.cmake is included (when not using HIP-CPU). +# when VerifyCompiler.cmake is included. include(FetchContent) -if(USE_HIP_CPU) - if(NOT DEPENDENCIES_FORCE_DOWNLOAD) - find_package(hip_cpu_rt QUIET) - endif() - if(NOT TARGET hip_cpu_rt::hip_cpu_rt) - message(STATUS "HIP-CPU runtime not found. Fetching...") - FetchContent_Declare( - hip-cpu - GIT_REPOSITORY https://github.com/ROCm-Developer-Tools/HIP-CPU.git - GIT_TAG 56f559c93be210bb300dad3673c06d2bb0119d13 # master@2022.07.01 - ) - FetchContent_MakeAvailable(hip-cpu) - if(NOT TARGET hip_cpu_rt::hip_cpu_rt) - add_library(hip_cpu_rt::hip_cpu_rt ALIAS hip_cpu_rt) - endif() - else() - find_package(hip_cpu_rt REQUIRED) - # If we found HIP-CPU as binary, search for transitive dependencies - find_package(Threads REQUIRED) - set(CMAKE_REQUIRED_FLAGS "-std=c++17") - include(CheckCXXSymbolExists) - check_cxx_symbol_exists(__GLIBCXX__ "cstddef" STL_IS_GLIBCXX) - set(STL_DEPENDS_ON_TBB ${STL_IS_GLIBCXX}) - if(STL_DEPENDS_ON_TBB) - find_package(TBB QUIET) - if(NOT TARGET TBB::tbb AND NOT TARGET tbb) - message(STATUS "Thread Building Blocks not found. Fetching...") - FetchContent_Declare( - thread-building-blocks - GIT_REPOSITORY https://github.com/oneapi-src/oneTBB.git - GIT_TAG 3df08fe234f23e732a122809b40eb129ae22733f # v2021.5.0 - ) - FetchContent_MakeAvailable(thread-building-blocks) - else() - find_package(TBB REQUIRED) - endif() - endif(STL_DEPENDS_ON_TBB) - endif() -endif(USE_HIP_CPU) - # Test dependencies if(BUILD_TEST) # NOTE1: Google Test has created a mess with legacy FindGTest.cmake and newer GTestConfig.cmake diff --git a/cmake/Summary.cmake b/cmake/Summary.cmake index e13a5fb29..474433306 100644 --- a/cmake/Summary.cmake +++ b/cmake/Summary.cmake @@ -36,9 +36,7 @@ function(print_configuration_summary) message(STATUS " Build type : ${CMAKE_BUILD_TYPE}") endif() message(STATUS " Install prefix : ${CMAKE_INSTALL_PREFIX}") - if(NOT USE_HIP_CPU) - message(STATUS " Device targets : ${GPU_TARGETS}") - endif() + message(STATUS " Device targets : ${GPU_TARGETS}") message(STATUS "") message(STATUS " ONLY_INSTALL : ${ONLY_INSTALL}") message(STATUS " BUILD_TEST : ${BUILD_TEST}") @@ -46,5 +44,4 @@ function(print_configuration_summary) message(STATUS " BUILD_NAIVE_BENCHMARK : ${BUILD_NAIVE_BENCHMARK}") message(STATUS " BUILD_EXAMPLE : ${BUILD_EXAMPLE}") message(STATUS " BUILD_DOCS : ${BUILD_DOCS}") - message(STATUS " USE_HIP_CPU : ${USE_HIP_CPU}") endfunction() diff --git a/docs/device_ops/adjacent_difference.rst b/docs/device_ops/adjacent_difference.rst index 670d5783f..1b45f516b 100644 --- a/docs/device_ops/adjacent_difference.rst +++ b/docs/device_ops/adjacent_difference.rst @@ -5,7 +5,7 @@ .. _dev-adjacent_difference: ******************************************************************** - Adjacent difference + Adjacent Difference ******************************************************************** Configuring the kernel diff --git a/docs/device_ops/adjacent_find.rst b/docs/device_ops/adjacent_find.rst new file mode 100644 index 000000000..30e6d99d8 --- /dev/null +++ b/docs/device_ops/adjacent_find.rst @@ -0,0 +1,20 @@ +.. meta:: + :description: rocPRIM documentation and API reference library + :keywords: rocPRIM, ROCm, API, documentation + +.. _dev-adjacent_find: + +******************************************************************** + Adjacent Find +******************************************************************** + +Configuring the kernel +======================== + +.. doxygenstruct:: rocprim::adjacent_find_config + +adjacent_find +======================== + +.. doxygenfunction:: rocprim::adjacent_find(void* const temporary_storage, std::size_t& storage_size, InputIteratorType input, OutputIteratorType output, const std::size_t size, const BinaryPred op=BinaryPred{}, const hipStream_t stream=0, const bool debug_synchronous=false) + diff --git a/docs/device_ops/find_end.rst b/docs/device_ops/find_end.rst new file mode 100644 index 000000000..70fa80811 --- /dev/null +++ b/docs/device_ops/find_end.rst @@ -0,0 +1,19 @@ +.. meta:: + :description: rocPRIM documentation and API reference library + :keywords: rocPRIM, ROCm, API, documentation + +.. _dev-find_end: + + +Find end +-------- + +Configuring the kernel +~~~~~~~~~~~~~~~~~~~~~~ + +.. doxygenstruct:: rocprim::search_config + +find_end +~~~~~~~~ + +.. doxygenfunction:: rocprim::find_end(void* temporary_storage, size_t& storage_size, InputIterator1 input, InputIterator2 keys, OutputIterator output, size_t size, size_t keys_size, BinaryFunction compare_function = BinaryFunction(), hipStream_t stream = 0, bool debug_synchronous = false) diff --git a/docs/device_ops/index.rst b/docs/device_ops/index.rst index 74db4ee48..701776c13 100644 --- a/docs/device_ops/index.rst +++ b/docs/device_ops/index.rst @@ -16,9 +16,11 @@ * :ref:`dev-partition` * :ref:`dev-run_length` * :ref:`dev-scan` + * :ref:`dev-search_n` * :ref:`dev-select` * :ref:`dev-reduce` * :ref:`dev-adjacent_difference` + * :ref:`dev-adjacent_find` * :ref:`dev-binary_search` * :ref:`dev-histogram` * :ref:`dev-device_copy` @@ -26,3 +28,5 @@ * :ref:`dev-nth_element` * :ref:`dev-partial_sort` * :ref:`dev-find_first_of` + * :ref:`dev-find_end` + * :ref:`dev-search` diff --git a/docs/device_ops/search.rst b/docs/device_ops/search.rst new file mode 100644 index 000000000..add5f0ea9 --- /dev/null +++ b/docs/device_ops/search.rst @@ -0,0 +1,19 @@ +.. meta:: + :description: rocPRIM documentation and API reference library + :keywords: rocPRIM, ROCm, API, documentation + +.. _dev-search: + + +Search +------ + +Configuring the kernel +~~~~~~~~~~~~~~~~~~~~~~ + +.. doxygenstruct:: rocprim::search_config + +search +~~~~~~ + +.. doxygenfunction:: rocprim::search(void* temporary_storage, size_t& storage_size, InputIterator1 input, InputIterator2 keys, OutputIterator output, size_t size, size_t keys_size, BinaryFunction compare_function = BinaryFunction(), hipStream_t stream = 0, bool debug_synchronous = false) diff --git a/docs/device_ops/search_n.rst b/docs/device_ops/search_n.rst new file mode 100644 index 000000000..fba3386ad --- /dev/null +++ b/docs/device_ops/search_n.rst @@ -0,0 +1,19 @@ +.. meta:: + :description: rocPRIM documentation and API reference library + :keywords: rocPRIM, ROCm, API, documentation + +.. _dev-search_n: + +******************************************************************** + Search N +******************************************************************** + +Configuring the kernel +======================== + +.. doxygenstruct:: rocprim::search_n + +search_n +======================== + +.. doxygenfunction:: rocprim::search_n(void* const temporary_storage, size_t& storage_size, InputIterator input, OutputIterator output, const size_t size, const size_t count, const typename std::iterator_traits::value_type* value, const BinaryPredicate binary_predicate = BinaryPredicate(), const hipStream_t stream = static_cast(0), const bool debug_synchronous = false) diff --git a/docs/reference/ops_summary.rst b/docs/reference/ops_summary.rst index 9121e2e31..7b9e61221 100644 --- a/docs/reference/ops_summary.rst +++ b/docs/reference/ops_summary.rst @@ -54,6 +54,10 @@ Sequence Search =============== * ``find_first_of`` searches for the first occurrence of any of the provided elements. +* ``adjacent_find`` searches a given sequence for the first occurence of two consecutive equal elements. +* ``search`` searches for the first occurrence of the sequence. +* ``search_n`` searches for the first occurrence of a sequence of count elements all equal to value. +* ``find_end`` searches for the last occurrence of the sequence. Other operations ====================== diff --git a/docs/sphinx/_toc.yml.in b/docs/sphinx/_toc.yml.in index b270acd5c..584671f4d 100644 --- a/docs/sphinx/_toc.yml.in +++ b/docs/sphinx/_toc.yml.in @@ -29,14 +29,18 @@ subtrees: - file: device_ops/partition.rst - file: device_ops/run_length_encoding.rst - file: device_ops/scan.rst + - file: device_ops/search_n.rst - file: device_ops/select.rst - file: device_ops/reduce.rst - file: device_ops/adjacent_difference.rst + - file: device_ops/adjacent_find.rst - file: device_ops/binary_search.rst - file: device_ops/histogram.rst - file: device_ops/device_copy.rst - file: device_ops/memcpy.rst - file: device_ops/find_first_of.rst + - file: device_ops/find_end.rst + - file: device_ops/search.rst - file: block_ops/index.rst subtrees: - entries: diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index 2c2ea5975..91fd34d56 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -27,25 +27,10 @@ function(add_rocprim_example EXAMPLE_SOURCE) PRIVATE rocprim_hip ) - if(NOT USE_HIP_CPU) - target_link_libraries(${EXAMPLE_TARGET} - PRIVATE - rocprim_hip - ) - else() - target_link_libraries(${TEST_EXAMPLE_TARGETTARGET} - PRIVATE - rocprim - Threads::Threads - hip_cpu_rt::hip_cpu_rt - ) - if(STL_DEPENDS_ON_TBB) - target_link_libraries(${EXAMPLE_TARGET} - PRIVATE - TBB::tbb - ) - endif() - endif() + target_link_libraries(${EXAMPLE_TARGET} + PRIVATE + rocprim_hip + ) set_target_properties(${EXAMPLE_TARGET} PROPERTIES RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/example" @@ -63,7 +48,7 @@ function(add_rocprim_example EXAMPLE_SOURCE) foreach( file_i ${third_party_dlls}) add_custom_command( TARGET ${EXAMPLE_TARGET} POST_BUILD COMMAND ${CMAKE_COMMAND} ARGS -E copy_if_different ${file_i} ${PROJECT_BINARY_DIR}/example ) endforeach( file_i ) - endif() + endif() endfunction() # **************************************************************************** diff --git a/rocprim/include/rocprim/block/block_load_func.hpp b/rocprim/include/rocprim/block/block_load_func.hpp index 2dd938ffd..0df327030 100644 --- a/rocprim/include/rocprim/block/block_load_func.hpp +++ b/rocprim/include/rocprim/block/block_load_func.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -146,7 +146,6 @@ void block_load_direct_blocked(unsigned int flat_id, { items[item] = static_cast(out_of_bounds); } - // TODO: Consider using std::fill for HIP-CPU, as uses memset() where appropriate block_load_direct_blocked(flat_id, block_input, items, valid); } diff --git a/rocprim/include/rocprim/block/detail/block_reduce_raking_reduce.hpp b/rocprim/include/rocprim/block/detail/block_reduce_raking_reduce.hpp index bcf252cc9..fa38aec88 100644 --- a/rocprim/include/rocprim/block/detail/block_reduce_raking_reduce.hpp +++ b/rocprim/include/rocprim/block/detail/block_reduce_raking_reduce.hpp @@ -72,15 +72,9 @@ class fast_array sizeof(int32_t))>> for(int i = 0; i < words_no; i++) { const size_t s = std::min(sizeof(int32_t), sizeof(T) - i * sizeof(int32_t)); -#ifdef __HIP_CPU_RT__ - std::memcpy(reinterpret_cast(&result) + i * sizeof(int32_t), - data + index + i * n, - s); -#else __builtin_memcpy(reinterpret_cast(&result) + i * sizeof(int32_t), data + index + i * n, s); -#endif } return result; } @@ -91,15 +85,9 @@ class fast_array sizeof(int32_t))>> for(int i = 0; i < words_no; i++) { const size_t s = std::min(sizeof(int32_t), sizeof(T) - i * sizeof(int32_t)); -#ifdef __HIP_CPU_RT__ - std::memcpy(data + index + i * n, - reinterpret_cast(&value) + i * sizeof(int32_t), - s); -#else __builtin_memcpy(data + index + i * n, reinterpret_cast(&value) + i * sizeof(int32_t), s); -#endif } } diff --git a/rocprim/include/rocprim/common.hpp b/rocprim/include/rocprim/common.hpp new file mode 100644 index 000000000..88897a10b --- /dev/null +++ b/rocprim/include/rocprim/common.hpp @@ -0,0 +1,62 @@ +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + + +#ifndef ROCPRIM_COMMON_HPP_ +#define ROCPRIM_COMMON_HPP_ +namespace detail +{ +#ifndef ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR + #define ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR(name, size, start) \ + do \ + { \ + auto _error = hipGetLastError(); \ + if(_error != hipSuccess) \ + return _error; \ + if(debug_synchronous) \ + { \ + std::cout << name << "(" << size << ")"; \ + auto __error = hipStreamSynchronize(stream); \ + if(__error != hipSuccess) \ + return __error; \ + auto _end = std::chrono::steady_clock::now(); \ + auto _d = std::chrono::duration_cast>(_end - start); \ + std::cout << " " << _d.count() * 1000 << " ms" << '\n'; \ + } \ + } \ + while(0) +#endif // ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR + +#ifndef ROCPRIM_RETURN_ON_ERROR + #define ROCPRIM_RETURN_ON_ERROR(...) \ + do \ + { \ + hipError_t error = (__VA_ARGS__); \ + if(error != hipSuccess) \ + { \ + return error; \ + } \ + } \ + while(0) +#endif // ROCPRIM_RETURN_ON_ERROR + +} // namespace detail + +#endif // ROCPRIM_COMMON_HPP_ diff --git a/rocprim/include/rocprim/config.hpp b/rocprim/include/rocprim/config.hpp index d2dd58e3d..e968f2793 100644 --- a/rocprim/include/rocprim/config.hpp +++ b/rocprim/include/rocprim/config.hpp @@ -182,7 +182,6 @@ #endif /// \brief Clang predefined macro for device code on AMD GPU targets, either 32 or 64. -/// For HIP-CPU with macro is not predefined, and rocPRIM defines it as 64. /// It is undefined behavior to use this macro in host code when compiling with Clang. #ifndef __AMDGCN_WAVEFRONT_SIZE #define __AMDGCN_WAVEFRONT_SIZE 64 diff --git a/rocprim/include/rocprim/device/detail/config/device_adjacent_find.hpp b/rocprim/include/rocprim/device/detail/config/device_adjacent_find.hpp new file mode 100644 index 000000000..e16c29597 --- /dev/null +++ b/rocprim/include/rocprim/device/detail/config/device_adjacent_find.hpp @@ -0,0 +1,461 @@ +// Copyright (c) 2022-2024 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_ADJACENT_FIND_HPP_ +#define ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_ADJACENT_FIND_HPP_ + +#include "../../../type_traits.hpp" +#include "../device_config_helper.hpp" + +#include + +/* DO NOT EDIT THIS FILE + * This file is automatically generated by `/scripts/autotune/create_optimization.py`. + * so most likely you want to edit rocprim/device/device_(algo)_config.hpp + */ + +/// \addtogroup primitivesmodule_deviceconfigs +/// @{ + +BEGIN_ROCPRIM_NAMESPACE + +namespace detail +{ + +template +struct default_adjacent_find_config : default_adjacent_find_config_base::type +{}; + +// Based on input_type = double +template +struct default_adjacent_find_config< + static_cast(target_arch::gfx1030), + input_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) + && (sizeof(input_type) <= 8) && (sizeof(input_type) > 4))>> + : adjacent_find_config<128, 8> +{}; + +// Based on input_type = float +template +struct default_adjacent_find_config< + static_cast(target_arch::gfx1030), + input_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) + && (sizeof(input_type) <= 4) && (sizeof(input_type) > 2))>> + : adjacent_find_config<512, 4> +{}; + +// Based on input_type = rocprim::half +template +struct default_adjacent_find_config< + static_cast(target_arch::gfx1030), + input_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) + && (sizeof(input_type) <= 2))>> : adjacent_find_config<256, 16> +{}; + +// Based on input_type = int64_t +template +struct default_adjacent_find_config< + static_cast(target_arch::gfx1030), + input_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(input_type) <= 8) && (sizeof(input_type) > 4))>> + : adjacent_find_config<512, 2> +{}; + +// Based on input_type = int +template +struct default_adjacent_find_config< + static_cast(target_arch::gfx1030), + input_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(input_type) <= 4) && (sizeof(input_type) > 2))>> + : adjacent_find_config<256, 8> +{}; + +// Based on input_type = short +template +struct default_adjacent_find_config< + static_cast(target_arch::gfx1030), + input_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(input_type) <= 2) && (sizeof(input_type) > 1))>> + : adjacent_find_config<128, 64> +{}; + +// Based on input_type = int8_t +template +struct default_adjacent_find_config< + static_cast(target_arch::gfx1030), + input_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(input_type) <= 1))>> : adjacent_find_config<128, 32> +{}; + +// Based on input_type = double +template +struct default_adjacent_find_config< + static_cast(target_arch::gfx1100), + input_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) + && (sizeof(input_type) <= 8) && (sizeof(input_type) > 4))>> + : adjacent_find_config<512, 2> +{}; + +// Based on input_type = float +template +struct default_adjacent_find_config< + static_cast(target_arch::gfx1100), + input_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) + && (sizeof(input_type) <= 4) && (sizeof(input_type) > 2))>> + : adjacent_find_config<128, 4> +{}; + +// Based on input_type = rocprim::half +template +struct default_adjacent_find_config< + static_cast(target_arch::gfx1100), + input_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) + && (sizeof(input_type) <= 2))>> : adjacent_find_config<512, 8> +{}; + +// Based on input_type = int64_t +template +struct default_adjacent_find_config< + static_cast(target_arch::gfx1100), + input_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(input_type) <= 8) && (sizeof(input_type) > 4))>> + : adjacent_find_config<256, 4> +{}; + +// Based on input_type = int +template +struct default_adjacent_find_config< + static_cast(target_arch::gfx1100), + input_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(input_type) <= 4) && (sizeof(input_type) > 2))>> + : adjacent_find_config<128, 8> +{}; + +// Based on input_type = short +template +struct default_adjacent_find_config< + static_cast(target_arch::gfx1100), + input_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(input_type) <= 2) && (sizeof(input_type) > 1))>> + : adjacent_find_config<64, 4> +{}; + +// Based on input_type = int8_t +template +struct default_adjacent_find_config< + static_cast(target_arch::gfx1100), + input_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(input_type) <= 1))>> : adjacent_find_config<128, 32> +{}; + +// Based on input_type = double +template +struct default_adjacent_find_config< + static_cast(target_arch::gfx906), + input_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) + && (sizeof(input_type) <= 8) && (sizeof(input_type) > 4))>> + : adjacent_find_config<128, 16> +{}; + +// Based on input_type = float +template +struct default_adjacent_find_config< + static_cast(target_arch::gfx906), + input_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) + && (sizeof(input_type) <= 4) && (sizeof(input_type) > 2))>> + : adjacent_find_config<128, 16> +{}; + +// Based on input_type = rocprim::half +template +struct default_adjacent_find_config< + static_cast(target_arch::gfx906), + input_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) + && (sizeof(input_type) <= 2))>> : adjacent_find_config<64, 4> +{}; + +// Based on input_type = int64_t +template +struct default_adjacent_find_config< + static_cast(target_arch::gfx906), + input_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(input_type) <= 8) && (sizeof(input_type) > 4))>> + : adjacent_find_config<128, 4> +{}; + +// Based on input_type = int +template +struct default_adjacent_find_config< + static_cast(target_arch::gfx906), + input_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(input_type) <= 4) && (sizeof(input_type) > 2))>> + : adjacent_find_config<128, 16> +{}; + +// Based on input_type = short +template +struct default_adjacent_find_config< + static_cast(target_arch::gfx906), + input_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(input_type) <= 2) && (sizeof(input_type) > 1))>> + : adjacent_find_config<128, 2> +{}; + +// Based on input_type = int8_t +template +struct default_adjacent_find_config< + static_cast(target_arch::gfx906), + input_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(input_type) <= 1))>> : adjacent_find_config<128, 2> +{}; + +// Based on input_type = double +template +struct default_adjacent_find_config< + static_cast(target_arch::gfx908), + input_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) + && (sizeof(input_type) <= 8) && (sizeof(input_type) > 4))>> + : adjacent_find_config<128, 8> +{}; + +// Based on input_type = float +template +struct default_adjacent_find_config< + static_cast(target_arch::gfx908), + input_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) + && (sizeof(input_type) <= 4) && (sizeof(input_type) > 2))>> + : adjacent_find_config<512, 4> +{}; + +// Based on input_type = rocprim::half +template +struct default_adjacent_find_config< + static_cast(target_arch::gfx908), + input_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) + && (sizeof(input_type) <= 2))>> : adjacent_find_config<256, 16> +{}; + +// Based on input_type = int64_t +template +struct default_adjacent_find_config< + static_cast(target_arch::gfx908), + input_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(input_type) <= 8) && (sizeof(input_type) > 4))>> + : adjacent_find_config<512, 2> +{}; + +// Based on input_type = int +template +struct default_adjacent_find_config< + static_cast(target_arch::gfx908), + input_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(input_type) <= 4) && (sizeof(input_type) > 2))>> + : adjacent_find_config<512, 4> +{}; + +// Based on input_type = short +template +struct default_adjacent_find_config< + static_cast(target_arch::gfx908), + input_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(input_type) <= 2) && (sizeof(input_type) > 1))>> + : adjacent_find_config<256, 16> +{}; + +// Based on input_type = int8_t +template +struct default_adjacent_find_config< + static_cast(target_arch::gfx908), + input_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(input_type) <= 1))>> : adjacent_find_config<64, 16> +{}; + +// Based on input_type = double +template +struct default_adjacent_find_config< + static_cast(target_arch::unknown), + input_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) + && (sizeof(input_type) <= 8) && (sizeof(input_type) > 4))>> + : adjacent_find_config<128, 8> +{}; + +// Based on input_type = float +template +struct default_adjacent_find_config< + static_cast(target_arch::unknown), + input_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) + && (sizeof(input_type) <= 4) && (sizeof(input_type) > 2))>> + : adjacent_find_config<512, 4> +{}; + +// Based on input_type = rocprim::half +template +struct default_adjacent_find_config< + static_cast(target_arch::unknown), + input_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) + && (sizeof(input_type) <= 2))>> : adjacent_find_config<256, 16> +{}; + +// Based on input_type = int64_t +template +struct default_adjacent_find_config< + static_cast(target_arch::unknown), + input_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(input_type) <= 8) && (sizeof(input_type) > 4))>> + : adjacent_find_config<512, 2> +{}; + +// Based on input_type = int +template +struct default_adjacent_find_config< + static_cast(target_arch::unknown), + input_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(input_type) <= 4) && (sizeof(input_type) > 2))>> + : adjacent_find_config<64, 64> +{}; + +// Based on input_type = short +template +struct default_adjacent_find_config< + static_cast(target_arch::unknown), + input_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(input_type) <= 2) && (sizeof(input_type) > 1))>> + : adjacent_find_config<256, 16> +{}; + +// Based on input_type = int8_t +template +struct default_adjacent_find_config< + static_cast(target_arch::unknown), + input_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(input_type) <= 1))>> : adjacent_find_config<64, 16> +{}; + +// Based on input_type = double +template +struct default_adjacent_find_config< + static_cast(target_arch::gfx90a), + input_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) + && (sizeof(input_type) <= 8) && (sizeof(input_type) > 4))>> + : adjacent_find_config<128, 8> +{}; + +// Based on input_type = float +template +struct default_adjacent_find_config< + static_cast(target_arch::gfx90a), + input_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) + && (sizeof(input_type) <= 4) && (sizeof(input_type) > 2))>> + : adjacent_find_config<512, 4> +{}; + +// Based on input_type = rocprim::half +template +struct default_adjacent_find_config< + static_cast(target_arch::gfx90a), + input_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) + && (sizeof(input_type) <= 2))>> : adjacent_find_config<256, 16> +{}; + +// Based on input_type = int64_t +template +struct default_adjacent_find_config< + static_cast(target_arch::gfx90a), + input_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(input_type) <= 8) && (sizeof(input_type) > 4))>> + : adjacent_find_config<512, 2> +{}; + +// Based on input_type = int +template +struct default_adjacent_find_config< + static_cast(target_arch::gfx90a), + input_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(input_type) <= 4) && (sizeof(input_type) > 2))>> + : adjacent_find_config<64, 64> +{}; + +// Based on input_type = short +template +struct default_adjacent_find_config< + static_cast(target_arch::gfx90a), + input_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(input_type) <= 2) && (sizeof(input_type) > 1))>> + : adjacent_find_config<256, 16> +{}; + +// Based on input_type = int8_t +template +struct default_adjacent_find_config< + static_cast(target_arch::gfx90a), + input_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(input_type) <= 1))>> : adjacent_find_config<64, 16> +{}; + +} // end namespace detail + +END_ROCPRIM_NAMESPACE + +/// @} +// end of group primitivesmodule_deviceconfigs + +#endif // ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_ADJACENT_FIND_HPP_ \ No newline at end of file diff --git a/rocprim/include/rocprim/device/detail/config/device_merge.hpp b/rocprim/include/rocprim/device/detail/config/device_merge.hpp new file mode 100644 index 000000000..83c2367ab --- /dev/null +++ b/rocprim/include/rocprim/device/detail/config/device_merge.hpp @@ -0,0 +1,2435 @@ +// Copyright (c) 2022-2024 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_MERGE_HPP_ +#define ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_MERGE_HPP_ + +#include "../../../type_traits.hpp" +#include "../device_config_helper.hpp" + +#include + +/* DO NOT EDIT THIS FILE + * This file is automatically generated by `/scripts/autotune/create_optimization.py`. + * so most likely you want to edit rocprim/device/device_(algo)_config.hpp + */ + +/// \addtogroup primitivesmodule_deviceconfigs +/// @{ + +BEGIN_ROCPRIM_NAMESPACE + +namespace detail +{ + +template +struct default_merge_config : default_merge_config_base::type +{}; + +// Based on key_type = double, value_type = int64_t +template +struct default_merge_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> : merge_config<1024, 2> +{}; + +// Based on key_type = double, value_type = int +template +struct default_merge_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> : merge_config<256, 16> +{}; + +// Based on key_type = double, value_type = short +template +struct default_merge_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> : merge_config<256, 16> +{}; + +// Based on key_type = double, value_type = int8_t +template +struct default_merge_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : merge_config<1024, 4> +{}; + +// Based on key_type = double, value_type = empty_type +template +struct default_merge_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))>> + : merge_config<1024, 4> +{}; + +// Based on key_type = float, value_type = int64_t +template +struct default_merge_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> : merge_config<1024, 4> +{}; + +// Based on key_type = float, value_type = int +template +struct default_merge_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> : merge_config<1024, 4> +{}; + +// Based on key_type = float, value_type = short +template +struct default_merge_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> : merge_config<1024, 8> +{}; + +// Based on key_type = float, value_type = int8_t +template +struct default_merge_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : merge_config<1024, 4> +{}; + +// Based on key_type = float, value_type = empty_type +template +struct default_merge_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))>> + : merge_config<1024, 8> +{}; + +// Based on key_type = rocprim::half, value_type = int64_t +template +struct default_merge_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> + : merge_config<1024, 1> +{}; + +// Based on key_type = rocprim::half, value_type = int +template +struct default_merge_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> + : merge_config<1024, 8> +{}; + +// Based on key_type = rocprim::half, value_type = short +template +struct default_merge_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> + : merge_config<256, 10> +{}; + +// Based on key_type = rocprim::half, value_type = int8_t +template +struct default_merge_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : merge_config<256, 10> +{}; + +// Based on key_type = rocprim::half, value_type = empty_type +template +struct default_merge_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))>> + : merge_config<256, 11> +{}; + +// Based on key_type = int64_t, value_type = int64_t +template +struct default_merge_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> : merge_config<1024, 2> +{}; + +// Based on key_type = int64_t, value_type = int +template +struct default_merge_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> : merge_config<256, 16> +{}; + +// Based on key_type = int64_t, value_type = short +template +struct default_merge_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> : merge_config<1024, 4> +{}; + +// Based on key_type = int64_t, value_type = int8_t +template +struct default_merge_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : merge_config<1024, 4> +{}; + +// Based on key_type = int64_t, value_type = empty_type +template +struct default_merge_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))>> + : merge_config<256, 16> +{}; + +// Based on key_type = int, value_type = int64_t +template +struct default_merge_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> : merge_config<1024, 4> +{}; + +// Based on key_type = int, value_type = int +template +struct default_merge_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> : merge_config<1024, 4> +{}; + +// Based on key_type = int, value_type = short +template +struct default_merge_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> : merge_config<1024, 8> +{}; + +// Based on key_type = int, value_type = int8_t +template +struct default_merge_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : merge_config<1024, 4> +{}; + +// Based on key_type = int, value_type = empty_type +template +struct default_merge_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))>> + : merge_config<1024, 8> +{}; + +// Based on key_type = short, value_type = int64_t +template +struct default_merge_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> : merge_config<1024, 2> +{}; + +// Based on key_type = short, value_type = int +template +struct default_merge_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> : merge_config<1024, 8> +{}; + +// Based on key_type = short, value_type = short +template +struct default_merge_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> : merge_config<1024, 8> +{}; + +// Based on key_type = short, value_type = int8_t +template +struct default_merge_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : merge_config<1024, 8> +{}; + +// Based on key_type = short, value_type = empty_type +template +struct default_merge_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))>> + : merge_config<256, 11> +{}; + +// Based on key_type = int8_t, value_type = int64_t +template +struct default_merge_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> + : merge_config<1024, 4> +{}; + +// Based on key_type = int8_t, value_type = int +template +struct default_merge_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> + : merge_config<1024, 4> +{}; + +// Based on key_type = int8_t, value_type = short +template +struct default_merge_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> + : merge_config<1024, 8> +{}; + +// Based on key_type = int8_t, value_type = int8_t +template +struct default_merge_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : merge_config<256, 10> +{}; + +// Based on key_type = int8_t, value_type = empty_type +template +struct default_merge_config< + static_cast(target_arch::gfx1030), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))>> + : merge_config<256, 11> +{}; + +// Based on key_type = double, value_type = int64_t +template +struct default_merge_config< + static_cast(target_arch::gfx1100), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> : merge_config<1024, 2> +{}; + +// Based on key_type = double, value_type = int +template +struct default_merge_config< + static_cast(target_arch::gfx1100), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> : merge_config<512, 8> +{}; + +// Based on key_type = double, value_type = short +template +struct default_merge_config< + static_cast(target_arch::gfx1100), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> : merge_config<512, 8> +{}; + +// Based on key_type = double, value_type = int8_t +template +struct default_merge_config< + static_cast(target_arch::gfx1100), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : merge_config<1024, 4> +{}; + +// Based on key_type = double, value_type = empty_type +template +struct default_merge_config< + static_cast(target_arch::gfx1100), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))>> + : merge_config<1024, 4> +{}; + +// Based on key_type = float, value_type = int64_t +template +struct default_merge_config< + static_cast(target_arch::gfx1100), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> : merge_config<1024, 4> +{}; + +// Based on key_type = float, value_type = int +template +struct default_merge_config< + static_cast(target_arch::gfx1100), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> : merge_config<1024, 4> +{}; + +// Based on key_type = float, value_type = short +template +struct default_merge_config< + static_cast(target_arch::gfx1100), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> : merge_config<1024, 8> +{}; + +// Based on key_type = float, value_type = int8_t +template +struct default_merge_config< + static_cast(target_arch::gfx1100), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : merge_config<1024, 8> +{}; + +// Based on key_type = float, value_type = empty_type +template +struct default_merge_config< + static_cast(target_arch::gfx1100), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))>> + : merge_config<1024, 8> +{}; + +// Based on key_type = rocprim::half, value_type = int64_t +template +struct default_merge_config< + static_cast(target_arch::gfx1100), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> + : merge_config<512, 8> +{}; + +// Based on key_type = rocprim::half, value_type = int +template +struct default_merge_config< + static_cast(target_arch::gfx1100), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> + : merge_config<1024, 8> +{}; + +// Based on key_type = rocprim::half, value_type = short +template +struct default_merge_config< + static_cast(target_arch::gfx1100), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> + : merge_config<1024, 8> +{}; + +// Based on key_type = rocprim::half, value_type = int8_t +template +struct default_merge_config< + static_cast(target_arch::gfx1100), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : merge_config<1024, 8> +{}; + +// Based on key_type = rocprim::half, value_type = empty_type +template +struct default_merge_config< + static_cast(target_arch::gfx1100), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))>> + : merge_config<512, 16> +{}; + +// Based on key_type = int64_t, value_type = int64_t +template +struct default_merge_config< + static_cast(target_arch::gfx1100), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> : merge_config<1024, 2> +{}; + +// Based on key_type = int64_t, value_type = int +template +struct default_merge_config< + static_cast(target_arch::gfx1100), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> : merge_config<512, 8> +{}; + +// Based on key_type = int64_t, value_type = short +template +struct default_merge_config< + static_cast(target_arch::gfx1100), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> : merge_config<512, 8> +{}; + +// Based on key_type = int64_t, value_type = int8_t +template +struct default_merge_config< + static_cast(target_arch::gfx1100), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : merge_config<1024, 4> +{}; + +// Based on key_type = int64_t, value_type = empty_type +template +struct default_merge_config< + static_cast(target_arch::gfx1100), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))>> + : merge_config<1024, 4> +{}; + +// Based on key_type = int, value_type = int64_t +template +struct default_merge_config< + static_cast(target_arch::gfx1100), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> : merge_config<1024, 4> +{}; + +// Based on key_type = int, value_type = int +template +struct default_merge_config< + static_cast(target_arch::gfx1100), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> : merge_config<1024, 4> +{}; + +// Based on key_type = int, value_type = short +template +struct default_merge_config< + static_cast(target_arch::gfx1100), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> : merge_config<1024, 8> +{}; + +// Based on key_type = int, value_type = int8_t +template +struct default_merge_config< + static_cast(target_arch::gfx1100), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : merge_config<1024, 8> +{}; + +// Based on key_type = int, value_type = empty_type +template +struct default_merge_config< + static_cast(target_arch::gfx1100), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))>> + : merge_config<1024, 8> +{}; + +// Based on key_type = short, value_type = int64_t +template +struct default_merge_config< + static_cast(target_arch::gfx1100), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> : merge_config<512, 8> +{}; + +// Based on key_type = short, value_type = int +template +struct default_merge_config< + static_cast(target_arch::gfx1100), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> : merge_config<1024, 8> +{}; + +// Based on key_type = short, value_type = short +template +struct default_merge_config< + static_cast(target_arch::gfx1100), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> : merge_config<1024, 8> +{}; + +// Based on key_type = short, value_type = int8_t +template +struct default_merge_config< + static_cast(target_arch::gfx1100), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : merge_config<512, 16> +{}; + +// Based on key_type = short, value_type = empty_type +template +struct default_merge_config< + static_cast(target_arch::gfx1100), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))>> + : merge_config<512, 16> +{}; + +// Based on key_type = int8_t, value_type = int64_t +template +struct default_merge_config< + static_cast(target_arch::gfx1100), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> + : merge_config<512, 8> +{}; + +// Based on key_type = int8_t, value_type = int +template +struct default_merge_config< + static_cast(target_arch::gfx1100), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> + : merge_config<1024, 8> +{}; + +// Based on key_type = int8_t, value_type = short +template +struct default_merge_config< + static_cast(target_arch::gfx1100), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> + : merge_config<512, 8> +{}; + +// Based on key_type = int8_t, value_type = int8_t +template +struct default_merge_config< + static_cast(target_arch::gfx1100), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : merge_config<512, 8> +{}; + +// Based on key_type = int8_t, value_type = empty_type +template +struct default_merge_config< + static_cast(target_arch::gfx1100), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))>> + : merge_config<512, 16> +{}; + +// Based on key_type = double, value_type = int64_t +template +struct default_merge_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> : merge_config<1024, 2> +{}; + +// Based on key_type = double, value_type = int +template +struct default_merge_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> : merge_config<256, 5> +{}; + +// Based on key_type = double, value_type = short +template +struct default_merge_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> : merge_config<256, 5> +{}; + +// Based on key_type = double, value_type = int8_t +template +struct default_merge_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : merge_config<256, 5> +{}; + +// Based on key_type = double, value_type = empty_type +template +struct default_merge_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))>> + : merge_config<256, 7> +{}; + +// Based on key_type = float, value_type = int64_t +template +struct default_merge_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> : merge_config<1024, 2> +{}; + +// Based on key_type = float, value_type = int +template +struct default_merge_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> : merge_config<512, 2> +{}; + +// Based on key_type = float, value_type = short +template +struct default_merge_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> : merge_config<512, 4> +{}; + +// Based on key_type = float, value_type = int8_t +template +struct default_merge_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : merge_config<512, 4> +{}; + +// Based on key_type = float, value_type = empty_type +template +struct default_merge_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))>> + : merge_config<256, 10> +{}; + +// Based on key_type = rocprim::half, value_type = int64_t +template +struct default_merge_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> + : merge_config<1024, 2> +{}; + +// Based on key_type = rocprim::half, value_type = int +template +struct default_merge_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> + : merge_config<512, 2> +{}; + +// Based on key_type = rocprim::half, value_type = short +template +struct default_merge_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> + : merge_config<512, 4> +{}; + +// Based on key_type = rocprim::half, value_type = int8_t +template +struct default_merge_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : merge_config<256, 10> +{}; + +// Based on key_type = rocprim::half, value_type = empty_type +template +struct default_merge_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))>> + : merge_config<256, 11> +{}; + +// Based on key_type = int64_t, value_type = int64_t +template +struct default_merge_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> : merge_config<1024, 2> +{}; + +// Based on key_type = int64_t, value_type = int +template +struct default_merge_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> : merge_config<256, 5> +{}; + +// Based on key_type = int64_t, value_type = short +template +struct default_merge_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> : merge_config<256, 5> +{}; + +// Based on key_type = int64_t, value_type = int8_t +template +struct default_merge_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : merge_config<256, 5> +{}; + +// Based on key_type = int64_t, value_type = empty_type +template +struct default_merge_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))>> + : merge_config<256, 7> +{}; + +// Based on key_type = int, value_type = int64_t +template +struct default_merge_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> : merge_config<1024, 2> +{}; + +// Based on key_type = int, value_type = int +template +struct default_merge_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> : merge_config<512, 2> +{}; + +// Based on key_type = int, value_type = short +template +struct default_merge_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> : merge_config<512, 4> +{}; + +// Based on key_type = int, value_type = int8_t +template +struct default_merge_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : merge_config<512, 4> +{}; + +// Based on key_type = int, value_type = empty_type +template +struct default_merge_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))>> + : merge_config<256, 10> +{}; + +// Based on key_type = short, value_type = int64_t +template +struct default_merge_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> : merge_config<512, 2> +{}; + +// Based on key_type = short, value_type = int +template +struct default_merge_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> : merge_config<512, 2> +{}; + +// Based on key_type = short, value_type = short +template +struct default_merge_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> : merge_config<512, 4> +{}; + +// Based on key_type = short, value_type = int8_t +template +struct default_merge_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : merge_config<256, 10> +{}; + +// Based on key_type = short, value_type = empty_type +template +struct default_merge_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))>> + : merge_config<256, 8> +{}; + +// Based on key_type = int8_t, value_type = int64_t +template +struct default_merge_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> + : merge_config<1024, 2> +{}; + +// Based on key_type = int8_t, value_type = int +template +struct default_merge_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> + : merge_config<512, 2> +{}; + +// Based on key_type = int8_t, value_type = short +template +struct default_merge_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> + : merge_config<512, 4> +{}; + +// Based on key_type = int8_t, value_type = int8_t +template +struct default_merge_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : merge_config<512, 4> +{}; + +// Based on key_type = int8_t, value_type = empty_type +template +struct default_merge_config< + static_cast(target_arch::gfx906), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))>> + : merge_config<256, 11> +{}; + +// Based on key_type = double, value_type = int64_t +template +struct default_merge_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> : merge_config<1024, 2> +{}; + +// Based on key_type = double, value_type = int +template +struct default_merge_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> : merge_config<512, 2> +{}; + +// Based on key_type = double, value_type = short +template +struct default_merge_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> : merge_config<256, 5> +{}; + +// Based on key_type = double, value_type = int8_t +template +struct default_merge_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : merge_config<256, 5> +{}; + +// Based on key_type = double, value_type = empty_type +template +struct default_merge_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))>> + : merge_config<256, 7> +{}; + +// Based on key_type = float, value_type = int64_t +template +struct default_merge_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> : merge_config<1024, 2> +{}; + +// Based on key_type = float, value_type = int +template +struct default_merge_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> : merge_config<512, 2> +{}; + +// Based on key_type = float, value_type = short +template +struct default_merge_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> : merge_config<512, 4> +{}; + +// Based on key_type = float, value_type = int8_t +template +struct default_merge_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : merge_config<512, 4> +{}; + +// Based on key_type = float, value_type = empty_type +template +struct default_merge_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))>> + : merge_config<512, 4> +{}; + +// Based on key_type = rocprim::half, value_type = int64_t +template +struct default_merge_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> + : merge_config<1024, 2> +{}; + +// Based on key_type = rocprim::half, value_type = int +template +struct default_merge_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> + : merge_config<1024, 2> +{}; + +// Based on key_type = rocprim::half, value_type = short +template +struct default_merge_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> + : merge_config<256, 4> +{}; + +// Based on key_type = rocprim::half, value_type = int8_t +template +struct default_merge_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : merge_config<256, 4> +{}; + +// Based on key_type = rocprim::half, value_type = empty_type +template +struct default_merge_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))>> + : merge_config<256, 8> +{}; + +// Based on key_type = int64_t, value_type = int64_t +template +struct default_merge_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> : merge_config<1024, 2> +{}; + +// Based on key_type = int64_t, value_type = int +template +struct default_merge_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> : merge_config<512, 2> +{}; + +// Based on key_type = int64_t, value_type = short +template +struct default_merge_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> : merge_config<256, 5> +{}; + +// Based on key_type = int64_t, value_type = int8_t +template +struct default_merge_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : merge_config<256, 5> +{}; + +// Based on key_type = int64_t, value_type = empty_type +template +struct default_merge_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))>> + : merge_config<256, 7> +{}; + +// Based on key_type = int, value_type = int64_t +template +struct default_merge_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> : merge_config<1024, 2> +{}; + +// Based on key_type = int, value_type = int +template +struct default_merge_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> : merge_config<512, 2> +{}; + +// Based on key_type = int, value_type = short +template +struct default_merge_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> : merge_config<512, 4> +{}; + +// Based on key_type = int, value_type = int8_t +template +struct default_merge_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : merge_config<512, 4> +{}; + +// Based on key_type = int, value_type = empty_type +template +struct default_merge_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))>> + : merge_config<256, 10> +{}; + +// Based on key_type = short, value_type = int64_t +template +struct default_merge_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> : merge_config<1024, 2> +{}; + +// Based on key_type = short, value_type = int +template +struct default_merge_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> : merge_config<512, 2> +{}; + +// Based on key_type = short, value_type = short +template +struct default_merge_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> : merge_config<512, 4> +{}; + +// Based on key_type = short, value_type = int8_t +template +struct default_merge_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : merge_config<512, 4> +{}; + +// Based on key_type = short, value_type = empty_type +template +struct default_merge_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))>> + : merge_config<512, 8> +{}; + +// Based on key_type = int8_t, value_type = int64_t +template +struct default_merge_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> + : merge_config<1024, 2> +{}; + +// Based on key_type = int8_t, value_type = int +template +struct default_merge_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> + : merge_config<1024, 2> +{}; + +// Based on key_type = int8_t, value_type = short +template +struct default_merge_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> + : merge_config<512, 4> +{}; + +// Based on key_type = int8_t, value_type = int8_t +template +struct default_merge_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : merge_config<512, 4> +{}; + +// Based on key_type = int8_t, value_type = empty_type +template +struct default_merge_config< + static_cast(target_arch::gfx908), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))>> + : merge_config<512, 8> +{}; + +// Based on key_type = double, value_type = int64_t +template +struct default_merge_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> : merge_config<1024, 2> +{}; + +// Based on key_type = double, value_type = int +template +struct default_merge_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> : merge_config<512, 2> +{}; + +// Based on key_type = double, value_type = short +template +struct default_merge_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> : merge_config<256, 5> +{}; + +// Based on key_type = double, value_type = int8_t +template +struct default_merge_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : merge_config<256, 5> +{}; + +// Based on key_type = double, value_type = empty_type +template +struct default_merge_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))>> + : merge_config<256, 7> +{}; + +// Based on key_type = float, value_type = int64_t +template +struct default_merge_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> : merge_config<1024, 2> +{}; + +// Based on key_type = float, value_type = int +template +struct default_merge_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> : merge_config<512, 2> +{}; + +// Based on key_type = float, value_type = short +template +struct default_merge_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> : merge_config<512, 4> +{}; + +// Based on key_type = float, value_type = int8_t +template +struct default_merge_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : merge_config<512, 4> +{}; + +// Based on key_type = float, value_type = empty_type +template +struct default_merge_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))>> + : merge_config<512, 4> +{}; + +// Based on key_type = rocprim::half, value_type = int64_t +template +struct default_merge_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> + : merge_config<1024, 2> +{}; + +// Based on key_type = rocprim::half, value_type = int +template +struct default_merge_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> + : merge_config<1024, 2> +{}; + +// Based on key_type = rocprim::half, value_type = short +template +struct default_merge_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> + : merge_config<256, 4> +{}; + +// Based on key_type = rocprim::half, value_type = int8_t +template +struct default_merge_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : merge_config<256, 4> +{}; + +// Based on key_type = rocprim::half, value_type = empty_type +template +struct default_merge_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))>> + : merge_config<256, 8> +{}; + +// Based on key_type = int64_t, value_type = int64_t +template +struct default_merge_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> : merge_config<1024, 2> +{}; + +// Based on key_type = int64_t, value_type = int +template +struct default_merge_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> : merge_config<512, 2> +{}; + +// Based on key_type = int64_t, value_type = short +template +struct default_merge_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> : merge_config<256, 5> +{}; + +// Based on key_type = int64_t, value_type = int8_t +template +struct default_merge_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : merge_config<256, 5> +{}; + +// Based on key_type = int64_t, value_type = empty_type +template +struct default_merge_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))>> + : merge_config<256, 7> +{}; + +// Based on key_type = int, value_type = int64_t +template +struct default_merge_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> : merge_config<1024, 2> +{}; + +// Based on key_type = int, value_type = int +template +struct default_merge_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> : merge_config<512, 2> +{}; + +// Based on key_type = int, value_type = short +template +struct default_merge_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> : merge_config<512, 4> +{}; + +// Based on key_type = int, value_type = int8_t +template +struct default_merge_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : merge_config<512, 4> +{}; + +// Based on key_type = int, value_type = empty_type +template +struct default_merge_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))>> + : merge_config<256, 10> +{}; + +// Based on key_type = short, value_type = int64_t +template +struct default_merge_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> : merge_config<1024, 2> +{}; + +// Based on key_type = short, value_type = int +template +struct default_merge_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> : merge_config<512, 2> +{}; + +// Based on key_type = short, value_type = short +template +struct default_merge_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> : merge_config<512, 4> +{}; + +// Based on key_type = short, value_type = int8_t +template +struct default_merge_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : merge_config<512, 4> +{}; + +// Based on key_type = short, value_type = empty_type +template +struct default_merge_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))>> + : merge_config<512, 8> +{}; + +// Based on key_type = int8_t, value_type = int64_t +template +struct default_merge_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> + : merge_config<1024, 2> +{}; + +// Based on key_type = int8_t, value_type = int +template +struct default_merge_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> + : merge_config<1024, 2> +{}; + +// Based on key_type = int8_t, value_type = short +template +struct default_merge_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> + : merge_config<512, 4> +{}; + +// Based on key_type = int8_t, value_type = int8_t +template +struct default_merge_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : merge_config<512, 4> +{}; + +// Based on key_type = int8_t, value_type = empty_type +template +struct default_merge_config< + static_cast(target_arch::unknown), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))>> + : merge_config<512, 8> +{}; + +// Based on key_type = double, value_type = int64_t +template +struct default_merge_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> : merge_config<1024, 2> +{}; + +// Based on key_type = double, value_type = int +template +struct default_merge_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> : merge_config<512, 2> +{}; + +// Based on key_type = double, value_type = short +template +struct default_merge_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> : merge_config<256, 5> +{}; + +// Based on key_type = double, value_type = int8_t +template +struct default_merge_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : merge_config<256, 5> +{}; + +// Based on key_type = double, value_type = empty_type +template +struct default_merge_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))>> + : merge_config<256, 7> +{}; + +// Based on key_type = float, value_type = int64_t +template +struct default_merge_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> : merge_config<1024, 2> +{}; + +// Based on key_type = float, value_type = int +template +struct default_merge_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> : merge_config<512, 2> +{}; + +// Based on key_type = float, value_type = short +template +struct default_merge_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> : merge_config<512, 4> +{}; + +// Based on key_type = float, value_type = int8_t +template +struct default_merge_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : merge_config<512, 4> +{}; + +// Based on key_type = float, value_type = empty_type +template +struct default_merge_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))>> + : merge_config<512, 4> +{}; + +// Based on key_type = rocprim::half, value_type = int64_t +template +struct default_merge_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> + : merge_config<1024, 2> +{}; + +// Based on key_type = rocprim::half, value_type = int +template +struct default_merge_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> + : merge_config<1024, 2> +{}; + +// Based on key_type = rocprim::half, value_type = short +template +struct default_merge_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> + : merge_config<256, 4> +{}; + +// Based on key_type = rocprim::half, value_type = int8_t +template +struct default_merge_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : merge_config<256, 4> +{}; + +// Based on key_type = rocprim::half, value_type = empty_type +template +struct default_merge_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))>> + : merge_config<256, 8> +{}; + +// Based on key_type = int64_t, value_type = int64_t +template +struct default_merge_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> : merge_config<1024, 2> +{}; + +// Based on key_type = int64_t, value_type = int +template +struct default_merge_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> : merge_config<512, 2> +{}; + +// Based on key_type = int64_t, value_type = short +template +struct default_merge_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> : merge_config<256, 5> +{}; + +// Based on key_type = int64_t, value_type = int8_t +template +struct default_merge_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : merge_config<256, 5> +{}; + +// Based on key_type = int64_t, value_type = empty_type +template +struct default_merge_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))>> + : merge_config<256, 7> +{}; + +// Based on key_type = int, value_type = int64_t +template +struct default_merge_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> : merge_config<1024, 2> +{}; + +// Based on key_type = int, value_type = int +template +struct default_merge_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> : merge_config<512, 2> +{}; + +// Based on key_type = int, value_type = short +template +struct default_merge_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> : merge_config<512, 4> +{}; + +// Based on key_type = int, value_type = int8_t +template +struct default_merge_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : merge_config<512, 4> +{}; + +// Based on key_type = int, value_type = empty_type +template +struct default_merge_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))>> + : merge_config<256, 10> +{}; + +// Based on key_type = short, value_type = int64_t +template +struct default_merge_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))>> : merge_config<1024, 2> +{}; + +// Based on key_type = short, value_type = int +template +struct default_merge_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))>> : merge_config<512, 2> +{}; + +// Based on key_type = short, value_type = short +template +struct default_merge_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))>> : merge_config<512, 4> +{}; + +// Based on key_type = short, value_type = int8_t +template +struct default_merge_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : merge_config<512, 4> +{}; + +// Based on key_type = short, value_type = empty_type +template +struct default_merge_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))>> + : merge_config<512, 8> +{}; + +// Based on key_type = int8_t, value_type = int64_t +template +struct default_merge_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> + : merge_config<1024, 2> +{}; + +// Based on key_type = int8_t, value_type = int +template +struct default_merge_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> + : merge_config<1024, 2> +{}; + +// Based on key_type = int8_t, value_type = short +template +struct default_merge_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> + : merge_config<512, 4> +{}; + +// Based on key_type = int8_t, value_type = int8_t +template +struct default_merge_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))>> + : merge_config<512, 4> +{}; + +// Based on key_type = int8_t, value_type = empty_type +template +struct default_merge_config< + static_cast(target_arch::gfx90a), + key_type, + value_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))>> + : merge_config<512, 8> +{}; + +} // end namespace detail + +END_ROCPRIM_NAMESPACE + +/// @} +// end of group primitivesmodule_deviceconfigs + +#endif // ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_MERGE_HPP_ \ No newline at end of file diff --git a/rocprim/include/rocprim/device/detail/config/device_select_predicated_flag.hpp b/rocprim/include/rocprim/device/detail/config/device_select_predicated_flag.hpp new file mode 100644 index 000000000..9b093ae3b --- /dev/null +++ b/rocprim/include/rocprim/device/detail/config/device_select_predicated_flag.hpp @@ -0,0 +1,1949 @@ +// Copyright (c) 2022-2024 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_SELECT_PREDICATED_FLAG_HPP_ +#define ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_SELECT_PREDICATED_FLAG_HPP_ + +#include "../../../type_traits.hpp" +#include "../device_config_helper.hpp" + +#include + +/* DO NOT EDIT THIS FILE + * This file is automatically generated by `/scripts/autotune/create_optimization.py`. + * so most likely you want to edit rocprim/device/device_(algo)_config.hpp + */ + +/// \addtogroup primitivesmodule_deviceconfigs +/// @{ + +BEGIN_ROCPRIM_NAMESPACE + +namespace detail +{ + +template +struct default_select_predicated_flag_config : default_partition_config_base::type +{}; + +// Based on data_type = double, flag_type = int64_t +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx1030), + data_type, + flag_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))>> : select_config<512, 4> +{}; + +// Based on data_type = double, flag_type = int +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx1030), + data_type, + flag_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))>> : select_config<512, 4> +{}; + +// Based on data_type = double, flag_type = short +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx1030), + data_type, + flag_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))>> : select_config<512, 4> +{}; + +// Based on data_type = double, flag_type = int8_t +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx1030), + data_type, + flag_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 1))>> + : select_config<512, 4> +{}; + +// Based on data_type = float, flag_type = int64_t +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx1030), + data_type, + flag_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))>> : select_config<512, 4> +{}; + +// Based on data_type = float, flag_type = int +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx1030), + data_type, + flag_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))>> : select_config<512, 4> +{}; + +// Based on data_type = float, flag_type = short +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx1030), + data_type, + flag_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))>> : select_config<512, 4> +{}; + +// Based on data_type = float, flag_type = int8_t +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx1030), + data_type, + flag_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 1))>> + : select_config<512, 4> +{}; + +// Based on data_type = rocprim::half, flag_type = int64_t +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx1030), + data_type, + flag_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> + : select_config<512, 4> +{}; + +// Based on data_type = rocprim::half, flag_type = int +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx1030), + data_type, + flag_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> + : select_config<512, 8> +{}; + +// Based on data_type = rocprim::half, flag_type = short +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx1030), + data_type, + flag_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> + : select_config<512, 8> +{}; + +// Based on data_type = rocprim::half, flag_type = int8_t +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx1030), + data_type, + flag_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 1))>> : select_config<512, 16> +{}; + +// Based on data_type = int64_t, flag_type = int64_t +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx1030), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) + && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> + : select_config<512, 4> +{}; + +// Based on data_type = int64_t, flag_type = int +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx1030), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) + && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> + : select_config<512, 4> +{}; + +// Based on data_type = int64_t, flag_type = short +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx1030), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) + && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> + : select_config<512, 4> +{}; + +// Based on data_type = int64_t, flag_type = int8_t +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx1030), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) + && (sizeof(flag_type) <= 1))>> : select_config<512, 4> +{}; + +// Based on data_type = int, flag_type = int64_t +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx1030), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) + && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> + : select_config<512, 4> +{}; + +// Based on data_type = int, flag_type = int +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx1030), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) + && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> + : select_config<512, 4> +{}; + +// Based on data_type = int, flag_type = short +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx1030), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) + && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> + : select_config<512, 4> +{}; + +// Based on data_type = int, flag_type = int8_t +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx1030), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) + && (sizeof(flag_type) <= 1))>> : select_config<512, 4> +{}; + +// Based on data_type = short, flag_type = int64_t +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx1030), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) + && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> + : select_config<512, 4> +{}; + +// Based on data_type = short, flag_type = int +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx1030), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) + && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> + : select_config<512, 8> +{}; + +// Based on data_type = short, flag_type = short +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx1030), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) + && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> + : select_config<512, 8> +{}; + +// Based on data_type = short, flag_type = int8_t +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx1030), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) + && (sizeof(flag_type) <= 1))>> : select_config<512, 16> +{}; + +// Based on data_type = int8_t, flag_type = int64_t +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx1030), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))>> : select_config<512, 4> +{}; + +// Based on data_type = int8_t, flag_type = int +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx1030), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))>> : select_config<512, 8> +{}; + +// Based on data_type = int8_t, flag_type = short +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx1030), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))>> : select_config<512, 16> +{}; + +// Based on data_type = int8_t, flag_type = int8_t +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx1030), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 1))>> + : select_config<512, 16> +{}; + +// Based on data_type = double, flag_type = int64_t +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx1100), + data_type, + flag_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))>> : select_config<512, 4> +{}; + +// Based on data_type = double, flag_type = int +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx1100), + data_type, + flag_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))>> : select_config<192, 4> +{}; + +// Based on data_type = double, flag_type = short +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx1100), + data_type, + flag_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))>> : select_config<512, 8> +{}; + +// Based on data_type = double, flag_type = int8_t +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx1100), + data_type, + flag_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 1))>> + : select_config<512, 8> +{}; + +// Based on data_type = float, flag_type = int64_t +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx1100), + data_type, + flag_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))>> : select_config<512, 8> +{}; + +// Based on data_type = float, flag_type = int +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx1100), + data_type, + flag_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))>> : select_config<192, 6> +{}; + +// Based on data_type = float, flag_type = short +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx1100), + data_type, + flag_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))>> : select_config<192, 8> +{}; + +// Based on data_type = float, flag_type = int8_t +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx1100), + data_type, + flag_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 1))>> + : select_config<192, 12> +{}; + +// Based on data_type = rocprim::half, flag_type = int64_t +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx1100), + data_type, + flag_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> + : select_config<512, 8> +{}; + +// Based on data_type = rocprim::half, flag_type = int +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx1100), + data_type, + flag_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> + : select_config<512, 16> +{}; + +// Based on data_type = rocprim::half, flag_type = short +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx1100), + data_type, + flag_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> + : select_config<384, 18> +{}; + +// Based on data_type = rocprim::half, flag_type = int8_t +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx1100), + data_type, + flag_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 1))>> : select_config<512, 20> +{}; + +// Based on data_type = int64_t, flag_type = int64_t +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx1100), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) + && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> + : select_config<512, 4> +{}; + +// Based on data_type = int64_t, flag_type = int +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx1100), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) + && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> + : select_config<192, 4> +{}; + +// Based on data_type = int64_t, flag_type = short +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx1100), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) + && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> + : select_config<512, 8> +{}; + +// Based on data_type = int64_t, flag_type = int8_t +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx1100), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) + && (sizeof(flag_type) <= 1))>> : select_config<512, 8> +{}; + +// Based on data_type = int, flag_type = int64_t +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx1100), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) + && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> + : select_config<512, 8> +{}; + +// Based on data_type = int, flag_type = int +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx1100), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) + && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> + : select_config<192, 6> +{}; + +// Based on data_type = int, flag_type = short +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx1100), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) + && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> + : select_config<192, 8> +{}; + +// Based on data_type = int, flag_type = int8_t +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx1100), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) + && (sizeof(flag_type) <= 1))>> : select_config<192, 12> +{}; + +// Based on data_type = short, flag_type = int64_t +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx1100), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) + && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> + : select_config<512, 8> +{}; + +// Based on data_type = short, flag_type = int +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx1100), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) + && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> + : select_config<512, 16> +{}; + +// Based on data_type = short, flag_type = short +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx1100), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) + && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> + : select_config<384, 18> +{}; + +// Based on data_type = short, flag_type = int8_t +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx1100), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) + && (sizeof(flag_type) <= 1))>> : select_config<512, 20> +{}; + +// Based on data_type = int8_t, flag_type = int64_t +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx1100), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))>> : select_config<512, 8> +{}; + +// Based on data_type = int8_t, flag_type = int +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx1100), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))>> : select_config<512, 16> +{}; + +// Based on data_type = int8_t, flag_type = short +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx1100), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))>> : select_config<512, 20> +{}; + +// Based on data_type = int8_t, flag_type = int8_t +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx1100), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 1))>> + : select_config<512, 20> +{}; + +// Based on data_type = double, flag_type = int64_t +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx906), + data_type, + flag_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))>> : select_config<192, 7> +{}; + +// Based on data_type = double, flag_type = int +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx906), + data_type, + flag_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))>> : select_config<192, 7> +{}; + +// Based on data_type = double, flag_type = short +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx906), + data_type, + flag_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))>> : select_config<256, 7> +{}; + +// Based on data_type = double, flag_type = int8_t +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx906), + data_type, + flag_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 1))>> + : select_config<256, 7> +{}; + +// Based on data_type = float, flag_type = int64_t +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx906), + data_type, + flag_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))>> : select_config<256, 6> +{}; + +// Based on data_type = float, flag_type = int +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx906), + data_type, + flag_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))>> : select_config<256, 11> +{}; + +// Based on data_type = float, flag_type = short +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx906), + data_type, + flag_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))>> : select_config<256, 10> +{}; + +// Based on data_type = float, flag_type = int8_t +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx906), + data_type, + flag_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 1))>> + : select_config<256, 11> +{}; + +// Based on data_type = rocprim::half, flag_type = int64_t +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx906), + data_type, + flag_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> + : select_config<256, 6> +{}; + +// Based on data_type = rocprim::half, flag_type = int +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx906), + data_type, + flag_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> + : select_config<256, 12> +{}; + +// Based on data_type = rocprim::half, flag_type = short +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx906), + data_type, + flag_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> + : select_config<256, 18> +{}; + +// Based on data_type = rocprim::half, flag_type = int8_t +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx906), + data_type, + flag_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 1))>> : select_config<256, 18> +{}; + +// Based on data_type = int64_t, flag_type = int64_t +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx906), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) + && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> + : select_config<192, 7> +{}; + +// Based on data_type = int64_t, flag_type = int +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx906), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) + && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> + : select_config<192, 7> +{}; + +// Based on data_type = int64_t, flag_type = short +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx906), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) + && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> + : select_config<256, 7> +{}; + +// Based on data_type = int64_t, flag_type = int8_t +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx906), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) + && (sizeof(flag_type) <= 1))>> : select_config<256, 7> +{}; + +// Based on data_type = int, flag_type = int64_t +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx906), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) + && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> + : select_config<256, 6> +{}; + +// Based on data_type = int, flag_type = int +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx906), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) + && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> + : select_config<256, 11> +{}; + +// Based on data_type = int, flag_type = short +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx906), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) + && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> + : select_config<256, 10> +{}; + +// Based on data_type = int, flag_type = int8_t +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx906), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) + && (sizeof(flag_type) <= 1))>> : select_config<256, 11> +{}; + +// Based on data_type = short, flag_type = int64_t +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx906), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) + && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> + : select_config<256, 6> +{}; + +// Based on data_type = short, flag_type = int +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx906), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) + && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> + : select_config<256, 12> +{}; + +// Based on data_type = short, flag_type = short +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx906), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) + && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> + : select_config<256, 18> +{}; + +// Based on data_type = short, flag_type = int8_t +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx906), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) + && (sizeof(flag_type) <= 1))>> : select_config<256, 18> +{}; + +// Based on data_type = int8_t, flag_type = int64_t +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx906), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))>> : select_config<256, 6> +{}; + +// Based on data_type = int8_t, flag_type = int +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx906), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))>> : select_config<256, 11> +{}; + +// Based on data_type = int8_t, flag_type = short +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx906), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))>> : select_config<256, 24> +{}; + +// Based on data_type = int8_t, flag_type = int8_t +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx906), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 1))>> + : select_config<192, 24> +{}; + +// Based on data_type = double, flag_type = int64_t +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx908), + data_type, + flag_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))>> : select_config<128, 6> +{}; + +// Based on data_type = double, flag_type = int +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx908), + data_type, + flag_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))>> : select_config<128, 6> +{}; + +// Based on data_type = double, flag_type = short +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx908), + data_type, + flag_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))>> : select_config<128, 6> +{}; + +// Based on data_type = double, flag_type = int8_t +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx908), + data_type, + flag_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 1))>> + : select_config<128, 7> +{}; + +// Based on data_type = float, flag_type = int64_t +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx908), + data_type, + flag_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))>> : select_config<256, 6> +{}; + +// Based on data_type = float, flag_type = int +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx908), + data_type, + flag_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))>> : select_config<128, 11> +{}; + +// Based on data_type = float, flag_type = short +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx908), + data_type, + flag_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))>> : select_config<256, 10> +{}; + +// Based on data_type = float, flag_type = int8_t +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx908), + data_type, + flag_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 1))>> + : select_config<256, 11> +{}; + +// Based on data_type = rocprim::half, flag_type = int64_t +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx908), + data_type, + flag_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> + : select_config<256, 6> +{}; + +// Based on data_type = rocprim::half, flag_type = int +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx908), + data_type, + flag_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> + : select_config<256, 12> +{}; + +// Based on data_type = rocprim::half, flag_type = short +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx908), + data_type, + flag_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> + : select_config<512, 18> +{}; + +// Based on data_type = rocprim::half, flag_type = int8_t +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx908), + data_type, + flag_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 1))>> : select_config<256, 18> +{}; + +// Based on data_type = int64_t, flag_type = int64_t +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx908), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) + && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> + : select_config<128, 6> +{}; + +// Based on data_type = int64_t, flag_type = int +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx908), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) + && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> + : select_config<128, 6> +{}; + +// Based on data_type = int64_t, flag_type = short +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx908), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) + && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> + : select_config<128, 6> +{}; + +// Based on data_type = int64_t, flag_type = int8_t +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx908), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) + && (sizeof(flag_type) <= 1))>> : select_config<128, 7> +{}; + +// Based on data_type = int, flag_type = int64_t +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx908), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) + && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> + : select_config<256, 6> +{}; + +// Based on data_type = int, flag_type = int +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx908), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) + && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> + : select_config<128, 11> +{}; + +// Based on data_type = int, flag_type = short +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx908), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) + && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> + : select_config<256, 10> +{}; + +// Based on data_type = int, flag_type = int8_t +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx908), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) + && (sizeof(flag_type) <= 1))>> : select_config<256, 11> +{}; + +// Based on data_type = short, flag_type = int64_t +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx908), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) + && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> + : select_config<256, 6> +{}; + +// Based on data_type = short, flag_type = int +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx908), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) + && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> + : select_config<256, 12> +{}; + +// Based on data_type = short, flag_type = short +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx908), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) + && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> + : select_config<512, 18> +{}; + +// Based on data_type = short, flag_type = int8_t +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx908), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) + && (sizeof(flag_type) <= 1))>> : select_config<256, 18> +{}; + +// Based on data_type = int8_t, flag_type = int64_t +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx908), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))>> : select_config<256, 6> +{}; + +// Based on data_type = int8_t, flag_type = int +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx908), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))>> : select_config<256, 11> +{}; + +// Based on data_type = int8_t, flag_type = short +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx908), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))>> : select_config<256, 24> +{}; + +// Based on data_type = int8_t, flag_type = int8_t +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx908), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 1))>> + : select_config<256, 24> +{}; + +// Based on data_type = double, flag_type = int64_t +template +struct default_select_predicated_flag_config< + static_cast(target_arch::unknown), + data_type, + flag_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))>> : select_config<128, 6> +{}; + +// Based on data_type = double, flag_type = int +template +struct default_select_predicated_flag_config< + static_cast(target_arch::unknown), + data_type, + flag_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))>> : select_config<128, 6> +{}; + +// Based on data_type = double, flag_type = short +template +struct default_select_predicated_flag_config< + static_cast(target_arch::unknown), + data_type, + flag_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))>> : select_config<128, 6> +{}; + +// Based on data_type = double, flag_type = int8_t +template +struct default_select_predicated_flag_config< + static_cast(target_arch::unknown), + data_type, + flag_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 1))>> + : select_config<128, 7> +{}; + +// Based on data_type = float, flag_type = int64_t +template +struct default_select_predicated_flag_config< + static_cast(target_arch::unknown), + data_type, + flag_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))>> : select_config<256, 6> +{}; + +// Based on data_type = float, flag_type = int +template +struct default_select_predicated_flag_config< + static_cast(target_arch::unknown), + data_type, + flag_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))>> : select_config<128, 11> +{}; + +// Based on data_type = float, flag_type = short +template +struct default_select_predicated_flag_config< + static_cast(target_arch::unknown), + data_type, + flag_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))>> : select_config<256, 10> +{}; + +// Based on data_type = float, flag_type = int8_t +template +struct default_select_predicated_flag_config< + static_cast(target_arch::unknown), + data_type, + flag_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 1))>> + : select_config<256, 11> +{}; + +// Based on data_type = rocprim::half, flag_type = int64_t +template +struct default_select_predicated_flag_config< + static_cast(target_arch::unknown), + data_type, + flag_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> + : select_config<256, 6> +{}; + +// Based on data_type = rocprim::half, flag_type = int +template +struct default_select_predicated_flag_config< + static_cast(target_arch::unknown), + data_type, + flag_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> + : select_config<256, 12> +{}; + +// Based on data_type = rocprim::half, flag_type = short +template +struct default_select_predicated_flag_config< + static_cast(target_arch::unknown), + data_type, + flag_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> + : select_config<512, 18> +{}; + +// Based on data_type = rocprim::half, flag_type = int8_t +template +struct default_select_predicated_flag_config< + static_cast(target_arch::unknown), + data_type, + flag_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 1))>> : select_config<256, 18> +{}; + +// Based on data_type = int64_t, flag_type = int64_t +template +struct default_select_predicated_flag_config< + static_cast(target_arch::unknown), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) + && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> + : select_config<128, 6> +{}; + +// Based on data_type = int64_t, flag_type = int +template +struct default_select_predicated_flag_config< + static_cast(target_arch::unknown), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) + && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> + : select_config<128, 6> +{}; + +// Based on data_type = int64_t, flag_type = short +template +struct default_select_predicated_flag_config< + static_cast(target_arch::unknown), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) + && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> + : select_config<128, 6> +{}; + +// Based on data_type = int64_t, flag_type = int8_t +template +struct default_select_predicated_flag_config< + static_cast(target_arch::unknown), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) + && (sizeof(flag_type) <= 1))>> : select_config<128, 7> +{}; + +// Based on data_type = int, flag_type = int64_t +template +struct default_select_predicated_flag_config< + static_cast(target_arch::unknown), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) + && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> + : select_config<256, 6> +{}; + +// Based on data_type = int, flag_type = int +template +struct default_select_predicated_flag_config< + static_cast(target_arch::unknown), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) + && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> + : select_config<128, 11> +{}; + +// Based on data_type = int, flag_type = short +template +struct default_select_predicated_flag_config< + static_cast(target_arch::unknown), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) + && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> + : select_config<256, 10> +{}; + +// Based on data_type = int, flag_type = int8_t +template +struct default_select_predicated_flag_config< + static_cast(target_arch::unknown), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) + && (sizeof(flag_type) <= 1))>> : select_config<256, 11> +{}; + +// Based on data_type = short, flag_type = int64_t +template +struct default_select_predicated_flag_config< + static_cast(target_arch::unknown), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) + && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> + : select_config<256, 6> +{}; + +// Based on data_type = short, flag_type = int +template +struct default_select_predicated_flag_config< + static_cast(target_arch::unknown), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) + && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> + : select_config<256, 12> +{}; + +// Based on data_type = short, flag_type = short +template +struct default_select_predicated_flag_config< + static_cast(target_arch::unknown), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) + && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> + : select_config<512, 18> +{}; + +// Based on data_type = short, flag_type = int8_t +template +struct default_select_predicated_flag_config< + static_cast(target_arch::unknown), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) + && (sizeof(flag_type) <= 1))>> : select_config<256, 18> +{}; + +// Based on data_type = int8_t, flag_type = int64_t +template +struct default_select_predicated_flag_config< + static_cast(target_arch::unknown), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))>> : select_config<256, 6> +{}; + +// Based on data_type = int8_t, flag_type = int +template +struct default_select_predicated_flag_config< + static_cast(target_arch::unknown), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))>> : select_config<256, 11> +{}; + +// Based on data_type = int8_t, flag_type = short +template +struct default_select_predicated_flag_config< + static_cast(target_arch::unknown), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))>> : select_config<256, 24> +{}; + +// Based on data_type = int8_t, flag_type = int8_t +template +struct default_select_predicated_flag_config< + static_cast(target_arch::unknown), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 1))>> + : select_config<256, 24> +{}; + +// Based on data_type = double, flag_type = int64_t +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx90a), + data_type, + flag_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))>> : select_config<128, 6> +{}; + +// Based on data_type = double, flag_type = int +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx90a), + data_type, + flag_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))>> : select_config<128, 6> +{}; + +// Based on data_type = double, flag_type = short +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx90a), + data_type, + flag_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))>> : select_config<128, 6> +{}; + +// Based on data_type = double, flag_type = int8_t +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx90a), + data_type, + flag_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 1))>> + : select_config<128, 7> +{}; + +// Based on data_type = float, flag_type = int64_t +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx90a), + data_type, + flag_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))>> : select_config<256, 6> +{}; + +// Based on data_type = float, flag_type = int +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx90a), + data_type, + flag_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))>> : select_config<128, 11> +{}; + +// Based on data_type = float, flag_type = short +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx90a), + data_type, + flag_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))>> : select_config<256, 10> +{}; + +// Based on data_type = float, flag_type = int8_t +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx90a), + data_type, + flag_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 1))>> + : select_config<256, 11> +{}; + +// Based on data_type = rocprim::half, flag_type = int64_t +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx90a), + data_type, + flag_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> + : select_config<256, 6> +{}; + +// Based on data_type = rocprim::half, flag_type = int +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx90a), + data_type, + flag_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> + : select_config<256, 12> +{}; + +// Based on data_type = rocprim::half, flag_type = short +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx90a), + data_type, + flag_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> + : select_config<512, 18> +{}; + +// Based on data_type = rocprim::half, flag_type = int8_t +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx90a), + data_type, + flag_type, + std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 1))>> : select_config<256, 18> +{}; + +// Based on data_type = int64_t, flag_type = int64_t +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx90a), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) + && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> + : select_config<128, 6> +{}; + +// Based on data_type = int64_t, flag_type = int +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx90a), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) + && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> + : select_config<128, 6> +{}; + +// Based on data_type = int64_t, flag_type = short +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx90a), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) + && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> + : select_config<128, 6> +{}; + +// Based on data_type = int64_t, flag_type = int8_t +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx90a), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) + && (sizeof(flag_type) <= 1))>> : select_config<128, 7> +{}; + +// Based on data_type = int, flag_type = int64_t +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx90a), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) + && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> + : select_config<256, 6> +{}; + +// Based on data_type = int, flag_type = int +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx90a), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) + && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> + : select_config<128, 11> +{}; + +// Based on data_type = int, flag_type = short +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx90a), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) + && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> + : select_config<256, 10> +{}; + +// Based on data_type = int, flag_type = int8_t +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx90a), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) + && (sizeof(flag_type) <= 1))>> : select_config<256, 11> +{}; + +// Based on data_type = short, flag_type = int64_t +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx90a), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) + && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> + : select_config<256, 6> +{}; + +// Based on data_type = short, flag_type = int +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx90a), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) + && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> + : select_config<256, 12> +{}; + +// Based on data_type = short, flag_type = short +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx90a), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) + && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> + : select_config<512, 18> +{}; + +// Based on data_type = short, flag_type = int8_t +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx90a), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) + && (sizeof(flag_type) <= 1))>> : select_config<256, 18> +{}; + +// Based on data_type = int8_t, flag_type = int64_t +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx90a), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))>> : select_config<256, 6> +{}; + +// Based on data_type = int8_t, flag_type = int +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx90a), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))>> : select_config<256, 11> +{}; + +// Based on data_type = int8_t, flag_type = short +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx90a), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))>> : select_config<256, 24> +{}; + +// Based on data_type = int8_t, flag_type = int8_t +template +struct default_select_predicated_flag_config< + static_cast(target_arch::gfx90a), + data_type, + flag_type, + std::enable_if_t<(!bool(rocprim::is_floating_point::value) + && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 1))>> + : select_config<256, 24> +{}; + +} // end namespace detail + +END_ROCPRIM_NAMESPACE + +/// @} +// end of group primitivesmodule_deviceconfigs + +#endif // ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_SELECT_PREDICATED_FLAG_HPP_ \ No newline at end of file diff --git a/rocprim/include/rocprim/device/detail/device_adjacent_find.hpp b/rocprim/include/rocprim/device/detail/device_adjacent_find.hpp new file mode 100644 index 000000000..737276084 --- /dev/null +++ b/rocprim/include/rocprim/device/detail/device_adjacent_find.hpp @@ -0,0 +1,156 @@ +// Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_DEVICE_DETAIL_DEVICE_ADJACENT_FIND_HPP_ +#define ROCPRIM_DEVICE_DETAIL_DEVICE_ADJACENT_FIND_HPP_ + +#include "device_config_helper.hpp" +#include "ordered_block_id.hpp" + +#include "../../block/block_load.hpp" +#include "../../block/block_reduce.hpp" +#include "../../intrinsics/thread.hpp" + +BEGIN_ROCPRIM_NAMESPACE + +namespace detail +{ +namespace adjacent_find +{ +template +ROCPRIM_KERNEL __launch_bounds__(1) +void init_adjacent_find(OutputT* reduce_output, + ordered_block_id ordered_tile_id, + const size_t size) +{ + // Reset output value. + *reduce_output = size; + + // Reset ordered_block_id. + ordered_tile_id.reset(); +} + +template +ROCPRIM_KERNEL +#ifndef DOXYGEN_DOCUMENTATION_BUILD +__launch_bounds__(device_params().kernel_config.block_size) +#endif +void block_reduce_kernel(TransformedInputIterator transformed_input, + ReduceIndexIterator reduce_output, + const std::size_t size, + BinaryPred op, + OrderedTileIdType ordered_tile_id) +{ + static constexpr adjacent_find_config_params params = device_params(); + static constexpr unsigned int block_size = params.kernel_config.block_size; + static constexpr unsigned int items_per_thread = params.kernel_config.items_per_thread; + static constexpr unsigned int items_per_tile = block_size * items_per_thread; + + using transformed_input_type = + typename std::iterator_traits::value_type; + using block_reduce_type = ::rocprim::block_reduce< + transformed_input_type, + block_size, + block_reduce_algorithm::raking_reduce>; + + ROCPRIM_SHARED_MEMORY union + { + typename decltype(ordered_tile_id)::storage_type tile_id; + std::size_t global_reduce_output; + } storage; + + // Get initial tile id + const unsigned int thread_id = threadIdx.x; + std::size_t tile_offset = ordered_tile_id.get(threadIdx.x, storage.tile_id) * items_per_tile; + + while(tile_offset < size) + { + // First thread of each block loads the latest global adjacent index found + if(thread_id == 0) + { + storage.global_reduce_output = atomic_load(reduce_output); + } + syncthreads(); + + // Early exit if a previous block or tile found an adjacent pair + if(storage.global_reduce_output < tile_offset) + { + return; + } + + // Do block reduction + transformed_input_type transformed_input_values[items_per_thread]; + transformed_input_type output_value; + + if(tile_offset + items_per_tile > size_t{size - 1}) /* Last incomplete processing */ + { + const std::size_t valid_in_last_iteration = size - 1 - tile_offset; + block_load_direct_striped(thread_id, + transformed_input + tile_offset, + transformed_input_values, + valid_in_last_iteration); + + // Thread reductions with boundary check + output_value = transformed_input_values[0]; + ROCPRIM_UNROLL + for(unsigned int i = 1; i < items_per_thread; i++) + { + if(thread_id + i * block_size < valid_in_last_iteration) + { + output_value = op(output_value, transformed_input_values[i]); + } + } + // Reduce thread reductions + block_reduce_type().reduce(output_value, // input + output_value, // output + std::min(valid_in_last_iteration, std::size_t{block_size}), + op); + } + else /* Complete processings */ + { + block_load_direct_striped(thread_id, + transformed_input + tile_offset, + transformed_input_values); + block_reduce_type().reduce(transformed_input_values, // input + output_value, // output + op); + } + + // Save reduction's index into output if an adjacent pair is found + if(thread_id == 0 && output_value < size) + { + // Store global minimum + atomic_min(reduce_output, output_value); + } + + // Get next tile's id + tile_offset = ordered_tile_id.get(threadIdx.x, storage.tile_id) * items_per_tile; + } +} +} // namespace adjacent_find +} // namespace detail + +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_DEVICE_DETAIL_DEVICE_ADJACENT_FIND_HPP_ diff --git a/rocprim/include/rocprim/device/detail/device_batch_memcpy.hpp b/rocprim/include/rocprim/device/detail/device_batch_memcpy.hpp index d43911a54..a28a1a8a1 100644 --- a/rocprim/include/rocprim/device/detail/device_batch_memcpy.hpp +++ b/rocprim/include/rocprim/device/detail/device_batch_memcpy.hpp @@ -53,6 +53,7 @@ #include "rocprim/intrinsics.hpp" #include "rocprim/intrinsics/thread.hpp" +#include "rocprim/common.hpp" #include "rocprim/config.hpp" #include @@ -1198,7 +1199,7 @@ ROCPRIM_INLINE static hipError_t batch_memcpy_func(void* temporary_ } if(debug_synchronous) { - hipStreamSynchronize(stream); + ROCPRIM_RETURN_ON_ERROR(hipStreamSynchronize(stream)); } // Launch batch_memcpy_non_blev_kernel. @@ -1216,7 +1217,7 @@ ROCPRIM_INLINE static hipError_t batch_memcpy_func(void* temporary_ } if(debug_synchronous) { - hipStreamSynchronize(stream); + ROCPRIM_RETURN_ON_ERROR(hipStreamSynchronize(stream)); } // Launch batch_memcpy_blev_kernel. @@ -1232,7 +1233,7 @@ ROCPRIM_INLINE static hipError_t batch_memcpy_func(void* temporary_ } if(debug_synchronous) { - hipStreamSynchronize(stream); + ROCPRIM_RETURN_ON_ERROR(hipStreamSynchronize(stream)); } return hipSuccess; diff --git a/rocprim/include/rocprim/device/detail/device_config_helper.hpp b/rocprim/include/rocprim/device/detail/device_config_helper.hpp index cadaed26d..87146056f 100644 --- a/rocprim/include/rocprim/device/detail/device_config_helper.hpp +++ b/rocprim/include/rocprim/device/detail/device_config_helper.hpp @@ -1093,6 +1093,14 @@ struct find_first_of_config_params kernel_config_params kernel_config{}; }; +struct adjacent_find_config_tag +{}; + +struct adjacent_find_config_params +{ + kernel_config_params kernel_config; +}; + } // namespace detail /// \brief Configuration of device-level find_first_of @@ -1111,6 +1119,24 @@ struct find_first_of_config : public detail::find_first_of_config_params #endif }; +/// \brief Configuration of device-level adjacent_find +/// +/// \tparam BlockSize number of threads in a block. +/// \tparam ItemsPerThread number of items processed by each thread. +template +struct adjacent_find_config : public detail::adjacent_find_config_params +{ + /// \brief Identifies the algorithm associated to the config. + using tag = detail::adjacent_find_config_tag; +#ifndef DOXYGEN_DOCUMENTATION_BUILD + constexpr adjacent_find_config() + : detail::adjacent_find_config_params{ + {BlockSize, ItemsPerThread, ROCPRIM_GRID_SIZE_LIMIT} + } + {} +#endif // DOXYGEN_DOCUMENTATION_BUILD +}; + namespace detail { @@ -1123,6 +1149,134 @@ struct default_find_first_of_config_base using type = find_first_of_config<256, ::rocprim::max(1u, 16u / item_scale)>; }; +template +struct default_adjacent_find_config_base +{ + static constexpr unsigned int item_scale + = ::rocprim::detail::ceiling_div(sizeof(InputT), sizeof(int)); + + using type + = adjacent_find_config::value, + ::rocprim::max(1u, 16u / item_scale)>; +}; + +} // namespace detail + +namespace detail +{ + +struct search_config_params +{ + unsigned int max_shared_key_bytes; + kernel_config_params kernel_config; +}; + +} // namespace detail + +/// \brief Configuration of device-level find_end +/// +/// \tparam BlockSize number of threads in a block. +/// \tparam ItemsPerThread number of items processed by each thread. +/// \tparam MaxSharedKeyBytes maximum number of bytes for which a shared key is used. +template +struct search_config : public detail::search_config_params +{ +#ifndef DOXYGEN_SHOULD_SKIP_THIS + constexpr search_config() + : detail::search_config_params{ + MaxSharedKeyBytes, {BlockSize, ItemsPerThread, ROCPRIM_GRID_SIZE_LIMIT} + } + {} +#endif +}; + +namespace detail +{ +struct search_n_config_params +{ + size_t threshold; + kernel_config_params kernel_config; +}; +} // namespace detail + +/// \brief Configuration of device-level search_n +/// +/// \tparam BlockSize number of threads in a block. +/// \tparam ItemsPerThread number of items processed by each thread. +template +struct search_n_config : public detail::search_n_config_params +{ +#ifndef DOXYGEN_DOCUMENTATION_BUILD + constexpr search_n_config() + : detail::search_n_config_params{ + 6, {BlockSize, ItemsPerThread, 0} + } + {} +#endif +}; + +namespace detail +{ + +struct merge_config_params +{ + kernel_config_params kernel_config; +}; + +} // namespace detail + +/** + * \brief Configuration of device-level merge operation. + * + * \tparam BlockSize number of threads in a block. + * \tparam ItemsPerThread number of items processed by each thread per tile. + */ +template +struct merge_config : public detail::merge_config_params +{ +#ifndef DOXYGEN_SHOULD_SKIP_THIS + /// Number of threads in a block. + static constexpr unsigned int block_size = BlockSize; + /// Number of items processed by each thread per tile. + static constexpr unsigned int items_per_thread = ItemsPerThread; + + constexpr merge_config() + : detail::merge_config_params{ + {BlockSize, ItemsPerThread} + } {}; + +#endif +}; + +namespace detail +{ + +template +struct default_merge_config_base +{ + static constexpr unsigned int item_scale = ::rocprim::detail::ceiling_div( + ::rocprim::max(sizeof(Key), sizeof(Value)), sizeof(int)); + + using type = merge_config::value, + ::rocprim::max(1u, 10u / item_scale)>; +}; + +template +struct default_merge_config_base +{ + static constexpr unsigned int item_scale + = ::rocprim::detail::ceiling_div(sizeof(Key), sizeof(int)); + + using type + = select_type>, + select_type_case>, + select_type_case>, + merge_config::value, + ::rocprim::max(1u, 10u / item_scale)>>; +}; + } // namespace detail END_ROCPRIM_NAMESPACE diff --git a/rocprim/include/rocprim/device/detail/device_nth_element.hpp b/rocprim/include/rocprim/device/detail/device_nth_element.hpp index 30336a269..2ed21127a 100644 --- a/rocprim/include/rocprim/device/detail/device_nth_element.hpp +++ b/rocprim/include/rocprim/device/detail/device_nth_element.hpp @@ -28,6 +28,7 @@ #include "../../block/block_store.hpp" #include "../../config.hpp" +#include "../../common.hpp" #include "../../intrinsics.hpp" #include "../../type_traits.hpp" @@ -48,37 +49,6 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { - -#define ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR(name, size, start) \ - do \ - { \ - hipError_t _error = hipGetLastError(); \ - if(_error != hipSuccess) \ - return _error; \ - if(debug_synchronous) \ - { \ - std::cout << name << "(" << size << ")"; \ - hipError_t __error = hipStreamSynchronize(stream); \ - if(__error != hipSuccess) \ - return __error; \ - auto _end = std::chrono::steady_clock::now(); \ - auto _d = std::chrono::duration_cast>(_end - start); \ - std::cout << " " << _d.count() * 1000 << " ms" << '\n'; \ - } \ - } \ - while(0) - -#define RETURN_ON_ERROR(...) \ - do \ - { \ - hipError_t error = (__VA_ARGS__); \ - if(error != hipSuccess) \ - { \ - return error; \ - } \ - } \ - while(0) - struct nth_element_onesweep_lookback_state { // The two most significant bits are used to indicate the status of the prefix - leaving the other 30 bits for the @@ -687,15 +657,16 @@ ROCPRIM_INLINE hipError_t std::cout << "iteration: " << iteration++ << '\n'; } - RETURN_ON_ERROR(hipMemsetAsync(buckets, 0, sizeof(*buckets) * num_buckets, stream)); + ROCPRIM_RETURN_ON_ERROR(hipMemsetAsync(buckets, 0, sizeof(*buckets) * num_buckets, stream)); - RETURN_ON_ERROR( + ROCPRIM_RETURN_ON_ERROR( hipMemsetAsync(equality_buckets, 0, sizeof(*equality_buckets) * num_buckets, stream)); // Reset lookback scan states to zero, indicating empty prefix. - RETURN_ON_ERROR(nth_element_onesweep_lookback_state::reset(lookback_states, - num_partitions * num_blocks, - stream)); + ROCPRIM_RETURN_ON_ERROR( + nth_element_onesweep_lookback_state::reset(lookback_states, + num_partitions * num_blocks, + stream)); start_timer(); kernel_find_splitters @@ -730,21 +701,21 @@ ROCPRIM_INLINE hipError_t ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("kernel_copy_buckets", size, start); // Copy the results in keys_buffer back to the keys - RETURN_ON_ERROR(transform(keys_buffer, - keys, - size, - ::rocprim::identity(), - stream, - debug_synchronous)); + ROCPRIM_RETURN_ON_ERROR(transform(keys_buffer, + keys, + size, + ::rocprim::identity(), + stream, + debug_synchronous)); n_th_element_iteration_data h_nth_element_data; - RETURN_ON_ERROR(hipMemcpyAsync(&h_nth_element_data, - nth_element_data, - sizeof(h_nth_element_data), - hipMemcpyDeviceToHost, - stream)); + ROCPRIM_RETURN_ON_ERROR(hipMemcpyAsync(&h_nth_element_data, + nth_element_data, + sizeof(h_nth_element_data), + hipMemcpyDeviceToHost, + stream)); - RETURN_ON_ERROR(hipStreamSynchronize(stream)); + ROCPRIM_RETURN_ON_ERROR(hipStreamSynchronize(stream)); size_t offset = h_nth_element_data.offset; size_t bucket_size = h_nth_element_data.size; @@ -768,8 +739,7 @@ ROCPRIM_INLINE hipError_t return hipSuccess; } -#undef ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR -#undef RETURN_ON_ERROR + } // namespace detail diff --git a/rocprim/include/rocprim/device/detail/device_partition.hpp b/rocprim/include/rocprim/device/detail/device_partition.hpp index 54804bb1b..37135cedf 100644 --- a/rocprim/include/rocprim/device/detail/device_partition.hpp +++ b/rocprim/include/rocprim/device/detail/device_partition.hpp @@ -49,9 +49,10 @@ namespace detail enum class select_method { - flag = 0, - predicate = 1, - unique = 2 + flag = 0, + predicate = 1, + predicated_flag = 2, + unique = 3 }; enum class partition_subalgo @@ -63,6 +64,7 @@ enum class partition_subalgo partition_three_way, select_flag, select_predicate, + select_predicated_flag, select_unique, select_unique_by_key }; @@ -114,6 +116,60 @@ ROCPRIM_DEVICE ROCPRIM_INLINE auto ::rocprim::syncthreads(); // sync threads to reuse shared memory } +template +ROCPRIM_DEVICE ROCPRIM_INLINE +auto partition_block_load_flags(InputIterator /* block_predecessor */, + FlagIterator block_flags, + ValueType (& /* values */)[ItemsPerThread], + bool (&is_selected)[ItemsPerThread], + UnaryPredicate predicate, + InequalityOp /* inequality_op */, + StorageType& storage, + const bool /* is_first_block */, + const unsigned int block_thread_id, + const bool is_global_last_block, + const unsigned int valid_in_global_last_block) -> + typename std::enable_if::type +{ + using flag_type = typename std::iterator_traits::value_type; + flag_type flags[ItemsPerThread]; + + if(is_global_last_block) // last block + { + BlockLoadFlagsType().load(block_flags, + flags, + valid_in_global_last_block, + false, + storage.load_flags); + const auto offset = block_thread_id * ItemsPerThread; + ROCPRIM_UNROLL + for(unsigned int i = 0; i < ItemsPerThread; i++) + { + is_selected[i] = ((offset + i) < valid_in_global_last_block) && predicate(flags[i]); + } + } + else + { + BlockLoadFlagsType().load(block_flags, flags, storage.load_flags); + ROCPRIM_UNROLL + for(unsigned int i = 0; i < ItemsPerThread; i++) + { + is_selected[i] = predicate(flags[i]); + } + } + ::rocprim::syncthreads(); // sync threads to reuse shared memory +} + template::value_type; using value_type = typename std::iterator_traits::value_type; + using flag_type = + typename std::conditional::value_type, + bool>::type; // Block primitives using block_load_key_type = ::rocprim:: block_load; using block_load_value_type = ::rocprim:: block_load; - using block_load_flag_type - = ::rocprim::block_load; + using block_load_flag_type = ::rocprim:: + block_load; using block_scan_offset_type = ::rocprim::block_scan; using block_discontinuity_key_type = ::rocprim::block_discontinuity; diff --git a/rocprim/include/rocprim/device/detail/device_search.hpp b/rocprim/include/rocprim/device/detail/device_search.hpp new file mode 100644 index 000000000..06656949f --- /dev/null +++ b/rocprim/include/rocprim/device/detail/device_search.hpp @@ -0,0 +1,440 @@ +// Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_DEVICE_DETAIL_DEVICE_SEARCH_HPP_ +#define ROCPRIM_DEVICE_DETAIL_DEVICE_SEARCH_HPP_ + +#include "../../detail/temp_storage.hpp" + +#include "../../common.hpp" +#include "../../config.hpp" + +#include "../../intrinsics.hpp" +#include "../../iterator/reverse_iterator.hpp" +#include "../config_types.hpp" +#include "../device_search_config.hpp" +#include "../device_transform.hpp" + +#include +#include + +#include +#include + +BEGIN_ROCPRIM_NAMESPACE + +namespace detail +{ + +template +ROCPRIM_DEVICE +void search_kernel_impl(InputIterator1 input, + InputIterator2 keys, + size_t* output, + size_t size, + size_t keys_size, + BinaryFunction compare_function) +{ + constexpr search_config_params params = device_params(); + + constexpr unsigned int block_size = params.kernel_config.block_size; + constexpr unsigned int items_per_thread = params.kernel_config.items_per_thread; + constexpr unsigned int items_per_block = block_size * items_per_thread; + + const unsigned int flat_id = rocprim::detail::block_thread_id<0>(); + const unsigned int flat_block_id = rocprim::detail::block_id<0>(); + + const size_t offset = flat_id * items_per_thread + flat_block_id * items_per_block; + bool find_pattern = false; + + // Check if it can have fit a key and a key has not yet be found with a lower index. + if(offset + keys_size > size || offset > atomic_load(output)) + { + return; + } + + size_t index = 0; + for(size_t id = offset; id < offset + items_per_thread; id++) + { + size_t i = 0; + size_t current_id = id; + for(; i < keys_size - 1 && current_id < size; i++, current_id++) + { + if(!compare_function(input[current_id], keys[i])) + { + break; + } + } + + // If the i is the last value for the key and the compare is also true, + // the pattern is found. + if(current_id < size && i == (keys_size - 1) + && compare_function(input[current_id], keys[i])) + { + index = id; + find_pattern = true; + break; + } + } + + // Construct a mask of threads in this wave which have the same digit. + lane_mask_type peer_mask = ballot(find_pattern); + + wave_barrier(); + + // The number of threads in the warp that have the same digit AND whose lane id is lower + // than the current thread's. + const unsigned int peer_digit_prefix = masked_bit_count(peer_mask); + + if(find_pattern && (peer_digit_prefix == 0)) + { + atomic_min(output, index); + } +} + +template +ROCPRIM_KERNEL +__launch_bounds__(device_params().kernel_config.block_size) +void search_kernel(InputIterator1 input, + InputIterator2 keys, + size_t* output, + size_t size, + size_t keys_size, + BinaryFunction compare_function) +{ + search_kernel_impl(input, keys, output, size, keys_size, compare_function); +} + +template +ROCPRIM_DEVICE +void search_kernel_shared_impl(InputIterator1 input, + InputIterator2 keys, + size_t* output, + size_t size, + size_t keys_size, + BinaryFunction compare_function) +{ + using value_type = typename std::iterator_traits::value_type; + using key_type = typename std::iterator_traits::value_type; + + constexpr search_config_params params = device_params(); + + constexpr unsigned int block_size = params.kernel_config.block_size; + constexpr unsigned int items_per_thread = params.kernel_config.items_per_thread; + constexpr unsigned int items_per_block = block_size * items_per_thread; + constexpr unsigned int max_shared_key = params.max_shared_key_bytes / sizeof(key_type); + + const unsigned int flat_id = rocprim::detail::block_thread_id<0>(); + const unsigned int flat_block_id = rocprim::detail::block_id<0>(); + + const size_t block_offset = flat_block_id * items_per_block; + const size_t offset = flat_id * items_per_thread; + bool find_pattern = false; + + ROCPRIM_SHARED_MEMORY uninitialized_array local_keys_; + ROCPRIM_SHARED_MEMORY uninitialized_array local_input_; + + // Check if a key was already found in a place before this block + if(block_offset > atomic_load(output)) + { + return; + } + + // Load in key in shared memory + const size_t batch_size = ceiling_div(keys_size, block_size); + for(size_t i = 0; i < batch_size; i++) + { + const size_t index = flat_id * batch_size + i; + if(index < keys_size) + { + local_keys_.emplace(index, keys[index]); + } + } + + using block_load_input = block_load; + + value_type elements[items_per_thread]; + + const bool is_complete_block = block_offset + items_per_block <= size; + + // Load in all the input values that are guaranteed to be loaded. + if(is_complete_block) + { + block_load_input().load(input + block_offset, elements); + for(size_t i = 0; i < items_per_thread; i++) + { + const size_t index = flat_id * items_per_thread + i; + local_input_.emplace(index, elements[i]); + } + } + else + { + block_load_input().load(input + block_offset, elements, size - block_offset); + for(size_t i = 0; i < items_per_thread; i++) + { + const size_t index = flat_id * items_per_thread + i; + const size_t index_value = block_offset + index; + if(index_value < size) + { + local_input_.emplace(index, elements[i]); + } + } + } + + const key_type* local_keys = local_keys_.get_unsafe_array(); + const value_type* local_input = local_input_.get_unsafe_array(); + + syncthreads(); + + // Check if it can have fit a key and a key has not yet be found with a lower index. + if(offset + block_offset + keys_size > size || offset > atomic_load(output)) + { + return; + } + + size_t index = 0; + const size_t check = size - block_offset; + const size_t check_both = rocprim::min(check, size_t(items_per_block)); + for(size_t id = offset; id < offset + items_per_thread; id++) + { + size_t i = 0; + size_t current_id = id; + // Values till the items_per_block are in shared_memory + for(; i < keys_size - 1 && current_id < check_both; i++, current_id++) + { + if(!compare_function(local_input[current_id], local_keys[i])) + { + break; + } + } + // Compare values that are not in the shared memory + for(; current_id >= items_per_block && i < keys_size - 1 && current_id < check; + i++, current_id++) + { + if(!compare_function(input[current_id + block_offset], local_keys[i])) + { + break; + } + } + + // If the i is the last value for the key and the compare is also true, + // the pattern is found. + if(current_id + block_offset < size && i == (keys_size - 1) + && compare_function(current_id < items_per_block ? local_input[current_id] + : input[current_id + block_offset], + local_keys[i])) + { + index = id + block_offset; + find_pattern = true; + // Want to find the first occurance, do not need to search further. + break; + } + } + + // Construct a mask of threads in this wave which have the same digit. + lane_mask_type peer_mask = ballot(find_pattern); + + wave_barrier(); + + // The number of threads in the warp that have the same digit AND whose lane id is lower + // than the current thread's. + const unsigned int peer_digit_prefix = masked_bit_count(peer_mask); + + if(find_pattern && (peer_digit_prefix == 0)) + { + atomic_min(output, index); + } +} + +template +ROCPRIM_KERNEL +__launch_bounds__(device_params().kernel_config.block_size) +void search_kernel_shared(InputIterator1 input, + InputIterator2 keys, + size_t* output, + size_t size, + size_t keys_size, + BinaryFunction compare_function) +{ + search_kernel_shared_impl(input, keys, output, size, keys_size, compare_function); +} + +template +ROCPRIM_KERNEL +void set_output_kernel(T* output, T value) +{ + *output = value; +} + +template +ROCPRIM_KERNEL +void reverse_index_kernel(T* output, T size, T keys_size) +{ + // Return the reverse index as long as the index is lower than the size. + if(*output < size) + { + *output = size - keys_size - *output; + } +} + +template +ROCPRIM_INLINE +hipError_t search_impl(void* temporary_storage, + size_t& storage_size, + InputIterator1 input, + InputIterator2 keys, + OutputIterator output, + size_t size, + size_t keys_size, + BinaryFunction compare_function, + hipStream_t stream, + bool debug_synchronous) +{ + using input_type = typename std::iterator_traits::value_type; + using key_type = typename std::iterator_traits::value_type; + using output_type = typename std::iterator_traits::value_type; + + using config = wrapped_search_config; + + target_arch target_arch; + ROCPRIM_RETURN_ON_ERROR(host_target_arch(stream, target_arch)); + + const search_config_params params = dispatch_target_arch(target_arch); + + const unsigned int block_size = params.kernel_config.block_size; + const unsigned int items_per_thread = params.kernel_config.items_per_thread; + const unsigned int items_per_block = block_size * items_per_thread; + + const unsigned int shared_key_mem_size_bytes = params.max_shared_key_bytes; + const unsigned int key_size_bytes = keys_size * sizeof(key_type); + + // Start point for time measurements + std::chrono::steady_clock::time_point start; + + const auto start_timer = [&start, debug_synchronous]() + { + if(debug_synchronous) + { + start = std::chrono::steady_clock::now(); + } + }; + + if(temporary_storage == nullptr) + { + storage_size = sizeof(size_t); + return hipSuccess; + } + + if(keys_size > size) + { + return hipErrorInvalidValue; + } + + size_t* tmp_output = reinterpret_cast(temporary_storage); + + start_timer(); + set_output_kernel<<<1, 1, 0, stream>>>(tmp_output, find_first && keys_size <= 0 ? 0 : size); + ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("set_output_kernel", 1, start); + + if(size > 0 && keys_size > 0) + { + const unsigned int num_blocks = ceiling_div(size, items_per_block); + if(key_size_bytes < shared_key_mem_size_bytes) + { + if ROCPRIM_IF_CONSTEXPR(find_first) + { + start_timer(); + search_kernel_shared + <<>>(input, + keys, + tmp_output, + size, + keys_size, + compare_function); + ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("search_kernel_shared", size, start); + } + else + { + start_timer(); + search_kernel_shared<<>>( + rocprim::make_reverse_iterator(input + size), + rocprim::make_reverse_iterator(keys + keys_size), + tmp_output, + size, + keys_size, + compare_function); + ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("search_kernel_shared", size, start); + } + } + else + { + if ROCPRIM_IF_CONSTEXPR(find_first) + { + start_timer(); + search_kernel<<>>(input, + keys, + tmp_output, + size, + keys_size, + compare_function); + ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("search_kernel", size, start); + } + else + { + start_timer(); + search_kernel<<>>( + rocprim::make_reverse_iterator(input + size), + rocprim::make_reverse_iterator(keys + keys_size), + tmp_output, + size, + keys_size, + compare_function); + ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("search_kernel", size, start); + } + } + + if ROCPRIM_IF_CONSTEXPR(!find_first) + { + start_timer(); + reverse_index_kernel<<<1, 1, 0, stream>>>(tmp_output, size, keys_size); + ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("reverse_index_kernel", 1, start); + } + } + + ROCPRIM_RETURN_ON_ERROR(transform(tmp_output, + output, + 1, + rocprim::identity(), + stream, + debug_synchronous)); + + return hipSuccess; +} + +} // namespace detail + +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_DEVICE_DETAIL_DEVICE_SEARCH_HPP_ diff --git a/rocprim/include/rocprim/device/detail/device_search_n.hpp b/rocprim/include/rocprim/device/detail/device_search_n.hpp new file mode 100644 index 000000000..9ecce58df --- /dev/null +++ b/rocprim/include/rocprim/device/detail/device_search_n.hpp @@ -0,0 +1,439 @@ +// Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_DEVICE_DETAIL_DEVICE_SEARCH_N_HPP_ +#define ROCPRIM_DEVICE_DETAIL_DEVICE_SEARCH_N_HPP_ + +#include "../../common.hpp" +#include "../../config.hpp" +#include "../config_types.hpp" +#include "../device_reduce.hpp" +#include "../device_search_n_config.hpp" +#include "../device_transform.hpp" + +#include + +BEGIN_ROCPRIM_NAMESPACE +namespace detail +{ + +inline void search_n_start_timer(std::chrono::steady_clock::time_point& start, + const bool debug_synchronous) +{ + if(debug_synchronous) + { + start = std::chrono::steady_clock::now(); + } +} + +template +ROCPRIM_KERNEL __launch_bounds__(1) +void search_n_init_kernel(SizeType* __restrict__ output, const SizeType target) +{ + *output = target; +} + +/// \brief Supports all forms of search_n operations, +/// but the efficiency is insufficient when `items_per_block` is too large. +template +ROCPRIM_KERNEL +#ifndef DOXYGEN_DOCUMENTATION_BUILD +__launch_bounds__(device_params().kernel_config.block_size) +#endif +void search_n_normal_kernel(InputIterator input, + size_t* __restrict__ output, + const size_t size, + const size_t count, + const typename std::iterator_traits::value_type* value, + const BinaryPredicate binary_predicate) +{ + constexpr auto params = device_params(); + constexpr auto block_size = params.kernel_config.block_size; + constexpr auto items_per_thread = params.kernel_config.items_per_thread; + constexpr auto items_per_block = block_size * items_per_thread; + + const size_t this_thread_start_idx + = (block_id<0>() * items_per_block) + (items_per_thread * block_thread_id<0>()); + + // TODO: This could cause load imbalance among threads + // So maybe there is a better way to do this + if(size < count + this_thread_start_idx) + { // not able to find a sequence equal to or longer than count + return; + } + + size_t remaining_count = count; + size_t sequence_start_idx = this_thread_start_idx; + + const size_t items_this_thread + = std::min(size - this_thread_start_idx, items_per_thread); + for(size_t i = this_thread_start_idx; + sequence_start_idx - this_thread_start_idx < items_this_thread + && i + remaining_count <= size; + ++i) + { + if(binary_predicate(input[i], *value)) + { + if(--remaining_count == 0) + { + atomic_min(output, sequence_start_idx); + return; + } + } + else + { + remaining_count = count; + sequence_start_idx = i + 1; + } + } +} + +template +ROCPRIM_KERNEL +#ifndef DOXYGEN_DOCUMENTATION_BUILD +__launch_bounds__(device_params().kernel_config.block_size) +#endif +void search_n_find_heads_kernel( + InputIterator input, + const size_t size, + const typename std::iterator_traits::value_type* value, + const BinaryPredicate binary_predicate, + size_t* __restrict__ unfiltered_heads, + const size_t group_size) +{ + constexpr auto params = device_params(); + constexpr auto block_size = params.kernel_config.block_size; + constexpr auto items_per_thread = params.kernel_config.items_per_thread; + constexpr auto items_per_block = block_size * items_per_thread; + + const size_t this_thread_start_idx + = (block_id<0>() * items_per_block) + (items_per_thread * block_thread_id<0>()); + const size_t items_this_thread + = std::min(this_thread_start_idx < size ? size - this_thread_start_idx : 0, + items_per_thread); + + for(size_t i = this_thread_start_idx; i < this_thread_start_idx + items_this_thread; i++) + { + if(binary_predicate(input[i], *value)) + { + if(i == 0) + { // is head // `size - i - 1` is the distance to the end + atomic_min(&(unfiltered_heads[i / group_size]), size - i - 1); + } + else if(!binary_predicate(input[i - 1], *value)) + { // is head + atomic_min(&(unfiltered_heads[i / group_size]), size - i - 1); + } + } + } +} + +template +ROCPRIM_KERNEL +#ifndef DOXYGEN_DOCUMENTATION_BUILD +__launch_bounds__(device_params().kernel_config.block_size) +#endif +void search_n_heads_filter_kernel(const size_t size, + const size_t count, + const size_t* __restrict__ heads, + const size_t heads_size, + size_t* __restrict__ filtered_heads, + size_t* __restrict__ filtered_heads_size) +{ + constexpr auto params = device_params(); + constexpr auto block_size = params.kernel_config.block_size; + constexpr auto items_per_thread = params.kernel_config.items_per_thread; + constexpr auto items_per_block = block_size * items_per_thread; + + const size_t this_thread_start_idx + = (block_id<0>() * items_per_block) + (block_thread_id<0>() * items_per_thread); + const size_t this_thread_end_idx + = std::min(items_per_thread + this_thread_start_idx, heads_size); + for(size_t i = this_thread_start_idx; i < this_thread_end_idx; ++i) + { + const auto cur_val = heads[i]; + if(cur_val == (size_t)-1) + { // this is not a valid head + continue; + } + const size_t this_head = size - cur_val - 1; + if(i + 1 == heads_size) + { // head of last group + if(size - this_head < count) + { // cannot make it to count + continue; + } + } + else if(i + 2 == heads_size) + { // the head before last head (last group might be incomplete so, the head before last head can be invalid) + const auto next_val = heads[i + 1]; + if(((next_val != (size_t)-1) ? ((size - next_val - 1) - this_head - 1) + : (size - this_head)) + < count) + { // cannot make it to count + continue; + } + } + else + { // other heads + const auto next_val = heads[i + 1]; + if((next_val != (size_t)-1) && (((size - next_val - 1) - this_head - 1) < count)) + { // if next head is invalid, the limit of this head should the next head, else it is possible to make the sequence to count + continue; + } + } + filtered_heads[atomic_add(filtered_heads_size, 1)] = this_head; + } +} + +template +ROCPRIM_KERNEL +#ifndef DOXYGEN_DOCUMENTATION_BUILD +__launch_bounds__(device_params().kernel_config.block_size) +#endif +void search_n_discard_heads_kernel( + InputIterator input, + const size_t size, + const size_t count, + const typename std::iterator_traits::value_type* value, + const BinaryPredicate binary_predicate, + size_t* __restrict__ heads, + size_t* num_heads) +{ + constexpr auto params = device_params(); + constexpr auto block_size = params.kernel_config.block_size; + constexpr auto items_per_thread = params.kernel_config.items_per_thread; + constexpr auto items_per_block = block_size * items_per_thread; + + const size_t heads_size = *num_heads; + if(heads_size == 0) + { + return; // should return + } + + size_t num_blocks_needed = ceiling_div(heads_size * count /*group_size*/, items_per_block); + if(block_id<0>() >= num_blocks_needed) + { + return; + } + + const size_t this_thread_start_idx + = (block_id<0>() * items_per_block) + (block_thread_id<0>() * items_per_thread); + + for(size_t global_idx = this_thread_start_idx; + global_idx < this_thread_start_idx + items_per_block; + global_idx++) + { + const size_t g_id /*group id*/ = global_idx / count /*group_size*/; + if(g_id >= heads_size) + { + return; + } + const size_t check_head + = heads[g_id] + 1; // the `head` is already checked, so we check the next value here + const size_t check_count = count - 1; + const size_t idx = check_head + (global_idx % count); + + if((idx >= size) || (idx >= (check_head + check_count))) + { + return; + } + if(!binary_predicate(input[idx], *value)) + { + heads[g_id] = size; + return; + } + } +} + +template +ROCPRIM_INLINE +hipError_t search_n_impl(void* temporary_storage, + size_t& storage_size, + InputIterator input, + OutputIterator output, + const size_t size, + const size_t count, + const typename std::iterator_traits::value_type* value, + const BinaryPredicate binary_predicate, + const hipStream_t stream, + const bool debug_synchronous) +{ + using input_type = typename std::iterator_traits::value_type; + using output_type = typename std::iterator_traits::value_type; + using config = wrapped_search_n_config; + + if(count > size) + { // size must greater than or equal to count + return hipErrorInvalidValue; + } + + target_arch target_arch; + ROCPRIM_RETURN_ON_ERROR(host_target_arch(stream, target_arch)); + + const auto params = dispatch_target_arch(target_arch); + const unsigned int block_size = params.kernel_config.block_size; + const unsigned int items_per_thread = params.kernel_config.items_per_thread; + const unsigned int items_per_block = block_size * items_per_thread; + const unsigned int num_blocks = ceiling_div(size, items_per_block); + + std::chrono::steady_clock::time_point start; + + size_t* tmp_output = reinterpret_cast(temporary_storage); + + if(size == 0 || count <= 0) + { // to be consist to the std::search_n + // calculate size + if(tmp_output == nullptr) + { + storage_size = sizeof(size_t); + return hipSuccess; + } + + // return end or begin + search_n_start_timer(start, debug_synchronous); + search_n_init_kernel<<<1, 1, 0, stream>>>(tmp_output, count <= 0 ? 0 : size); + ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("search_n_init_kernel", 1, start); + ROCPRIM_RETURN_ON_ERROR( + transform(tmp_output, output, 1, identity(), stream, debug_synchronous)); + return hipSuccess; + } + else if(count <= params.threshold) + { // reduce search_n will have a maximum access time of params.threshold + // So if the count is equals to or smaller than params.threshold, `normal_search_n` should be faster + // calculate size + if(tmp_output == nullptr) + { + storage_size = sizeof(size_t); + return hipSuccess; + } + + // do `normal_search_n` + search_n_start_timer(start, debug_synchronous); + search_n_init_kernel<<<1, 1, 0, stream>>>(tmp_output, size); + ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("search_n_init_kernel", 1, start); + // TODO: There can be overlapping between threads, this probably can be optimized + search_n_normal_kernel<<>>(input, + tmp_output, + size, + count, + value, + binary_predicate); + ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("search_n_normal_kernel", size, start); + ROCPRIM_RETURN_ON_ERROR( + transform(tmp_output, output, 1, identity(), stream, debug_synchronous)); + return hipSuccess; + } + else + { // the count is greater than params.threshold + // group_size is equal to `count` + const size_t num_groups = ceiling_div(size, count /*group_size*/); + size_t reduce_storage_size; + ROCPRIM_RETURN_ON_ERROR(reduce(nullptr, + reduce_storage_size, + reinterpret_cast(0), + output, + size, + num_groups, + minimum{}, + stream, + debug_synchronous)); + size_t front_size + = std::max(sizeof(size_t) + (sizeof(size_t) * num_groups), reduce_storage_size); + if(tmp_output == nullptr) + { + storage_size = front_size + (sizeof(size_t) * num_groups); + return hipSuccess; + } + + const size_t num_blocks_for_heads_filter = ceiling_div(num_groups, items_per_block); + const size_t num_blocks_for_discard_heads + = ceiling_div(num_groups * count, items_per_block); + + auto unfiltered_heads = reinterpret_cast(reinterpret_cast(temporary_storage) + + sizeof(size_t)); + auto filtered_heads + = reinterpret_cast(reinterpret_cast(temporary_storage) + front_size); + + search_n_start_timer(start, debug_synchronous); + // initialization + ROCPRIM_RETURN_ON_ERROR(hipMemsetAsync(tmp_output, 0, sizeof(size_t), stream)); + ROCPRIM_RETURN_ON_ERROR( + hipMemsetAsync(unfiltered_heads, -1, sizeof(size_t) * num_groups * 2, stream)); + + // find the thread heads of each group + search_n_find_heads_kernel + <<>>(input, + size, + value, + binary_predicate, + unfiltered_heads, + count /*group_size*/); + ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("search_n_find_heads_kernel", size, start); + + // filter heads + // move valid heads into filtered_heads, and set the size of filtered_heads to tmp_output + search_n_heads_filter_kernel + <<>>(size, + count, + unfiltered_heads, + num_groups, + filtered_heads, + tmp_output); + ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("search_n_heads_filter_kernel", + num_groups, + start); + + // check if any valid heads make a valid sequence + // max access time for each item is 1 + // TODO: num_blocks_for_discard_heads is actually graeter than the actural valid filtered_heads_size + // so the actural num_blocks_for_discard_heads needed is smaller than the current value + search_n_discard_heads_kernel + <<>>( + input, + size, + count, + value, + binary_predicate, + filtered_heads, + tmp_output); // currently the tmp_output contains the actual size of filtered_heads + ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("search_n_discard_heads_kernel ", + num_groups, + start); + + // calculate the minimum valid head + ROCPRIM_RETURN_ON_ERROR(reduce(temporary_storage, + reduce_storage_size, + filtered_heads, + output, + size, // original value + num_groups, + minimum{}, + stream, + debug_synchronous)); + return hipSuccess; // no needs to call transform, return directly + } +} + +} // namespace detail + +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_DEVICE_DEVICE_SEARCH_N_HPP_ diff --git a/rocprim/include/rocprim/device/device_adjacent_difference.hpp b/rocprim/include/rocprim/device/device_adjacent_difference.hpp index 4b4703ebe..39949c55d 100644 --- a/rocprim/include/rocprim/device/device_adjacent_difference.hpp +++ b/rocprim/include/rocprim/device/device_adjacent_difference.hpp @@ -29,6 +29,7 @@ #include "device_transform.hpp" #include "../config.hpp" +#include "../common.hpp" #include "../functional.hpp" #include "../detail/temp_storage.hpp" @@ -53,23 +54,6 @@ BEGIN_ROCPRIM_NAMESPACE #ifndef DOXYGEN_SHOULD_SKIP_THIS // Do not document -#define ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR(name, size, start) \ - { \ - auto _error = hipGetLastError(); \ - if(_error != hipSuccess) \ - return _error; \ - if(debug_synchronous) \ - { \ - std::cout << name << "(" << size << ")"; \ - auto __error = hipStreamSynchronize(stream); \ - if(__error != hipSuccess) \ - return __error; \ - auto _end = std::chrono::high_resolution_clock::now(); \ - auto _d = std::chrono::duration_cast>(_end - start); \ - std::cout << " " << _d.count() * 1000 << " ms" << '\n'; \ - } \ - } - namespace detail { template start; + std::chrono::time_point start; if(debug_synchronous) { std::cout << "index: " << i << '\n'; std::cout << "current_size: " << current_size << '\n'; std::cout << "number of blocks: " << current_blocks << '\n'; - start = std::chrono::high_resolution_clock::now(); + start = std::chrono::steady_clock::now(); } hipLaunchKernelGGL(HIP_KERNEL_NAME(adjacent_difference_kernel), dim3(current_blocks), @@ -228,7 +212,7 @@ hipError_t adjacent_difference_impl(void* const temporary_storage, } } // namespace detail - #undef ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR + #endif // DOXYGEN_SHOULD_SKIP_THIS diff --git a/rocprim/include/rocprim/device/device_adjacent_find.hpp b/rocprim/include/rocprim/device/device_adjacent_find.hpp new file mode 100644 index 000000000..47edc0734 --- /dev/null +++ b/rocprim/include/rocprim/device/device_adjacent_find.hpp @@ -0,0 +1,299 @@ +// Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_DEVICE_DEVICE_ADJACENT_FIND_HPP_ +#define ROCPRIM_DEVICE_DEVICE_ADJACENT_FIND_HPP_ + +#include "detail/device_adjacent_find.hpp" +#include "detail/device_config_helper.hpp" +#include "device_adjacent_find_config.hpp" +#include "device_reduce.hpp" +#include "device_transform.hpp" + +#include "../common.hpp" +#include "../functional.hpp" +#include "../iterator/counting_iterator.hpp" +#include "../iterator/transform_iterator.hpp" +#include "../iterator/zip_iterator.hpp" +#include "../types/tuple.hpp" + +#include + +BEGIN_ROCPRIM_NAMESPACE + +#ifndef DOXYGEN_DOCUMENTATION_BUILD // Do not document + +namespace detail +{ +template +ROCPRIM_INLINE +hipError_t adjacent_find_impl(void* const temporary_storage, + std::size_t& storage_size, + InputIterator input, + OutputIterator output, + const std::size_t size, + BinaryPred op, + const hipStream_t stream, + const bool debug_synchronous) +{ + // Data types + using input_type = typename std::iterator_traits::value_type; + using op_result_type = bool; + using index_type = std::size_t; + using wrapped_input_type = ::rocprim::tuple; + + // Operations types + using reduce_op_type = ::rocprim::minimum; + + // Use dynamic tile id + using ordered_tile_id_type = detail::ordered_block_id; + + // Kernel launch config + using config = wrapped_adjacent_find_config; + + // Calculate required temporary storage + ordered_tile_id_type::id_type* ordered_tile_id_storage; + index_type* reduce_output = nullptr; + + hipError_t result = detail::temp_storage::partition( + temporary_storage, + storage_size, + detail::temp_storage::make_linear_partition( + detail::temp_storage::make_partition(&ordered_tile_id_storage, + ordered_tile_id_type::get_temp_storage_layout()), + detail::temp_storage::ptr_aligned_array(&reduce_output, sizeof(*reduce_output)))); + + if(result != hipSuccess || temporary_storage == nullptr) + { + return result; + } + + std::chrono::steady_clock::time_point start; + if(debug_synchronous) + { + start = std::chrono::steady_clock::now(); + } + + // Launch adjacent_find::init_adjacent_find + auto ordered_tile_id = ordered_tile_id_type::create(ordered_tile_id_storage); + adjacent_find::init_adjacent_find<<<1, 1, 0, stream>>>(reduce_output, ordered_tile_id, size); + ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR( + "rocprim::detail::adjacent_find::init_adjacent_find", + size, + start); + + if(size > 1) + { + // Wrap adjacent input in zip iterator with idx values + auto iota = ::rocprim::make_counting_iterator(0); + auto wrapped_input + = ::rocprim::make_zip_iterator(::rocprim::make_tuple(input, input + 1, iota)); + + // Transform input + auto wrapped_equal_op = [op, size](const wrapped_input_type& a) -> index_type + { + if(op_result_type(op(::rocprim::get<0>(a), ::rocprim::get<1>(a)))) + { + return ::rocprim::get<2>(a); + } + return size; + }; + auto transformed_input + = ::rocprim::make_transform_iterator(wrapped_input, wrapped_equal_op); + + auto adjacent_find_block_reduce_kernel + = adjacent_find::block_reduce_kernel; + target_arch target_arch; + ROCPRIM_RETURN_ON_ERROR(host_target_arch(stream, target_arch)); + const adjacent_find_config_params params = dispatch_target_arch(target_arch); + const unsigned int block_size = params.kernel_config.block_size; + const unsigned int items_per_thread = params.kernel_config.items_per_thread; + const unsigned int items_per_block = block_size * items_per_thread; + const unsigned int grid_size = (size + items_per_block - 1) / items_per_block; + const unsigned int shared_mem_bytes = 0; /*no dynamic shared mem*/ + + // Get grid size for maximum occupancy, as we may not be able to schedule all the blocks + // at the same time + int min_grid_size = 0; + int optimal_block_size = 0; + ROCPRIM_RETURN_ON_ERROR(hipOccupancyMaxPotentialBlockSize(&min_grid_size, + &optimal_block_size, + adjacent_find_block_reduce_kernel, + shared_mem_bytes, + int(block_size))); + min_grid_size = std::min(static_cast(min_grid_size), grid_size); + + if(debug_synchronous) + { + start = std::chrono::steady_clock::now(); + } + + // Launch adjacent_find::block_reduce_kernel + adjacent_find_block_reduce_kernel<<>>( + transformed_input, + reduce_output, + size, + reduce_op_type{}, + ordered_tile_id); + ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR( + "rocprim::detail::adjacent_find::block_reduce_kernel", + size, + start); + } + + ROCPRIM_RETURN_ON_ERROR(::rocprim::transform(reduce_output, + output, + 1, + ::rocprim::identity(), + stream, + debug_synchronous)); + + return hipSuccess; +} + +} // namespace detail + +#endif // DOXYGEN_DOCUMENTATION_BUILD + +/// \addtogroup devicemodule +/// @{ + +/// \brief Searches the input sequence for the first appearance of a consecutive pair of equal elements. +/// +/// The returned index is either: the index within the input array of the first element of the first +/// pair of consecutive equal elements found or the size of the input array if no such pair is found. +/// Equivalent to the following code +/// \code{.cpp} +/// if(size > 1) +/// { +/// for(std::size_t i = 0; i < size - 1 ; ++i) +/// if (op(input[i], input[i + 1])) +/// return i; +/// } +/// return size; +/// \endcode +/// +/// \par Overview +/// * The contents of the inputs are not altered by the function. +/// * Returns the required size of `temporary_storage` in `storage_size` if `temporary_storage` is a null pointer. +/// * Accepts custom \p op. +/// * Streams in graph capture mode are supported. +/// +/// \tparam Config [optional] Configuration of the primitive, must be `default_config` or `adjacent_find_config`. +/// \tparam InputIteratorType [inferred] Random-access iterator type of the input range. Must meet the +/// requirements of a C++ InputIterator concept. It can be a simple pointer type. +/// \tparam OutputIteratorType [inferred] Random-access iterator type of the output index. Must meet the +/// requirements of a C++ OutputIterator concept. It can be a simple pointer type. +/// \tparam BinaryPred [inferred] Boolean binary operation function object that will be applied to +/// consecutive items to check whether they are equal or not. The signature of the function should be equivalent +/// to the following: +/// bool f(const T& a, const T& b). The signature does not need to have +/// const &, but the function object must not modify the object passed to it. +/// The operator must meet the C++ named requirement \p BinaryPredicate. +/// The default operation used is rocprim::equal_to, where \p T is the type of the elements +/// in the input range obtained with std::iterator_traits::value_type>. +/// +/// \param [in] temporary_storage Pointer to a device-accessible temporary storage. When +/// a null pointer is passed, the required allocation size (in bytes) is written to +/// `storage_size` and the function returns without performing any device computation. +/// \param [in,out] storage_size Reference to a size (in bytes) of `temporary_storage` +/// \param [in] input Iterator to the input range. +/// \param [out] output iterator to the output index. +/// \param [in] size Number of items in the input. +/// \param [in] op [optional] The boolean binary operation to be used by the algorithm. Default is +/// \p ::rocprim::equal_to specialized for the type of the input elements. +/// \param [in] stream [optional] HIP stream object. Default is `0` (the default stream). +/// \param [in] debug_synchronous [optional] If true, synchronization after every kernel +/// launch is forced in order to check for errors and extra debugging info is printed to the +/// standard output. Default value is `false`. +/// +/// \return `hipSuccess` (0) after a successful search, otherwise the HIP runtime error of +/// type `hipError_t`. +/// +/// \par Example +/// \parblock +/// In this example a device-level adjacent_find operation is performed on integer values. +/// +/// \code{.cpp} +/// #include //or +/// +/// // Custom boolean binary function +/// auto equal_op = [](int a, int b) -> bool { return (a - b == 2); }; +/// +/// // Prepare input and output (declare pointers, allocate device memory etc.) +/// std::size_t size; // e.g., 8 +/// int* input; // e.g., [8, 7, 5, 4, 3, 2, 1, 0] +/// std::size_t* output; // output index +/// auto custom_op = equal_op{}; +/// +/// std::size_t temporary_storage_size_bytes; +/// void* temporary_storage_ptr = nullptr; +/// +/// // Get required size of the temporary storage +/// rocprim::adjacent_find( +/// temporary_storage_ptr, temporary_storage_size_bytes, input, output, size, custom_op); +/// +/// // Allocate temporary storage +/// hipMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); +/// +/// // Perform adjacent find +/// rocprim::adjacent_find( +/// temporary_storage_ptr, temporary_storage_size_bytes, input, output, size, custom_op); +/// // output: 1 +/// \endcode +/// \endparblock +template::value_type>> +ROCPRIM_INLINE +hipError_t adjacent_find(void* const temporary_storage, + std::size_t& storage_size, + InputIterator input, + OutputIterator output, + const std::size_t size, + BinaryPred op = BinaryPred{}, + const hipStream_t stream = 0, + const bool debug_synchronous = false) +{ + return detail::adjacent_find_impl(temporary_storage, + storage_size, + input, + output, + size, + op, + stream, + debug_synchronous); +} + +/// @} +// end of group devicemodule + +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_DEVICE_DEVICE_ADJACENT_FIND_HPP_ diff --git a/rocprim/include/rocprim/device/device_adjacent_find_config.hpp b/rocprim/include/rocprim/device/device_adjacent_find_config.hpp new file mode 100644 index 000000000..5a2990849 --- /dev/null +++ b/rocprim/include/rocprim/device/device_adjacent_find_config.hpp @@ -0,0 +1,105 @@ +// Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_DEVICE_DEVICE_ADJACENT_FIND_CONFIG_HPP_ +#define ROCPRIM_DEVICE_DEVICE_ADJACENT_FIND_CONFIG_HPP_ + +#include "config_types.hpp" + +#include "detail/config/device_adjacent_find.hpp" +#include "detail/device_config_helper.hpp" + +/// \addtogroup primitivesmodule_deviceconfigs +/// @{ + +BEGIN_ROCPRIM_NAMESPACE + +namespace detail +{ + +template +struct wrapped_adjacent_find_config +{ + static_assert(std::is_same::value, + "Config must be a specialization of struct template adjacent_find_config"); + + template + struct architecture_config + { + static constexpr adjacent_find_config_params params = Config{}; + }; +}; + +// Generic for default config: instantiate base config. +template +struct wrapped_adjacent_find_impl +{ + template + struct architecture_config + { + static constexpr adjacent_find_config_params params = + typename default_adjacent_find_config_base::type{}; + }; +}; + +// Specialization for default config if types are arithmetic or half/bfloat16-precision +// floating point types: instantiate the tuned config. +template +struct wrapped_adjacent_find_impl::value>> +{ + template + struct architecture_config + { + static constexpr adjacent_find_config_params params + = default_adjacent_find_config(Arch), Type>(); + }; +}; + +// Specialization for default config. +template +struct wrapped_adjacent_find_config : wrapped_adjacent_find_impl +{}; + +#ifndef DOXYGEN_DOCUMENTATION_BUILD +template +template +constexpr adjacent_find_config_params + wrapped_adjacent_find_config::architecture_config::params; + +template +template +constexpr adjacent_find_config_params + wrapped_adjacent_find_impl::architecture_config::params; + +template +template +constexpr adjacent_find_config_params wrapped_adjacent_find_impl< + Type, + std::enable_if_t::value>>::architecture_config::params; +#endif // DOXYGEN_DOCUMENTATION_BUILD + +} // namespace detail + +END_ROCPRIM_NAMESPACE + +/// @} +// end of group primitivesmodule_deviceconfigs + +#endif // ROCPRIM_DEVICE_DEVICE_ADJACENT_FIND_CONFIG_HPP_ diff --git a/rocprim/include/rocprim/device/device_binary_search.hpp b/rocprim/include/rocprim/device/device_binary_search.hpp index 36b79bc17..312f42f48 100644 --- a/rocprim/include/rocprim/device/device_binary_search.hpp +++ b/rocprim/include/rocprim/device/device_binary_search.hpp @@ -142,8 +142,8 @@ struct is_default_or_has_tag /// \param [in] stream - [optional] HIP stream object. Default is `0` (default stream). /// \param [in] debug_synchronous - [optional] If true, synchronization after every kernel launch is /// forced in order to check for errors. -/// \return `hipSuccess` (`0)` after a successful search; otherwise a HIP runtime error of -/// type `hipError_t.` +/// \return `hipSuccess` (0) after a successful search; otherwise a HIP runtime error of +/// type `hipError_t`. /// /// \par Example /// \parblock diff --git a/rocprim/include/rocprim/device/device_find_end.hpp b/rocprim/include/rocprim/device/device_find_end.hpp new file mode 100644 index 000000000..1e35f3426 --- /dev/null +++ b/rocprim/include/rocprim/device/device_find_end.hpp @@ -0,0 +1,155 @@ +// Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_DEVICE_DEVICE_FIND_END_HPP_ +#define ROCPRIM_DEVICE_DEVICE_FIND_END_HPP_ + +#include "../config.hpp" + +#include "config_types.hpp" +#include "detail/device_search.hpp" + +#include + +#include +#include + +BEGIN_ROCPRIM_NAMESPACE + +/// \addtogroup devicemodule +/// @{ + +/// \brief Searches for the last occurrence of the sequence. +/// +/// Searches the input for the last occurence of a sequence, according to a particular +/// comparison function. If found, the index of the first item of the found sequence +/// in the input is returned. Otherwise, returns the size of the input. +/// +/// \par Overview +/// * The contents of the inputs are not altered by the function. +/// * Returns the required size of `temporary_storage` in `storage_size` +/// if `temporary_storage` is a null pointer. +/// * Accepts custom compare_functions for find_end across the device. +/// * Streams in graph capture mode are supported +/// +/// \tparam Config [optional] configuration of the primitive, must be `default_config` or `search_config`. +/// \tparam InputIterator1 [inferred] random-access iterator type of the input range. Must meet the +/// requirements of a C++ InputIterator concept. It can be a simple pointer type. +/// \tparam InputIterator2 [inferred] random-access iterator type of the input range. Must meet the +/// requirements of a C++ InputIterator concept. It can be a simple pointer type. +/// \tparam OutputIterator [inferred] random-access iterator type of the input range. Must meet the +/// requirements of a C++ OutputIterator concept. It can be a simple pointer type. +/// \tparam BinaryFunction [inferred] Type of binary function that accepts two arguments of the +/// type `InputIterator1` and returns a value convertible to bool. +/// Default type is `rocprim::less<>.` +/// +/// \param [in] temporary_storage pointer to a device-accessible temporary storage. When +/// a null pointer is passed, the required allocation size (in bytes) is written to +/// `storage_size` and function returns without performing the find_end. +/// \param [in,out] storage_size reference to a size (in bytes) of `temporary_storage`. +/// \param [in] input iterator to the input range. +/// \param [in] keys iterator to the key range. +/// \param [out] output iterator to the output range. The output is one element. +/// \param [in] size number of elements in the input range. +/// \param [in] keys_size number of elements in the key range. +/// \param [in] compare_function binary operation function object that will be used for comparison. +/// The signature of the function should be equivalent to the following: +/// bool f(const T &a, const T &b);. The signature does not need to have +/// const &, but function object must not modify the objects passed to it. +/// The comparator must meet the C++ named requirement BinaryPredicate. +/// The default value is `BinaryFunction()`. +/// \param [in] stream [optional] HIP stream object. Default is `0` (default stream). +/// \param [in] debug_synchronous [optional] If true, synchronization after every kernel +/// launch is forced in order to check for errors. Default value is `false`. +/// +/// \returns `hipSuccess` (`0`) after successful search; otherwise a HIP runtime error of +/// type `hipError_t`. +/// +/// \par Example +/// \parblock +/// In this example a device-level find_end is performed where input values are +/// represented by an array of unsigned integers and the key is also an array +/// of unsigned integers. +/// +/// \code{.cpp} +/// #include +/// +/// // Prepare input and output (declare pointers, allocate device memory etc.) +/// size_t size; // e.g., 10 +/// size_t key_size; // e.g., 3 +/// unsigned int * input; // e.g., [ 6, 3, 5, 4, 1, 8, 2, 5, 4, 1 ] +/// unsigned int * key; // e.g., [ 5, 4, 1 ] +/// unsigned int * output; // e.g., empty array of size 1 +/// +/// size_t temporary_storage_size_bytes; +/// void * temporary_storage_ptr = nullptr; +/// // Get required size of the temporary storage +/// rocprim::find_end( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// input, key, output, size, key_size +/// ); +/// +/// // allocate temporary storage +/// hipMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); +/// +/// // perform find_end +/// rocprim::find_end( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// input, key, output, size, key_size +/// ); +/// // output: [ 7 ] +/// \endcode +/// \endparblock +template::value_type>> +ROCPRIM_INLINE +hipError_t find_end(void* temporary_storage, + size_t& storage_size, + InputIterator1 input, + InputIterator2 keys, + OutputIterator output, + size_t size, + size_t keys_size, + BinaryFunction compare_function = BinaryFunction(), + hipStream_t stream = 0, + bool debug_synchronous = false) +{ + return detail::search_impl(temporary_storage, + storage_size, + input, + keys, + output, + size, + keys_size, + compare_function, + stream, + debug_synchronous); +} + +/// @} +// end of group devicemodule + +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_DEVICE_DEVICE_FIND_END_HPP_ diff --git a/rocprim/include/rocprim/device/device_find_first_of.hpp b/rocprim/include/rocprim/device/device_find_first_of.hpp index 203547004..5b04b1617 100644 --- a/rocprim/include/rocprim/device/device_find_first_of.hpp +++ b/rocprim/include/rocprim/device/device_find_first_of.hpp @@ -39,25 +39,6 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { -#define ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR(name, size, start) \ - do \ - { \ - hipError_t _error = hipGetLastError(); \ - if(_error != hipSuccess) \ - return _error; \ - if(debug_synchronous) \ - { \ - std::cout << name << "(" << size << ")"; \ - hipError_t __error = hipStreamSynchronize(stream); \ - if(__error != hipSuccess) \ - return __error; \ - auto _end = std::chrono::steady_clock::now(); \ - auto _d = std::chrono::duration_cast>(_end - start); \ - std::cout << " " << _d.count() * 1000 << " ms" << '\n'; \ - } \ - } \ - while(0) - template struct find_first_of_impl_kernels { @@ -88,7 +69,8 @@ struct find_first_of_impl_kernels constexpr unsigned int items_per_block = block_size * items_per_thread; constexpr unsigned int identity = std::numeric_limits::max(); - using type = typename std::iterator_traits::value_type; + using type = + typename std::remove_const_t::value_type>; using key_type = typename std::iterator_traits::value_type; const unsigned int thread_id = ::rocprim::detail::block_thread_id<0>(); @@ -286,8 +268,6 @@ hipError_t find_first_of_impl(void* temporary_storage, return transform(tmp_output, output, 1, ::rocprim::identity(), stream, debug_synchronous); } -#undef ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR - } // namespace detail /// \addtogroup devicemodule diff --git a/rocprim/include/rocprim/device/device_histogram.hpp b/rocprim/include/rocprim/device/device_histogram.hpp index 170c54398..c18337b05 100644 --- a/rocprim/include/rocprim/device/device_histogram.hpp +++ b/rocprim/include/rocprim/device/device_histogram.hpp @@ -27,6 +27,7 @@ #include #include "../config.hpp" +#include "../common.hpp" #include "../detail/various.hpp" #include "../functional.hpp" @@ -123,23 +124,6 @@ ROCPRIM_KERNEL __launch_bounds__( bins_bits); } -#define ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR(name, size, start) \ - { \ - auto _error = hipGetLastError(); \ - if(_error != hipSuccess) \ - return _error; \ - if(debug_synchronous) \ - { \ - std::cout << name << "(" << size << ")"; \ - auto __error = hipStreamSynchronize(stream); \ - if(__error != hipSuccess) \ - return __error; \ - auto _end = std::chrono::high_resolution_clock::now(); \ - auto _d = std::chrono::duration_cast>(_end - start); \ - std::cout << " " << _d.count() * 1000 << " ms" << '\n'; \ - } \ - } - template), dim3(::rocprim::detail::ceiling_div(max_bins, block_size)), @@ -241,7 +225,7 @@ inline hipError_t histogram_impl(void* temporary_storage, { if(debug_synchronous) { - start = std::chrono::high_resolution_clock::now(); + start = std::chrono::steady_clock::now(); } auto kernel = HIP_KERNEL_NAME(histogram_shared_kernel), @@ -436,7 +420,7 @@ inline hipError_t histogram_range_impl(void* temporary_storage, debug_synchronous); } -#undef ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR + } // namespace detail @@ -686,6 +670,13 @@ inline hipError_t histogram_even(void* temporary_storage, /// \returns \p hipSuccess (\p 0) after successful histogram operation; otherwise a HIP runtime error of /// type \p hipError_t. /// +/// \par Notes +/// * Currently the \p Channels template parameter has no strict restriction on its value. However, +/// internally a vector type of elements of type \p SampleIterator and length \p Channels is used +/// to represent the input items, so the amount of local memory available will limit the range of +/// possible values for this template parameter. +/// * \p ActiveChannels must be less or equal than \p Channels. +/// /// \par Example /// \parblock /// In this example histograms for 3 channels (RGB) are computed on an array of 8-bit RGBA samples. @@ -800,6 +791,13 @@ inline hipError_t multi_histogram_even(void* temporary_storage, /// \returns \p hipSuccess (\p 0) after successful histogram operation; otherwise a HIP runtime error of /// type \p hipError_t. /// +/// \par Notes +/// * Currently the \p Channels template parameter has no strict restriction on its value. However, +/// internally a vector type of elements of type \p SampleIterator and length \p Channels is used +/// to represent the input items, so the amount of local memory available will limit the range of +/// possible values for this template parameter. +/// * \p ActiveChannels must be less or equal than \p Channels. +/// /// \par Example /// \parblock /// In this example histograms for 3 channels (RGB) are computed on an array of 8-bit RGBA samples. @@ -1107,6 +1105,13 @@ inline hipError_t histogram_range(void* temporary_storage, /// \returns \p hipSuccess (\p 0) after successful histogram operation; otherwise a HIP runtime error of /// type \p hipError_t. /// +/// \par Notes +/// * Currently the \p Channels template parameter has no strict restriction on its value. However, +/// internally a vector type of elements of type \p SampleIterator and length \p Channels is used +/// to represent the input items, so the amount of local memory available will limit the range of +/// possible values for this template parameter. +/// * \p ActiveChannels must be less or equal than \p Channels. +/// /// \par Example /// \parblock /// In this example histograms for 3 channels (RGB) are computed on an array of 8-bit RGBA samples. @@ -1216,6 +1221,13 @@ inline hipError_t multi_histogram_range(void* temporary_storage, /// \returns \p hipSuccess (\p 0) after successful histogram operation; otherwise a HIP runtime error of /// type \p hipError_t. /// +/// \par Notes +/// * Currently the \p Channels template parameter has no strict restriction on its value. However, +/// internally a vector type of elements of type \p SampleIterator and length \p Channels is used +/// to represent the input items, so the amount of local memory available will limit the range of +/// possible values for this template parameter. +/// * \p ActiveChannels must be less or equal than \p Channels. +/// /// \par Example /// \parblock /// In this example histograms for 3 channels (RGB) are computed on an array of 8-bit RGBA samples. diff --git a/rocprim/include/rocprim/device/device_merge.hpp b/rocprim/include/rocprim/device/device_merge.hpp index 5ef72dd15..5dcb00159 100644 --- a/rocprim/include/rocprim/device/device_merge.hpp +++ b/rocprim/include/rocprim/device/device_merge.hpp @@ -26,6 +26,7 @@ #include #include "../config.hpp" +#include "../common.hpp" #include "../detail/temp_storage.hpp" #include "../detail/various.hpp" @@ -40,21 +41,20 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { -template< - class IndexIterator, - class KeysInputIterator1, - class KeysInputIterator2, - class BinaryFunction -> +template ROCPRIM_KERNEL -__launch_bounds__(ROCPRIM_DEFAULT_MAX_BLOCK_SIZE) -void partition_kernel(IndexIterator index, +__launch_bounds__(device_params().kernel_config.block_size) +void partition_kernel(IndexIterator index, KeysInputIterator1 keys_input1, KeysInputIterator2 keys_input2, - const size_t input1_size, - const size_t input2_size, + const size_t input1_size, + const size_t input2_size, const unsigned int spacing, - BinaryFunction compare_function) + BinaryFunction compare_function) { partition_kernel_impl( index, keys_input1, keys_input2, input1_size, input2_size, @@ -62,53 +62,42 @@ void partition_kernel(IndexIterator index, ); } -template< - unsigned int BlockSize, - unsigned int ItemsPerThread, - class IndexIterator, - class KeysInputIterator1, - class KeysInputIterator2, - class KeysOutputIterator, - class ValuesInputIterator1, - class ValuesInputIterator2, - class ValuesOutputIterator, - class BinaryFunction -> +template ROCPRIM_KERNEL -__launch_bounds__(BlockSize) -void merge_kernel(IndexIterator index, - KeysInputIterator1 keys_input1, - KeysInputIterator2 keys_input2, - KeysOutputIterator keys_output, +__launch_bounds__(device_params().kernel_config.block_size) +void merge_kernel(IndexIterator index, + KeysInputIterator1 keys_input1, + KeysInputIterator2 keys_input2, + KeysOutputIterator keys_output, ValuesInputIterator1 values_input1, ValuesInputIterator2 values_input2, ValuesOutputIterator values_output, - const size_t input1_size, - const size_t input2_size, - BinaryFunction compare_function) + const size_t input1_size, + const size_t input2_size, + BinaryFunction compare_function) { - merge_kernel_impl( - index, keys_input1, keys_input2, keys_output, - values_input1, values_input2, values_output, - input1_size, input2_size, compare_function - ); + static constexpr merge_config_params params = device_params(); + merge_kernel_impl( + index, + keys_input1, + keys_input2, + keys_output, + values_input1, + values_input2, + values_output, + input1_size, + input2_size, + compare_function); } -#define ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR(name, size, start) \ - { \ - auto _error = hipGetLastError(); \ - if(_error != hipSuccess) return _error; \ - if(debug_synchronous) \ - { \ - std::cout << name << "(" << size << ")"; \ - auto __error = hipStreamSynchronize(stream); \ - if(__error != hipSuccess) return __error; \ - auto _end = std::chrono::high_resolution_clock::now(); \ - auto _d = std::chrono::duration_cast>(_end - start); \ - std::cout << " " << _d.count() * 1000 << " ms" << '\n'; \ - } \ - } - template< class Config, class KeysInputIterator1, @@ -138,16 +127,20 @@ hipError_t merge_impl(void * temporary_storage, using key_type = typename std::iterator_traits::value_type; using value_type = typename std::iterator_traits::value_type; - // Get default config if Config is default_config - using config = detail::default_or_custom_config< - Config, - detail::default_merge_config - >; + using config = wrapped_merge_config; + + detail::target_arch target_arch; + hipError_t result = detail::host_target_arch(stream, target_arch); + if(result != hipSuccess) + { + return result; + } + const merge_config_params params = detail::dispatch_target_arch(target_arch); - static constexpr unsigned int block_size = config::block_size; - static constexpr unsigned int half_block = block_size / 2; - static constexpr unsigned int items_per_thread = config::items_per_thread; - static constexpr auto items_per_block = block_size * items_per_thread; + const unsigned int block_size = params.kernel_config.block_size; + const unsigned int half_block = block_size / 2; + const unsigned int items_per_thread = params.kernel_config.items_per_thread; + const auto items_per_block = block_size * items_per_thread; const unsigned int partitions = ((input1_size + input2_size) + items_per_block - 1) / items_per_block; @@ -167,7 +160,7 @@ hipError_t merge_impl(void * temporary_storage, return hipSuccess; // Start point for time measurements - std::chrono::high_resolution_clock::time_point start; + std::chrono::steady_clock::time_point start; auto number_of_blocks = partitions; if(debug_synchronous) @@ -179,29 +172,43 @@ hipError_t merge_impl(void * temporary_storage, const unsigned partition_blocks = ((partitions + 1) + half_block - 1) / half_block; - if(debug_synchronous) start = std::chrono::high_resolution_clock::now(); - hipLaunchKernelGGL( - HIP_KERNEL_NAME(detail::partition_kernel), - dim3(partition_blocks), dim3(half_block), 0, stream, - index, keys_input1, keys_input2, input1_size, input2_size, - items_per_block, compare_function - ); + if(debug_synchronous) start = std::chrono::steady_clock::now(); + hipLaunchKernelGGL(HIP_KERNEL_NAME(detail::partition_kernel), + dim3(partition_blocks), + dim3(half_block), + 0, + stream, + index, + keys_input1, + keys_input2, + input1_size, + input2_size, + items_per_block, + compare_function); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("partition_kernel", input1_size, start); - if(debug_synchronous) start = std::chrono::high_resolution_clock::now(); - hipLaunchKernelGGL( - HIP_KERNEL_NAME(detail::merge_kernel), - dim3(number_of_blocks), dim3(block_size), 0, stream, - index, keys_input1, keys_input2, keys_output, - values_input1, values_input2, values_output, - input1_size, input2_size, compare_function - ); + if(debug_synchronous) start = std::chrono::steady_clock::now(); + hipLaunchKernelGGL(HIP_KERNEL_NAME(detail::merge_kernel), + dim3(number_of_blocks), + dim3(block_size), + 0, + stream, + index, + keys_input1, + keys_input2, + keys_output, + values_input1, + values_input2, + values_output, + input1_size, + input2_size, + compare_function); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("merge_kernel", input1_size, start); return hipSuccess; } -#undef ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR + } // end of detail namespace diff --git a/rocprim/include/rocprim/device/device_merge_config.hpp b/rocprim/include/rocprim/device/device_merge_config.hpp index b1df2f44c..5bf72abe7 100644 --- a/rocprim/include/rocprim/device/device_merge_config.hpp +++ b/rocprim/include/rocprim/device/device_merge_config.hpp @@ -26,6 +26,8 @@ #include "../config.hpp" #include "../detail/various.hpp" #include "../functional.hpp" +#include "detail/config/device_merge.hpp" +#include "detail/device_config_helper.hpp" #include "config_types.hpp" @@ -34,135 +36,45 @@ BEGIN_ROCPRIM_NAMESPACE -/// \brief Configuration of device-level merge primitives. -template -using merge_config = kernel_config; - namespace detail { -template -struct merge_config_803 +// generic struct that instantiates custom configurations +template +struct wrapped_merge_config { - static constexpr unsigned int item_scale = - ::rocprim::detail::ceiling_div(::rocprim::max(sizeof(Key), sizeof(Value)), sizeof(int)); - - // TODO Tune when merge-by-key is ready - using type = merge_config::value, - ::rocprim::max(1u, 10u / item_scale)>; + template + struct architecture_config + { + static constexpr merge_config_params params = Config(); + }; }; -template -struct merge_config_803 +// specialized for rocprim::default_config, which instantiates the default_ALGO_config +template +struct wrapped_merge_config { - static constexpr unsigned int item_scale = - ::rocprim::detail::ceiling_div(sizeof(Key), sizeof(int)); - - using type - = select_type>, - select_type_case>, - select_type_case>, - merge_config::value, - ::rocprim::max(1u, 10u / item_scale)>>; + template + struct architecture_config + { + static constexpr merge_config_params params + = default_merge_config(Arch), KeyType, ValueType>{}; + }; }; -template -struct merge_config_900 -{ - static constexpr unsigned int item_scale = - ::rocprim::detail::ceiling_div(::rocprim::max(sizeof(Key), sizeof(Value)), sizeof(int)); - - // TODO Tune when merge-by-key is ready - using type = merge_config::value, - ::rocprim::max(1u, 10u / item_scale)>; -}; +#ifndef DOXYGEN_DOCUMENTATION_BUILD +template +template +constexpr merge_config_params + wrapped_merge_config::architecture_config::params; -template -struct merge_config_900 -{ - static constexpr unsigned int item_scale = - ::rocprim::detail::ceiling_div(sizeof(Key), sizeof(int)); - - using type - = select_type>, - select_type_case>, - select_type_case>, - merge_config::value, - ::rocprim::max(1u, 10u / item_scale)>>; -}; - -// TODO: We need to update these parameters template -struct merge_config_90a -{ - static constexpr unsigned int item_scale = - ::rocprim::detail::ceiling_div(::rocprim::max(sizeof(Key), sizeof(Value)), sizeof(int)); - - // TODO Tune when merge-by-key is ready - using type = merge_config::value, - ::rocprim::max(1u, 10u / item_scale)>; -}; - -template -struct merge_config_90a -{ - static constexpr unsigned int item_scale = - ::rocprim::detail::ceiling_div(sizeof(Key), sizeof(int)); - - using type - = select_type>, - select_type_case>, - select_type_case>, - merge_config::value, - ::rocprim::max(1u, 10u / item_scale)>>; -}; - -// TODO: We need to update these parameters -template -struct merge_config_1030 -{ - static constexpr unsigned int item_scale = - ::rocprim::detail::ceiling_div(::rocprim::max(sizeof(Key), sizeof(Value)), sizeof(int)); - - // TODO Tune when merge-by-key is ready - using type = merge_config::value, - ::rocprim::max(1u, 10u / item_scale)>; -}; - -template -struct merge_config_1030 -{ - static constexpr unsigned int item_scale = - ::rocprim::detail::ceiling_div(sizeof(Key), sizeof(int)); - - using type - = select_type>, - select_type_case>, - select_type_case>, - merge_config::value, - ::rocprim::max(1u, 10u / item_scale)>>; -}; - -template -struct default_merge_config - : select_arch< - TargetArch, - select_arch_case<803, merge_config_803>, - select_arch_case<900, merge_config_900>, - select_arch_case>, - select_arch_case<1030, merge_config_1030>, - merge_config_900 - > { }; +template +constexpr merge_config_params + wrapped_merge_config::architecture_config::params; +#endif // DOXYGEN_DOCUMENTATION_BUILD -} // end namespace detail +} // namespace detail END_ROCPRIM_NAMESPACE diff --git a/rocprim/include/rocprim/device/device_merge_sort.hpp b/rocprim/include/rocprim/device/device_merge_sort.hpp index 8ab10499d..872d4027a 100644 --- a/rocprim/include/rocprim/device/device_merge_sort.hpp +++ b/rocprim/include/rocprim/device/device_merge_sort.hpp @@ -26,6 +26,7 @@ #include #include "../config.hpp" +#include "../common.hpp" #include "../detail/temp_storage.hpp" #include "../detail/various.hpp" @@ -136,21 +137,6 @@ void device_block_merge_mergepath_kernel(KeysInputIterator keys_input, merge_partitions); } -#define ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR(name, size, start) \ - { \ - auto _error = hipGetLastError(); \ - if(_error != hipSuccess) return _error; \ - if(debug_synchronous) \ - { \ - std::cout << name << "(" << size << ")"; \ - auto __error = hipStreamSynchronize(stream); \ - if(__error != hipSuccess) return __error; \ - auto _end = std::chrono::high_resolution_clock::now(); \ - auto _d = std::chrono::duration_cast>(_end - start); \ - std::cout << " " << _d.count() * 1000 << " ms" << '\n'; \ - } \ - } - template ROCPRIM_KERNEL __launch_bounds__( device_params() @@ -323,7 +309,7 @@ inline hipError_t merge_sort_block_merge( } // Start point for time measurements - std::chrono::high_resolution_clock::time_point start; + std::chrono::steady_clock::time_point start; bool temporary_store = true; for(OffsetT block = sorted_block_size; block < size; block *= 2) @@ -338,7 +324,7 @@ inline hipError_t merge_sort_block_merge( if(use_mergepath && block >= merge_mergepath_items_per_block) { if(debug_synchronous) - start = std::chrono::high_resolution_clock::now(); + start = std::chrono::steady_clock::now(); hipLaunchKernelGGL( HIP_KERNEL_NAME(device_block_merge_mergepath_partition_kernel), dim3(merge_partition_number_of_blocks), @@ -357,7 +343,7 @@ inline hipError_t merge_sort_block_merge( start); if(debug_synchronous) - start = std::chrono::high_resolution_clock::now(); + start = std::chrono::steady_clock::now(); hipLaunchKernelGGL(HIP_KERNEL_NAME(device_block_merge_mergepath_kernel), calculate_grid_dim(merge_mergepath_number_of_blocks, merge_mergepath_block_size), @@ -380,7 +366,7 @@ inline hipError_t merge_sort_block_merge( else { if(debug_synchronous) - start = std::chrono::high_resolution_clock::now(); + start = std::chrono::steady_clock::now(); // As this kernel is only called with small sizes, it is safe to use 32-bit integers // for size and block. hipLaunchKernelGGL(HIP_KERNEL_NAME(device_block_merge_oddeven_kernel), @@ -397,7 +383,7 @@ inline hipError_t merge_sort_block_merge( compare_function); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("device_block_merge_oddeven_kernel", size, - start) + start); } return hipSuccess; }; @@ -483,9 +469,9 @@ inline hipError_t merge_sort_block_sort(KeysInputIterator keys_input, } // Start point for time measurements - std::chrono::high_resolution_clock::time_point start; + std::chrono::steady_clock::time_point start; if(debug_synchronous) - start = std::chrono::high_resolution_clock::now(); + start = std::chrono::steady_clock::now(); hipLaunchKernelGGL( HIP_KERNEL_NAME(block_sort_kernel), @@ -636,7 +622,7 @@ inline hipError_t merge_sort_impl( return hipSuccess; } -#undef ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR + } // end of detail namespace diff --git a/rocprim/include/rocprim/device/device_partial_sort.hpp b/rocprim/include/rocprim/device/device_partial_sort.hpp index d0d0dc4e0..857a8c4a6 100644 --- a/rocprim/include/rocprim/device/device_partial_sort.hpp +++ b/rocprim/include/rocprim/device/device_partial_sort.hpp @@ -31,6 +31,7 @@ #include "device_merge_sort.hpp" #include "device_nth_element.hpp" #include "device_partial_sort_config.hpp" +#include "device_radix_sort.hpp" #include "device_transform.hpp" #include @@ -38,6 +39,7 @@ #include #include +#include BEGIN_ROCPRIM_NAMESPACE @@ -47,16 +49,108 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { -#define RETURN_ON_ERROR(...) \ - do \ - { \ - hipError_t error = (__VA_ARGS__); \ - if(error != hipSuccess) \ - { \ - return error; \ - } \ - } \ - while(0) +template +struct radix_sort_condition_checker +{ + using key_type = typename std::iterator_traits::value_type; + + static constexpr bool is_custom_decomposer + = !std::is_same::value; + static constexpr bool descending + = std::is_same>::value; + static constexpr bool ascending = std::is_same>::value; + static constexpr bool is_radix_key_fundamental = detail::radix_key_fundamental::value; + static constexpr bool use_radix_sort + = (is_radix_key_fundamental || is_custom_decomposer) && (descending || ascending); +}; + +// Primary template for SortImpl +template> +struct SortImpl +{ + static ROCPRIM_INLINE + hipError_t algo_sort(void* temporary_storage, + size_t& storage_size, + KeysInputIterator keys_input, + KeysOutputIterator keys_output, + const size_t size, + BinaryFunction compare_function, + const hipStream_t stream, + bool debug_synchronous, + typename std::iterator_traits::value_type* keys_buffer, + Decomposer /*decomposer*/) + { + // Merge sort implementation + return detail::merge_sort_impl( + temporary_storage, + storage_size, + keys_input, + keys_output, + static_cast(nullptr), + static_cast(nullptr), + size, + compare_function, + stream, + debug_synchronous, + keys_buffer); + } +}; + +// Specialization for radix sort +template +struct SortImpl +{ + static ROCPRIM_INLINE + hipError_t + algo_sort(void* temporary_storage, + size_t& storage_size, + KeysInputIterator keys_input, + KeysOutputIterator keys_output, + const size_t size, + BinaryFunction /*compare_function*/, + const hipStream_t stream, + bool debug_synchronous, + typename std::iterator_traits::value_type* /*keys_buffer*/, + Decomposer decomposer = {}) + { + // Radix sort implementation + bool ignored; + return detail::radix_sort_impl( + temporary_storage, + storage_size, + keys_input, + nullptr, + keys_output, + static_cast(nullptr), + static_cast(nullptr), + static_cast(nullptr), + size, + ignored, + decomposer, + 0, + sizeof(typename std::iterator_traits::value_type) * 8, + stream, + debug_synchronous); + } +}; template struct partial_sort_nth_element_helper @@ -65,42 +159,49 @@ struct partial_sort_nth_element_helper class KeysInputIterator, class KeysInputIteratorNthElement, class KeysOutputIterator, - class BinaryFunction> + class BinaryFunction, + class Decomposer = rocprim::identity_decomposer> ROCPRIM_INLINE hipError_t - merge_sort_impl(void* temporary_storage, - size_t& storage_size, - KeysInputIterator keys_input, - KeysInputIteratorNthElement keys_input_nth_element, - KeysOutputIterator keys_output, - const size_t size, - BinaryFunction compare_function, - const hipStream_t stream, - bool debug_synchronous, - typename std::iterator_traits::value_type* keys_buffer - = nullptr) + algo_sort_impl(void* temporary_storage, + size_t& storage_size, + KeysInputIterator keys_input, + KeysInputIteratorNthElement /*keys_input_nth_element*/, + KeysOutputIterator keys_output, + const size_t size, + BinaryFunction compare_function, + const hipStream_t stream, + bool debug_synchronous, + typename std::iterator_traits::value_type* keys_buffer + = nullptr, + Decomposer decomposer = {}) { - (void)keys_input_nth_element; - return detail::merge_sort_impl(temporary_storage, - storage_size, - keys_input, - keys_output, - static_cast(nullptr), - static_cast(nullptr), - size, - compare_function, - stream, - debug_synchronous, - keys_buffer); + using checker = radix_sort_condition_checker; + return SortImpl::algo_sort(temporary_storage, + storage_size, + keys_input, + keys_output, + size, + compare_function, + stream, + debug_synchronous, + keys_buffer, + decomposer); } template ROCPRIM_INLINE hipError_t nth_element_impl( - void* temporary_storage, - size_t& storage_size, - KeysIterator keys, - KeysIteratorNthElement keys_nth_element, + void* temporary_storage, + size_t& storage_size, + KeysIterator keys, + KeysIteratorNthElement /*keys_nth_element*/, size_t nth, size_t size, BinaryFunction compare_function, @@ -108,7 +209,6 @@ struct partial_sort_nth_element_helper bool debug_synchronous, typename std::iterator_traits::value_type* keys_double_buffer) { - (void)keys_nth_element; return detail::nth_element_impl(temporary_storage, storage_size, keys, @@ -124,45 +224,51 @@ struct partial_sort_nth_element_helper template<> struct partial_sort_nth_element_helper { - template + class BinaryFunction, + class Decomposer = rocprim::identity_decomposer> ROCPRIM_INLINE hipError_t - merge_sort_impl(void* temporary_storage, - size_t& storage_size, - KeysInputIterator keys_input, - KeysInputIteratorNthElement keys_input_nth_element, - KeysOutputIterator keys_output, - const size_t size, - BinaryFunction compare_function, - const hipStream_t stream, - bool debug_synchronous, - typename std::iterator_traits::value_type* keys_buffer - = nullptr) + algo_sort_impl(void* temporary_storage, + size_t& storage_size, + KeysInputIterator /*keys_input*/, + KeysInputIteratorNthElement keys_input_nth_element, + KeysOutputIterator keys_output, + const size_t size, + BinaryFunction compare_function, + const hipStream_t stream, + bool debug_synchronous, + typename std::iterator_traits::value_type* keys_buffer + = nullptr, + Decomposer decomposer = {}) { - (void)keys_input; - return detail::merge_sort_impl(temporary_storage, + using checker = radix_sort_condition_checker; + return SortImpl::algo_sort(temporary_storage, storage_size, keys_input_nth_element, keys_output, - static_cast(nullptr), - static_cast(nullptr), size, compare_function, stream, debug_synchronous, - keys_buffer); + keys_buffer, + decomposer); } template ROCPRIM_INLINE hipError_t nth_element_impl( - void* temporary_storage, - size_t& storage_size, - KeysIterator keys, + void* temporary_storage, + size_t& storage_size, + KeysIterator /*keys*/, KeysIteratorNthElement keys_nth_element, size_t nth, size_t size, @@ -171,7 +277,6 @@ struct partial_sort_nth_element_helper bool debug_synchronous, typename std::iterator_traits::value_type* keys_double_buffer) { - (void)keys; return detail::nth_element_impl(temporary_storage, storage_size, keys_nth_element, @@ -188,7 +293,8 @@ template + class Decomposer = rocprim::identity_decomposer, + bool inplace = true> hipError_t partial_sort_impl(void* temporary_storage, size_t& storage_size, KeysInputIterator keys_in, @@ -197,12 +303,12 @@ hipError_t partial_sort_impl(void* temporary_storage, size_t size, BinaryFunction compare_function, hipStream_t stream, - bool debug_synchronous) + bool debug_synchronous, + Decomposer decomposer = {}) { using key_type = typename std::iterator_traits::value_type; using input_reference_type = typename std::iterator_traits::reference; using config = default_or_custom_config>; - using config_merge_sort = typename config::merge_sort; using config_nth_element = typename config::nth_element; static_assert(!std::is_const>::value || !inplace, @@ -220,14 +326,14 @@ hipError_t partial_sort_impl(void* temporary_storage, key_type* keys_buffer_placeholder = reinterpret_cast(1); void* temporary_storage_nth_element = nullptr; - void* temporary_storage_merge_sort = nullptr; + void* temporary_storage_algo_sort = nullptr; key_type* keys_buffer = nullptr; key_type* keys_output_nth_element = nullptr; const bool full_sort = middle + 1 == size; if(!full_sort) { - RETURN_ON_ERROR( + ROCPRIM_RETURN_ON_ERROR( helper.template nth_element_impl(nullptr, storage_size_nth_element, keys_in, @@ -239,19 +345,20 @@ hipError_t partial_sort_impl(void* temporary_storage, debug_synchronous, keys_buffer_placeholder)); } - size_t storage_size_merge_sort{}; - - RETURN_ON_ERROR(helper.template merge_sort_impl( - nullptr, - storage_size_merge_sort, - keys_in, - keys_output_nth_element, - keys_out, - (!inplace || full_sort) ? middle + 1 : middle, - compare_function, - stream, - debug_synchronous, - keys_buffer_placeholder)); // keys_buffer + size_t storage_size_algo_sort{}; + + ROCPRIM_RETURN_ON_ERROR( + helper.template algo_sort_impl(nullptr, + storage_size_algo_sort, + keys_in, + keys_output_nth_element, + keys_out, + (!inplace || full_sort) ? middle + 1 : middle, + compare_function, + stream, + debug_synchronous, + keys_buffer_placeholder, // keys_buffer + decomposer)); const hipError_t partition_result = temp_storage::partition( temporary_storage, @@ -260,7 +367,7 @@ hipError_t partial_sort_impl(void* temporary_storage, temp_storage::ptr_aligned_array(&keys_buffer, size), temp_storage::ptr_aligned_array(&keys_output_nth_element, inplace ? 0 : size), temp_storage::make_partition(&temporary_storage_nth_element, storage_size_nth_element), - temp_storage::make_partition(&temporary_storage_merge_sort, storage_size_merge_sort))); + temp_storage::make_partition(&temporary_storage_algo_sort, storage_size_algo_sort))); if(partition_result != hipSuccess || temporary_storage == nullptr) { @@ -274,17 +381,17 @@ hipError_t partial_sort_impl(void* temporary_storage, if(!inplace) { - RETURN_ON_ERROR(transform(keys_in, - keys_output_nth_element, - size, - rocprim::identity(), - stream, - debug_synchronous)); + ROCPRIM_RETURN_ON_ERROR(transform(keys_in, + keys_output_nth_element, + size, + rocprim::identity(), + stream, + debug_synchronous)); } if(!full_sort) { - RETURN_ON_ERROR( + ROCPRIM_RETURN_ON_ERROR( helper.template nth_element_impl(temporary_storage_nth_element, storage_size_nth_element, keys_in, @@ -301,27 +408,27 @@ hipError_t partial_sort_impl(void* temporary_storage, { if(!inplace) { - RETURN_ON_ERROR(transform(keys_output_nth_element, - keys_out, - 1, - rocprim::identity(), - stream, - debug_synchronous)); + ROCPRIM_RETURN_ON_ERROR(transform(keys_output_nth_element, + keys_out, + 1, + rocprim::identity(), + stream, + debug_synchronous)); } return hipSuccess; } - return helper.template merge_sort_impl(temporary_storage_merge_sort, - storage_size_merge_sort, - keys_in, - keys_output_nth_element, - keys_out, - (!inplace || full_sort) ? middle + 1 - : middle, - compare_function, - stream, - debug_synchronous, - keys_buffer); // keys_buffer + return helper.template algo_sort_impl(temporary_storage_algo_sort, + storage_size_algo_sort, + keys_in, + keys_output_nth_element, + keys_out, + (!inplace || full_sort) ? middle + 1 : middle, + compare_function, + stream, + debug_synchronous, + keys_buffer, // keys_buffer + decomposer); } } // namespace detail @@ -334,6 +441,8 @@ hipError_t partial_sort_impl(void* temporary_storage, /// if `temporary_storage` is a null pointer. /// * Accepts custom compare_functions for partial_sort_copy across the device. /// * Streams in graph capture mode are not supported +/// * When possible, partial_sort_copy will use radix_sort as the sorting algorithm. If radix sort is not applicable, it will fall back to merge_sort. +/// If a custom decomposer is provided, partial_sort_copy will use radix_sort. /// /// \par Stability /// \p partial_sort_copy is not stable: it doesn't necessarily preserve the relative ordering @@ -351,6 +460,7 @@ hipError_t partial_sort_impl(void* temporary_storage, /// requirements of a C++ InputIterator concept. It can be a simple pointer type. /// \tparam CompareFunction [inferred] Type of binary function that accepts two arguments of the /// type `KeysIterator` and returns a value convertible to bool. Default type is `::rocprim::less<>.` +/// \tparam Decomposer The type of the decomposer functor. Default is ::rocprim::identity_decomposer. /// /// \param [in] temporary_storage pointer to a device-accessible temporary storage. When /// a null pointer is passed, the required allocation size (in bytes) is written to @@ -370,6 +480,8 @@ hipError_t partial_sort_impl(void* temporary_storage, /// \param [in] stream [optional] HIP stream object. Default is `0` (default stream). /// \param [in] debug_synchronous [optional] If true, synchronization after every kernel /// launch is forced in order to check for errors. Default value is `false`. +/// \param [in] decomposer decomposer functor that produces a tuple of references from the +/// input key type. /// /// \returns `hipSuccess` (`0`) after successful rearrangement; otherwise a HIP runtime error of /// type `hipError_t`. @@ -411,7 +523,8 @@ template::value_type>> + = ::rocprim::less::value_type>, + class Decomposer = ::rocprim::identity_decomposer> hipError_t partial_sort_copy(void* temporary_storage, size_t& storage_size, KeysInputIterator keys_input, @@ -420,7 +533,8 @@ hipError_t partial_sort_copy(void* temporary_storage, size_t size, BinaryFunction compare_function = BinaryFunction(), hipStream_t stream = 0, - bool debug_synchronous = false) + bool debug_synchronous = false, + Decomposer decomposer = Decomposer()) { using key_type = typename std::iterator_traits::value_type; static_assert( @@ -428,17 +542,21 @@ hipError_t partial_sort_copy(void* temporary_storage, typename std::iterator_traits::value_type>::value, "KeysInputIterator and KeysOutputIterator must have the same value_type"); - return detail:: - partial_sort_impl( - temporary_storage, - storage_size, - keys_input, - keys_output, - middle, - size, - compare_function, - stream, - debug_synchronous); + return detail::partial_sort_impl(temporary_storage, + storage_size, + keys_input, + keys_output, + middle, + size, + compare_function, + stream, + debug_synchronous, + decomposer); } /// \brief Rearranges elements such that the range [0, middle) contains the sorted middle smallest elements in the range [0, size). @@ -449,6 +567,8 @@ hipError_t partial_sort_copy(void* temporary_storage, /// if `temporary_storage` is a null pointer. /// * Accepts custom compare_functions for partial_sort across the device. /// * Streams in graph capture mode are not supported +/// * When possible, partial_sort will use radix_sort as the sorting algorithm. If radix sort is not applicable, it will fall back to merge_sort. +/// If a custom decomposer is provided, partial_sort will use radix_sort. /// /// \par Stability /// \p partial_sort is not stable: it doesn't necessarily preserve the relative ordering @@ -464,6 +584,7 @@ hipError_t partial_sort_copy(void* temporary_storage, /// requirements of a C++ InputIterator concept. It can be a simple pointer type. /// \tparam CompareFunction [inferred] Type of binary function that accepts two arguments of the /// type `KeysIterator` and returns a value convertible to bool. Default type is `::rocprim::less<>.` +/// \tparam Decomposer The type of the decomposer functor. Default is ::rocprim::identity_decomposer. /// /// \param [in] temporary_storage pointer to a device-accessible temporary storage. When /// a null pointer is passed, the required allocation size (in bytes) is written to @@ -481,6 +602,8 @@ hipError_t partial_sort_copy(void* temporary_storage, /// \param [in] stream [optional] HIP stream object. Default is `0` (default stream). /// \param [in] debug_synchronous [optional] If true, synchronization after every kernel /// launch is forced in order to check for errors. Default value is `false`. +/// \param [in] decomposer decomposer functor that produces a tuple of references from the +/// input key type. /// /// \returns `hipSuccess` (`0`) after successful rearrangement; otherwise a HIP runtime error of /// type `hipError_t`. @@ -520,7 +643,8 @@ hipError_t partial_sort_copy(void* temporary_storage, template::value_type>> + = ::rocprim::less::value_type>, + class Decomposer = ::rocprim::identity_decomposer> hipError_t partial_sort(void* temporary_storage, size_t& storage_size, KeysIterator keys, @@ -528,24 +652,26 @@ hipError_t partial_sort(void* temporary_storage, size_t size, BinaryFunction compare_function = BinaryFunction(), hipStream_t stream = 0, - bool debug_synchronous = false) + bool debug_synchronous = false, + Decomposer decomposer = {}) { - return detail::partial_sort_impl(temporary_storage, - storage_size, - keys, - keys, - middle, - size, - compare_function, - stream, - debug_synchronous); + return detail:: + partial_sort_impl( + temporary_storage, + storage_size, + keys, + keys, + middle, + size, + compare_function, + stream, + debug_synchronous, + decomposer); } /// @} // end of group devicemodule -#undef RETURN_ON_ERROR - END_ROCPRIM_NAMESPACE #endif // ROCPRIM_DEVICE_DEVICE_PARTIAL_SORT_HPP_ diff --git a/rocprim/include/rocprim/device/device_partial_sort_config.hpp b/rocprim/include/rocprim/device/device_partial_sort_config.hpp index 29aa4eb56..4230f6eb0 100644 --- a/rocprim/include/rocprim/device/device_partial_sort_config.hpp +++ b/rocprim/include/rocprim/device/device_partial_sort_config.hpp @@ -36,20 +36,27 @@ BEGIN_ROCPRIM_NAMESPACE /// Must be \p nth_element_config or \p default_config. /// \tparam MergeSortConfig - configuration of device-level merge sort operation. /// Must be \p merge_sort_config or \p default_config. -template +/// \tparam RadixSortConfig - configuration of device-level radix sort operation. +/// Must be \p radix_sort_config or \p default_config. +template struct partial_sort_config { /// \brief Configuration of device-level nth element operation. using nth_element = NthElementConfig; /// \brief Configuration of device-level merge sort operation. using merge_sort = MergeSortConfig; + /// \brief Configuration of device-level radix sort operation. + using radix_sort = RadixSortConfig; }; namespace detail { template -using default_partial_sort_config = partial_sort_config; +using default_partial_sort_config + = partial_sort_config; } // end namespace detail diff --git a/rocprim/include/rocprim/device/device_partition.hpp b/rocprim/include/rocprim/device/device_partition.hpp index 70acb4a9b..b2494c2cc 100644 --- a/rocprim/include/rocprim/device/device_partition.hpp +++ b/rocprim/include/rocprim/device/device_partition.hpp @@ -27,6 +27,7 @@ #include #include "../config.hpp" +#include "../common.hpp" #include "../detail/temp_storage.hpp" #include "../detail/various.hpp" #include "../functional.hpp" @@ -95,25 +96,11 @@ ROCPRIM_KERNEL std::cout << name << "(" << size << ")"; \ auto error = hipStreamSynchronize(stream); \ if(error != hipSuccess) return error; \ - auto end = std::chrono::high_resolution_clock::now(); \ + auto end = std::chrono::steady_clock::now(); \ auto d = std::chrono::duration_cast>(end - start); \ std::cout << " " << d.count() * 1000 << " ms" << '\n'; \ } -#define ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR(name, size, start) \ - { \ - auto _error = hipGetLastError(); \ - if(_error != hipSuccess) return _error; \ - if(debug_synchronous) \ - { \ - std::cout << name << "(" << size << ")"; \ - auto __error = hipStreamSynchronize(stream); \ - if(__error != hipSuccess) return __error; \ - auto _end = std::chrono::high_resolution_clock::now(); \ - auto _d = std::chrono::duration_cast>(_end - start); \ - std::cout << " " << _d.count() * 1000 << " ms" << '\n'; \ - } \ - } template +struct wrapped_partition_config +{ + template + struct architecture_config + { + static constexpr partition_config_params params + = default_select_predicated_flag_config(Arch), + KeyType, + ValueType>{}; + }; +}; + template struct wrapped_partition_config::architecture_config::params; +template +template +constexpr partition_config_params + wrapped_partition_config::architecture_config::params; + template template constexpr partition_config_params diff --git a/rocprim/include/rocprim/device/device_radix_sort.hpp b/rocprim/include/rocprim/device/device_radix_sort.hpp index d57e89d2f..d7fbc92d2 100644 --- a/rocprim/include/rocprim/device/device_radix_sort.hpp +++ b/rocprim/include/rocprim/device/device_radix_sort.hpp @@ -27,6 +27,7 @@ #include #include "../config.hpp" +#include "../common.hpp" #include "../detail/temp_storage.hpp" #include "../detail/various.hpp" @@ -49,25 +50,6 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { -#ifndef ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR - -#define ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR(name, size, start) \ - { \ - auto _error = hipGetLastError(); \ - if(_error != hipSuccess) return _error; \ - if(debug_synchronous) \ - { \ - std::cout << name << "(" << size << ")"; \ - auto __error = hipStreamSynchronize(stream); \ - if(__error != hipSuccess) return __error; \ - auto _end = std::chrono::high_resolution_clock::now(); \ - auto _d = std::chrono::duration_cast>(_end - start); \ - std::cout << " " << _d.count() * 1000 << " ms" << '\n'; \ - } \ - } - -#endif - template constexpr auto tuple_bit_size_impl() -> std::enable_if_t::value, size_t> @@ -177,13 +159,13 @@ hipError_t radix_sort_onesweep_global_offsets(KeysInputIterator keys_input, if(error != hipSuccess) return error; - std::chrono::high_resolution_clock::time_point start; + std::chrono::steady_clock::time_point start; if(debug_synchronous) { std::cout << "blocks " << blocks << '\n'; std::cout << "full_blocks " << full_blocks << '\n'; - start = std::chrono::high_resolution_clock::now(); + start = std::chrono::steady_clock::now(); } // Compute a histogram for each digit. @@ -204,7 +186,7 @@ hipError_t radix_sort_onesweep_global_offsets(KeysInputIterator keys_input, // Scan each histogram separately to get the final offsets. if(debug_synchronous) { - start = std::chrono::high_resolution_clock::now(); + start = std::chrono::steady_clock::now(); } hipLaunchKernelGGL(HIP_KERNEL_NAME(onesweep_scan_histograms_kernel), @@ -332,7 +314,7 @@ hipError_t radix_sort_onesweep_iteration( if(error != hipSuccess) return error; - std::chrono::high_resolution_clock::time_point start; + std::chrono::steady_clock::time_point start; if(debug_synchronous) { std::cout << "radix_bits " << params.radix_bits_per_place << '\n'; @@ -347,7 +329,7 @@ hipError_t radix_sort_onesweep_iteration( std::cout << "offset " << offset << '\n'; std::cout << "blocks " << blocks << '\n'; std::cout << "full_blocks " << full_blocks << '\n'; - start = std::chrono::high_resolution_clock::now(); + start = std::chrono::steady_clock::now(); } if(from_input && to_output) @@ -765,7 +747,7 @@ hipError_t } } -#undef ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR + } // end namespace detail diff --git a/rocprim/include/rocprim/device/device_reduce.hpp b/rocprim/include/rocprim/device/device_reduce.hpp index b5fc28491..6703e29ee 100644 --- a/rocprim/include/rocprim/device/device_reduce.hpp +++ b/rocprim/include/rocprim/device/device_reduce.hpp @@ -29,6 +29,7 @@ #include "config_types.hpp" #include "../config.hpp" +#include "../common.hpp" #include "../detail/temp_storage.hpp" #include "../detail/various.hpp" @@ -80,32 +81,6 @@ void block_reduce_kernel(InputIterator input, std::cout << " " << _d.count() * 1000 << " ms" << '\n'; \ } -#define ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR(name, size, start) \ - { \ - auto _error = hipGetLastError(); \ - if(_error != hipSuccess) return _error; \ - if(debug_synchronous) \ - { \ - std::cout << name << "(" << size << ")"; \ - auto __error = hipStreamSynchronize(stream); \ - if(__error != hipSuccess) return __error; \ - auto _end = std::chrono::steady_clock::now(); \ - auto _d = std::chrono::duration_cast>(_end - start); \ - std::cout << " " << _d.count() * 1000 << " ms" << '\n'; \ - } \ - } - -#define ROCPRIM_RETURN_ON_ERROR(...) \ - do \ - { \ - hipError_t error = (__VA_ARGS__); \ - if(error != hipSuccess) \ - { \ - return error; \ - } \ - } \ - while(0) - #define SINGLE_REDUCE_KERNEL(fit_larger, fit_items) \ do \ { \ @@ -313,8 +288,6 @@ hipError_t reduce_impl(void * temporary_storage, } #undef SINGLE_REDUCE_KERNEL -#undef ROCPRIM_RETURN_ON_ERROR -#undef ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR #undef ROCPRIM_DETAIL_HIP_SYNC } // namespace detail diff --git a/rocprim/include/rocprim/device/device_reduce_by_key.hpp b/rocprim/include/rocprim/device/device_reduce_by_key.hpp index eba9d2c38..64b577b48 100644 --- a/rocprim/include/rocprim/device/device_reduce_by_key.hpp +++ b/rocprim/include/rocprim/device/device_reduce_by_key.hpp @@ -31,6 +31,7 @@ #include "detail/lookback_scan_state.hpp" #include "../config.hpp" +#include "../common.hpp" #include "../detail/temp_storage.hpp" #include "../detail/various.hpp" #include "../functional.hpp" @@ -137,25 +138,6 @@ ROCPRIM_KERNEL number_of_tiles_launch); } -#define ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR(name, size, start) \ - do \ - { \ - auto _error = hipGetLastError(); \ - if(_error != hipSuccess) \ - return _error; \ - if(debug_synchronous) \ - { \ - std::cout << name << "(" << size << ")"; \ - auto __error = hipStreamSynchronize(stream); \ - if(__error != hipSuccess) \ - return __error; \ - auto _end = std::chrono::high_resolution_clock::now(); \ - auto _d = std::chrono::duration_cast>(_end - start); \ - std::cout << " " << _d.count() * 1000 << " ms" << '\n'; \ - } \ - } \ - while(false) - template #include "../config.hpp" +#include "../common.hpp" #include "../detail/various.hpp" #include "../iterator/constant_iterator.hpp" @@ -43,25 +44,6 @@ BEGIN_ROCPRIM_NAMESPACE /// \addtogroup devicemodule /// @{ -namespace detail -{ - -#define ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR(name, size, start) \ - { \ - if(error != hipSuccess) return error; \ - if(debug_synchronous) \ - { \ - std::cout << name << "(" << size << ")"; \ - auto error = hipStreamSynchronize(stream); \ - if(error != hipSuccess) return error; \ - auto end = std::chrono::high_resolution_clock::now(); \ - auto d = std::chrono::duration_cast>(end - start); \ - std::cout << " " << d.count() * 1000 << " ms" << '\n'; \ - } \ - } - -} // end detail namespace - /// \brief Parallel run-length encoding for device level. /// /// run_length_encode function performs a device-wide run-length encoding of runs (groups) @@ -357,9 +339,9 @@ hipError_t run_length_encode_non_trivial_runs(void * temporary_storage, ptr += counts_tmp_bytes; all_runs_count_tmp = reinterpret_cast(ptr); - std::chrono::high_resolution_clock::time_point start; + std::chrono::steady_clock::time_point start; - if(debug_synchronous) start = std::chrono::high_resolution_clock::now(); + if(debug_synchronous) start = std::chrono::steady_clock::now(); error = ::rocprim::reduce_by_key( temporary_storage, reduce_by_key_bytes, input, @@ -376,7 +358,7 @@ hipError_t run_length_encode_non_trivial_runs(void * temporary_storage, reduce_op, ::rocprim::equal_to(), stream, debug_synchronous ); - ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("rocprim::reduce_by_key", size, start) + ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("rocprim::reduce_by_key", size, start); // Read count of all runs (including trivial runs) count_type all_runs_count; @@ -389,7 +371,7 @@ hipError_t run_length_encode_non_trivial_runs(void * temporary_storage, return error; // Select non-trivial runs - if(debug_synchronous) start = std::chrono::high_resolution_clock::now(); + if(debug_synchronous) start = std::chrono::steady_clock::now(); error = ::rocprim::select( temporary_storage, select_bytes, ::rocprim::make_zip_iterator(::rocprim::make_tuple(offsets_tmp, counts_tmp)), @@ -399,12 +381,12 @@ hipError_t run_length_encode_non_trivial_runs(void * temporary_storage, non_trivial_runs_select_op, stream, debug_synchronous ); - ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("rocprim::select", all_runs_count, start) + ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("rocprim::select", all_runs_count, start); return hipSuccess; } -#undef ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR + /// @} // end of group devicemodule diff --git a/rocprim/include/rocprim/device/device_scan.hpp b/rocprim/include/rocprim/device/device_scan.hpp index 498a3f580..0633222e5 100644 --- a/rocprim/include/rocprim/device/device_scan.hpp +++ b/rocprim/include/rocprim/device/device_scan.hpp @@ -26,6 +26,7 @@ #include #include "../config.hpp" +#include "../common.hpp" #include "../detail/temp_storage.hpp" #include "../detail/various.hpp" #include "../functional.hpp" @@ -159,26 +160,11 @@ ROCPRIM_KERNEL std::cout << name << "(" << size << ")"; \ auto error = hipStreamSynchronize(stream); \ if(error != hipSuccess) return error; \ - auto end = std::chrono::high_resolution_clock::now(); \ + auto end = std::chrono::steady_clock::now(); \ auto d = std::chrono::duration_cast>(end - start); \ std::cout << " " << d.count() * 1000 << " ms" << '\n'; \ } -#define ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR(name, size, start) \ - { \ - auto _error = hipGetLastError(); \ - if(_error != hipSuccess) return _error; \ - if(debug_synchronous) \ - { \ - std::cout << name << "(" << size << ")"; \ - auto __error = hipStreamSynchronize(stream); \ - if(__error != hipSuccess) return __error; \ - auto _end = std::chrono::high_resolution_clock::now(); \ - auto _d = std::chrono::duration_cast>(_end - start); \ - std::cout << " " << _d.count() * 1000 << " ms" << '\n'; \ - } \ - } - template 1); }); - ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("lookback_scan_kernel", current_size, start) + ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("lookback_scan_kernel", current_size, start); // Swap the last_elements if(number_of_launch > 1) @@ -384,7 +370,7 @@ inline auto scan_impl(void* temporary_storage, std::cout << "block_size " << block_size << '\n'; std::cout << "number of blocks " << number_of_blocks << '\n'; std::cout << "items_per_block " << items_per_block << '\n'; - start = std::chrono::high_resolution_clock::now(); + start = std::chrono::steady_clock::now(); } single_scan_kernel().kernel_config.block_si previous_last_value); } -#define ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR(name, size, start) \ - do \ - { \ - auto _error = hipGetLastError(); \ - if(_error != hipSuccess) \ - return _error; \ - if(debug_synchronous) \ - { \ - std::cout << name << "(" << size << ")"; \ - auto __error = hipStreamSynchronize(stream); \ - if(__error != hipSuccess) \ - return __error; \ - auto _end = std::chrono::high_resolution_clock::now(); \ - auto _d = std::chrono::duration_cast>(_end - start); \ - std::cout << " " << _d.count() * 1000 << " ms" << '\n'; \ - } \ - } while(false) template + +#include +#include + +BEGIN_ROCPRIM_NAMESPACE + +/// \addtogroup devicemodule +/// @{ + +/// \brief Searches for the first occurrence of the sequence. +/// +/// Searches the input for the first occurence of a sequence, according to a particular +/// comparison function. If found, the index of the first item of the found sequence +/// in the input is returned. Otherwise, returns the size of the input. +/// +/// \par Overview +/// * The contents of the inputs are not altered by the function. +/// * Returns the required size of `temporary_storage` in `storage_size` +/// if `temporary_storage` is a null pointer. +/// * Accepts custom compare_functions for search across the device. +/// * Streams in graph capture mode are supported +/// +/// \tparam Config [optional] configuration of the primitive, must be `default_config` or `search_config`. +/// \tparam InputIterator1 [inferred] random-access iterator type of the input range. Must meet the +/// requirements of a C++ InputIterator concept. It can be a simple pointer type. +/// \tparam InputIterator2 [inferred] random-access iterator type of the input range. Must meet the +/// requirements of a C++ InputIterator concept. It can be a simple pointer type. +/// \tparam OutputIterator [inferred] random-access iterator type of the input range. Must meet the +/// requirements of a C++ OutputIterator concept. It can be a simple pointer type. +/// \tparam BinaryFunction [inferred] Type of binary function that accepts two arguments of the +/// type `InputIterator1` and returns a value convertible to bool. +/// Default type is `rocprim::less<>.` +/// +/// \param [in] temporary_storage pointer to a device-accessible temporary storage. When +/// a null pointer is passed, the required allocation size (in bytes) is written to +/// `storage_size` and function returns without performing the search. +/// \param [in,out] storage_size reference to a size (in bytes) of `temporary_storage`. +/// \param [in] input iterator to the input range. +/// \param [in] keys iterator to the key range. +/// \param [out] output iterator to the output range. The output is one element. +/// \param [in] size number of elements in the input range. +/// \param [in] keys_size number of elements in the key range. +/// \param [in] compare_function binary operation function object that will be used for comparison. +/// The signature of the function should be equivalent to the following: +/// bool f(const T &a, const T &b);. The signature does not need to have +/// const &, but function object must not modify the objects passed to it. +/// The comparator must meet the C++ named requirement BinaryPredicate. +/// The default value is `BinaryFunction()`. +/// \param [in] stream [optional] HIP stream object. Default is `0` (default stream). +/// \param [in] debug_synchronous [optional] If true, synchronization after every kernel +/// launch is forced in order to check for errors. Default value is `false`. +/// +/// \returns `hipSuccess` (`0`) after successful search; otherwise a HIP runtime error of +/// type `hipError_t`. +/// +/// \par Example +/// \parblock +/// In this example a device-level search is performed where input values are +/// represented by an array of unsigned integers and the key is also an array +/// of unsigned integers. +/// +/// \code{.cpp} +/// #include +/// +/// // Prepare input and output (declare pointers, allocate device memory etc.) +/// size_t size; // e.g., 10 +/// size_t key_size; // e.g., 3 +/// unsigned int * input; // e.g., [ 6, 3, 5, 4, 1, 8, 2, 5, 4, 1 ] +/// unsigned int * key; // e.g., [ 5, 4, 1 ] +/// unsigned int * output; // e.g., empty array of size 1 +/// +/// size_t temporary_storage_size_bytes; +/// void * temporary_storage_ptr = nullptr; +/// // Get required size of the temporary storage +/// rocprim::search( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// input, key, output, size, key_size +/// ); +/// +/// // allocate temporary storage +/// hipMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); +/// +/// // perform search +/// rocprim::search( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// input, key, output, size, key_size +/// ); +/// // output: [ 2 ] +/// \endcode +/// \endparblock +template::value_type>> +ROCPRIM_INLINE +hipError_t search(void* temporary_storage, + size_t& storage_size, + InputIterator1 input, + InputIterator2 keys, + OutputIterator output, + size_t size, + size_t keys_size, + BinaryFunction compare_function = BinaryFunction(), + hipStream_t stream = 0, + bool debug_synchronous = false) +{ + return detail::search_impl(temporary_storage, + storage_size, + input, + keys, + output, + size, + keys_size, + compare_function, + stream, + debug_synchronous); +} + +/// @} +// end of group devicemodule + +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_DEVICE_DEVICE_SEARCH_HPP_ diff --git a/rocprim/include/rocprim/device/device_search_config.hpp b/rocprim/include/rocprim/device/device_search_config.hpp new file mode 100644 index 000000000..c6b23ea42 --- /dev/null +++ b/rocprim/include/rocprim/device/device_search_config.hpp @@ -0,0 +1,77 @@ +// Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_DEVICE_DEVICE_FIND_END_CONFIG_HPP_ +#define ROCPRIM_DEVICE_DEVICE_FIND_END_CONFIG_HPP_ + +#include "config_types.hpp" + +#include "detail/device_config_helper.hpp" + +/// \addtogroup primitivesmodule_deviceconfigs +/// @{ + +BEGIN_ROCPRIM_NAMESPACE + +namespace detail +{ + +// generic struct that instantiates custom configurations +template +struct wrapped_search_config +{ + template + struct architecture_config + { + static constexpr search_config_params params = Config{}; + }; +}; + +// specialized for rocprim::default_config, which instantiates the default_search_config +template +struct wrapped_search_config +{ + template + struct architecture_config + { + static constexpr search_config_params params = {2048, kernel_config<256, 4>()}; + }; +}; + +#ifndef DOXYGEN_SHOULD_SKIP_THIS +template +template +constexpr search_config_params + wrapped_search_config::architecture_config::params; + +template +template +constexpr search_config_params + wrapped_search_config::architecture_config::params; +#endif // DOXYGEN_SHOULD_SKIP_THIS + +} // namespace detail + +END_ROCPRIM_NAMESPACE + +/// @} +// end of group primitivesmodule_deviceconfigs + +#endif // ROCPRIM_DEVICE_DEVICE_FIND_END_CONFIG_HPP_ diff --git a/rocprim/include/rocprim/device/device_search_n.hpp b/rocprim/include/rocprim/device/device_search_n.hpp new file mode 100644 index 000000000..168b4513b --- /dev/null +++ b/rocprim/include/rocprim/device/device_search_n.hpp @@ -0,0 +1,111 @@ +// Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_DEVICE_DEVICE_SEARCH_N_HPP_ +#define ROCPRIM_DEVICE_DEVICE_SEARCH_N_HPP_ + +#include "../config.hpp" +#include "config_types.hpp" +#include "detail/device_search_n.hpp" +#include "device_search_n_config.hpp" + +#include +#include +#include + +BEGIN_ROCPRIM_NAMESPACE + +/// \addtogroup devicemodule +/// @{ + +/// \brief Searches for the first occurrence of a sequence of \p count elements all equal to \p value. +/// +/// The equality of the elements of the sequence and the given value is determined according to a +/// given comparison function. If found, the index of the first item of the found sequence +/// in the input is returned. Otherwise, returns the size of the input. +/// +/// \par Overview +/// * The contents of the inputs are not altered by the function. +/// * Returns the required size of `temporary_storage` in `storage_size` +/// if `temporary_storage` is a null pointer. +/// * Accepts custom compare_functions for search across the device. +/// * Streams in graph capture mode are supported +/// +/// \tparam Config [optional] configuration of the primitive. It must be `default_config` or `search_n_config`. +/// \tparam InputIterator [inferred] random-access iterator type of the input range. Must meet the +/// requirements of a C++ InputIterator concept. It can be a simple pointer type. +/// \tparam OutputIterator [inferred] random-access iterator type of the input range. Must meet the +/// requirements of a C++ InputIterator concept. It can be a simple pointer type. +/// \tparam BinaryPredicate [inferred] Type of binary function that accepts two arguments of +/// type `InputIterator` and returns a value convertible to bool. Default type is `rocprim::equal_to<>.` +/// +/// \param [in] temporary_storage pointer to a device-accessible temporary storage. When +/// a null pointer is passed, the required allocation size (in bytes) is written to +/// `storage_size` and function returns without performing the search. +/// \param [in,out] storage_size reference to a size (in bytes) of `temporary_storage`. +/// \param [in] input iterator to the input range. +/// \param [out] output iterator to the output range. The output is one element. +/// \param [in] size number of elements in the input range. +/// \param [in] count number of elements in the sequence. Must be less or equal than \p size, otherwise `hipErrorInvalidValue` will be returned. +/// \param [in] value value of the elements to search for. +/// \param [in] binary_predicate binary operation function object that will be used for comparison. +/// The signature of the function should be equivalent to the following: +/// bool f(const T &a, const T &b);. The signature does not need to have +/// const &, but function object must not modify the objects passed to it. +/// The comparator must meet the C++ requirements of BinaryPredicate. +/// The default value is `BinaryPredicate()`. +/// \param [in] stream [optional] HIP stream object. Default is `0` (default stream). +/// \param [in] debug_synchronous [optional] If true, synchronization after every kernel +/// launch is forced in order to check for errors. Default value is `false`. +/// +/// \returns `hipSuccess` (`0`) after successful search; otherwise a HIP runtime error of +/// type `hipError_t`. +template::value_type>> +ROCPRIM_INLINE +hipError_t search_n(void* temporary_storage, + size_t& storage_size, + InputIterator input, + OutputIterator output, + const size_t size, + const size_t count, + const typename std::iterator_traits::value_type* value, + const BinaryPredicate binary_predicate = BinaryPredicate(), + const hipStream_t stream = static_cast(0), + const bool debug_synchronous = false) +{ + return detail::search_n_impl(temporary_storage, + storage_size, + input, + output, + size, + count, + value, + binary_predicate, + stream, + debug_synchronous); +} + +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_DEVICE_DEVICE_SEARCH_N_HPP_ diff --git a/rocprim/include/rocprim/device/device_search_n_config.hpp b/rocprim/include/rocprim/device/device_search_n_config.hpp new file mode 100644 index 000000000..1aa14d039 --- /dev/null +++ b/rocprim/include/rocprim/device/device_search_n_config.hpp @@ -0,0 +1,71 @@ +// Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_DEVICE_DEVICE_SEARCH_N_CONFIG_HPP_ +#define ROCPRIM_DEVICE_DEVICE_SEARCH_N_CONFIG_HPP_ + +#include "config_types.hpp" + +#include "detail/device_config_helper.hpp" + +BEGIN_ROCPRIM_NAMESPACE + +namespace detail +{ + +// generic struct that instantiates custom configurations +template +struct wrapped_search_n_config +{ + template + struct architecture_config + { + static constexpr search_n_config_params params = Config{}; + }; +}; + +// specialized for rocprim::default_config, which instantiates the default_search_n_config +template +struct wrapped_search_n_config +{ + template + struct architecture_config + { + static constexpr search_n_config_params params = {6, kernel_config<256, 4>()}; + }; +}; + +#ifndef DOXYGEN_DOCUMENTATION_BUILD +template +template +constexpr search_n_config_params + wrapped_search_n_config::architecture_config::params; + +template +template +constexpr search_n_config_params + wrapped_search_n_config::architecture_config::params; +#endif // DOXYGEN_DOCUMENTATION_BUILD + +} // namespace detail + +END_ROCPRIM_NAMESPACE + +#endif // ROCPRIM_DEVICE_DEVICE_SEARCH_N_HPP_ diff --git a/rocprim/include/rocprim/device/device_segmented_radix_sort.hpp b/rocprim/include/rocprim/device/device_segmented_radix_sort.hpp index 1da25f003..e0c21cb7a 100644 --- a/rocprim/include/rocprim/device/device_segmented_radix_sort.hpp +++ b/rocprim/include/rocprim/device/device_segmented_radix_sort.hpp @@ -28,6 +28,7 @@ #include #include "../config.hpp" +#include "../common.hpp" #include "../detail/various.hpp" #include "config_types.hpp" @@ -211,21 +212,6 @@ ROCPRIM_KERNEL __launch_bounds__( end_bit); } -#define ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR(name, size, start) \ - { \ - auto _error = hipGetLastError(); \ - if(_error != hipSuccess) return _error; \ - if(debug_synchronous) \ - { \ - std::cout << name << "(" << size << ")"; \ - auto __error = hipStreamSynchronize(stream); \ - if(__error != hipSuccess) return __error; \ - auto _end = std::chrono::high_resolution_clock::now(); \ - auto _d = std::chrono::duration_cast>(_end - start); \ - std::cout << " " << _d.count() * 1000 << " ms" << '\n'; \ - } \ - } - struct Partitioner { bool three_way_partitioning; @@ -505,8 +491,8 @@ hipError_t segmented_radix_sort_impl(void * temporary_storage, } if(large_segment_count > 0) { - std::chrono::high_resolution_clock::time_point start; - if(debug_synchronous) start = std::chrono::high_resolution_clock::now(); + std::chrono::steady_clock::time_point start; + if(debug_synchronous) start = std::chrono::steady_clock::now(); hipLaunchKernelGGL(HIP_KERNEL_NAME(segmented_sort_large_kernel), dim3(large_segment_count), dim3(params.kernel_config.block_size), @@ -527,15 +513,15 @@ hipError_t segmented_radix_sort_impl(void * temporary_storage, end_bit); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("segmented_sort:large_segments", large_segment_count, - start) + start); } if(three_way_partitioning && medium_segment_count > 0) { const auto medium_segment_grid_size = ::rocprim::detail::ceiling_div(medium_segment_count, medium_segments_per_block); - std::chrono::high_resolution_clock::time_point start; + std::chrono::steady_clock::time_point start; if(debug_synchronous) - start = std::chrono::high_resolution_clock::now(); + start = std::chrono::steady_clock::now(); hipLaunchKernelGGL(HIP_KERNEL_NAME(segmented_sort_medium_kernel), dim3(medium_segment_grid_size), dim3(params.warp_sort_config.block_size_medium), @@ -556,14 +542,14 @@ hipError_t segmented_radix_sort_impl(void * temporary_storage, end_bit); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("segmented_sort:medium_segments", medium_segment_count, - start) + start); } if(small_segment_count > 0) { const auto small_segment_grid_size = ::rocprim::detail::ceiling_div(small_segment_count, small_segments_per_block); - std::chrono::high_resolution_clock::time_point start; - if(debug_synchronous) start = std::chrono::high_resolution_clock::now(); + std::chrono::steady_clock::time_point start; + if(debug_synchronous) start = std::chrono::steady_clock::now(); hipLaunchKernelGGL(HIP_KERNEL_NAME(segmented_sort_small_kernel), dim3(small_segment_grid_size), dim3(params.warp_sort_config.block_size_small), @@ -584,13 +570,13 @@ hipError_t segmented_radix_sort_impl(void * temporary_storage, end_bit); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("segmented_sort:small_segments", small_segment_count, - start) + start); } } else { - std::chrono::high_resolution_clock::time_point start; - if(debug_synchronous) start = std::chrono::high_resolution_clock::now(); + std::chrono::steady_clock::time_point start; + if(debug_synchronous) start = std::chrono::steady_clock::now(); hipLaunchKernelGGL(HIP_KERNEL_NAME(segmented_sort_kernel), dim3(segments), dim3(params.kernel_config.block_size), @@ -608,12 +594,12 @@ hipError_t segmented_radix_sort_impl(void * temporary_storage, iterations, begin_bit, end_bit); - ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("segmented_sort", segments, start) + ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("segmented_sort", segments, start); } return hipSuccess; } -#undef ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR + } // end namespace detail diff --git a/rocprim/include/rocprim/device/device_segmented_reduce.hpp b/rocprim/include/rocprim/device/device_segmented_reduce.hpp index 614ba7b56..b0c5826d0 100644 --- a/rocprim/include/rocprim/device/device_segmented_reduce.hpp +++ b/rocprim/include/rocprim/device/device_segmented_reduce.hpp @@ -26,6 +26,7 @@ #include #include "../config.hpp" +#include "../common.hpp" #include "../detail/various.hpp" #include "../functional.hpp" @@ -63,21 +64,6 @@ ROCPRIM_KERNEL __launch_bounds__( ); } -#define ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR(name, size, start) \ - { \ - auto _error = hipGetLastError(); \ - if(_error != hipSuccess) return _error; \ - if(debug_synchronous) \ - { \ - std::cout << name << "(" << size << ")"; \ - auto __error = hipStreamSynchronize(stream); \ - if(__error != hipSuccess) return __error; \ - auto _end = std::chrono::high_resolution_clock::now(); \ - auto _d = std::chrono::duration_cast>(_end - start); \ - std::cout << " " << _d.count() * 1000 << " ms" << '\n'; \ - } \ - } - template< class Config, class InputIterator, @@ -126,9 +112,9 @@ hipError_t segmented_reduce_impl(void * temporary_storage, if( segments == 0u ) return hipSuccess; - std::chrono::high_resolution_clock::time_point start; + std::chrono::steady_clock::time_point start; - if(debug_synchronous) start = std::chrono::high_resolution_clock::now(); + if(debug_synchronous) start = std::chrono::steady_clock::now(); hipLaunchKernelGGL( HIP_KERNEL_NAME(segmented_reduce_kernel), dim3(segments), dim3(block_size), 0, stream, @@ -141,7 +127,7 @@ hipError_t segmented_reduce_impl(void * temporary_storage, return hipSuccess; } -#undef ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR + } // end of detail namespace diff --git a/rocprim/include/rocprim/device/device_segmented_scan.hpp b/rocprim/include/rocprim/device/device_segmented_scan.hpp index 33df4b323..833650390 100644 --- a/rocprim/include/rocprim/device/device_segmented_scan.hpp +++ b/rocprim/include/rocprim/device/device_segmented_scan.hpp @@ -26,6 +26,7 @@ #include #include "../config.hpp" +#include "../common.hpp" #include "../detail/various.hpp" #include "../iterator/zip_iterator.hpp" @@ -46,6 +47,31 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { +template +struct transform_op_t +{ + InputIterator input; + HeadFlagIterator head_flags; + ResultType initial_value_converted; + size_t size; + + ROCPRIM_DEVICE + auto operator()(const size_t i) const + { + FlagType flag(false); + if(i + 1 < size) + { + flag = head_flags[i + 1]; + } + ResultType value = initial_value_converted; + if(!flag) + { + value = input[i]; + } + return rocprim::make_tuple(value, flag); + } +}; + template< bool Exclusive, class Config, @@ -71,21 +97,6 @@ void segmented_scan_kernel(InputIterator input, ); } -#define ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR(name, size, start) \ - { \ - auto _error = hipGetLastError(); \ - if(_error != hipSuccess) return _error; \ - if(debug_synchronous) \ - { \ - std::cout << name << "(" << size << ")"; \ - auto __error = hipStreamSynchronize(stream); \ - if(__error != hipSuccess) return __error; \ - auto _end = std::chrono::high_resolution_clock::now(); \ - auto _d = std::chrono::duration_cast>(_end - start); \ - std::cout << " " << _d.count() * 1000 << " ms" << '\n'; \ - } \ - } - template< bool Exclusive, class Config, @@ -134,8 +145,8 @@ hipError_t segmented_scan_impl(void * temporary_storage, if( segments == 0u ) return hipSuccess; - std::chrono::high_resolution_clock::time_point start; - if(debug_synchronous) start = std::chrono::high_resolution_clock::now(); + std::chrono::steady_clock::time_point start; + if(debug_synchronous) start = std::chrono::steady_clock::now(); hipLaunchKernelGGL( HIP_KERNEL_NAME(segmented_scan_kernel), dim3(segments), dim3(block_size), 0, stream, @@ -147,7 +158,7 @@ hipError_t segmented_scan_impl(void * temporary_storage, return hipSuccess; } -#undef ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR + } // end of detail namespace @@ -601,39 +612,26 @@ hipError_t segmented_exclusive_scan(void * temporary_storage, detail::headflag_scan_op_wrapper< result_type, flag_type, BinaryFunction >; + using transform_op + = detail::transform_op_t; const result_type initial_value_converted = static_cast(initial_value); // Flag the last item of each segment as the next segment's head, use initial_value as its value, // then run exclusive scan return exclusive_scan( - temporary_storage, storage_size, + temporary_storage, + storage_size, rocprim::make_transform_iterator( rocprim::make_counting_iterator(0), - [input, head_flags, initial_value_converted, size] - ROCPRIM_DEVICE - (const size_t i) - { - flag_type flag(false); - if(i + 1 < size) - { - flag = head_flags[i + 1]; - } - result_type value = initial_value_converted; - if(!flag) - { - value = input[i]; - } - return rocprim::make_tuple(value, flag); - } - ), + transform_op{input, head_flags, initial_value_converted, size}), rocprim::make_zip_iterator(rocprim::make_tuple(output, rocprim::make_discard_iterator())), - rocprim::make_tuple(initial_value_converted, flag_type(true)), // init value is a head of the first segment + rocprim::make_tuple(initial_value_converted, + flag_type(true)), // init value is a head of the first segment size, headflag_scan_op_wrapper_type(scan_op), stream, - debug_synchronous - ); + debug_synchronous); } /// @} diff --git a/rocprim/include/rocprim/device/device_select.hpp b/rocprim/include/rocprim/device/device_select.hpp index e021ff005..3f8267521 100644 --- a/rocprim/include/rocprim/device/device_select.hpp +++ b/rocprim/include/rocprim/device/device_select.hpp @@ -199,9 +199,13 @@ hipError_t select(void * temporary_storage, /// \param [out] selected_count_output - iterator to the total number of selected values (length of \p output). /// \param [in] size - number of element in the input range. /// \param [in] predicate - unary function object that will be used for selecting values. -/// The signature of the function should be equivalent to the following: +/// The predicate must meet the C++ named requirement \p BinaryPredicate : +/// - The result of applying the predicate must be convertible to bool +/// - The predicate must accept const object arguments, with the same behavior regardless of +/// whether its arguments are const or non-const. +/// In practice, the signature of the function should be equivalent to the following: /// bool f(const T &a);. The signature does not need to have -/// const &, but function object must not modify the object passed to it. +/// const &, but the function object must not modify the object passed to it. /// \param [in] stream - [optional] HIP stream object. The default is \p 0 (default stream). /// \param [in] debug_synchronous - [optional] If true, synchronization after every kernel /// launch is forced in order to check for errors. The default value is \p false. @@ -209,7 +213,7 @@ hipError_t select(void * temporary_storage, /// \par Example /// \parblock /// In this example a device-level select operation is performed on an array of -/// integer values, only even values are selected. +/// integer values. Only even values are selected. /// /// \code{.cpp} /// #include @@ -217,7 +221,7 @@ hipError_t select(void * temporary_storage, /// auto predicate = /// [] __device__ (int a) -> bool /// { -/// return (a%2) == 0; +/// return (a % 2) == 0; /// }; /// /// // Prepare input and output (declare pointers, allocate device memory etc.) @@ -296,6 +300,142 @@ hipError_t select(void * temporary_storage, predicate); } +/// \brief Parallel select primitive for device level using a range of pre-selected flags. +/// +/// Performs a device-wide selection based on input \p flags to which a selection operator is +/// applied before doing the selection. If a value \p x from \p input should be selected and +/// copied into \p output range the corresponding item from the \p flags range should be set +/// to such value that, after applying the predicate(x) to it, it can be implicitly +/// converted to \p true (\p bool type). +/// +/// \par Overview +/// * Returns the required size of \p temporary_storage in \p storage_size +/// if \p temporary_storage in a null pointer. +/// * Range specified by \p input and \p flags must have at least \p size elements. +/// * Range specified by \p output must have at least so many elements, that all selected +/// values can be copied into it. +/// * Range specified by \p selected_count_output must have at least 1 element. +/// +/// \tparam Config - [optional] Configuration of the primitive, must be `default_config` or `select_config`. +/// \tparam InputIterator - random-access iterator type of the input range. It can be +/// a simple pointer type. +/// \tparam FlagIterator - random-access iterator type of the flag range. It can be +/// a simple pointer type. +/// \tparam OutputIterator - random-access iterator type of the output range. It can be +/// a simple pointer type. +/// \tparam SelectedCountOutputIterator - random-access iterator type of the selected_count_output +/// value. It can be a simple pointer type. +/// \tparam UnaryPredicate - type of a unary selection predicate. +/// +/// \param [in] temporary_storage - pointer to a device-accessible temporary storage. When +/// a null pointer is passed, the required allocation size (in bytes) is written to +/// \p storage_size and function returns without performing the select operation. +/// \param [in,out] storage_size - reference to a size (in bytes) of \p temporary_storage. +/// \param [in] input - iterator to the first element in the range to select values from. +/// \param [in] flags - iterator to the selection flag corresponding to the first element from \p input range. +/// \param [out] output - iterator to the first element in the output range. +/// \param [out] selected_count_output - iterator to the total number of selected values (length of \p output). +/// \param [in] size - number of element in the input range. +/// \param [in] predicate - unary function object that will be used for selecting flags. +/// The predicate must meet the C++ named requirement \p BinaryPredicate: +/// - The result of applying the predicate must be convertible to bool. +/// - The predicate must accept const object arguments, with the same behavior regardless of +/// whether its arguments are const or non-const. +/// In practice, the signature of the function should be equivalent to the following: +/// bool f(const T &a);. The signature does not need to have +/// const &, but the function object must not modify the object passed to it. +/// \param [in] stream - [optional] HIP stream object. The default is \p 0 (default stream). +/// \param [in] debug_synchronous - [optional] If true, synchronization after every kernel +/// launch is forced in order to check for errors. The default value is \p false. +/// +/// \par Example +/// \parblock +/// In this example a device-level select operation is performed on an array of +/// integer values. Only values with even flags are selected. +/// +/// \code{.cpp} +/// #include +/// +/// auto predicate = +/// [] __device__ (int a) -> bool +/// { +/// return (a % 2) == 0; +/// }; +/// +/// // Prepare input and output (declare pointers, allocate device memory etc.) +/// size_t input_size; // e.g., 8 +/// int * input; // e.g., [1, 2, 3, 4, 5, 6, 7, 8] +/// int * flags; // e.g., [0, 1, 2, 3, 4, 5, 6, 7] +/// int * output; // empty array of 8 elements +/// size_t * output_count; // empty array of 1 element +/// +/// size_t temporary_storage_size_bytes; +/// void * temporary_storage_ptr = nullptr; +/// // Get required size of the temporary storage +/// rocprim::select( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// input, flags, output, output_count, +/// predicate, input_size +/// ); +/// +/// // allocate temporary storage +/// hipMalloc(&temporary_storage_ptr, temporary_storage_size_bytes); +/// +/// // perform selection +/// rocprim::select( +/// temporary_storage_ptr, temporary_storage_size_bytes, +/// input, flags, output, output_count, +/// predicate, input_size +/// ); +/// // output: [1, 3, 5, 7] +/// // output_count: 4 +/// \endcode +/// \endparblock +template +inline hipError_t select(void* temporary_storage, + size_t& storage_size, + InputIterator input, + FlagIterator flags, + OutputIterator output, + SelectedCountOutputIterator selected_count_output, + const size_t size, + UnaryPredicate predicate, + const hipStream_t stream = 0, + const bool debug_synchronous = false) +{ + // Dummy inequality operation + using inequality_op_type = ::rocprim::empty_type; + using offset_type = unsigned int; + rocprim::empty_type* const no_values = nullptr; // key only + + using output_key_iterator_tuple = tuple; + output_key_iterator_tuple output_tuple{output, ::rocprim::empty_type()}; + + using output_value_iterator_tuple = tuple<::rocprim::empty_type*, ::rocprim::empty_type*>; + const output_value_iterator_tuple no_output_values{nullptr, nullptr}; // key only + + return detail::partition_impl(temporary_storage, + storage_size, + input, + no_values, + flags, + output_tuple, + no_output_values, + selected_count_output, + size, + inequality_op_type(), + stream, + debug_synchronous, + predicate); +} + /// \brief Device-level parallel unique primitive. /// /// From given \p input range unique primitive eliminates all but the first element from every diff --git a/rocprim/include/rocprim/device/device_transform.hpp b/rocprim/include/rocprim/device/device_transform.hpp index 5f54bd8fb..8beee4f5a 100644 --- a/rocprim/include/rocprim/device/device_transform.hpp +++ b/rocprim/include/rocprim/device/device_transform.hpp @@ -27,6 +27,7 @@ #include #include "../config.hpp" +#include "../common.hpp" #include "../detail/various.hpp" #include "../iterator/zip_iterator.hpp" #include "../types/tuple.hpp" @@ -56,21 +57,6 @@ ROCPRIM_KERNEL ResultType>(input, size, output, transform_op); } -#define ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR(name, size, start) \ - { \ - auto _error = hipGetLastError(); \ - if(_error != hipSuccess) return _error; \ - if(debug_synchronous) \ - { \ - std::cout << name << "(" << size << ")"; \ - _error = hipStreamSynchronize(stream); \ - if(_error != hipSuccess) return _error; \ - auto _end = std::chrono::high_resolution_clock::now(); \ - auto _d = std::chrono::duration_cast>(_end - start); \ - std::cout << " " << _d.count() * 1000 << " ms" << '\n'; \ - } \ - } - } // end of detail namespace /// \brief Parallel transform primitive for device level. @@ -159,7 +145,7 @@ inline hipError_t transform(InputIterator input, const auto items_per_block = block_size * items_per_thread; // Start point for time measurements - std::chrono::high_resolution_clock::time_point start; + std::chrono::steady_clock::time_point start; const auto size_limit = params.kernel_config.size_limit; const auto number_of_blocks_limit = ::rocprim::max(size_limit / items_per_block, 1); @@ -182,7 +168,7 @@ inline hipError_t transform(InputIterator input, const auto current_blocks = (current_size + items_per_block - 1) / items_per_block; if(debug_synchronous) - start = std::chrono::high_resolution_clock::now(); + start = std::chrono::steady_clock::now(); hipLaunchKernelGGL(HIP_KERNEL_NAME(detail::transform_kernel), dim3(current_blocks), dim3(block_size), @@ -280,7 +266,7 @@ hipError_t transform(InputIterator1 input1, ); } -#undef ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR + END_ROCPRIM_NAMESPACE diff --git a/rocprim/include/rocprim/device/specialization/device_radix_block_sort.hpp b/rocprim/include/rocprim/device/specialization/device_radix_block_sort.hpp index b357bba4a..b05c5cc29 100644 --- a/rocprim/include/rocprim/device/specialization/device_radix_block_sort.hpp +++ b/rocprim/include/rocprim/device/specialization/device_radix_block_sort.hpp @@ -21,6 +21,7 @@ #ifndef ROCPRIM_DEVICE_SPECIALIZATION_DEVICE_RADIX_SINGLE_SORT_HPP_ #define ROCPRIM_DEVICE_SPECIALIZATION_DEVICE_RADIX_SINGLE_SORT_HPP_ +#include "../../common.hpp" #include "../detail/device_radix_sort.hpp" #include "../device_radix_sort_config.hpp" @@ -29,23 +30,6 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { -#define ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR(name, size, start) \ - { \ - auto _error = hipGetLastError(); \ - if(_error != hipSuccess) \ - return _error; \ - if(debug_synchronous) \ - { \ - std::cout << name << "(" << size << ")"; \ - auto __error = hipStreamSynchronize(stream); \ - if(__error != hipSuccess) \ - return __error; \ - auto _end = std::chrono::high_resolution_clock::now(); \ - auto _d = std::chrono::duration_cast>(_end - start); \ - std::cout << " " << _d.count() * 1000 << " ms" << '\n'; \ - } \ - } - template @@ -138,7 +122,7 @@ inline hipError_t radix_sort_block_sort(KeysInputIterator keys_input, decomposer, bit, current_radix_bits); - ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("radix_sort_block_sort_kernel", size, start) + ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("radix_sort_block_sort_kernel", size, start); return hipSuccess; } diff --git a/rocprim/include/rocprim/intrinsics/atomic.hpp b/rocprim/include/rocprim/intrinsics/atomic.hpp index f0daea443..03cd5333a 100644 --- a/rocprim/include/rocprim/intrinsics/atomic.hpp +++ b/rocprim/include/rocprim/intrinsics/atomic.hpp @@ -57,6 +57,30 @@ namespace detail return ::atomicAdd(address, value); } + ROCPRIM_DEVICE ROCPRIM_INLINE + unsigned int atomic_wrapinc(unsigned int* address, unsigned int value) + { + return ::atomicInc(address, value); + } + + ROCPRIM_DEVICE ROCPRIM_INLINE + unsigned int atomic_max(unsigned int* address, unsigned int value) + { + return ::atomicMax(address, value); + } + + ROCPRIM_DEVICE ROCPRIM_INLINE + unsigned long atomic_max(unsigned long* address, unsigned long value) + { + return ::atomicMax(address, value); + } + + ROCPRIM_DEVICE ROCPRIM_INLINE + unsigned long long atomic_max(unsigned long long* address, unsigned long long value) + { + return ::atomicMax(address, value); + } + ROCPRIM_DEVICE ROCPRIM_INLINE unsigned int atomic_min(unsigned int* address, unsigned int value) { @@ -76,13 +100,27 @@ namespace detail } ROCPRIM_DEVICE ROCPRIM_INLINE - unsigned int atomic_wrapinc(unsigned int* address, unsigned int value) + unsigned int atomic_cas(unsigned int* address, unsigned int compare, unsigned int value) { - return ::atomicInc(address, value); + return ::atomicCAS(address, compare, value); + } + + ROCPRIM_DEVICE ROCPRIM_INLINE + unsigned long atomic_cas(unsigned long* address, unsigned long compare, unsigned long value) + { + return ::atomicCAS(address, compare, value); } ROCPRIM_DEVICE ROCPRIM_INLINE - unsigned int atomic_exch(unsigned int * address, unsigned int value) + unsigned long long atomic_cas(unsigned long long* address, + unsigned long long compare, + unsigned long long value) + { + return ::atomicCAS(address, compare, value); + } + + ROCPRIM_DEVICE ROCPRIM_INLINE + unsigned int atomic_exch(unsigned int* address, unsigned int value) { return ::atomicExch(address, value); } @@ -118,7 +156,8 @@ namespace detail return __hip_atomic_load(address, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); } - ROCPRIM_DEVICE ROCPRIM_INLINE void atomic_store(unsigned char* address, unsigned char value) + ROCPRIM_DEVICE ROCPRIM_INLINE + void atomic_store(unsigned char* address, unsigned char value) { __hip_atomic_store(address, value, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); } diff --git a/rocprim/include/rocprim/intrinsics/thread.hpp b/rocprim/include/rocprim/intrinsics/thread.hpp index 81a61da58..a97cb021a 100644 --- a/rocprim/include/rocprim/intrinsics/thread.hpp +++ b/rocprim/include/rocprim/intrinsics/thread.hpp @@ -75,12 +75,7 @@ unsigned int flat_tile_size() ROCPRIM_DEVICE ROCPRIM_INLINE unsigned int lane_id() { -#ifndef __HIP_CPU_RT__ return ::__lane_id(); -#else - using namespace hip::detail; - return id(Fiber::this_fiber()) % device_warp_size(); -#endif } /// \brief Returns flat (linear, 1D) thread identifier in a multidimensional block (tile). diff --git a/rocprim/include/rocprim/intrinsics/warp.hpp b/rocprim/include/rocprim/intrinsics/warp.hpp index 66d2e359c..a3d47ac35 100644 --- a/rocprim/include/rocprim/intrinsics/warp.hpp +++ b/rocprim/include/rocprim/intrinsics/warp.hpp @@ -50,18 +50,11 @@ ROCPRIM_DEVICE ROCPRIM_INLINE unsigned int masked_bit_count(lane_mask_type x, unsigned int add = 0) { int c; -#ifndef __HIP_CPU_RT__ - #if ROCPRIM_WAVEFRONT_SIZE == 32 +#if ROCPRIM_WAVEFRONT_SIZE == 32 c = ::__builtin_amdgcn_mbcnt_lo(x, add); - #else +#else c = ::__builtin_amdgcn_mbcnt_lo(static_cast(x), add); c = ::__builtin_amdgcn_mbcnt_hi(static_cast(x >> 32), c); - #endif -#else - using namespace hip::detail; - const auto tidx{id(Fiber::this_fiber()) % device_warp_size()}; - std::bitset bits{x >> (device_warp_size() - tidx)}; - c = static_cast(bits.count()) + add; #endif return c; } @@ -72,37 +65,13 @@ namespace detail ROCPRIM_DEVICE ROCPRIM_INLINE int warp_any(int predicate) { -#ifndef __HIP_CPU_RT__ return ::__any(predicate); -#else - using namespace hip::detail; - const auto tidx{id(Fiber::this_fiber()) % device_warp_size()}; - auto& lds{Tile::scratchpad, 1>()[0]}; - - lds[tidx] = static_cast(predicate); - - barrier(Tile::this_tile()); - - return lds.any(); -#endif } ROCPRIM_DEVICE ROCPRIM_INLINE int warp_all(int predicate) { -#ifndef __HIP_CPU_RT__ return ::__all(predicate); -#else - using namespace hip::detail; - const auto tidx{id(Fiber::this_fiber()) % device_warp_size()}; - auto& lds{Tile::scratchpad, 1>()[0]}; - - lds[tidx] = static_cast(predicate); - - barrier(Tile::this_tile()); - - return lds.all(); -#endif } } // end detail namespace diff --git a/rocprim/include/rocprim/intrinsics/warp_shuffle.hpp b/rocprim/include/rocprim/intrinsics/warp_shuffle.hpp index 51de7dc3a..7f0a86c8b 100644 --- a/rocprim/include/rocprim/intrinsics/warp_shuffle.hpp +++ b/rocprim/include/rocprim/intrinsics/warp_shuffle.hpp @@ -67,18 +67,10 @@ warp_shuffle_op(const T& input, ShuffleOp&& op) for(int i = 0; i < words_no; i++) { const size_t s = std::min(sizeof(int), sizeof(T) - i * sizeof(int)); - int word; -#ifdef __HIP_CPU_RT__ - std::memcpy(&word, reinterpret_cast(&input) + i * sizeof(int), s); -#else + int word; __builtin_memcpy(&word, reinterpret_cast(&input) + i * sizeof(int), s); -#endif word = op(word); -#ifdef __HIP_CPU_RT__ - std::memcpy(reinterpret_cast(&output) + i * sizeof(int), &word, s); -#else __builtin_memcpy(reinterpret_cast(&output) + i * sizeof(int), &word, s); -#endif } return output; @@ -99,13 +91,8 @@ T warp_move_dpp(const T& input) // __builtin_amdgcn_update_dpp, hence fail to parse the template altogether. (Except MSVC // because even using /permissive- they somehow still do delayed parsing of the body of // function templates, even though they pinky-swear they don't.) -#if !defined(__HIP_CPU_RT__) return ::__builtin_amdgcn_mov_dpp(v, dpp_ctrl, row_mask, bank_mask, bound_ctrl); -#else - return v; -#endif - } - ); + }); } /// \brief Swizzle for any data type. diff --git a/rocprim/include/rocprim/iterator.hpp b/rocprim/include/rocprim/iterator.hpp index c5215222d..a3fb040fa 100644 --- a/rocprim/include/rocprim/iterator.hpp +++ b/rocprim/include/rocprim/iterator.hpp @@ -29,9 +29,7 @@ #include "iterator/counting_iterator.hpp" #include "iterator/discard_iterator.hpp" #include "iterator/predicate_iterator.hpp" -#ifndef __HIP_CPU_RT__ #include "iterator/texture_cache_iterator.hpp" -#endif #include "iterator/transform_iterator.hpp" #include "iterator/zip_iterator.hpp" diff --git a/rocprim/include/rocprim/iterator/reverse_iterator.hpp b/rocprim/include/rocprim/iterator/reverse_iterator.hpp index 65b133e14..5ae685514 100644 --- a/rocprim/include/rocprim/iterator/reverse_iterator.hpp +++ b/rocprim/include/rocprim/iterator/reverse_iterator.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2022 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2022-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -42,6 +42,13 @@ BEGIN_ROCPRIM_NAMESPACE /// * Use it to iterate over the elements of a container in reverse. /// /// \tparam SourceIterator - type of the wrapped iterator. +template +class reverse_iterator; + +template +ROCPRIM_HOST_DEVICE +constexpr reverse_iterator make_reverse_iterator(SourceIterator source_iterator); + template class reverse_iterator { @@ -62,20 +69,43 @@ class reverse_iterator /// The category of the iterator. using iterator_category = std::random_access_iterator_tag; + /// \brief Constructs a new default reverse_iterator. + ROCPRIM_HOST_DEVICE constexpr reverse_iterator() : source_iterator_(nullptr) + {} + /// \brief Constructs a new reverse_iterator using the supplied source. + [[deprecated("The initialisation constructor of 'rocprim::reverse_iterator' will be " + "marked explicit in ROCm 7.0. Use 'rocprim::make_reverse_iterator' " + "instead.")]] ROCPRIM_HOST_DEVICE constexpr /*explicit*/ + reverse_iterator(SourceIterator source_iterator) + : source_iterator_(source_iterator) + {} + + /// \brief Constructs a new reverse_iterator using that of the supplied source. + template + ROCPRIM_HOST_DEVICE constexpr reverse_iterator( + const reverse_iterator& source_reverse_iterator) + : source_iterator_(source_reverse_iterator.base()) + {} + +#ifndef DOXYGEN_SHOULD_SKIP_THIS ROCPRIM_HOST_DEVICE - reverse_iterator(SourceIterator source_iterator) : source_iterator_(source_iterator) {} + constexpr SourceIterator base() const + { + return source_iterator_; + } - #ifndef DOXYGEN_SHOULD_SKIP_THIS ROCPRIM_HOST_DEVICE - reverse_iterator& operator++() + constexpr reverse_iterator& + operator++() { --source_iterator_; return *this; } ROCPRIM_HOST_DEVICE - reverse_iterator operator++(int) + constexpr reverse_iterator + operator++(int) { reverse_iterator old = *this; --source_iterator_; @@ -83,14 +113,16 @@ class reverse_iterator } ROCPRIM_HOST_DEVICE - reverse_iterator& operator--() + constexpr reverse_iterator& + operator--() { ++source_iterator_; return *this; } ROCPRIM_HOST_DEVICE - reverse_iterator operator--(int) + constexpr reverse_iterator + operator--(int) { reverse_iterator old = *this; ++source_iterator_; @@ -98,86 +130,99 @@ class reverse_iterator } ROCPRIM_HOST_DEVICE - reference operator*() + constexpr reference + operator*() const { return *(source_iterator_ - static_cast(1)); } ROCPRIM_HOST_DEVICE - reference operator[](difference_type distance) + constexpr reference + operator[](difference_type distance) const { reverse_iterator i = (*this) + distance; return *i; } ROCPRIM_HOST_DEVICE - reverse_iterator operator+(difference_type distance) const + constexpr reverse_iterator + operator+(difference_type distance) const { - return reverse_iterator(source_iterator_ - distance); + return rocprim::make_reverse_iterator(source_iterator_ - distance); } ROCPRIM_HOST_DEVICE - reverse_iterator& operator+=(difference_type distance) + constexpr reverse_iterator& + operator+=(difference_type distance) { source_iterator_ -= distance; return *this; } ROCPRIM_HOST_DEVICE - reverse_iterator operator-(difference_type distance) const + constexpr reverse_iterator + operator-(difference_type distance) const { - return reverse_iterator(source_iterator_ + distance); + return rocprim::make_reverse_iterator(source_iterator_ + distance); } ROCPRIM_HOST_DEVICE - reverse_iterator& operator-=(difference_type distance) + constexpr reverse_iterator& + operator-=(difference_type distance) { source_iterator_ += distance; return *this; } ROCPRIM_HOST_DEVICE - difference_type operator-(reverse_iterator other) const + constexpr difference_type + operator-(reverse_iterator other) const { return other.source_iterator_ - source_iterator_; } ROCPRIM_HOST_DEVICE - bool operator==(reverse_iterator other) const + constexpr bool + operator==(reverse_iterator other) const { return source_iterator_ == other.source_iterator_; } ROCPRIM_HOST_DEVICE - bool operator!=(reverse_iterator other) const + constexpr bool + operator!=(reverse_iterator other) const { return source_iterator_ != other.source_iterator_; } ROCPRIM_HOST_DEVICE - bool operator<(reverse_iterator other) const + constexpr bool + operator<(reverse_iterator other) const { return other.source_iterator_ < source_iterator_; } ROCPRIM_HOST_DEVICE - bool operator<=(reverse_iterator other) const + constexpr bool + operator<=(reverse_iterator other) const { return other.source_iterator_ <= source_iterator_; } ROCPRIM_HOST_DEVICE - bool operator>(reverse_iterator other) const + constexpr bool + operator>(reverse_iterator other) const { return other.source_iterator_ > source_iterator_; } ROCPRIM_HOST_DEVICE - bool operator>=(reverse_iterator other) const + constexpr bool + operator>=(reverse_iterator other) const { return other.source_iterator_ >= source_iterator_; } - #endif // DOXYGEN_SHOULD_SKIP_THIS +#endif // DOXYGEN_SHOULD_SKIP_THIS private: SourceIterator source_iterator_; @@ -185,8 +230,9 @@ class reverse_iterator #ifndef DOXYGEN_SHOULD_SKIP_THIS template -ROCPRIM_HOST_DEVICE reverse_iterator - operator+(typename reverse_iterator::difference_type distance, +ROCPRIM_HOST_DEVICE +constexpr reverse_iterator + operator+(typename reverse_iterator::difference_type distance, const reverse_iterator& iterator) { return iterator + distance; @@ -200,10 +246,13 @@ ROCPRIM_HOST_DEVICE reverse_iterator /// \param source_iterator - the iterator to wrap in the created \p reverse_iterator. /// \return A \p reverse_iterator that wraps \p source_iterator. template -ROCPRIM_HOST_DEVICE reverse_iterator - make_reverse_iterator(SourceIterator source_iterator) +ROCPRIM_HOST_DEVICE +constexpr reverse_iterator make_reverse_iterator(SourceIterator source_iterator) { +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wdeprecated-declarations" return reverse_iterator(source_iterator); +#pragma clang diagnostic pop } END_ROCPRIM_NAMESPACE diff --git a/rocprim/include/rocprim/rocprim.hpp b/rocprim/include/rocprim/rocprim.hpp index 843a3d172..2cd772532 100644 --- a/rocprim/include/rocprim/rocprim.hpp +++ b/rocprim/include/rocprim/rocprim.hpp @@ -60,8 +60,10 @@ #include "block/config.hpp" #include "device/device_adjacent_difference.hpp" +#include "device/device_adjacent_find.hpp" #include "device/device_binary_search.hpp" #include "device/device_copy.hpp" +#include "device/device_find_end.hpp" #include "device/device_find_first_of.hpp" #include "device/device_histogram.hpp" #include "device/device_memcpy.hpp" @@ -76,6 +78,8 @@ #include "device/device_run_length_encode.hpp" #include "device/device_scan.hpp" #include "device/device_scan_by_key.hpp" +#include "device/device_search.hpp" +#include "device/device_search_n.hpp" #include "device/device_segmented_radix_sort.hpp" #include "device/device_segmented_reduce.hpp" #include "device/device_segmented_scan.hpp" diff --git a/rocprim/include/rocprim/thread/radix_key_codec.hpp b/rocprim/include/rocprim/thread/radix_key_codec.hpp index c83767890..a8b6026cb 100644 --- a/rocprim/include/rocprim/thread/radix_key_codec.hpp +++ b/rocprim/include/rocprim/thread/radix_key_codec.hpp @@ -227,10 +227,12 @@ using radix_key_fundamental = typename has_bit_key_type> static_assert(radix_key_fundamental::value, "'int' should be fundamental"); static_assert(!radix_key_fundamental::value, "'int*' should not be fundamental"); -static_assert(radix_key_fundamental<__int128_t>::value, "'__int128_t' should be fundamental"); -static_assert(radix_key_fundamental<__uint128_t>::value, "'__uint128_t' should be fundamental"); -static_assert(!radix_key_fundamental<__int128_t*>::value, - "'__int128_t*' should not be fundamental"); +static_assert(radix_key_fundamental::value, + "'rocprim::int128_t' should be fundamental"); +static_assert(radix_key_fundamental::value, + "'rocprim::uint128_t' should be fundamental"); +static_assert(!radix_key_fundamental::value, + "'rocprim::int128_t*' should not be fundamental"); } // namespace detail diff --git a/rocprim/include/rocprim/thread/thread_load.hpp b/rocprim/include/rocprim/thread/thread_load.hpp index 2f72ea7ee..c29cf5832 100644 --- a/rocprim/include/rocprim/thread/thread_load.hpp +++ b/rocprim/include/rocprim/thread/thread_load.hpp @@ -101,8 +101,6 @@ ROCPRIM_DEVICE __forceinline__ T AsmThreadLoad(void * ptr) ROCPRIM_ASM_THREAD_LOAD(cache_modifier, llvm_cache_modifier, uint64_t, uint64_t, flat_load_dwordx2, v, wait_inst, wait_cmd); \ ROCPRIM_ASM_THREAD_LOAD(cache_modifier, llvm_cache_modifier, double, uint64_t, flat_load_dwordx2, v, wait_inst, wait_cmd); -// [HIP-CPU] MSVC: erronous inline assembly specification (Triggers error C2059: syntax error: 'volatile') -#ifndef __HIP_CPU_RT__ #if defined(__gfx940__) || defined(__gfx941__) ROCPRIM_ASM_THREAD_LOAD_GROUP(load_ca, "sc0", "s_waitcnt", ""); ROCPRIM_ASM_THREAD_LOAD_GROUP(load_cg, "sc1", "s_waitcnt", ""); @@ -128,7 +126,6 @@ ROCPRIM_ASM_THREAD_LOAD_GROUP(load_volatile, "glc", "s_waitcnt", "vmcnt"); // TODO find correct modifiers to match these ROCPRIM_ASM_THREAD_LOAD_GROUP(load_ldg, "", "s_waitcnt", ""); ROCPRIM_ASM_THREAD_LOAD_GROUP(load_cs, "", "s_waitcnt", ""); -#endif // __HIP_CPU_RT__ #endif @@ -160,13 +157,7 @@ template template [[deprecated("Use a dereference instead.")]] ROCPRIM_DEVICE ROCPRIM_INLINE T thread_load(T* ptr) { -#ifndef __HIP_CPU_RT__ return detail::AsmThreadLoad(ptr); -#else - T retval; - std::memcpy(&retval, ptr, sizeof(T)); - return retval; -#endif } /// @} diff --git a/rocprim/include/rocprim/thread/thread_store.hpp b/rocprim/include/rocprim/thread/thread_store.hpp index 0ff7f4148..7f8ba7600 100644 --- a/rocprim/include/rocprim/thread/thread_store.hpp +++ b/rocprim/include/rocprim/thread/thread_store.hpp @@ -103,8 +103,6 @@ ROCPRIM_DEVICE __forceinline__ void AsmThreadStore(void * ptr, T val) ROCPRIM_ASM_THREAD_STORE(cache_modifier, llvm_cache_modifier, uint64_t, uint64_t, flat_store_dwordx2, v, wait_inst, wait_cmd); \ ROCPRIM_ASM_THREAD_STORE(cache_modifier, llvm_cache_modifier, double, uint64_t, flat_store_dwordx2, v, wait_inst, wait_cmd); -// [HIP-CPU] MSVC: erronous inline assembly specification (Triggers error C2059: syntax error: 'volatile') -#ifndef __HIP_CPU_RT__ #if defined(__gfx940__) || defined(__gfx941__) ROCPRIM_ASM_THREAD_STORE_GROUP(store_wb, "sc0 sc1", "s_waitcnt", ""); // TODO: gfx942 validation ROCPRIM_ASM_THREAD_STORE_GROUP(store_cg, "sc0 sc1", "s_waitcnt", ""); @@ -128,7 +126,6 @@ ROCPRIM_ASM_THREAD_STORE_GROUP(store_volatile, "glc", "s_waitcnt", "vmcnt"); #endif // TODO find correct modifiers to match these ROCPRIM_ASM_THREAD_STORE_GROUP(store_cs, "", "s_waitcnt", ""); -#endif // __HIP_CPU_RT__ #endif @@ -159,11 +156,7 @@ template [[deprecated("Use a dereference instead.")]] ROCPRIM_DEVICE ROCPRIM_INLINE void thread_store(T* ptr, T val) { -#ifndef __HIP_CPU_RT__ detail::AsmThreadStore(ptr, val); -#else - std::memcpy(ptr, &val, sizeof(T)); -#endif } /// @} diff --git a/rocprim/include/rocprim/type_traits.hpp b/rocprim/include/rocprim/type_traits.hpp index 91bcb14dc..f8368a4ba 100644 --- a/rocprim/include/rocprim/type_traits.hpp +++ b/rocprim/include/rocprim/type_traits.hpp @@ -52,8 +52,8 @@ struct is_integral : std::integral_constant< bool, std::is_integral::value - || std::is_same<__int128_t, typename std::remove_cv::type>::value - || std::is_same<__uint128_t, typename std::remove_cv::type>::value> + || std::is_same<::rocprim::int128_t, typename std::remove_cv::type>::value + || std::is_same<::rocprim::uint128_t, typename std::remove_cv::type>::value> {}; /// \brief Extension of `std::is_arithmetic`, which includes support for \ref rocprim::half , \ref rocprim::bfloat16 and 128-bit integers. @@ -64,8 +64,8 @@ struct is_arithmetic std::is_arithmetic::value || std::is_same<::rocprim::half, typename std::remove_cv::type>::value || std::is_same<::rocprim::bfloat16, typename std::remove_cv::type>::value - || std::is_same<__int128_t, typename std::remove_cv::type>::value - || std::is_same<__uint128_t, typename std::remove_cv::type>::value> + || std::is_same<::rocprim::int128_t, typename std::remove_cv::type>::value + || std::is_same<::rocprim::uint128_t, typename std::remove_cv::type>::value> {}; /// \brief Extension of `std::is_fundamental`, which includes support for \ref rocprim::half , \ref rocprim::bfloat16 and 128-bit integers. @@ -76,8 +76,8 @@ struct is_fundamental std::is_fundamental::value || std::is_same<::rocprim::half, typename std::remove_cv::type>::value || std::is_same<::rocprim::bfloat16, typename std::remove_cv::type>::value - || std::is_same<__int128_t, typename std::remove_cv::type>::value - || std::is_same<__uint128_t, typename std::remove_cv::type>::value> + || std::is_same<::rocprim::int128_t, typename std::remove_cv::type>::value + || std::is_same<::rocprim::uint128_t, typename std::remove_cv::type>::value> {}; /// \brief Extension of `std::is_unsigned`, which includes support for 128-bit integers. @@ -86,7 +86,7 @@ struct is_unsigned : std::integral_constant< bool, std::is_unsigned::value - || std::is_same<__uint128_t, typename std::remove_cv::type>::value> + || std::is_same<::rocprim::uint128_t, typename std::remove_cv::type>::value> {}; /// \brief Extension of `std::is_signed`, which includes support for \ref rocprim::half , \ref rocprim::bfloat16 and 128-bit integers. @@ -97,7 +97,7 @@ struct is_signed std::is_signed::value || std::is_same<::rocprim::half, typename std::remove_cv::type>::value || std::is_same<::rocprim::bfloat16, typename std::remove_cv::type>::value - || std::is_same<__int128_t, typename std::remove_cv::type>::value> + || std::is_same<::rocprim::int128_t, typename std::remove_cv::type>::value> {}; /// \brief Extension of `std::is_scalar`, which includes support for \ref rocprim::half , \ref rocprim::bfloat16 and 128-bit integers. @@ -108,8 +108,8 @@ struct is_scalar std::is_scalar::value || std::is_same<::rocprim::half, typename std::remove_cv::type>::value || std::is_same<::rocprim::bfloat16, typename std::remove_cv::type>::value - || std::is_same<__int128_t, typename std::remove_cv::type>::value - || std::is_same<__uint128_t, typename std::remove_cv::type>::value> + || std::is_same<::rocprim::int128_t, typename std::remove_cv::type>::value + || std::is_same<::rocprim::uint128_t, typename std::remove_cv::type>::value> {}; /// \brief Extension of `std::make_unsigned`, which includes support for 128-bit integers. @@ -119,20 +119,20 @@ struct make_unsigned : std::make_unsigned #ifndef DOXYGEN_SHOULD_SKIP_THIS // skip specialized versions template<> -struct make_unsigned<__int128_t> +struct make_unsigned<::rocprim::int128_t> { - using type = __uint128_t; + using type = ::rocprim::uint128_t; }; template<> -struct make_unsigned<__uint128_t> +struct make_unsigned<::rocprim::uint128_t> { - using type = __uint128_t; + using type = ::rocprim::uint128_t; }; #endif -static_assert(std::is_same::type, __uint128_t>::value, - "'__int128_t' needs to implement 'make_unsigned' trait."); +static_assert(std::is_same::type, ::rocprim::uint128_t>::value, + "'rocprim::int128_t' needs to implement 'make_unsigned' trait."); /// \brief Extension of `std::is_compound`, which includes support for \ref rocprim::half , \ref rocprim::bfloat16 and 128-bit integers. template @@ -142,6 +142,58 @@ struct is_compound !is_fundamental::value > {}; +/// \brief Extension of `std::numeric_limits`, which includes support for 128-bit integers. +template +struct numeric_limits : std::numeric_limits +{}; + +#ifndef DOXYGEN_SHOULD_SKIP_THIS // skip specialized versions +template<> +struct numeric_limits : std::numeric_limits +{ + static constexpr int digits = 128; + static constexpr int digits10 = 38; + + static constexpr rocprim::uint128_t max() + { + return rocprim::int128_t{-1}; + } + + static constexpr rocprim::uint128_t min() + { + return rocprim::uint128_t{0}; + } + + static constexpr rocprim::uint128_t lowest() + { + return min(); + } +}; + +template<> +struct numeric_limits : std::numeric_limits +{ + static constexpr int digits = 127; + static constexpr int digits10 = 38; + + static constexpr rocprim::int128_t max() + { + return numeric_limits::max() >> 1; + } + + static constexpr rocprim::int128_t min() + { + return -numeric_limits::max() - 1; + } + + static constexpr rocprim::int128_t lowest() + { + return min(); + } +}; + +#endif // DOXYGEN_SHOULD_SKIP_THIS + /// \brief Used to retrieve a type that can be treated as unsigned version of the template parameter. /// \tparam T - The signed type to find an unsigned equivalent for. /// \tparam size - the desired size (in bytes) of the unsigned type @@ -181,7 +233,7 @@ struct get_unsigned_bits_type template struct get_unsigned_bits_type { - typedef __uint128_t unsigned_type; + typedef ::rocprim::uint128_t unsigned_type; }; #endif // DOXYGEN_SHOULD_SKIP_THIS diff --git a/rocprim/include/rocprim/types.hpp b/rocprim/include/rocprim/types.hpp index 5b4dff662..fa1e2f40e 100644 --- a/rocprim/include/rocprim/types.hpp +++ b/rocprim/include/rocprim/types.hpp @@ -56,20 +56,17 @@ struct make_vector_type /// template parameter should not be used. struct empty_type {}; -/// \brief Binary operator that takes two instances of empty_type, usually used -/// as nop replacement for the HIP-CPU back-end -struct empty_binary_op -{ - /// \brief Invocation operator. - constexpr empty_type operator()(const empty_type&, const empty_type&) const { return empty_type{}; } -}; - /// \brief A decomposer that must be passed to the radix sort algorithms when /// sorting keys that are arithmetic types. /// To sort custom types, a custom decomposer should be provided. struct identity_decomposer {}; +/// \brief 128 bit unsigned integer +using uint128_t = __uint128_t; +/// \brief 128 bit signed integer +using int128_t = __int128_t; + /// \brief Half-precision floating point type using half = ::__half; /// \brief bfloat16 floating point type @@ -87,19 +84,10 @@ using lane_mask_type = unsigned long long int; #endif /// \brief Native half-precision floating point type -#ifdef __HIP_CPU_RT__ -using native_half = half; -#else using native_half = _Float16; -#endif /// \brief native bfloat16 type -#ifdef __HIP_CPU_RT__ -// TODO: Find a better type -using native_bfloat16 = bfloat16; -#else using native_bfloat16 = bfloat16; -#endif END_ROCPRIM_NAMESPACE diff --git a/rocprim/include/rocprim/warp/detail/warp_segment_bounds.hpp b/rocprim/include/rocprim/warp/detail/warp_segment_bounds.hpp index 91e95396a..7032d7f8e 100644 --- a/rocprim/include/rocprim/warp/detail/warp_segment_bounds.hpp +++ b/rocprim/include/rocprim/warp/detail/warp_segment_bounds.hpp @@ -52,23 +52,10 @@ ROCPRIM_DEVICE ROCPRIM_INLINE auto last_in_warp_segment(Flag flag) -> // Make sure last item in logical warp is marked as a tail warp_flags |= lane_mask_type(1) << (WarpSize - 1U); // Calculate logical lane id of the last valid value in the segment -#ifndef __HIP_CPU_RT__ - #if ROCPRIM_WAVEFRONT_SIZE == 32 +#if ROCPRIM_WAVEFRONT_SIZE == 32 return ::__ffs(warp_flags) - 1; - #else - return ::__ffsll(warp_flags) - 1; - #endif -#else -#if _MSC_VER - // TODO: verify correctness - unsigned long tmp = 0; - _BitScanReverse64(&tmp, warp_flags); - return 1u << tmp; -#elif __GNUC__ - return __builtin_ctzl(warp_flags); #else - static_assert(false, "Look for GCC/Clang implementation"); -#endif + return ::__ffsll(warp_flags) - 1; #endif } diff --git a/scripts/autotune/create_optimization.py b/scripts/autotune/create_optimization.py index 933e8f08d..093f7b79b 100755 --- a/scripts/autotune/create_optimization.py +++ b/scripts/autotune/create_optimization.py @@ -295,10 +295,10 @@ def skip_entry(config_selection_type: SelectionType, fallback_entry: FallbackCas empty_fallback = FallbackCase(None, EMPTY_TYPENAME, 0, 0, False) # If a type is optional, also generate the fallbacks where the type is empty. - fallback_entries_0: List[FallbackCase] = self.fallback_entries + fallback_entries_0: List[FallbackCase] = self.fallback_entries.copy() if config_selection_types[0].is_optional: fallback_entries_0.append(empty_fallback) - fallback_entries_1: List[FallbackCase] = self.fallback_entries + fallback_entries_1: List[FallbackCase] = self.fallback_entries.copy() if config_selection_types[1].is_optional: fallback_entries_1.append(empty_fallback) @@ -523,6 +523,14 @@ class AlgorithmDeviceAdjacentDifferenceInplace(Algorithm): def __init__(self, fallback_entries): Algorithm.__init__(self, fallback_entries) +class AlgorithmDeviceAdjacentFind(Algorithm): + algorithm_name = "device_adjacent_find" + cpp_configuration_template_name = "adjacent_find_config_template" + config_selection_params = [ + SelectionType(name="input_type", is_optional=False, select_on_size_only=False)] + def __init__(self, fallback_entries): + Algorithm.__init__(self, fallback_entries) + class AlgorithmDeviceSegmentedRadixSort(Algorithm): algorithm_name = "device_segmented_radix_sort" cpp_configuration_template_name = "segmented_radix_sort_config_template" @@ -596,6 +604,15 @@ class AlgorithmDeviceSelectPredicate(Algorithm): def __init__(self, fallback_entries): Algorithm.__init__(self, fallback_entries) +class AlgorithmDeviceSelectPredicatedFlag(Algorithm): + algorithm_name = "device_select_predicated_flag" + cpp_configuration_template_name = "select_predicated_flag_config_template" + config_selection_params = [ + SelectionType(name="data_type", is_optional=False, select_on_size_only=False), + SelectionType(name="flag_type", is_optional=False, select_on_size_only=True)] + def __init__(self, fallback_entries): + Algorithm.__init__(self, fallback_entries) + class AlgorithmDeviceSelectUnique(Algorithm): algorithm_name = "device_select_unique" cpp_configuration_template_name = "select_unique_config_template" @@ -630,6 +647,15 @@ class AlgorithmDeviceFindFirstOf(Algorithm): def __init__(self, fallback_entries): Algorithm.__init__(self, fallback_entries) +class AlgorithmDeviceMerge(Algorithm): + algorithm_name = "device_merge" + cpp_configuration_template_name = "merge_config_template" + config_selection_params = [ + SelectionType(name="key_type", is_optional=False, select_on_size_only=False), + SelectionType(name="value_type", is_optional=True, select_on_size_only=True)] + def __init__(self, fallback_entries): + Algorithm.__init__(self, fallback_entries) + def filt_algo_regex(e: FallbackCase, algorithm_name): if e.algo_regex: return re.match(e.algo_regex, algorithm_name) is not None @@ -663,6 +689,8 @@ def create_algorithm(algorithm_name: str, fallback_entries: List[FallbackCase]): return AlgorithmDeviceAdjacentDifference(fallback_entries) elif algorithm_name == 'device_adjacent_difference_inplace': return AlgorithmDeviceAdjacentDifferenceInplace(fallback_entries) + elif algorithm_name == 'device_adjacent_find': + return AlgorithmDeviceAdjacentFind(fallback_entries) elif algorithm_name == 'device_segmented_radix_sort': return AlgorithmDeviceSegmentedRadixSort(fallback_entries) elif algorithm_name == 'device_transform': @@ -681,6 +709,8 @@ def create_algorithm(algorithm_name: str, fallback_entries: List[FallbackCase]): return AlgorithmDeviceSelectFlag(fallback_entries) elif algorithm_name == 'device_select_predicate': return AlgorithmDeviceSelectPredicate(fallback_entries) + elif algorithm_name == 'device_select_predicated_flag': + return AlgorithmDeviceSelectPredicatedFlag(fallback_entries) elif algorithm_name == 'device_select_unique': return AlgorithmDeviceSelectUnique(fallback_entries) elif algorithm_name == 'device_select_unique_by_key': @@ -689,6 +719,8 @@ def create_algorithm(algorithm_name: str, fallback_entries: List[FallbackCase]): return AlgorithmDeviceReduceByKey(fallback_entries) elif algorithm_name == 'device_find_first_of': return AlgorithmDeviceFindFirstOf(fallback_entries) + elif algorithm_name == 'device_merge': + return AlgorithmDeviceMerge(fallback_entries) else: raise(NotSupportedError(f'Algorithm "{algorithm_name}" is not supported (yet)')) @@ -806,6 +838,7 @@ def main(): parser.add_argument("-p", "--out_basedir", type=str, help="Base dir for the output files, for each algorithm a new file will be created in this directory", required=True) parser.add_argument("-c", "--fallback_configuration", type=argparse.FileType('r'), default=os.path.join(current_dir, "fallback_config.json"), help="Configuration for fallbacks for not tested datatypes") args = parser.parse_args() + #import pdb; pdb.set_trace() benchmark_manager = BenchmarkDataManager(args.fallback_configuration) diff --git a/scripts/autotune/templates/adjacent_find_config_template b/scripts/autotune/templates/adjacent_find_config_template new file mode 100644 index 000000000..8ef25c83e --- /dev/null +++ b/scripts/autotune/templates/adjacent_find_config_template @@ -0,0 +1,20 @@ +{% extends "config_template" %} + +{% macro get_header_guard() %} +ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_ADJACENT_FIND_HPP_ +{%- endmacro %} + +{% macro kernel_configuration(measurement) -%} +adjacent_find_config<{{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }}> { }; +{%- endmacro %} + +{% macro general_case() -%} +template +struct default_adjacent_find_config : default_adjacent_find_config_base::type +{}; +{%- endmacro %} + +{% macro configuration_fallback(benchmark_of_architecture, based_on_type, fallback_selection_criteria) -%} +// Based on {{ based_on_type }} +template struct default_adjacent_find_config({{ benchmark_of_architecture.name }}), input_type, {{ fallback_selection_criteria }}> : +{%- endmacro %} diff --git a/scripts/autotune/templates/merge_config_template b/scripts/autotune/templates/merge_config_template new file mode 100644 index 000000000..49ce9b233 --- /dev/null +++ b/scripts/autotune/templates/merge_config_template @@ -0,0 +1,25 @@ +{% extends "config_template" %} + +{% macro get_header_guard() %} +ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_MERGE_HPP_ +{%- endmacro %} + +{% macro kernel_configuration(measurement) -%} +merge_config<{{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }}> { }; +{%- endmacro %} + +{% macro general_case() -%} +template +struct default_merge_config : default_merge_base::type +{}; +{%- endmacro %} + +{% macro configuration_fallback(benchmark_of_architecture, based_on_type, fallback_selection_criteria) -%} +// Based on {{ based_on_type }} +template +struct default_merge_config< + static_cast({{ benchmark_of_architecture.name }}), + key_type, + value_type, + {{ fallback_selection_criteria }}> : +{%- endmacro %} \ No newline at end of file diff --git a/scripts/autotune/templates/partition_flag_config_template b/scripts/autotune/templates/partition_flag_config_template index 3890c5cdc..4495e5de7 100644 --- a/scripts/autotune/templates/partition_flag_config_template +++ b/scripts/autotune/templates/partition_flag_config_template @@ -10,7 +10,7 @@ select_config<{{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }}> { {% macro general_case() -%} template -struct default_partition_flag_config : default_partition_config_base::type +struct default_partition_flag_config : default_partition_config_base::type {}; {%- endmacro %} diff --git a/scripts/autotune/templates/partition_predicate_config_template b/scripts/autotune/templates/partition_predicate_config_template index 30ff6abf4..021a343ae 100644 --- a/scripts/autotune/templates/partition_predicate_config_template +++ b/scripts/autotune/templates/partition_predicate_config_template @@ -10,7 +10,7 @@ select_config<{{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }}> { {% macro general_case() -%} template -struct default_partition_predicate_config : default_partition_config_base::type +struct default_partition_predicate_config : default_partition_config_base::type {}; {%- endmacro %} diff --git a/scripts/autotune/templates/partition_three_way_config_template b/scripts/autotune/templates/partition_three_way_config_template index 541e7d046..01f06bc8b 100644 --- a/scripts/autotune/templates/partition_three_way_config_template +++ b/scripts/autotune/templates/partition_three_way_config_template @@ -10,7 +10,7 @@ select_config<{{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }}> { {% macro general_case() -%} template -struct default_partition_three_way_config : default_partition_config_base::type +struct default_partition_three_way_config : default_partition_config_base::type {}; {%- endmacro %} diff --git a/scripts/autotune/templates/partition_two_way_flag_config_template b/scripts/autotune/templates/partition_two_way_flag_config_template index c01ef312c..16826143e 100644 --- a/scripts/autotune/templates/partition_two_way_flag_config_template +++ b/scripts/autotune/templates/partition_two_way_flag_config_template @@ -10,7 +10,7 @@ select_config<{{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }}> { {% macro general_case() -%} template -struct default_partition_two_way_flag_config : default_partition_config_base::type +struct default_partition_two_way_flag_config : default_partition_config_base::type {}; {%- endmacro %} diff --git a/scripts/autotune/templates/partition_two_way_predicate_config_template b/scripts/autotune/templates/partition_two_way_predicate_config_template index 2db3134f0..ac11f1459 100644 --- a/scripts/autotune/templates/partition_two_way_predicate_config_template +++ b/scripts/autotune/templates/partition_two_way_predicate_config_template @@ -10,7 +10,7 @@ select_config<{{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }}> { {% macro general_case() -%} template -struct default_partition_two_way_predicate_config : default_partition_config_base::type +struct default_partition_two_way_predicate_config : default_partition_config_base::type {}; {%- endmacro %} diff --git a/scripts/autotune/templates/select_flag_config_template b/scripts/autotune/templates/select_flag_config_template index 5f095dc72..36a7055e4 100644 --- a/scripts/autotune/templates/select_flag_config_template +++ b/scripts/autotune/templates/select_flag_config_template @@ -10,7 +10,7 @@ select_config<{{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }}> { {% macro general_case() -%} template -struct default_select_flag_config : default_partition_config_base::type +struct default_select_flag_config : default_partition_config_base::type {}; {%- endmacro %} diff --git a/scripts/autotune/templates/select_predicate_config_template b/scripts/autotune/templates/select_predicate_config_template index c3a9102bb..137ba5ee3 100644 --- a/scripts/autotune/templates/select_predicate_config_template +++ b/scripts/autotune/templates/select_predicate_config_template @@ -10,7 +10,7 @@ select_config<{{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }}> { {% macro general_case() -%} template -struct default_select_predicate_config : default_partition_config_base::type +struct default_select_predicate_config : default_partition_config_base::type {}; {%- endmacro %} diff --git a/scripts/autotune/templates/select_predicated_flag_config_template b/scripts/autotune/templates/select_predicated_flag_config_template new file mode 100644 index 000000000..2ba3565ae --- /dev/null +++ b/scripts/autotune/templates/select_predicated_flag_config_template @@ -0,0 +1,20 @@ +{% extends "config_template" %} + +{% macro get_header_guard() %} +ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_SELECT_PREDICATED_FLAG_HPP_ +{%- endmacro %} + +{% macro kernel_configuration(measurement) -%} +select_config<{{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }}> { }; +{%- endmacro %} + +{% macro general_case() -%} +template +struct default_select_predicated_flag_config : default_partition_config_base::type +{}; +{%- endmacro %} + +{% macro configuration_fallback(benchmark_of_architecture, based_on_type, fallback_selection_criteria) -%} +// Based on {{ based_on_type }} +template struct default_select_predicated_flag_config({{ benchmark_of_architecture.name }}), data_type, flag_type, {{ fallback_selection_criteria }}> : +{%- endmacro %} diff --git a/scripts/autotune/templates/select_unique_by_key_config_template b/scripts/autotune/templates/select_unique_by_key_config_template index f7c676ad8..2177f95af 100644 --- a/scripts/autotune/templates/select_unique_by_key_config_template +++ b/scripts/autotune/templates/select_unique_by_key_config_template @@ -10,7 +10,7 @@ select_config<{{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }}> { {% macro general_case() -%} template -struct default_select_unique_by_key_config : default_partition_config_base::type +struct default_select_unique_by_key_config : default_partition_config_base::type {}; {%- endmacro %} diff --git a/scripts/autotune/templates/select_unique_config_template b/scripts/autotune/templates/select_unique_config_template index 3214bedb3..9bb8aec9f 100644 --- a/scripts/autotune/templates/select_unique_config_template +++ b/scripts/autotune/templates/select_unique_config_template @@ -10,7 +10,7 @@ select_config<{{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }}> { {% macro general_case() -%} template -struct default_select_unique_config : default_partition_config_base::type +struct default_select_unique_config : default_partition_config_base::type {}; {%- endmacro %} diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 5205cffa3..099a80678 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -68,25 +68,10 @@ function(add_hip_test TEST_NAME TEST_SOURCES) GTest::GTest GTest::Main ) - if(NOT USE_HIP_CPU) - target_link_libraries(${TEST_TARGET} - PRIVATE - rocprim_hip - ) - else() - target_link_libraries(${TEST_TARGET} - PRIVATE - rocprim - Threads::Threads - hip_cpu_rt::hip_cpu_rt - ) - if(STL_DEPENDS_ON_TBB) - target_link_libraries(${TEST_TARGET} - PRIVATE - TBB::tbb - ) - endif() - endif() + target_link_libraries(${TEST_TARGET} + PRIVATE + rocprim_hip + ) target_compile_options(${TEST_TARGET} PRIVATE @@ -125,6 +110,7 @@ endfunction() add_hip_test("hip.device_api" hip/test_hip_api.cpp) add_hip_test("hip.async_copy" hip/test_hip_async_copy.cpp) add_hip_test("hip.ordered_block_id" hip/test_ordered_block_id.cpp) + # rocPRIM test add_subdirectory(rocprim) diff --git a/test/common_test_header.hpp b/test/common_test_header.hpp index 1e239b08c..b6e2a1ac1 100755 --- a/test/common_test_header.hpp +++ b/test/common_test_header.hpp @@ -42,9 +42,6 @@ // HIP API #include #include -#ifndef __HIP_CPU_RT__ -#include -#endif // GoogleTest-compatible HIP_CHECK macro. FAIL is called to log the Google Test trace. // The lambda is invoked immediately as assertions that generate a fatal failure can @@ -62,6 +59,22 @@ } #endif +#define HIP_CHECK_MEMORY(condition) \ + { \ + hipError_t error = condition; \ + if(error == hipErrorOutOfMemory) \ + { \ + std::cout << "Out of memory. Skipping size = " << size << std::endl; \ + break; \ + } \ + if(error != hipSuccess) \ + { \ + std::cout << "HIP error: " << hipGetErrorString(error) << " line: " << __LINE__ \ + << std::endl; \ + exit(error); \ + } \ + } + #if(defined(__GNUC__) || defined(__clang__)) && (defined(__GLIBCXX__) || defined(_LIBCPP_VERSION)) #define ROCPRIM_HAS_INT128_SUPPORT 1 #else diff --git a/test/extra/CMakeLists.txt b/test/extra/CMakeLists.txt index d5e9e2576..23a9a6002 100644 --- a/test/extra/CMakeLists.txt +++ b/test/extra/CMakeLists.txt @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -46,11 +46,19 @@ include(VerifyCompiler) find_package(rocprim REQUIRED CONFIG PATHS "/opt/rocm/rocprim") # Build CXX flags -set(CMAKE_CXX_STANDARD 14) +if (NOT DEFINED CMAKE_CXX_STANDARD) + set(CMAKE_CXX_STANDARD 17) +endif() set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_EXTENSIONS OFF) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra -Werror") +if (CMAKE_CXX_STANDARD EQUAL 14) + message(WARNING "C++14 will be deprecated in the next major release") +elseif(NOT CMAKE_CXX_STANDARD EQUAL 17) + message(FATAL_ERROR "Only C++14 and C++17 are supported") +endif() + # Enable testing (ctest) enable_testing() diff --git a/test/hip/test_hip_api.cpp b/test/hip/test_hip_api.cpp index 844bb40f1..5002087d4 100644 --- a/test/hip/test_hip_api.cpp +++ b/test/hip/test_hip_api.cpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -20,17 +20,19 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. +#include "../rocprim/test_utils_device_ptr.hpp" #include "common_test_header.hpp" template -__device__ T ax(const T a, const T x) +__device__ +T ax(const T a, const T x) { return x * a; } -template +template __global__ -void saxpy_kernel(const T * x, T * y, const T a, const size_t size) +void saxpy_kernel(const T* x, T* y, const T a, const size_t size) { const unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; if(i < size) @@ -47,47 +49,25 @@ TEST(HIPTests, Saxpy) const size_t N = 100; - const float a = 100.0f; + const float a = 100.0f; std::vector x(N, 2.0f); std::vector y(N, 1.0f); - float * d_x; - float * d_y; - HIP_CHECK(test_common_utils::hipMallocHelper(&d_x, N * sizeof(float))); - HIP_CHECK(test_common_utils::hipMallocHelper(&d_y, N * sizeof(float))); - HIP_CHECK( - hipMemcpy( - d_x, x.data(), - N * sizeof(float), - hipMemcpyHostToDevice - ) - ); - HIP_CHECK( - hipMemcpy( - d_y, y.data(), - N * sizeof(float), - hipMemcpyHostToDevice - ) - ); - HIP_CHECK(hipDeviceSynchronize()); + test_utils::device_ptr d_x(x); + test_utils::device_ptr d_y(y); - hipLaunchKernelGGL( - HIP_KERNEL_NAME(saxpy_kernel), - dim3((N + 255)/256), dim3(256), 0, 0, - d_x, d_y, a, N - ); + hipLaunchKernelGGL(HIP_KERNEL_NAME(saxpy_kernel), + dim3((N + 255) / 256), + dim3(256), + 0, + 0, + d_x.get(), + d_y.get(), + a, + N); HIP_CHECK(hipGetLastError()); - HIP_CHECK( - hipMemcpy( - y.data(), d_y, - N * sizeof(float), - hipMemcpyDeviceToHost - ) - ); - HIP_CHECK(hipDeviceSynchronize()); - HIP_CHECK(hipFree(d_x)); - HIP_CHECK(hipFree(d_y)); + y = d_y.load(); for(size_t i = 0; i < N; i++) { diff --git a/test/hip/test_ordered_block_id.cpp b/test/hip/test_ordered_block_id.cpp index f4d9e22ef..11f688f04 100644 --- a/test/hip/test_ordered_block_id.cpp +++ b/test/hip/test_ordered_block_id.cpp @@ -21,8 +21,8 @@ #include #include +#include "../rocprim/test_utils_device_ptr.hpp" #include "common_test_header.hpp" - // required rocprim headers #include @@ -50,18 +50,14 @@ void test_kernel(unsigned int* flags) __host__ bool test_func(int block_count, int thread_count) { - unsigned int* d_flags; - std::vector h_vec(block_count); - HIP_CHECK(hipMalloc(&d_flags, block_count * sizeof(unsigned int))); - test_kernel<<>>(d_flags); + test_utils::device_ptr d_flags(block_count); + + test_kernel<<>>(d_flags.get()); HIP_CHECK(hipGetLastError()); HIP_CHECK(hipDeviceSynchronize()); - HIP_CHECK(hipMemcpy(h_vec.data(), - d_flags, - block_count * sizeof(unsigned int), - hipMemcpyDeviceToHost)); - HIP_CHECK(hipFree(d_flags)); + + auto h_vec = d_flags.load(); for(const auto i : h_vec) { if(i != 1) diff --git a/test/rocprim/CMakeLists.txt b/test/rocprim/CMakeLists.txt index 6fe81fb06..529b65d33 100644 --- a/test/rocprim/CMakeLists.txt +++ b/test/rocprim/CMakeLists.txt @@ -50,25 +50,10 @@ function(add_rocprim_test_internal TEST_NAME TEST_SOURCES TEST_TARGET) GTest::GTest GTest::Main ) - if(NOT USE_HIP_CPU) - target_link_libraries(${TEST_TARGET} - PRIVATE - rocprim_hip - ) - else() - target_link_libraries(${TEST_TARGET} - PRIVATE - rocprim - Threads::Threads - hip_cpu_rt::hip_cpu_rt - ) - if(STL_DEPENDS_ON_TBB) - target_link_libraries(${TEST_TARGET} - PRIVATE - TBB::tbb - ) - endif() - endif() + target_link_libraries(${TEST_TARGET} + PRIVATE + rocprim_hip + ) target_compile_options(${TEST_TARGET} PRIVATE @@ -263,6 +248,8 @@ add_rocprim_test("rocprim.device_batch_memcpy" test_device_batch_memcpy.cpp) add_rocprim_test("rocprim.device_binary_search" test_device_binary_search.cpp) add_rocprim_test("rocprim.device_find_first_of" test_device_find_first_of.cpp) add_rocprim_test("rocprim.device_adjacent_difference" test_device_adjacent_difference.cpp) +add_rocprim_test("rocprim.device_adjacent_find" test_device_adjacent_find.cpp) +add_rocprim_test("rocprim.device_find_end" test_device_find_end.cpp) add_rocprim_test("rocprim.device_histogram" test_device_histogram.cpp) add_rocprim_test("rocprim.device_merge" test_device_merge.cpp) add_rocprim_test("rocprim.device_merge_sort" test_device_merge_sort.cpp) @@ -274,7 +261,9 @@ add_rocprim_test("rocprim.device_reduce_by_key" test_device_reduce_by_key.cpp) add_rocprim_test("rocprim.device_reduce" test_device_reduce.cpp) add_rocprim_test("rocprim.device_run_length_encode" test_device_run_length_encode.cpp) add_rocprim_test("rocprim.device_scan" test_device_scan.cpp) +add_rocprim_test("rocprim.device_search" test_device_search.cpp) add_rocprim_test_parallel("rocprim.device_segmented_radix_sort" test_device_segmented_radix_sort.cpp.in) +add_rocprim_test("rocprim.device_search_n" test_device_search_n.cpp) add_rocprim_test("rocprim.device_segmented_reduce" test_device_segmented_reduce.cpp) add_rocprim_test("rocprim.device_segmented_scan" test_device_segmented_scan.cpp) add_rocprim_test("rocprim.device_select" test_device_select.cpp) @@ -284,9 +273,7 @@ add_rocprim_test("rocprim.lookback_reproducibility" test_lookback_reproducibilit add_rocprim_test("rocprim.radix_key_codec" test_radix_key_codec.cpp) add_rocprim_test("rocprim.predicate_iterator" test_predicate_iterator.cpp) add_rocprim_test("rocprim.reverse_iterator" test_reverse_iterator.cpp) -if(NOT USE_HIP_CPU) add_rocprim_test("rocprim.texture_cache_iterator" test_texture_cache_iterator.cpp) -endif() add_rocprim_test("rocprim.thread" test_thread.cpp) add_rocprim_test("rocprim.thread_algos" test_thread_algos.cpp) add_rocprim_test("rocprim.transform_iterator" test_transform_iterator.cpp) diff --git a/test/rocprim/indirect_iterator.hpp b/test/rocprim/indirect_iterator.hpp index 17688dae4..deaab5081 100644 --- a/test/rocprim/indirect_iterator.hpp +++ b/test/rocprim/indirect_iterator.hpp @@ -112,7 +112,7 @@ class indirect_iterator ROCPRIM_HOST_DEVICE inline reference operator*() const { - return *ptr_; + return reference{*ptr_}; } ROCPRIM_HOST_DEVICE inline reference operator[](difference_type n) const diff --git a/test/rocprim/test_arg_index_iterator.cpp b/test/rocprim/test_arg_index_iterator.cpp index e2feb5d1d..32022d0d2 100644 --- a/test/rocprim/test_arg_index_iterator.cpp +++ b/test/rocprim/test_arg_index_iterator.cpp @@ -145,7 +145,8 @@ TYPED_TEST(RocprimArgIndexIteratorTests, ReduceArgMinimum) Iterator d_iter(d_input); arg_min reduce_op; - const key_value max(std::numeric_limits::max(), std::numeric_limits::max()); + const key_value max(test_utils::numeric_limits::max(), + test_utils::numeric_limits::max()); // Calculate expected results on host Iterator x(input.data()); @@ -195,8 +196,8 @@ TYPED_TEST(RocprimArgIndexIteratorTests, ReduceArgMinimum) test_utils::assert_eq(output[0].key, expected.key); test_utils::assert_eq(output[0].value, expected.value); - hipFree(d_input); - hipFree(d_output); - hipFree(d_temp_storage); + HIP_CHECK(hipFree(d_input)); + HIP_CHECK(hipFree(d_output)); + HIP_CHECK(hipFree(d_temp_storage)); } } diff --git a/test/rocprim/test_block_radix_rank.hpp b/test/rocprim/test_block_radix_rank.hpp index 301c96c15..c05059d1e 100644 --- a/test/rocprim/test_block_radix_rank.hpp +++ b/test/rocprim/test_block_radix_rank.hpp @@ -139,18 +139,11 @@ void test_block_radix_rank() SCOPED_TRACE(testing::Message() << "with seed = " << seed_value); // Generate data - std::vector keys_input; - if(rocprim::is_floating_point::value) - { - keys_input = test_utils::get_random_data(size, T(-1000), T(+1000), seed_value); - } - else - { - keys_input = test_utils::get_random_data(size, - std::numeric_limits::min(), - std::numeric_limits::max(), - seed_value); - } + std::vector keys_input + = test_utils::get_random_data(size, + test_utils::generate_limits::min(), + test_utils::generate_limits::max(), + seed_value); // Calculated expected results on host std::vector expected(size); diff --git a/test/rocprim/test_block_radix_sort.kernels.hpp b/test/rocprim/test_block_radix_sort.kernels.hpp index b0b7919ab..83dc8a0c2 100644 --- a/test/rocprim/test_block_radix_sort.kernels.hpp +++ b/test/rocprim/test_block_radix_sort.kernels.hpp @@ -198,18 +198,11 @@ auto test_block_radix_sort() -> typename std::enable_if::type // Generate data auto keys_output = std::make_unique(size); - if(rocprim::is_floating_point::value) - { - test_utils::generate_random_data_n(keys_output.get(), size, -100, +100, rng_engine); - } - else - { - test_utils::generate_random_data_n(keys_output.get(), - size, - test_utils::numeric_limits::min(), - test_utils::numeric_limits::max(), - rng_engine); - } + test_utils::generate_random_data_n(keys_output.get(), + size, + test_utils::generate_limits::min(), + test_utils::generate_limits::max(), + rng_engine); // Calculate expected results on host std::vector expected(keys_output.get(), keys_output.get() + size); @@ -304,18 +297,11 @@ auto test_block_radix_sort() -> typename std::enable_if::type // Generate data auto keys_output = std::make_unique(size); - if(rocprim::is_floating_point::value) - { - test_utils::generate_random_data_n(keys_output.get(), size, -100, +100, rng_engine); - } - else - { - test_utils::generate_random_data_n(keys_output.get(), - size, - test_utils::numeric_limits::min(), - test_utils::numeric_limits::max(), - rng_engine); - } + test_utils::generate_random_data_n(keys_output.get(), + size, + test_utils::generate_limits::min(), + test_utils::generate_limits::max(), + rng_engine); std::vector values_output(size); std::iota(values_output.begin(), values_output.end(), 0u); diff --git a/test/rocprim/test_block_run_length_decode.cpp b/test/rocprim/test_block_run_length_decode.cpp index 4afd78166..8afd944bf 100644 --- a/test/rocprim/test_block_run_length_decode.cpp +++ b/test/rocprim/test_block_run_length_decode.cpp @@ -167,7 +167,7 @@ TYPED_TEST(HipcubBlockRunLengthDecodeTest, TestDecode) size_t num_runs = runs_per_thread * block_size; constexpr LengthT max_run_length = static_cast( - std::min(1000ll, static_cast(std::numeric_limits::max()))); + std::min(1000ll, static_cast(test_utils::numeric_limits::max()))); auto run_items = std::vector(num_runs); run_items[0] = test_utils::get_random_value(test_utils::numeric_limits::min(), @@ -200,8 +200,8 @@ TYPED_TEST(HipcubBlockRunLengthDecodeTest, TestDecode) const auto empty_run_items = test_utils::get_random_data(num_trailing_empty_runs, - std::numeric_limits::min(), - std::numeric_limits::max(), + test_utils::numeric_limits::min(), + test_utils::numeric_limits::max(), seed_value); run_items.insert(run_items.end(), empty_run_items.begin(), empty_run_items.end()); run_lengths.insert(run_lengths.end(), num_trailing_empty_runs, static_cast(0)); diff --git a/test/rocprim/test_constant_iterator.cpp b/test/rocprim/test_constant_iterator.cpp index 1062c8e45..dfdf132e1 100644 --- a/test/rocprim/test_constant_iterator.cpp +++ b/test/rocprim/test_constant_iterator.cpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -117,7 +117,7 @@ TYPED_TEST(RocprimConstantIteratorTests, Transform) // Validating results test_utils::assert_near(output, expected, test_utils::precision); - hipFree(d_output); + HIP_CHECK(hipFree(d_output)); } } diff --git a/test/rocprim/test_counting_iterator.cpp b/test/rocprim/test_counting_iterator.cpp index b35e2be20..7273eed72 100644 --- a/test/rocprim/test_counting_iterator.cpp +++ b/test/rocprim/test_counting_iterator.cpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -126,7 +126,7 @@ TYPED_TEST(RocprimCountingIteratorTests, Transform) ASSERT_EQ(output[i], expected[i]) << "where index = " << i; } - hipFree(d_output); + HIP_CHECK(hipFree(d_output)); } } diff --git a/test/rocprim/test_device_adjacent_difference.cpp b/test/rocprim/test_device_adjacent_difference.cpp index 00b82256f..f0df7d644 100644 --- a/test/rocprim/test_device_adjacent_difference.cpp +++ b/test/rocprim/test_device_adjacent_difference.cpp @@ -439,7 +439,7 @@ TYPED_TEST(RocprimDeviceAdjacentDifferenceTests, AdjacentDifference) ASSERT_NO_FATAL_FAILURE(run_and_verify(output_it, d_output)); - hipFree(d_output); + HIP_CHECK(hipFree(d_output)); } // if api_variant is not no_alias we should check the inplace function call @@ -453,8 +453,8 @@ TYPED_TEST(RocprimDeviceAdjacentDifferenceTests, AdjacentDifference) HIP_CHECK(hipStreamDestroy(stream)); } - hipFree(d_temp_storage); - hipFree(d_input); + HIP_CHECK(hipFree(d_temp_storage)); + HIP_CHECK(hipFree(d_input)); } } } @@ -714,9 +714,9 @@ TYPED_TEST(RocprimDeviceAdjacentDifferenceLargeTests, LargeIndices) ASSERT_EQ(incorrect_flag, 0); ASSERT_EQ(counter, rocprim::detail::ceiling_div(size, sampling_rate)); - hipFree(d_temp_storage); - hipFree(d_incorrect_flag); - hipFree(d_counter); + HIP_CHECK(hipFree(d_temp_storage)); + HIP_CHECK(hipFree(d_incorrect_flag)); + HIP_CHECK(hipFree(d_counter)); if(TestFixture::use_graphs) { diff --git a/test/rocprim/test_device_adjacent_find.cpp b/test/rocprim/test_device_adjacent_find.cpp new file mode 100644 index 000000000..50ee889a9 --- /dev/null +++ b/test/rocprim/test_device_adjacent_find.cpp @@ -0,0 +1,237 @@ +// MIT License +// +// Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#include "../common_test_header.hpp" + +#include "indirect_iterator.hpp" +#include "test_utils_custom_test_types.hpp" +#include "test_utils_types.hpp" + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +// Params for tests +template, + class Config = rocprim::default_config, + bool UseGraphs = false, + bool UseIndirectIterator = false> +struct DeviceAdjacentFindParams +{ + using input_type = InputType; + using output_type = OutputType; + using op_type = OpType; + using config = Config; + static constexpr bool use_graphs = UseGraphs; + static constexpr bool use_indirect_iterator = UseIndirectIterator; +}; + +template +class RocprimDeviceAdjacentFindTests : public ::testing::Test +{ +public: + using input_type = typename Params::input_type; + using output_type = typename Params::output_type; + using op_type = typename Params::op_type; + using config = typename Params::config; + static constexpr bool use_graphs = Params::use_graphs; + static constexpr bool use_indirect_iterator = Params::use_indirect_iterator; + static constexpr bool debug_synchronous = false; +}; + +// Custom types +using custom_int2 = test_utils::custom_test_type; +using custom_double2 = test_utils::custom_test_type; +using custom_int64_array = test_utils::custom_test_array_type; + +// Custom configs +using custom_config_0 = rocprim::adjacent_find_config<128, 4>; + +using RocprimDeviceAdjacentFindTestsParams = ::testing::Types< + // Tests with default configuration + DeviceAdjacentFindParams, + DeviceAdjacentFindParams, + DeviceAdjacentFindParams, + DeviceAdjacentFindParams, + DeviceAdjacentFindParams, + DeviceAdjacentFindParams, + // Tests for custom types + DeviceAdjacentFindParams, + DeviceAdjacentFindParams, + DeviceAdjacentFindParams, + // Tests for supported config structs + DeviceAdjacentFindParams, + custom_config_0>, + // Tests for hipGraph support + DeviceAdjacentFindParams, + rocprim::default_config, + true>, + // Tests for when output's value_type is void + DeviceAdjacentFindParams, + rocprim::default_config, + false, + true>>; + +TYPED_TEST_SUITE(RocprimDeviceAdjacentFindTests, RocprimDeviceAdjacentFindTestsParams); + +TYPED_TEST(RocprimDeviceAdjacentFindTests, AdjacentFind) +{ + int device_id = test_common_utils::obtain_device_from_ctest(); + SCOPED_TRACE(testing::Message() << "with device_id = " << device_id); + HIP_CHECK(hipSetDevice(device_id)); + + using T = typename TestFixture::input_type; + using output_type = typename TestFixture::output_type; + using op_type = typename TestFixture::op_type; + static constexpr bool use_indirect_iterator = TestFixture::use_indirect_iterator; + const bool debug_synchronous = TestFixture::debug_synchronous; + using Config = typename TestFixture::config; + + op_type op{}; + + for(std::size_t seed_index = 0; seed_index < random_seeds_count + seed_size; seed_index++) + { + const unsigned int seed_value + = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count]; + SCOPED_TRACE(testing::Message() << "with seed = " << seed_value); + + for(auto size : test_utils::get_sizes(seed_value)) + { + hipStream_t stream = 0; // default + if(TestFixture::use_graphs) + { + // Default stream does not support hipGraph stream capture, so create one + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + } + + SCOPED_TRACE(testing::Message() << "with size = " << size); + + // Get random index for first adjacent pair + std::size_t first_adj_index = 0; + if(size > 1) + { + first_adj_index + = std::min(test_utils::get_random_value( + 0, + static_cast(test_utils::numeric_limits::max()), + seed_value), + size - 2); + } + SCOPED_TRACE(testing::Message() << "with first_adj_index = " << first_adj_index); + + // Generate input values + std::vector input(size); + std::iota(input.begin(), input.begin() + first_adj_index, 0); + std::fill(input.begin(), input.end(), first_adj_index); + + T* d_input; + output_type* d_output; + HIP_CHECK( + test_common_utils::hipMallocHelper(&d_input, input.size() * sizeof(*d_input))); + HIP_CHECK(test_common_utils::hipMallocHelper(&d_output, sizeof(*d_output))); + HIP_CHECK(hipMemcpy(d_input, + input.data(), + input.size() * sizeof(*d_input), + hipMemcpyHostToDevice)); + + const auto output_it + = test_utils::wrap_in_identity_iterator(d_output); + + // Allocate temporary storage + std::size_t tmp_storage_size; + void* d_tmp_storage = nullptr; + HIP_CHECK(::rocprim::adjacent_find(d_tmp_storage, + tmp_storage_size, + d_input, + output_it, + size, + op, + stream, + debug_synchronous)); + ASSERT_GT(tmp_storage_size, 0); + HIP_CHECK(test_common_utils::hipMallocHelper(&d_tmp_storage, tmp_storage_size)); + + test_utils::GraphHelper gHelper; + if(TestFixture::use_graphs) + { + gHelper.startStreamCapture(stream); + } + + // Run + HIP_CHECK(::rocprim::adjacent_find(d_tmp_storage, + tmp_storage_size, + d_input, + output_it, + size, + op, + stream, + debug_synchronous)); + + if(TestFixture::use_graphs) + { + gHelper.createAndLaunchGraph(stream); + } + + HIP_CHECK(hipGetLastError()); + HIP_CHECK(hipDeviceSynchronize()); + + // Allocate memory for output and copy to host side + output_type output; + HIP_CHECK(hipMemcpy(&output, d_output, sizeof(*d_output), hipMemcpyDeviceToHost)); + + // Calculate expected results on host + const auto expected + = (size > 1) ? std::adjacent_find(input.cbegin(), input.cend(), op) - input.begin() + : size; + + // Check if output values are as expected + ASSERT_EQ(output, expected); + + // Cleanup + HIP_CHECK(hipFree(d_input)); + HIP_CHECK(hipFree(d_output)); + HIP_CHECK(hipFree(d_tmp_storage)); + + if(TestFixture::use_graphs) + { + gHelper.cleanupGraphHelper(); + HIP_CHECK(hipStreamDestroy(stream)); + } + } + } +} diff --git a/test/rocprim/test_device_batch_memcpy.cpp b/test/rocprim/test_device_batch_memcpy.cpp index d899305aa..ab42985d5 100644 --- a/test/rocprim/test_device_batch_memcpy.cpp +++ b/test/rocprim/test_device_batch_memcpy.cpp @@ -341,165 +341,177 @@ TYPED_TEST(RocprimDeviceBatchMemcpyTests, SizeAndTypeVariation) constexpr int32_t num_tlev = num_buffers - num_blev - num_wlev; // Get random buffer sizes - uint32_t seed = 0; - SCOPED_TRACE(testing::Message() << "with seed= " << seed); - std::mt19937_64 rng{seed}; - - std::vector h_buffer_num_elements(num_buffers); - - auto iter = h_buffer_num_elements.begin(); - - if(num_tlev > 0) - iter = test_utils::generate_random_data_n(iter, num_tlev, 1, wlev_min_elems - 1, rng); - if(num_wlev > 0) - iter = test_utils::generate_random_data_n(iter, - num_wlev, - wlev_min_elems, - blev_min_elems - 1, - rng); - if(num_blev > 0) - iter = test_utils::generate_random_data_n(iter, num_blev, blev_min_elems, max_elems, rng); - - const byte_offset_type total_num_elements = std::accumulate(h_buffer_num_elements.begin(), - h_buffer_num_elements.end(), - byte_offset_type{0}); - - // Shuffle the sizes so that size classes aren't clustered - std::shuffle(h_buffer_num_elements.begin(), h_buffer_num_elements.end(), rng); - - // And the total byte size - const byte_offset_type total_num_bytes = total_num_elements * sizeof(value_type); - - // Device pointers - value_type* d_input = nullptr; - value_type* d_output = nullptr; - value_type** d_buffer_srcs = nullptr; - value_type** d_buffer_dsts = nullptr; - buffer_size_type* d_buffer_sizes = nullptr; - - // Calculate temporary storage - - size_t temp_storage_bytes = 0; - - batch_copy(nullptr, - temp_storage_bytes, - d_buffer_srcs, - d_buffer_dsts, - d_buffer_sizes, - num_buffers, - hipStreamDefault); - - void* d_temp_storage = nullptr; - - // Allocate memory. - HIP_CHECK(hipMalloc(&d_input, total_num_bytes)); - HIP_CHECK(hipMalloc(&d_output, total_num_bytes)); - - HIP_CHECK(hipMalloc(&d_buffer_srcs, num_buffers * sizeof(*d_buffer_srcs))); - HIP_CHECK(hipMalloc(&d_buffer_dsts, num_buffers * sizeof(*d_buffer_dsts))); - HIP_CHECK(hipMalloc(&d_buffer_sizes, num_buffers * sizeof(*d_buffer_sizes))); - - HIP_CHECK(hipMalloc(&d_temp_storage, temp_storage_bytes)); - - // Generate data. - std::vector h_input_for_memcpy; - std::vector h_input_for_copy; - init_input(h_input_for_memcpy, h_input_for_copy, rng, total_num_bytes); - - // Generate the source and shuffled destination offsets. - std::vector src_offsets; - std::vector dst_offsets; - - if(shuffled) + for(size_t seed_index = 0; seed_index < random_seeds_count + seed_size; ++seed_index) { - src_offsets = shuffled_exclusive_scan(h_buffer_num_elements, rng); - dst_offsets = shuffled_exclusive_scan(h_buffer_num_elements, rng); - } - else - { - src_offsets = std::vector(num_buffers); - dst_offsets = std::vector(num_buffers); - - // Consecutive offsets (no shuffling). - // src/dst offsets first element is 0, so skip that! - std::partial_sum(h_buffer_num_elements.begin(), - h_buffer_num_elements.end() - 1, - src_offsets.begin() + 1); - std::partial_sum(h_buffer_num_elements.begin(), - h_buffer_num_elements.end() - 1, - dst_offsets.begin() + 1); - } - - // Get the byte size of each buffer - std::vector h_buffer_num_bytes(num_buffers); - for(size_t i = 0; i < num_buffers; ++i) - { - h_buffer_num_bytes[i] = h_buffer_num_elements[i] * sizeof(value_type); - } - - // Generate the source and destination pointers. - std::vector h_buffer_srcs(num_buffers); - std::vector h_buffer_dsts(num_buffers); - - for(int32_t i = 0; i < num_buffers; ++i) - { - h_buffer_srcs[i] = d_input + src_offsets[i]; - h_buffer_dsts[i] = d_output + dst_offsets[i]; - } - - // Prepare the batch memcpy. - if(isMemCpy) - { - HIP_CHECK( - hipMemcpy(d_input, h_input_for_memcpy.data(), total_num_bytes, hipMemcpyHostToDevice)); - HIP_CHECK(hipMemcpy(d_buffer_sizes, - h_buffer_num_bytes.data(), - h_buffer_num_bytes.size() * sizeof(*d_buffer_sizes), + seed_type seed_value + = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count]; + SCOPED_TRACE(testing::Message() << "with seed= " << seed_value); + std::mt19937_64 rng{seed_value}; + + std::vector h_buffer_num_elements(num_buffers); + + auto iter = h_buffer_num_elements.begin(); + + if(num_tlev > 0) + iter = test_utils::generate_random_data_n(iter, num_tlev, 1, wlev_min_elems - 1, rng); + if(num_wlev > 0) + iter = test_utils::generate_random_data_n(iter, + num_wlev, + wlev_min_elems, + blev_min_elems - 1, + rng); + if(num_blev > 0) + iter = test_utils::generate_random_data_n(iter, + num_blev, + blev_min_elems, + max_elems, + rng); + + const byte_offset_type total_num_elements = std::accumulate(h_buffer_num_elements.begin(), + h_buffer_num_elements.end(), + byte_offset_type{0}); + + // Shuffle the sizes so that size classes aren't clustered + std::shuffle(h_buffer_num_elements.begin(), h_buffer_num_elements.end(), rng); + + // And the total byte size + const byte_offset_type total_num_bytes = total_num_elements * sizeof(value_type); + + // Device pointers + value_type* d_input = nullptr; + value_type* d_output = nullptr; + value_type** d_buffer_srcs = nullptr; + value_type** d_buffer_dsts = nullptr; + buffer_size_type* d_buffer_sizes = nullptr; + + // Calculate temporary storage + + size_t temp_storage_bytes = 0; + + batch_copy(nullptr, + temp_storage_bytes, + d_buffer_srcs, + d_buffer_dsts, + d_buffer_sizes, + num_buffers, + hipStreamDefault); + + void* d_temp_storage = nullptr; + + // Allocate memory. + HIP_CHECK(hipMalloc(&d_input, total_num_bytes)); + HIP_CHECK(hipMalloc(&d_output, total_num_bytes)); + + HIP_CHECK(hipMalloc(&d_buffer_srcs, num_buffers * sizeof(*d_buffer_srcs))); + HIP_CHECK(hipMalloc(&d_buffer_dsts, num_buffers * sizeof(*d_buffer_dsts))); + HIP_CHECK(hipMalloc(&d_buffer_sizes, num_buffers * sizeof(*d_buffer_sizes))); + + HIP_CHECK(hipMalloc(&d_temp_storage, temp_storage_bytes)); + + // Generate data. + std::vector h_input_for_memcpy; + std::vector h_input_for_copy; + init_input(h_input_for_memcpy, h_input_for_copy, rng, total_num_bytes); + + // Generate the source and shuffled destination offsets. + std::vector src_offsets; + std::vector dst_offsets; + + if(shuffled) + { + src_offsets = shuffled_exclusive_scan(h_buffer_num_elements, rng); + dst_offsets = shuffled_exclusive_scan(h_buffer_num_elements, rng); + } + else + { + src_offsets = std::vector(num_buffers); + dst_offsets = std::vector(num_buffers); + + // Consecutive offsets (no shuffling). + // src/dst offsets first element is 0, so skip that! + std::partial_sum(h_buffer_num_elements.begin(), + h_buffer_num_elements.end() - 1, + src_offsets.begin() + 1); + std::partial_sum(h_buffer_num_elements.begin(), + h_buffer_num_elements.end() - 1, + dst_offsets.begin() + 1); + } + + // Get the byte size of each buffer + std::vector h_buffer_num_bytes(num_buffers); + for(size_t i = 0; i < num_buffers; ++i) + { + h_buffer_num_bytes[i] = h_buffer_num_elements[i] * sizeof(value_type); + } + + // Generate the source and destination pointers. + std::vector h_buffer_srcs(num_buffers); + std::vector h_buffer_dsts(num_buffers); + + for(int32_t i = 0; i < num_buffers; ++i) + { + h_buffer_srcs[i] = d_input + src_offsets[i]; + h_buffer_dsts[i] = d_output + dst_offsets[i]; + } + + // Prepare the batch memcpy. + if(isMemCpy) + { + HIP_CHECK(hipMemcpy(d_input, + h_input_for_memcpy.data(), + total_num_bytes, + hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_buffer_sizes, + h_buffer_num_bytes.data(), + h_buffer_num_bytes.size() * sizeof(*d_buffer_sizes), + hipMemcpyHostToDevice)); + } + else + { + HIP_CHECK(hipMemcpy(d_input, + h_input_for_copy.data(), + total_num_bytes, + hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_buffer_sizes, + h_buffer_num_elements.data(), + h_buffer_num_elements.size() * sizeof(*d_buffer_sizes), + hipMemcpyHostToDevice)); + } + + HIP_CHECK(hipMemcpy(d_buffer_srcs, + h_buffer_srcs.data(), + h_buffer_srcs.size() * sizeof(*d_buffer_srcs), hipMemcpyHostToDevice)); - } - else - { - HIP_CHECK( - hipMemcpy(d_input, h_input_for_copy.data(), total_num_bytes, hipMemcpyHostToDevice)); - HIP_CHECK(hipMemcpy(d_buffer_sizes, - h_buffer_num_elements.data(), - h_buffer_num_elements.size() * sizeof(*d_buffer_sizes), + HIP_CHECK(hipMemcpy(d_buffer_dsts, + h_buffer_dsts.data(), + h_buffer_dsts.size() * sizeof(*d_buffer_dsts), hipMemcpyHostToDevice)); - } - HIP_CHECK(hipMemcpy(d_buffer_srcs, - h_buffer_srcs.data(), - h_buffer_srcs.size() * sizeof(*d_buffer_srcs), - hipMemcpyHostToDevice)); - HIP_CHECK(hipMemcpy(d_buffer_dsts, - h_buffer_dsts.data(), - h_buffer_dsts.size() * sizeof(*d_buffer_dsts), - hipMemcpyHostToDevice)); - - // Run batched memcpy. - batch_copy(d_temp_storage, - temp_storage_bytes, - d_buffer_srcs, - d_buffer_dsts, - d_buffer_sizes, - num_buffers, - hipStreamDefault); - - // Verify results. - check_result(h_input_for_memcpy, - h_input_for_copy, - d_output, - total_num_bytes, - total_num_elements, - num_buffers, - src_offsets, - dst_offsets, - h_buffer_num_bytes); - - HIP_CHECK(hipFree(d_temp_storage)); - HIP_CHECK(hipFree(d_buffer_sizes)); - HIP_CHECK(hipFree(d_buffer_dsts)); - HIP_CHECK(hipFree(d_buffer_srcs)); - HIP_CHECK(hipFree(d_output)); - HIP_CHECK(hipFree(d_input)); + // Run batched memcpy. + batch_copy(d_temp_storage, + temp_storage_bytes, + d_buffer_srcs, + d_buffer_dsts, + d_buffer_sizes, + num_buffers, + hipStreamDefault); + + // Verify results. + check_result(h_input_for_memcpy, + h_input_for_copy, + d_output, + total_num_bytes, + total_num_elements, + num_buffers, + src_offsets, + dst_offsets, + h_buffer_num_bytes); + + HIP_CHECK(hipFree(d_temp_storage)); + HIP_CHECK(hipFree(d_buffer_sizes)); + HIP_CHECK(hipFree(d_buffer_dsts)); + HIP_CHECK(hipFree(d_buffer_srcs)); + HIP_CHECK(hipFree(d_output)); + HIP_CHECK(hipFree(d_input)); + } } diff --git a/test/rocprim/test_device_find_end.cpp b/test/rocprim/test_device_find_end.cpp new file mode 100644 index 000000000..64da6f00e --- /dev/null +++ b/test/rocprim/test_device_find_end.cpp @@ -0,0 +1,445 @@ +// MIT License +// +// Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +// required test headers +#include "indirect_iterator.hpp" +#include "test_utils_assertions.hpp" +#include "test_utils_custom_float_type.hpp" +#include "test_utils_custom_test_types.hpp" +#include "test_utils_data_generation.hpp" +#include "test_utils_types.hpp" + +#include "../common_test_header.hpp" + +#include "rocprim/device/device_find_end.hpp" + +#include +#include + +// Params for tests +template, + class Config = rocprim::default_config, + bool UseGraphs = false, + bool UseIndirectIterator = false> +struct DeviceFindEndParams +{ + using value_type = ValueType; + using key_type = KeyType; + using index_type = IndexType; + using compare_function = CompareFunction; + using config = Config; + static constexpr bool use_graphs = UseGraphs; + static constexpr bool use_indirect_iterator = UseIndirectIterator; +}; + +template +class RocprimDeviceFindEndTests : public ::testing::Test +{ +public: + using value_type = typename Params::value_type; + using key_type = typename Params::key_type; + using index_type = typename Params::index_type; + using compare_function = typename Params::compare_function; + using config = typename Params::config; + const bool debug_synchronous = false; + static constexpr bool use_graphs = Params::use_graphs; + static constexpr bool use_indirect_iterator = Params::use_indirect_iterator; +}; + +using RocprimDeviceFindEndTestsParams = ::testing::Types< + DeviceFindEndParams, + DeviceFindEndParams, + DeviceFindEndParams, + DeviceFindEndParams, + DeviceFindEndParams>, + DeviceFindEndParams, + DeviceFindEndParams, + DeviceFindEndParams, + DeviceFindEndParams, + DeviceFindEndParams, + DeviceFindEndParams>, + DeviceFindEndParams>, + DeviceFindEndParams, + DeviceFindEndParams, + DeviceFindEndParams>, + DeviceFindEndParams, + DeviceFindEndParams>, + DeviceFindEndParams, rocprim::default_config, true>, + DeviceFindEndParams, rocprim::default_config>, + DeviceFindEndParams, rocprim::default_config>, + DeviceFindEndParams, + rocprim::default_config, + false, + true>, + DeviceFindEndParams, + rocprim::search_config<64, 16, 1024>, + false, + false>>; + +TYPED_TEST_SUITE(RocprimDeviceFindEndTests, RocprimDeviceFindEndTestsParams); + +TYPED_TEST(RocprimDeviceFindEndTests, FindEnd) +{ + int device_id = test_common_utils::obtain_device_from_ctest(); + SCOPED_TRACE(testing::Message() << "with device_id = " << device_id); + HIP_CHECK(hipSetDevice(device_id)); + + using value_type = typename TestFixture::value_type; + using key_type = typename TestFixture::key_type; + using index_type = typename TestFixture::index_type; + using compare_function = typename TestFixture::compare_function; + using config = typename TestFixture::config; + const bool debug_synchronous = TestFixture::debug_synchronous; + constexpr bool use_indirect_iterator = TestFixture::use_indirect_iterator; + + std::vector key_sizes = {0, 1, 10, 1000, 10000}; + + for(size_t seed_index = 0; seed_index < random_seeds_count + seed_size; seed_index++) + { + unsigned int seed_value + = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count]; + SCOPED_TRACE(testing::Message() << "with seed = " << seed_value); + + for(size_t size : test_utils::get_sizes(seed_value)) + { + hipStream_t stream = 0; // default + + SCOPED_TRACE(testing::Message() << "with size = " << size); + + for(size_t key_size : key_sizes) + { + SCOPED_TRACE(testing::Message() << "with key size = " << key_size); + + if(TestFixture::use_graphs) + { + // Default stream does not support hipGraph stream capture, so create one + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + } + + size_t pattern = 0; + if(size > 0) + { + pattern = test_utils::get_random_value(0, size - 1, seed_value); + } + + SCOPED_TRACE(testing::Message() << "with index = " << pattern); + + // Generate data + std::vector input; + if(rocprim::is_floating_point::value) + { + input = test_utils::get_random_data(size, -1000, 1000, seed_value); + } + else + { + input = test_utils::get_random_data( + size, + test_utils::numeric_limits::min(), + test_utils::numeric_limits::max(), + seed_value); + } + + std::vector keys(key_size); + if(pattern + key_size < size) + { + keys.assign(input.begin() + pattern, input.begin() + pattern + key_size); + } + else + { + keys.assign(input.begin() + pattern, input.end()); + } + + value_type* d_input; + key_type* d_keys; + index_type* d_output; + HIP_CHECK( + test_common_utils::hipMallocHelper(&d_input, input.size() * sizeof(*d_input))); + + HIP_CHECK( + test_common_utils::hipMallocHelper(&d_keys, keys.size() * sizeof(*d_keys))); + + HIP_CHECK(test_common_utils::hipMallocHelper(&d_output, sizeof(*d_output))); + + HIP_CHECK(hipMemcpy(d_input, + input.data(), + input.size() * sizeof(*d_input), + hipMemcpyHostToDevice)); + + HIP_CHECK(hipMemcpy(d_keys, + keys.data(), + keys.size() * sizeof(*d_keys), + hipMemcpyHostToDevice)); + + const auto input_it + = test_utils::wrap_in_indirect_iterator(d_input); + const auto keys_it + = test_utils::wrap_in_indirect_iterator(d_keys); + const auto output_keys + = test_utils::wrap_in_indirect_iterator(d_output); + + // compare function + compare_function compare_op; + + // temp storage + size_t temp_storage_size_bytes; + void* d_temp_storage = nullptr; + // Get size of d_temp_storage + HIP_CHECK(rocprim::find_end(nullptr, + temp_storage_size_bytes, + input_it, + keys_it, + output_keys, + input.size(), + keys.size(), + compare_op, + stream, + debug_synchronous)); + + // temp_storage_size_bytes must be >0 + ASSERT_GT(temp_storage_size_bytes, 0); + + // allocate temporary storage + HIP_CHECK( + test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size_bytes)); + + test_utils::GraphHelper gHelper; + if(TestFixture::use_graphs) + { + gHelper.startStreamCapture(stream); + } + + // Run + HIP_CHECK(rocprim::find_end(d_temp_storage, + temp_storage_size_bytes, + input_it, + keys_it, + output_keys, + input.size(), + keys.size(), + compare_op, + stream, + debug_synchronous)); + + if(TestFixture::use_graphs) + { + gHelper.createAndLaunchGraph(stream); + } + + HIP_CHECK(hipGetLastError()); + HIP_CHECK(hipDeviceSynchronize()); + + index_type output; + // Copy output to host + HIP_CHECK(hipMemcpy(&output, d_output, sizeof(*d_output), hipMemcpyDeviceToHost)); + + index_type expected = std::find_end(input.begin(), + input.end(), + keys.begin(), + keys.end(), + compare_op) + - input.begin(); + + ASSERT_EQ(output, expected); + + HIP_CHECK(hipFree(d_input)); + HIP_CHECK(hipFree(d_keys)); + HIP_CHECK(hipFree(d_output)); + HIP_CHECK(hipFree(d_temp_storage)); + + if(TestFixture::use_graphs) + { + gHelper.cleanupGraphHelper(); + HIP_CHECK(hipStreamDestroy(stream)); + } + } + } + } +} + +TYPED_TEST(RocprimDeviceFindEndTests, FindEndRepetition) +{ + int device_id = test_common_utils::obtain_device_from_ctest(); + SCOPED_TRACE(testing::Message() << "with device_id = " << device_id); + HIP_CHECK(hipSetDevice(device_id)); + + using value_type = typename TestFixture::value_type; + using key_type = typename TestFixture::key_type; + using index_type = typename TestFixture::index_type; + using compare_function = typename TestFixture::compare_function; + using config = typename TestFixture::config; + const bool debug_synchronous = TestFixture::debug_synchronous; + constexpr bool use_indirect_iterator = TestFixture::use_indirect_iterator; + + size_t key_size = 10; + + for(size_t seed_index = 0; seed_index < random_seeds_count + seed_size; seed_index++) + { + unsigned int seed_value + = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count]; + SCOPED_TRACE(testing::Message() << "with seed = " << seed_value); + + for(size_t size : test_utils::get_sizes(seed_value)) + { + hipStream_t stream = 0; // default + + if(size < key_size) + { + continue; + } + + SCOPED_TRACE(testing::Message() << "with size = " << size); + + if(TestFixture::use_graphs) + { + // Default stream does not support hipGraph stream capture, so create one + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + } + + // Generate data + std::vector keys; + if(rocprim::is_floating_point::value) + { + keys = test_utils::get_random_data(key_size, -1000, 1000, seed_value); + } + else + { + keys = test_utils::get_random_data( + key_size, + test_utils::numeric_limits::min(), + test_utils::numeric_limits::max(), + seed_value); + } + + std::vector input(size); + for(size_t i = 0; i < size / key_size; i++) + { + std::copy(keys.begin(), keys.end(), input.begin() + i * key_size); + } + + value_type* d_input; + key_type* d_keys; + index_type* d_output; + HIP_CHECK( + test_common_utils::hipMallocHelper(&d_input, input.size() * sizeof(*d_input))); + + HIP_CHECK(test_common_utils::hipMallocHelper(&d_keys, keys.size() * sizeof(*d_keys))); + + HIP_CHECK(test_common_utils::hipMallocHelper(&d_output, sizeof(*d_output))); + + HIP_CHECK(hipMemcpy(d_input, + input.data(), + input.size() * sizeof(*d_input), + hipMemcpyHostToDevice)); + + HIP_CHECK(hipMemcpy(d_keys, + keys.data(), + keys.size() * sizeof(*d_keys), + hipMemcpyHostToDevice)); + + const auto input_it + = test_utils::wrap_in_indirect_iterator(d_input); + + // compare function + compare_function compare_op; + + // temp storage + size_t temp_storage_size_bytes; + void* d_temp_storage = nullptr; + // Get size of d_temp_storage + HIP_CHECK(rocprim::find_end(nullptr, + temp_storage_size_bytes, + input_it, + d_keys, + d_output, + input.size(), + keys.size(), + compare_op, + stream, + debug_synchronous)); + + // temp_storage_size_bytes must be >0 + ASSERT_GT(temp_storage_size_bytes, 0); + + // allocate temporary storage + HIP_CHECK(test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size_bytes)); + + test_utils::GraphHelper gHelper; + if(TestFixture::use_graphs) + { + gHelper.startStreamCapture(stream); + } + + // Run + HIP_CHECK(rocprim::find_end(d_temp_storage, + temp_storage_size_bytes, + input_it, + d_keys, + d_output, + input.size(), + keys.size(), + compare_op, + stream, + debug_synchronous)); + + if(TestFixture::use_graphs) + { + gHelper.createAndLaunchGraph(stream); + } + + HIP_CHECK(hipGetLastError()); + HIP_CHECK(hipDeviceSynchronize()); + + index_type output; + // Copy output to host + HIP_CHECK(hipMemcpy(&output, d_output, sizeof(*d_output), hipMemcpyDeviceToHost)); + + index_type expected + = std::find_end(input.begin(), input.end(), keys.begin(), keys.end(), compare_op) + - input.begin(); + + ASSERT_EQ(output, expected); + + HIP_CHECK(hipFree(d_input)); + HIP_CHECK(hipFree(d_keys)); + HIP_CHECK(hipFree(d_output)); + HIP_CHECK(hipFree(d_temp_storage)); + + if(TestFixture::use_graphs) + { + gHelper.cleanupGraphHelper(); + HIP_CHECK(hipStreamDestroy(stream)); + } + } + } +} diff --git a/test/rocprim/test_device_find_first_of.cpp b/test/rocprim/test_device_find_first_of.cpp index dd22ec7ff..98314cab6 100644 --- a/test/rocprim/test_device_find_first_of.cpp +++ b/test/rocprim/test_device_find_first_of.cpp @@ -224,9 +224,11 @@ TYPED_TEST(RocprimDeviceFindFirstOfTests, FindFirstOf) hipMemcpyHostToDevice)); const auto input_it - = test_utils::wrap_in_indirect_iterator(d_input); + = test_utils::wrap_in_indirect_iterator( + d_input); const auto keys_it - = test_utils::wrap_in_indirect_iterator(d_keys); + = test_utils::wrap_in_indirect_iterator( + d_keys); // compare function compare_function compare_op; diff --git a/test/rocprim/test_device_histogram.cpp b/test/rocprim/test_device_histogram.cpp index ffa46216a..37afd47f8 100644 --- a/test/rocprim/test_device_histogram.cpp +++ b/test/rocprim/test_device_histogram.cpp @@ -65,13 +65,14 @@ inline auto get_random_samples(size_t size, U min, U max, int seed_value) -> { const long long min1 = static_cast(min); const long long max1 = static_cast(max); - const long long d = max1 - min1; + const long long d = max1 - min1; return test_utils::get_random_data( size, - static_cast(std::max(min1 - d / 10, static_cast(std::numeric_limits::lowest()))), - static_cast(std::min(max1 + d / 10, static_cast(std::numeric_limits::max()))), - seed_value - ); + static_cast(std::max(min1 - d / 10, + static_cast(test_utils::numeric_limits::lowest()))), + static_cast( + std::min(max1 + d / 10, static_cast(test_utils::numeric_limits::max()))), + seed_value); } template @@ -80,13 +81,14 @@ inline auto get_random_samples(size_t size, U min, U max, int seed_value) -> { const double min1 = static_cast(min); const double max1 = static_cast(max); - const double d = max1 - min1; + const double d = max1 - min1; return test_utils::get_random_data( size, - static_cast(std::max(min1 - d / 10, static_cast(std::numeric_limits::lowest()))), - static_cast(std::min(max1 + d / 10, static_cast(std::numeric_limits::max()))), - seed_value - ); + static_cast( + std::max(min1 - d / 10, static_cast(test_utils::numeric_limits::lowest()))), + static_cast( + std::min(max1 + d / 10, static_cast(test_utils::numeric_limits::max()))), + seed_value); } // Does nothing, used for testing iterators (not raw pointers) as samples input diff --git a/test/rocprim/test_device_merge.cpp b/test/rocprim/test_device_merge.cpp index d028dcf14..0d0a6fde2 100644 --- a/test/rocprim/test_device_merge.cpp +++ b/test/rocprim/test_device_merge.cpp @@ -35,22 +35,25 @@ #include #include +#include #include #include +using DefaultConfig = rocprim::default_config; + // Params for tests -template< - class KeyType, - class ValueType, - class CompareOp = ::rocprim::less, - bool UseGraphs = false -> +template, + bool UseGraphs = false, + typename Config = rocprim::default_config> struct DeviceMergeParams { using key_type = KeyType; using value_type = ValueType; using compare_op_type = CompareOp; static constexpr bool use_graphs = UseGraphs; + using config = Config; }; template @@ -60,6 +63,7 @@ class RocprimDeviceMergeTests : public ::testing::Test using key_type = typename Params::key_type; using value_type = typename Params::value_type; using compare_op_type = typename Params::compare_op_type; + using config = typename Params::config; const bool debug_synchronous = false; static constexpr bool use_graphs = Params::use_graphs; }; @@ -107,6 +111,7 @@ TYPED_TEST_SUITE(RocprimDeviceMergeTests, RocprimDeviceMergeTestsParams); TYPED_TEST(RocprimDeviceMergeTests, MergeKey) { + using config = typename TestFixture::config; int device_id = test_common_utils::obtain_device_from_ctest(); SCOPED_TRACE(testing::Message() << "with device_id = " << device_id); HIP_CHECK(hipSetDevice(device_id)); @@ -195,16 +200,16 @@ TYPED_TEST(RocprimDeviceMergeTests, MergeKey) size_t temp_storage_size_bytes; void * d_temp_storage = nullptr; // Get size of d_temp_storage - HIP_CHECK(rocprim::merge(d_temp_storage, - temp_storage_size_bytes, - d_keys_input1, - d_keys_input2, - d_keys_checking_output, - keys_input1.size(), - keys_input2.size(), - compare_op, - stream, - debug_synchronous)); + HIP_CHECK(rocprim::merge(d_temp_storage, + temp_storage_size_bytes, + d_keys_input1, + d_keys_input2, + d_keys_checking_output, + keys_input1.size(), + keys_input2.size(), + compare_op, + stream, + debug_synchronous)); // temp_storage_size_bytes must be >0 ASSERT_GT(temp_storage_size_bytes, 0); @@ -219,15 +224,16 @@ TYPED_TEST(RocprimDeviceMergeTests, MergeKey) } // Run - HIP_CHECK( - rocprim::merge( - d_temp_storage, temp_storage_size_bytes, - d_keys_input1, d_keys_input2, - d_keys_checking_output, - keys_input1.size(), keys_input2.size(), - compare_op, stream, debug_synchronous - ) - ); + HIP_CHECK(rocprim::merge(d_temp_storage, + temp_storage_size_bytes, + d_keys_input1, + d_keys_input2, + d_keys_checking_output, + keys_input1.size(), + keys_input2.size(), + compare_op, + stream, + debug_synchronous)); if(TestFixture::use_graphs) { @@ -251,10 +257,10 @@ TYPED_TEST(RocprimDeviceMergeTests, MergeKey) // Check if keys_output values are as expected ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(keys_output, expected)); - hipFree(d_keys_input1); - hipFree(d_keys_input2); - hipFree(d_keys_output); - hipFree(d_temp_storage); + HIP_CHECK(hipFree(d_keys_input1)); + HIP_CHECK(hipFree(d_keys_input2)); + HIP_CHECK(hipFree(d_keys_output)); + HIP_CHECK(hipFree(d_temp_storage)); if (TestFixture::use_graphs) gHelper.cleanupGraphHelper(); @@ -267,6 +273,7 @@ TYPED_TEST(RocprimDeviceMergeTests, MergeKey) TYPED_TEST(RocprimDeviceMergeTests, MergeKeyValue) { + using config = typename TestFixture::config; int device_id = test_common_utils::obtain_device_from_ctest(); SCOPED_TRACE(testing::Message() << "with device_id = " << device_id); HIP_CHECK(hipSetDevice(device_id)); @@ -399,17 +406,19 @@ TYPED_TEST(RocprimDeviceMergeTests, MergeKeyValue) size_t temp_storage_size_bytes; void * d_temp_storage = nullptr; // Get size of d_temp_storage - HIP_CHECK( - rocprim::merge( - d_temp_storage, temp_storage_size_bytes, - d_keys_input1, d_keys_input2, - d_keys_checking_output, - d_values_input1, d_values_input2, - d_values_checking_output, - keys_input1.size(), keys_input2.size(), - compare_op, stream, TestFixture::debug_synchronous - ) - ); + HIP_CHECK(rocprim::merge(d_temp_storage, + temp_storage_size_bytes, + d_keys_input1, + d_keys_input2, + d_keys_checking_output, + d_values_input1, + d_values_input2, + d_values_checking_output, + keys_input1.size(), + keys_input2.size(), + compare_op, + stream, + TestFixture::debug_synchronous)); // temp_storage_size_bytes must be >0 ASSERT_GT(temp_storage_size_bytes, 0); @@ -424,17 +433,19 @@ TYPED_TEST(RocprimDeviceMergeTests, MergeKeyValue) } // Run - HIP_CHECK( - rocprim::merge( - d_temp_storage, temp_storage_size_bytes, - d_keys_input1, d_keys_input2, - d_keys_checking_output, - d_values_input1, d_values_input2, - d_values_checking_output, - keys_input1.size(), keys_input2.size(), - compare_op, stream, TestFixture::debug_synchronous - ) - ); + HIP_CHECK(rocprim::merge(d_temp_storage, + temp_storage_size_bytes, + d_keys_input1, + d_keys_input2, + d_keys_checking_output, + d_values_input1, + d_values_input2, + d_values_checking_output, + keys_input1.size(), + keys_input2.size(), + compare_op, + stream, + TestFixture::debug_synchronous)); if(TestFixture::use_graphs) { @@ -472,13 +483,13 @@ TYPED_TEST(RocprimDeviceMergeTests, MergeKeyValue) ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(keys_output, expected_key)); ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(values_output, expected_value)); - hipFree(d_keys_input1); - hipFree(d_keys_input2); - hipFree(d_keys_output); - hipFree(d_values_input1); - hipFree(d_values_input2); - hipFree(d_values_output); - hipFree(d_temp_storage); + HIP_CHECK(hipFree(d_keys_input1)); + HIP_CHECK(hipFree(d_keys_input2)); + HIP_CHECK(hipFree(d_keys_output)); + HIP_CHECK(hipFree(d_values_input1)); + HIP_CHECK(hipFree(d_values_input2)); + HIP_CHECK(hipFree(d_values_output)); + HIP_CHECK(hipFree(d_temp_storage)); if (TestFixture::use_graphs) gHelper.cleanupGraphHelper(); @@ -489,7 +500,7 @@ TYPED_TEST(RocprimDeviceMergeTests, MergeKeyValue) HIP_CHECK(hipStreamDestroy(stream)); } -template +template void testMergeMismatchedIteratorTypes() { const int device_id = test_common_utils::obtain_device_from_ctest(); @@ -522,9 +533,9 @@ void testMergeMismatchedIteratorTypes() keys_input1.size() * sizeof(keys_input1[0]), hipMemcpyHostToDevice)); - const auto d_keys_input2 = rocprim::make_transform_iterator(rocprim::make_counting_iterator(0), - [] __host__ __device__(int value) - { return value * 2 + 1; }); + const auto d_keys_input2 + = rocprim::make_transform_iterator(rocprim::make_counting_iterator(0), + [](int value) { return value * 2 + 1; }); static constexpr bool debug_synchronous = false; @@ -536,16 +547,16 @@ void testMergeMismatchedIteratorTypes() } size_t temp_storage_size_bytes = 0; - HIP_CHECK(rocprim::merge(nullptr, - temp_storage_size_bytes, - d_keys_input1, - d_keys_input2, - d_keys_output, - keys_input1.size(), - keys_input1.size(), - rocprim::less{}, - stream, - debug_synchronous)); + HIP_CHECK(rocprim::merge(nullptr, + temp_storage_size_bytes, + d_keys_input1, + d_keys_input2, + d_keys_output, + keys_input1.size(), + keys_input1.size(), + rocprim::less{}, + stream, + debug_synchronous)); ASSERT_GT(temp_storage_size_bytes, 0); @@ -558,16 +569,16 @@ void testMergeMismatchedIteratorTypes() gHelper.startStreamCapture(stream); } - HIP_CHECK(rocprim::merge(d_temp_storage, - temp_storage_size_bytes, - d_keys_input1, - d_keys_input2, - d_keys_output, - keys_input1.size(), - keys_input1.size(), - rocprim::less{}, - hipStreamDefault, - debug_synchronous)); + HIP_CHECK(rocprim::merge(d_temp_storage, + temp_storage_size_bytes, + d_keys_input1, + d_keys_input2, + d_keys_output, + keys_input1.size(), + keys_input1.size(), + rocprim::less{}, + hipStreamDefault, + debug_synchronous)); if(UseGraphs) { @@ -595,10 +606,10 @@ void testMergeMismatchedIteratorTypes() TEST(RocprimDeviceMergeTests, MergeMismatchedIteratorTypes) { - testMergeMismatchedIteratorTypes(); + testMergeMismatchedIteratorTypes(); } TEST(RocprimDeviceMergeTests, MergeMismatchedIteratorTypesWithGraphs) { - testMergeMismatchedIteratorTypes(); + testMergeMismatchedIteratorTypes(); } diff --git a/test/rocprim/test_device_merge_sort.cpp b/test/rocprim/test_device_merge_sort.cpp index 81d4815d1..e0e5b5c20 100644 --- a/test/rocprim/test_device_merge_sort.cpp +++ b/test/rocprim/test_device_merge_sort.cpp @@ -490,7 +490,7 @@ void testLargeIndices() hipMemcpyDeviceToHost)); // Check if output values are as expected - const size_t unique_keys = size_t(std::numeric_limits::max()) + 1; + const size_t unique_keys = size_t(test_utils::numeric_limits::max()) + 1; const size_t segment_length = rocprim::detail::ceiling_div(size, unique_keys); const size_t full_segments = size % unique_keys == 0 ? unique_keys : size % unique_keys; for(size_t i = 0; i < size; i += 4321) diff --git a/test/rocprim/test_device_nth_element.cpp b/test/rocprim/test_device_nth_element.cpp index 9e74abd3d..6359de486 100644 --- a/test/rocprim/test_device_nth_element.cpp +++ b/test/rocprim/test_device_nth_element.cpp @@ -218,19 +218,12 @@ TYPED_TEST(RocprimDeviceNthelementTests, NthelementKey) SCOPED_TRACE(testing::Message() << "with nth_element = " << nth_element); // Generate data - std::vector input; - if(rocprim::is_floating_point::value) - { - input = test_utils::get_random_data(size, -1000, 1000, seed_value); - } - else - { - input = test_utils::get_random_data( - size, - test_utils::numeric_limits::min(), - test_utils::numeric_limits::max(), - seed_value); - } + std::vector input = test_utils::get_random_data( + size, + test_utils::generate_limits::min(), + test_utils::generate_limits::max(), + seed_value); + std::vector output(size); key_type* d_input; diff --git a/test/rocprim/test_device_partial_sort.cpp b/test/rocprim/test_device_partial_sort.cpp index fa08c0f41..2c9a3162f 100644 --- a/test/rocprim/test_device_partial_sort.cpp +++ b/test/rocprim/test_device_partial_sort.cpp @@ -27,6 +27,7 @@ #include "test_utils_custom_float_type.hpp" #include "test_utils_custom_test_types.hpp" #include "test_utils_data_generation.hpp" +#include "test_utils_sort_comparator.hpp" #include "test_utils_types.hpp" #include "../common_test_header.hpp" @@ -50,12 +51,14 @@ template, class Config = ::rocprim::default_config, bool UseGraphs = false, - bool UseIndirectIterator = false> + bool UseIndirectIterator = false, + class Decomposer = ::rocprim::identity_decomposer> struct DevicePartialSortParams { using key_type = KeyType; using compare_function = CompareFunction; using config = Config; + using decomposer = Decomposer; static constexpr bool use_graphs = UseGraphs; static constexpr bool use_indirect_iterator = UseIndirectIterator; }; @@ -67,6 +70,7 @@ class RocprimDevicePartialSortTests : public ::testing::Test using key_type = typename Params::key_type; using compare_function = typename Params::compare_function; using config = typename Params::config; + using decomposer = typename Params::decomposer; const bool debug_synchronous = false; static constexpr bool use_graphs = Params::use_graphs; static constexpr bool use_indirect_iterator = Params::use_indirect_iterator; @@ -95,7 +99,14 @@ using RocprimDevicePartialSortTestsParams = ::testing::Types< ::rocprim::less, rocprim::partial_sort_config< rocprim:: - nth_element_config<128, 4, 32, 16, rocprim::block_radix_rank_algorithm::basic>>>>; + nth_element_config<128, 4, 32, 16, rocprim::block_radix_rank_algorithm::basic>>>, + DevicePartialSortParams< + test_utils::custom_test_type, + ::rocprim::less>, + rocprim::default_config, + false, + false, + test_utils::custom_test_type_decomposer>>>; template void inline compare_partial_sort_cpp_14(InputVector input, @@ -179,6 +190,7 @@ TYPED_TEST(RocprimDevicePartialSortTests, PartialSort) using key_type = std::remove_cv_t; using compare_function = typename TestFixture::compare_function; using config = typename TestFixture::config; + using decomposer = typename TestFixture::decomposer; const bool debug_synchronous = TestFixture::debug_synchronous; constexpr bool use_indirect_iterator = TestFixture::use_indirect_iterator; @@ -210,19 +222,12 @@ TYPED_TEST(RocprimDevicePartialSortTests, PartialSort) HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); } - std::vector input; - if(rocprim::is_floating_point::value) - { - input = test_utils::get_random_data(size, -1000, 1000, seed_value); - } - else - { - input = test_utils::get_random_data( - size, - test_utils::numeric_limits::min(), - test_utils::numeric_limits::max(), - seed_value); - } + std::vector input = test_utils::get_random_data( + size, + test_utils::generate_limits::min(), + test_utils::generate_limits::max(), + seed_value); + key_type* d_input; HIP_CHECK(test_common_utils::hipMallocHelper(&d_input, size * sizeof(key_type))); HIP_CHECK(hipMemcpy(d_input, @@ -237,6 +242,7 @@ TYPED_TEST(RocprimDevicePartialSortTests, PartialSort) = test_utils::wrap_in_indirect_iterator(d_input); compare_function compare_op; + decomposer decomposer_op; // Allocate temporary storage size_t temp_storage_size_bytes{}; @@ -247,7 +253,8 @@ TYPED_TEST(RocprimDevicePartialSortTests, PartialSort) size, compare_op, stream, - debug_synchronous)); + debug_synchronous, + decomposer_op)); ASSERT_GT(temp_storage_size_bytes, 0); void* d_temp_storage{}; @@ -383,6 +390,7 @@ TYPED_TEST(RocprimDevicePartialSortTests, PartialSortCopy) using key_type = std::remove_cv_t; using compare_function = typename TestFixture::compare_function; using config = typename TestFixture::config; + using decomposer = typename TestFixture::decomposer; const bool debug_synchronous = TestFixture::debug_synchronous; constexpr bool input_is_const = std::is_const_v; constexpr bool use_indirect_iterator = TestFixture::use_indirect_iterator; @@ -415,27 +423,16 @@ TYPED_TEST(RocprimDevicePartialSortTests, PartialSortCopy) HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); } - std::vector input; - std::vector output_original; - if(rocprim::is_floating_point::value) - { - input = test_utils::get_random_data(size, -1000, 1000, seed_value); - output_original - = test_utils::get_random_data(size, -1000, 1000, seed_value + 1); - } - else - { - input = test_utils::get_random_data( - size, - test_utils::numeric_limits::min(), - test_utils::numeric_limits::max(), - seed_value); - output_original = test_utils::get_random_data( - size, - test_utils::numeric_limits::min(), - test_utils::numeric_limits::max(), - seed_value + 1); - } + std::vector input = test_utils::get_random_data( + size, + test_utils::generate_limits::min(), + test_utils::generate_limits::max(), + seed_value); + std::vector output_original = test_utils::get_random_data( + size, + test_utils::generate_limits::min(), + test_utils::generate_limits::max(), + seed_value + 1); key_type* d_input; HIP_CHECK(test_common_utils::hipMallocHelper(&d_input, size * sizeof(key_type))); @@ -456,6 +453,7 @@ TYPED_TEST(RocprimDevicePartialSortTests, PartialSortCopy) test_utils::wrap_in_const(d_input)); compare_function compare_op; + decomposer decomposer_op; // Allocate temporary storage size_t temp_storage_size_bytes{}; @@ -468,7 +466,8 @@ TYPED_TEST(RocprimDevicePartialSortTests, PartialSortCopy) size, compare_op, stream, - debug_synchronous)); + debug_synchronous, + decomposer_op)); ASSERT_GT(temp_storage_size_bytes, 0); void* d_temp_storage{}; diff --git a/test/rocprim/test_device_partition.cpp b/test/rocprim/test_device_partition.cpp index 902cd4ccb..a3dedc186 100644 --- a/test/rocprim/test_device_partition.cpp +++ b/test/rocprim/test_device_partition.cpp @@ -217,11 +217,11 @@ TYPED_TEST(RocprimDevicePartitionTests, Flagged) ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(output, expected_selected, expected_selected.size())); ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(output_rejected, expected_rejected, expected_rejected.size())); - hipFree(d_input); - hipFree(d_flags); - hipFree(d_output); - hipFree(d_selected_count_output); - hipFree(d_temp_storage); + HIP_CHECK(hipFree(d_input)); + HIP_CHECK(hipFree(d_flags)); + HIP_CHECK(hipFree(d_output)); + HIP_CHECK(hipFree(d_selected_count_output)); + HIP_CHECK(hipFree(d_temp_storage)); if(TestFixture::use_graphs) { @@ -236,6 +236,20 @@ TYPED_TEST(RocprimDevicePartitionTests, Flagged) } } +template +struct select_op_t +{ + __host__ __device__ + auto operator()(const T& value) -> bool + { + if(value == T(50)) + { + return true; + } + return false; + } +}; + TYPED_TEST(RocprimDevicePartitionTests, PredicateEmptyInput) { int device_id = test_common_utils::obtain_device_from_ctest(); @@ -254,11 +268,7 @@ TYPED_TEST(RocprimDevicePartitionTests, PredicateEmptyInput) HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); } - auto select_op = [] __host__ __device__ (const T& value) -> bool - { - if(value == T(50)) return true; - return false; - }; + auto select_op = select_op_t{}; U * d_output; unsigned int * d_selected_count_output; @@ -328,9 +338,9 @@ TYPED_TEST(RocprimDevicePartitionTests, PredicateEmptyInput) ); ASSERT_EQ(selected_count_output, 0); - hipFree(d_output); - hipFree(d_selected_count_output); - hipFree(d_temp_storage); + HIP_CHECK(hipFree(d_output)); + HIP_CHECK(hipFree(d_selected_count_output)); + HIP_CHECK(hipFree(d_temp_storage)); if (TestFixture::use_graphs) { @@ -358,11 +368,7 @@ TYPED_TEST(RocprimDevicePartitionTests, Predicate) HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); } - auto select_op = [] __host__ __device__ (const T& value) -> bool - { - if(value == T(50)) return true; - return false; - }; + auto select_op = select_op_t{}; for (size_t seed_index = 0; seed_index < random_seeds_count + seed_size; seed_index++) { @@ -474,10 +480,10 @@ TYPED_TEST(RocprimDevicePartitionTests, Predicate) ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(output, expected_selected, expected_selected.size())); ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(output_rejected, expected_rejected, expected_rejected.size())); - hipFree(d_input); - hipFree(d_output); - hipFree(d_selected_count_output); - hipFree(d_temp_storage); + HIP_CHECK(hipFree(d_input)); + HIP_CHECK(hipFree(d_output)); + HIP_CHECK(hipFree(d_selected_count_output)); + HIP_CHECK(hipFree(d_temp_storage)); if(TestFixture::use_graphs) { @@ -511,12 +517,7 @@ TYPED_TEST(RocprimDevicePartitionTests, PredicateTwoWay) HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); } - auto select_op = [] __host__ __device__(const T& value) -> bool - { - if(value == T(50)) - return true; - return false; - }; + auto select_op = select_op_t{}; for(size_t seed_index = 0; seed_index < random_seeds_count + seed_size; seed_index++) { @@ -636,11 +637,11 @@ TYPED_TEST(RocprimDevicePartitionTests, PredicateTwoWay) ASSERT_NO_FATAL_FAILURE( test_utils::assert_eq(rejected, expected_rejected, expected_rejected.size())); - hipFree(d_input); - hipFree(d_selected); - hipFree(d_rejected); - hipFree(d_selected_count_output); - hipFree(d_temp_storage); + HIP_CHECK(hipFree(d_input)); + HIP_CHECK(hipFree(d_selected)); + HIP_CHECK(hipFree(d_rejected)); + HIP_CHECK(hipFree(d_selected_count_output)); + HIP_CHECK(hipFree(d_temp_storage)); if(TestFixture::use_graphs) { @@ -848,12 +849,12 @@ TYPED_TEST(RocprimDevicePartitionTests, PredicateThreeWay) ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(output, expected, expected.size())); - hipFree(d_input); - hipFree(d_first_output); - hipFree(d_second_output); - hipFree(d_unselected_output); - hipFree(d_selected_counts); - hipFree(d_temp_storage); + HIP_CHECK(hipFree(d_input)); + HIP_CHECK(hipFree(d_first_output)); + HIP_CHECK(hipFree(d_second_output)); + HIP_CHECK(hipFree(d_unselected_output)); + HIP_CHECK(hipFree(d_selected_counts)); + HIP_CHECK(hipFree(d_temp_storage)); if(TestFixture::use_graphs) { @@ -1480,6 +1481,20 @@ TEST_P(RocprimDevicePartitionLargeInputTests, LargeInputPartitionThreeWay) } } +template +struct select_data_op_t +{ + __host__ __device__ + auto operator()(const T& value) -> bool + { + if(value.data[0] == 128) + { + return true; + } + return false; + } +}; + // This test checks to make sure that the block size is reduced correctly // when our data size and type are set in a way that we will exceed the shared // memory limit. Since the block size calculation is done at compile time, @@ -1518,12 +1533,7 @@ TEST(RocprimDevicePartitionBlockSizeTests, BlockSize) const bool debug_synchronous = false; const hipStream_t stream = 0; // default stream - auto select_op = [] __host__ __device__ (const T& value) -> bool - { - // The data values are in [0, 255]. Partition on the midpoint. - if(value.data[0] == 128) return true; - return false; - }; + auto select_op = select_data_op_t{}; // Use some power of two and off-by-one-from-power-of-two data sizes. const std::vector sizes = {256, 257, 511, 512, 1024, 1025}; @@ -1629,10 +1639,10 @@ TEST(RocprimDevicePartitionBlockSizeTests, BlockSize) ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(output, expected_selected, expected_selected.size())); ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(output_rejected, expected_rejected, expected_rejected.size())); - hipFree(d_input); - hipFree(d_output); - hipFree(d_selected_count_output); - hipFree(d_temp_storage); + HIP_CHECK(hipFree(d_input)); + HIP_CHECK(hipFree(d_output)); + HIP_CHECK(hipFree(d_selected_count_output)); + HIP_CHECK(hipFree(d_temp_storage)); } } } \ No newline at end of file diff --git a/test/rocprim/test_device_radix_sort.cpp.in b/test/rocprim/test_device_radix_sort.cpp.in index c70343d03..a49138ada 100644 --- a/test/rocprim/test_device_radix_sort.cpp.in +++ b/test/rocprim/test_device_radix_sort.cpp.in @@ -47,12 +47,13 @@ #if ROCPRIM_TEST_SLICE == 0 TEST(SUITE, SortKeysOver4G) { sort_keys_over_4g(); } TEST(SUITE, SortKeysOver4GWithGraphs) { sort_keys_over_4g(); } + TEST(SUITE, SortKeysLargeSizes) { sort_keys_large_sizes(); } #endif #if ROCPRIM_TEST_TYPE_SLICE == 0 #if ROCPRIM_HAS_INT128_SUPPORT - INSTANTIATE(params<__int128_t, __int128_t>) - INSTANTIATE(params<__uint128_t, __uint128_t>) + INSTANTIATE(params) + INSTANTIATE(params) #endif INSTANTIATE(params) INSTANTIATE(params) diff --git a/test/rocprim/test_device_radix_sort.hpp b/test/rocprim/test_device_radix_sort.hpp index 3d1b973b9..2a109e58f 100644 --- a/test/rocprim/test_device_radix_sort.hpp +++ b/test/rocprim/test_device_radix_sort.hpp @@ -87,8 +87,8 @@ auto generate_key_input(KeyIter keys_input, size_t size, engine_type& rng_engine using key_type = typename std::iterator_traits::value_type; test_utils::generate_random_data_n(keys_input, size, - std::numeric_limits::min(), - std::numeric_limits::max(), + test_utils::numeric_limits::min(), + test_utils::numeric_limits::max(), rng_engine); } @@ -1302,8 +1302,8 @@ void sort_keys_over_4g() std::vector keys_input = test_utils::get_random_data(size, - std::numeric_limits::min(), - std::numeric_limits::max(), + test_utils::numeric_limits::min(), + test_utils::numeric_limits::max(), seed_value); //generate histogram of the randomly generated values @@ -1373,7 +1373,7 @@ void sort_keys_over_4g() hipMemcpyDeviceToHost)); size_t counter = 0; - for(size_t i = 0; i <= std::numeric_limits::max(); ++i) + for(size_t i = 0; i <= test_utils::numeric_limits::max(); ++i) { for(size_t j = 0; j < histogram[i]; ++j) { @@ -1393,4 +1393,86 @@ void sort_keys_over_4g() } } +inline void sort_keys_large_sizes() +{ + int device_id = test_common_utils::obtain_device_from_ctest(); + SCOPED_TRACE(testing::Message() << "with device_id= " << device_id); + HIP_CHECK(hipSetDevice(device_id)); + + using key_type = uint8_t; + constexpr unsigned int start_bit = 0; + constexpr unsigned int end_bit = 8; + + hipStream_t stream = 0; + + // Currently, CI enforces a hard limit of 96 GB on memory allocations. + // Temporarily use sizes that will require less space than the limit. + const std::vector sizes = test_utils::get_large_sizes<35>(seeds[0]); + for(const size_t size : sizes) + { + SCOPED_TRACE(testing::Message() << "with size = " << size); + + // Generate data + std::vector keys_input(size); + std::iota(keys_input.begin(), keys_input.end(), 0); + + key_type* d_keys; + HIP_CHECK_MEMORY(test_common_utils::hipMallocHelper(&d_keys, size * sizeof(key_type))); + HIP_CHECK( + hipMemcpy(d_keys, keys_input.data(), size * sizeof(key_type), hipMemcpyHostToDevice)); + + void* d_temporary_storage = nullptr; + size_t temporary_storage_bytes = 0; + HIP_CHECK(rocprim::radix_sort_keys(d_temporary_storage, + temporary_storage_bytes, + d_keys, + d_keys, + size, + start_bit, + end_bit, + stream)); + + ASSERT_GT(temporary_storage_bytes, 0U); + + HIP_CHECK_MEMORY( + test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_bytes)); + + HIP_CHECK(rocprim::radix_sort_keys(d_temporary_storage, + temporary_storage_bytes, + d_keys, + d_keys, + size, + start_bit, + end_bit, + stream)); + + HIP_CHECK(hipFree(d_temporary_storage)); + + std::vector keys_output(size); + HIP_CHECK( + hipMemcpy(keys_output.data(), d_keys, size * sizeof(key_type), hipMemcpyDeviceToHost)); + + HIP_CHECK(hipFree(d_keys)); + + // Check if output values are as expected + const size_t unique_keys = size_t(rocprim::numeric_limits::max()) + 1; + const size_t segment_length = rocprim::detail::ceiling_div(size, unique_keys); + const size_t full_segments = size % unique_keys == 0 ? unique_keys : size % unique_keys; + for(size_t i = 0; i < size; i += 4321) + { + key_type expected; + if(i / segment_length < full_segments) + { + expected = key_type(i / segment_length); + } + else + { + expected = key_type((i - full_segments * segment_length) / (segment_length - 1) + + full_segments); + } + ASSERT_EQ(keys_output[i], expected) << "with index = " << i; + } + } +} + #endif // TEST_DEVICE_RADIX_SORT_HPP_ diff --git a/test/rocprim/test_device_reduce.cpp b/test/rocprim/test_device_reduce.cpp index 589dfc06a..8514f1c5f 100644 --- a/test/rocprim/test_device_reduce.cpp +++ b/test/rocprim/test_device_reduce.cpp @@ -222,8 +222,8 @@ TYPED_TEST(RocprimDeviceReduceTests, ReduceEmptyInput) ); ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(output, initial_value)); - hipFree(d_output); - hipFree(d_temp_storage); + HIP_CHECK(hipFree(d_output)); + HIP_CHECK(hipFree(d_temp_storage)); if (TestFixture::use_graphs) { @@ -354,9 +354,9 @@ TYPED_TEST(RocprimDeviceReduceTests, ReduceSum) ASSERT_NO_FATAL_FAILURE( test_utils::assert_near(output[0], expected, test_utils::precision * size)); - hipFree(d_input); - hipFree(d_output); - hipFree(d_temp_storage); + HIP_CHECK(hipFree(d_input)); + HIP_CHECK(hipFree(d_output)); + HIP_CHECK(hipFree(d_temp_storage)); if (TestFixture::use_graphs) { @@ -506,9 +506,9 @@ TYPED_TEST(RocprimDeviceReduceTests, ReduceArgMinimum) test_utils::assert_eq(output[0].key, expected.key); test_utils::assert_eq(output[0].value, expected.value); - hipFree(d_input); - hipFree(d_output); - hipFree(d_temp_storage); + HIP_CHECK(hipFree(d_input)); + HIP_CHECK(hipFree(d_output)); + HIP_CHECK(hipFree(d_temp_storage)); if (TestFixture::use_graphs) { @@ -605,8 +605,8 @@ void testLargeIndices() ASSERT_EQ(output, expected_output); - hipFree(d_temp_storage); - hipFree(d_output); + HIP_CHECK(hipFree(d_temp_storage)); + HIP_CHECK(hipFree(d_output)); if(use_graphs) { @@ -745,9 +745,9 @@ TYPED_TEST(RocprimDeviceReducePrecisionTests, ReduceSumInputEqualExponentFunctio // Check if output values are as expected ASSERT_NO_FATAL_FAILURE(test_utils::assert_near(output[0], expected, precision)); - hipFree(d_input); - hipFree(d_output); - hipFree(d_temp_storage); + HIP_CHECK(hipFree(d_input)); + HIP_CHECK(hipFree(d_output)); + HIP_CHECK(hipFree(d_temp_storage)); if (TestFixture::use_graphs) { @@ -874,9 +874,9 @@ TYPED_TEST(RocprimDeviceReduceTests, ReduceMinimum) ? 0 : std::max(test_utils::precision, test_utils::precision))); - hipFree(d_input); - hipFree(d_output); - hipFree(d_temp_storage); + HIP_CHECK(hipFree(d_input)); + HIP_CHECK(hipFree(d_output)); + HIP_CHECK(hipFree(d_temp_storage)); if (TestFixture::use_graphs) { diff --git a/test/rocprim/test_device_run_length_encode.cpp b/test/rocprim/test_device_run_length_encode.cpp index 2f1030365..6d837e047 100644 --- a/test/rocprim/test_device_run_length_encode.cpp +++ b/test/rocprim/test_device_run_length_encode.cpp @@ -210,6 +210,8 @@ TYPED_TEST(RocprimDeviceRunLengthEncode, Encode) stream, debug_synchronous)); + HIP_CHECK(hipDeviceSynchronize()); + HIP_CHECK(hipFree(d_temporary_storage)); std::vector unique_output(runs_count_expected); diff --git a/test/rocprim/test_device_scan.cpp b/test/rocprim/test_device_scan.cpp index b62746a48..03831c46f 100644 --- a/test/rocprim/test_device_scan.cpp +++ b/test/rocprim/test_device_scan.cpp @@ -317,8 +317,8 @@ TYPED_TEST(RocprimDeviceScanTests, InclusiveScanEmptyInput) ASSERT_FALSE(out_of_bounds.get()); - hipFree(d_output); - hipFree(d_temp_storage); + HIP_CHECK(hipFree(d_output)); + HIP_CHECK(hipFree(d_temp_storage)); if (TestFixture::use_graphs) { @@ -471,9 +471,9 @@ TYPED_TEST(RocprimDeviceScanTests, InclusiveScan) ASSERT_NO_FATAL_FAILURE( test_utils::assert_near(output, expected, single_op_precision * size)); - hipFree(d_input); - hipFree(d_output); - hipFree(d_temp_storage); + HIP_CHECK(hipFree(d_input)); + HIP_CHECK(hipFree(d_output)); + HIP_CHECK(hipFree(d_temp_storage)); if (TestFixture::use_graphs) { @@ -634,9 +634,9 @@ TYPED_TEST(RocprimDeviceScanTests, ExclusiveScan) ASSERT_NO_FATAL_FAILURE( test_utils::assert_near(output, expected, single_op_precision * size)); - hipFree(d_input); - hipFree(d_output); - hipFree(d_temp_storage); + HIP_CHECK(hipFree(d_input)); + HIP_CHECK(hipFree(d_output)); + HIP_CHECK(hipFree(d_temp_storage)); if (TestFixture::use_graphs) { @@ -812,10 +812,10 @@ TYPED_TEST(RocprimDeviceScanTests, InclusiveScanByKey) ASSERT_NO_FATAL_FAILURE( test_utils::assert_near(output, expected, single_op_precision * size)); - hipFree(d_keys); - hipFree(d_input); - hipFree(d_output); - hipFree(d_temp_storage); + HIP_CHECK(hipFree(d_keys)); + HIP_CHECK(hipFree(d_input)); + HIP_CHECK(hipFree(d_output)); + HIP_CHECK(hipFree(d_temp_storage)); if (TestFixture::use_graphs) { @@ -996,10 +996,10 @@ TYPED_TEST(RocprimDeviceScanTests, ExclusiveScanByKey) ASSERT_NO_FATAL_FAILURE( test_utils::assert_near(output, expected, single_op_precision * size)); - hipFree(d_keys); - hipFree(d_input); - hipFree(d_output); - hipFree(d_temp_storage); + HIP_CHECK(hipFree(d_keys)); + HIP_CHECK(hipFree(d_input)); + HIP_CHECK(hipFree(d_output)); + HIP_CHECK(hipFree(d_temp_storage)); if (TestFixture::use_graphs) { @@ -1177,8 +1177,8 @@ void testLargeIndicesInclusiveScan() ASSERT_EQ(output, expected_output); - hipFree(d_temp_storage); - hipFree(d_output); + HIP_CHECK(hipFree(d_temp_storage)); + HIP_CHECK(hipFree(d_output)); if(UseGraphs) { @@ -1310,8 +1310,8 @@ void testLargeIndicesExclusiveScan() ASSERT_EQ(output, expected_output); - hipFree(d_temp_storage); - hipFree(d_output); + HIP_CHECK(hipFree(d_temp_storage)); + HIP_CHECK(hipFree(d_output)); if(UseGraphs) { @@ -1823,11 +1823,11 @@ TYPED_TEST(RocprimDeviceScanFutureTests, ExclusiveScan) // Check if output values are as expected ASSERT_NO_FATAL_FAILURE(test_utils::assert_near(output, expected, precision)); - hipFree(d_input); - hipFree(d_output); - hipFree(d_future_input); - hipFree(d_initial_value); - hipFree(d_temp_storage); + HIP_CHECK(hipFree(d_input)); + HIP_CHECK(hipFree(d_output)); + HIP_CHECK(hipFree(d_future_input)); + HIP_CHECK(hipFree(d_initial_value)); + HIP_CHECK(hipFree(d_temp_storage)); if (TestFixture::use_graphs) { diff --git a/test/rocprim/test_device_search.cpp b/test/rocprim/test_device_search.cpp new file mode 100644 index 000000000..33735f5f8 --- /dev/null +++ b/test/rocprim/test_device_search.cpp @@ -0,0 +1,442 @@ +// MIT License +// +// Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +// required test headers +#include "indirect_iterator.hpp" +#include "test_utils_assertions.hpp" +#include "test_utils_custom_float_type.hpp" +#include "test_utils_custom_test_types.hpp" +#include "test_utils_data_generation.hpp" +#include "test_utils_types.hpp" + +#include "../common_test_header.hpp" + +#include "rocprim/device/device_search.hpp" + +#include +#include + +// Params for tests +template, + class Config = rocprim::default_config, + bool UseGraphs = false, + bool UseIndirectIterator = false> +struct DeviceSearchParams +{ + using value_type = ValueType; + using key_type = KeyType; + using index_type = IndexType; + using compare_function = CompareFunction; + using config = Config; + static constexpr bool use_graphs = UseGraphs; + static constexpr bool use_indirect_iterator = UseIndirectIterator; +}; + +template +class RocprimDeviceSearchTests : public ::testing::Test +{ +public: + using value_type = typename Params::value_type; + using key_type = typename Params::key_type; + using index_type = typename Params::index_type; + using compare_function = typename Params::compare_function; + using config = typename Params::config; + const bool debug_synchronous = false; + static constexpr bool use_graphs = Params::use_graphs; + static constexpr bool use_indirect_iterator = Params::use_indirect_iterator; +}; + +using RocprimDeviceSearchTestsParams = ::testing::Types< + DeviceSearchParams, + DeviceSearchParams, + DeviceSearchParams, + DeviceSearchParams, + DeviceSearchParams>, + DeviceSearchParams, + DeviceSearchParams, + DeviceSearchParams, + DeviceSearchParams, + DeviceSearchParams, + DeviceSearchParams>, + DeviceSearchParams>, + DeviceSearchParams, + DeviceSearchParams, + DeviceSearchParams>, + DeviceSearchParams, + DeviceSearchParams>, + DeviceSearchParams, rocprim::default_config, true>, + DeviceSearchParams, rocprim::default_config>, + DeviceSearchParams, rocprim::default_config>, + DeviceSearchParams, + rocprim::default_config, + false, + true>, + DeviceSearchParams, + rocprim::search_config<64, 16, 1024>, + false, + false>>; + +TYPED_TEST_SUITE(RocprimDeviceSearchTests, RocprimDeviceSearchTestsParams); + +TYPED_TEST(RocprimDeviceSearchTests, Search) +{ + int device_id = test_common_utils::obtain_device_from_ctest(); + SCOPED_TRACE(testing::Message() << "with device_id = " << device_id); + HIP_CHECK(hipSetDevice(device_id)); + + using value_type = typename TestFixture::value_type; + using key_type = typename TestFixture::key_type; + using index_type = typename TestFixture::index_type; + using compare_function = typename TestFixture::compare_function; + using config = typename TestFixture::config; + const bool debug_synchronous = TestFixture::debug_synchronous; + constexpr bool use_indirect_iterator = TestFixture::use_indirect_iterator; + + std::vector key_sizes = {0, 1, 10, 1000, 10000}; + + for(size_t seed_index = 0; seed_index < random_seeds_count + seed_size; seed_index++) + { + unsigned int seed_value + = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count]; + SCOPED_TRACE(testing::Message() << "with seed = " << seed_value); + + for(size_t size : test_utils::get_sizes(seed_value)) + { + hipStream_t stream = 0; // default + + SCOPED_TRACE(testing::Message() << "with size = " << size); + + for(size_t key_size : key_sizes) + { + SCOPED_TRACE(testing::Message() << "with key size = " << key_size); + + if(TestFixture::use_graphs) + { + // Default stream does not support hipGraph stream capture, so create one + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + } + + size_t pattern = 0; + if(size > 0) + { + pattern = test_utils::get_random_value(0, size - 1, seed_value); + } + + SCOPED_TRACE(testing::Message() << "with index = " << pattern); + + // Generate data + std::vector input; + if(rocprim::is_floating_point::value) + { + input = test_utils::get_random_data(size, -1000, 1000, seed_value); + } + else + { + input = test_utils::get_random_data( + size, + test_utils::numeric_limits::min(), + test_utils::numeric_limits::max(), + seed_value); + } + + std::vector keys(key_size); + if(pattern + key_size < size) + { + keys.assign(input.begin() + pattern, input.begin() + pattern + key_size); + } + else + { + keys.assign(input.begin() + pattern, input.end()); + } + + value_type* d_input; + key_type* d_keys; + index_type* d_output; + HIP_CHECK( + test_common_utils::hipMallocHelper(&d_input, input.size() * sizeof(*d_input))); + + HIP_CHECK( + test_common_utils::hipMallocHelper(&d_keys, keys.size() * sizeof(*d_keys))); + + HIP_CHECK(test_common_utils::hipMallocHelper(&d_output, sizeof(*d_output))); + + HIP_CHECK(hipMemcpy(d_input, + input.data(), + input.size() * sizeof(*d_input), + hipMemcpyHostToDevice)); + + HIP_CHECK(hipMemcpy(d_keys, + keys.data(), + keys.size() * sizeof(*d_keys), + hipMemcpyHostToDevice)); + + const auto input_it + = test_utils::wrap_in_indirect_iterator(d_input); + const auto keys_it + = test_utils::wrap_in_indirect_iterator(d_keys); + const auto output_keys + = test_utils::wrap_in_indirect_iterator(d_output); + + // compare function + compare_function compare_op; + + // temp storage + size_t temp_storage_size_bytes; + void* d_temp_storage = nullptr; + // Get size of d_temp_storage + HIP_CHECK(rocprim::search(nullptr, + temp_storage_size_bytes, + input_it, + keys_it, + output_keys, + input.size(), + keys.size(), + compare_op, + stream, + debug_synchronous)); + + // temp_storage_size_bytes must be >0 + ASSERT_GT(temp_storage_size_bytes, 0); + + // allocate temporary storage + HIP_CHECK( + test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size_bytes)); + + test_utils::GraphHelper gHelper; + if(TestFixture::use_graphs) + { + gHelper.startStreamCapture(stream); + } + + // Run + HIP_CHECK(rocprim::search(d_temp_storage, + temp_storage_size_bytes, + input_it, + keys_it, + output_keys, + input.size(), + keys.size(), + compare_op, + stream, + debug_synchronous)); + + if(TestFixture::use_graphs) + { + gHelper.createAndLaunchGraph(stream); + } + + HIP_CHECK(hipGetLastError()); + HIP_CHECK(hipDeviceSynchronize()); + + index_type output; + // Copy output to host + HIP_CHECK(hipMemcpy(&output, d_output, sizeof(*d_output), hipMemcpyDeviceToHost)); + + index_type expected + = std::search(input.begin(), input.end(), keys.begin(), keys.end(), compare_op) + - input.begin(); + + ASSERT_EQ(output, expected); + + HIP_CHECK(hipFree(d_input)); + HIP_CHECK(hipFree(d_keys)); + HIP_CHECK(hipFree(d_output)); + HIP_CHECK(hipFree(d_temp_storage)); + + if(TestFixture::use_graphs) + { + gHelper.cleanupGraphHelper(); + HIP_CHECK(hipStreamDestroy(stream)); + } + } + } + } +} + +TYPED_TEST(RocprimDeviceSearchTests, SearchRepetition) +{ + int device_id = test_common_utils::obtain_device_from_ctest(); + SCOPED_TRACE(testing::Message() << "with device_id = " << device_id); + HIP_CHECK(hipSetDevice(device_id)); + + using value_type = typename TestFixture::value_type; + using key_type = typename TestFixture::key_type; + using index_type = typename TestFixture::index_type; + using compare_function = typename TestFixture::compare_function; + using config = typename TestFixture::config; + const bool debug_synchronous = TestFixture::debug_synchronous; + constexpr bool use_indirect_iterator = TestFixture::use_indirect_iterator; + + size_t key_size = 10; + + for(size_t seed_index = 0; seed_index < random_seeds_count + seed_size; seed_index++) + { + unsigned int seed_value + = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count]; + SCOPED_TRACE(testing::Message() << "with seed = " << seed_value); + + for(size_t size : test_utils::get_sizes(seed_value)) + { + hipStream_t stream = 0; // default + + if(size < key_size) + { + continue; + } + + SCOPED_TRACE(testing::Message() << "with size = " << size); + + if(TestFixture::use_graphs) + { + // Default stream does not support hipGraph stream capture, so create one + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + } + + // Generate data + std::vector keys; + if(rocprim::is_floating_point::value) + { + keys = test_utils::get_random_data(key_size, -1000, 1000, seed_value); + } + else + { + keys = test_utils::get_random_data( + key_size, + test_utils::numeric_limits::min(), + test_utils::numeric_limits::max(), + seed_value); + } + + std::vector input(size); + for(size_t i = 0; i < size / key_size; i++) + { + std::copy(keys.begin(), keys.end(), input.begin() + i * key_size); + } + + value_type* d_input; + key_type* d_keys; + index_type* d_output; + HIP_CHECK( + test_common_utils::hipMallocHelper(&d_input, input.size() * sizeof(*d_input))); + + HIP_CHECK(test_common_utils::hipMallocHelper(&d_keys, keys.size() * sizeof(*d_keys))); + + HIP_CHECK(test_common_utils::hipMallocHelper(&d_output, sizeof(*d_output))); + + HIP_CHECK(hipMemcpy(d_input, + input.data(), + input.size() * sizeof(*d_input), + hipMemcpyHostToDevice)); + + HIP_CHECK(hipMemcpy(d_keys, + keys.data(), + keys.size() * sizeof(*d_keys), + hipMemcpyHostToDevice)); + + const auto input_it + = test_utils::wrap_in_indirect_iterator(d_input); + + // compare function + compare_function compare_op; + + // temp storage + size_t temp_storage_size_bytes; + void* d_temp_storage = nullptr; + // Get size of d_temp_storage + HIP_CHECK(rocprim::search(nullptr, + temp_storage_size_bytes, + input_it, + d_keys, + d_output, + input.size(), + keys.size(), + compare_op, + stream, + debug_synchronous)); + + // temp_storage_size_bytes must be >0 + ASSERT_GT(temp_storage_size_bytes, 0); + + // allocate temporary storage + HIP_CHECK(test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size_bytes)); + + test_utils::GraphHelper gHelper; + if(TestFixture::use_graphs) + { + gHelper.startStreamCapture(stream); + } + + // Run + HIP_CHECK(rocprim::search(d_temp_storage, + temp_storage_size_bytes, + input_it, + d_keys, + d_output, + input.size(), + keys.size(), + compare_op, + stream, + debug_synchronous)); + + if(TestFixture::use_graphs) + { + gHelper.createAndLaunchGraph(stream); + } + + HIP_CHECK(hipGetLastError()); + HIP_CHECK(hipDeviceSynchronize()); + + index_type output; + // Copy output to host + HIP_CHECK(hipMemcpy(&output, d_output, sizeof(*d_output), hipMemcpyDeviceToHost)); + + index_type expected + = std::search(input.begin(), input.end(), keys.begin(), keys.end(), compare_op) + - input.begin(); + + ASSERT_EQ(output, expected); + + HIP_CHECK(hipFree(d_input)); + HIP_CHECK(hipFree(d_keys)); + HIP_CHECK(hipFree(d_output)); + HIP_CHECK(hipFree(d_temp_storage)); + + if(TestFixture::use_graphs) + { + gHelper.cleanupGraphHelper(); + HIP_CHECK(hipStreamDestroy(stream)); + } + } + } +} diff --git a/test/rocprim/test_device_search_n.cpp b/test/rocprim/test_device_search_n.cpp new file mode 100644 index 000000000..af4c2fd09 --- /dev/null +++ b/test/rocprim/test_device_search_n.cpp @@ -0,0 +1,1392 @@ +// Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#include "../common_test_header.hpp" + +#include "test_utils_custom_test_types.hpp" +#include "test_utils_device_ptr.hpp" +#include "test_utils_types.hpp" + +#include + +#include + +template +using limit_type = test_utils::numeric_limits; + +template, + class Config = rocprim::default_config, + bool UseGraphs = false, + bool UseIndirectIterator = false> +struct DeviceSearchNParams +{ + using input_type = InputIterator; + using output_type = OutputIterator; + using op_type = BinaryPredicate; + using config = Config; + static constexpr bool use_graphs = UseGraphs; + static constexpr bool use_indirect_iterator = UseIndirectIterator; +}; + +template +class RocprimDeviceSearchNTests : public ::testing::Test +{ +public: + using input_type = typename Params::input_type; + using output_type = typename Params::output_type; + using op_type = typename Params::op_type; + using config = typename Params::config; + static constexpr bool use_graphs = Params::use_graphs; + static constexpr bool use_indirect_iterator = Params::use_indirect_iterator; + static constexpr bool debug_synchronous = false; +}; + +// Custom types +using custom_int2 = test_utils::custom_test_type; +using custom_double2 = test_utils::custom_test_type; +using custom_int64_array = test_utils::custom_test_array_type; + +// Custom configs +using custom_config_0 = rocprim::search_n_config<256, 4>; + +using RocprimDeviceSearchNTestsParams = ::testing::Types< + // Tests with default configuration + DeviceSearchNParams, + DeviceSearchNParams, + DeviceSearchNParams, + DeviceSearchNParams, + DeviceSearchNParams, + DeviceSearchNParams, + // Tests for custom types + DeviceSearchNParams, + DeviceSearchNParams, + DeviceSearchNParams, + // Tests for supported config structs + DeviceSearchNParams, + custom_config_0>, + // Tests for hipGraph support + DeviceSearchNParams, + rocprim::default_config, + true>, + // Tests for when output's value_type is void + DeviceSearchNParams, rocprim::default_config, false, true>>; + +TYPED_TEST_SUITE(RocprimDeviceSearchNTests, RocprimDeviceSearchNTestsParams); + +TYPED_TEST(RocprimDeviceSearchNTests, RandomTest) +{ + int device_id = test_common_utils::obtain_device_from_ctest(); + + HIP_CHECK(hipSetDevice(device_id)); + + using input_type = typename TestFixture::input_type; + using output_type = typename TestFixture::output_type; + using op_type = typename TestFixture::op_type; + using config = typename TestFixture::config; + + constexpr bool debug_synchronous = TestFixture::debug_synchronous; + op_type op{}; + + for(size_t seed_index = 0; seed_index < random_seeds_count + seed_size; seed_index++) + { + unsigned int seed_value + = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count]; + + for(const auto size : test_utils::get_sizes(seed_value)) + { + hipStream_t stream = 0; // default + size_t count = test_utils::get_random_value(0, size, ++seed_value); + size_t temp_storage_size; + input_type h_value + = test_utils::get_random_value(0, + limit_type::max(), + ++seed_value); + std::vector h_input + = test_utils::get_random_data(size, + 0, + limit_type::max(), + ++seed_value); + auto index = 0; + if(size > count) + { + index = test_utils::get_random_value(0, size - 1 - count, ++seed_value); + std::fill(h_input.begin() + index, h_input.begin() + index + count, h_value); + } + + output_type h_output; + test_utils::device_ptr d_input(h_input); + test_utils::device_ptr d_value(&h_value, 1); + test_utils::device_ptr d_output(1); + test_utils::device_ptr d_temp_storage; + + SCOPED_TRACE(testing::Message() << "with size = " << h_input.size()); + SCOPED_TRACE(testing::Message() << "with count = " << count); + SCOPED_TRACE(testing::Message() << "with value = " << h_value); + + if ROCPRIM_IF_CONSTEXPR(TestFixture::use_graphs) + { + // Default stream does not support hipGraph stream capture, so create one + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + } + // get size + HIP_CHECK(rocprim::search_n(0, + temp_storage_size, + d_input.get(), + d_output.get(), + h_input.size(), + count, + nullptr)); + + d_temp_storage.resize(temp_storage_size); + + test_utils::GraphHelper gHelper; + if(TestFixture::use_graphs) + { + gHelper.startStreamCapture(stream); + } + + HIP_CHECK(rocprim::search_n(d_temp_storage.get(), + temp_storage_size, + d_input.get(), + d_output.get(), + h_input.size(), + count, + d_value.get(), + op, + stream, + debug_synchronous)); + + if ROCPRIM_IF_CONSTEXPR(TestFixture::use_graphs) + { + gHelper.createAndLaunchGraph(stream); + } + + HIP_CHECK(hipGetLastError()); + HIP_CHECK(hipStreamSynchronize(stream)); + + const auto expected + = std::search_n(h_input.cbegin(), h_input.cend(), count, h_value, op) + - h_input.cbegin(); + h_output = d_output.load()[0]; + ASSERT_EQ(h_output, expected); + + if ROCPRIM_IF_CONSTEXPR(TestFixture::use_graphs) + { + gHelper.cleanupGraphHelper(); + HIP_CHECK(hipStreamDestroy(stream)); + } + } + } +} + +TYPED_TEST(RocprimDeviceSearchNTests, MaxCount) +{ + int device_id = test_common_utils::obtain_device_from_ctest(); + + HIP_CHECK(hipSetDevice(device_id)); + + using input_type = typename TestFixture::input_type; + using output_type = typename TestFixture::output_type; + using op_type = typename TestFixture::op_type; + using config = typename TestFixture::config; + + constexpr bool debug_synchronous = TestFixture::debug_synchronous; + op_type op{}; + + for(size_t seed_index = 0; seed_index < random_seeds_count + seed_size; seed_index++) + { + unsigned int seed_value + = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count]; + + for(const auto size : test_utils::get_sizes(seed_value)) + { + hipStream_t stream = 0; // default + size_t count = size; + size_t temp_storage_size; + input_type h_value + = test_utils::get_random_value(0, + limit_type::max(), + ++seed_value); + std::vector h_input(size, h_value); + output_type h_output; + test_utils::device_ptr d_input(h_input); + test_utils::device_ptr d_value(&h_value, 1); + test_utils::device_ptr d_output(1); + test_utils::device_ptr d_temp_storage; + + SCOPED_TRACE(testing::Message() << "with size = " << h_input.size()); + SCOPED_TRACE(testing::Message() << "with count = " << count); + SCOPED_TRACE(testing::Message() << "with value = " << h_value); + + if ROCPRIM_IF_CONSTEXPR(TestFixture::use_graphs) + { + // Default stream does not support hipGraph stream capture, so create one + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + } + // get size + HIP_CHECK(rocprim::search_n(0, + temp_storage_size, + d_input.get(), + d_output.get(), + h_input.size(), + count, + nullptr)); + + d_temp_storage.resize(temp_storage_size); + + test_utils::GraphHelper gHelper; + if(TestFixture::use_graphs) + { + gHelper.startStreamCapture(stream); + } + + HIP_CHECK(rocprim::search_n(d_temp_storage.get(), + temp_storage_size, + d_input.get(), + d_output.get(), + h_input.size(), + count, + d_value.get(), + op, + stream, + debug_synchronous)); + + if ROCPRIM_IF_CONSTEXPR(TestFixture::use_graphs) + { + gHelper.createAndLaunchGraph(stream); + } + + HIP_CHECK(hipGetLastError()); + HIP_CHECK(hipStreamSynchronize(stream)); + + const auto expected + = std::search_n(h_input.cbegin(), h_input.cend(), count, h_value, op) + - h_input.cbegin(); + + h_output = d_output.load()[0]; + + ASSERT_EQ(h_output, expected); + + if ROCPRIM_IF_CONSTEXPR(TestFixture::use_graphs) + { + gHelper.cleanupGraphHelper(); + HIP_CHECK(hipStreamDestroy(stream)); + } + } + } +} + +TYPED_TEST(RocprimDeviceSearchNTests, MinCount) +{ + int device_id = test_common_utils::obtain_device_from_ctest(); + + HIP_CHECK(hipSetDevice(device_id)); + + using input_type = typename TestFixture::input_type; + using output_type = typename TestFixture::output_type; + using op_type = typename TestFixture::op_type; + using config = typename TestFixture::config; + + constexpr bool debug_synchronous = TestFixture::debug_synchronous; + op_type op{}; + + for(size_t seed_index = 0; seed_index < random_seeds_count + seed_size; seed_index++) + { + unsigned int seed_value + = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count]; + + for(const auto size : test_utils::get_sizes(seed_value)) + { + hipStream_t stream = 0; // default + size_t count = 0; + size_t temp_storage_size; + input_type h_value + = test_utils::get_random_value(0, + limit_type::max(), + ++seed_value); + std::vector h_input(size, h_value); + output_type h_output; + test_utils::device_ptr d_input(h_input); + test_utils::device_ptr d_value(&h_value, 1); + test_utils::device_ptr d_output(1); + test_utils::device_ptr d_temp_storage; + + SCOPED_TRACE(testing::Message() << "with size = " << h_input.size()); + SCOPED_TRACE(testing::Message() << "with count = " << count); + SCOPED_TRACE(testing::Message() << "with value = " << h_value); + + if ROCPRIM_IF_CONSTEXPR(TestFixture::use_graphs) + { + // Default stream does not support hipGraph stream capture, so create one + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + } + // get size + HIP_CHECK(rocprim::search_n(0, + temp_storage_size, + d_input.get(), + d_output.get(), + h_input.size(), + count, + nullptr)); + + d_temp_storage.resize(temp_storage_size); + + test_utils::GraphHelper gHelper; + if(TestFixture::use_graphs) + { + gHelper.startStreamCapture(stream); + } + HIP_CHECK(rocprim::search_n(d_temp_storage.get(), + temp_storage_size, + d_input.get(), + d_output.get(), + h_input.size(), + count, + d_value.get(), + op, + stream, + debug_synchronous)); + + if ROCPRIM_IF_CONSTEXPR(TestFixture::use_graphs) + { + gHelper.createAndLaunchGraph(stream); + } + + HIP_CHECK(hipGetLastError()); + HIP_CHECK(hipStreamSynchronize(stream)); + + const auto expected + = std::search_n(h_input.cbegin(), h_input.cend(), count, h_value, op) + - h_input.cbegin(); + + h_output = d_output.load()[0]; + + ASSERT_EQ(h_output, expected); + + if ROCPRIM_IF_CONSTEXPR(TestFixture::use_graphs) + { + gHelper.cleanupGraphHelper(); + HIP_CHECK(hipStreamDestroy(stream)); + } + } + } +} + +TYPED_TEST(RocprimDeviceSearchNTests, StartFromBegin) +{ + int device_id = test_common_utils::obtain_device_from_ctest(); + + HIP_CHECK(hipSetDevice(device_id)); + + using input_type = typename TestFixture::input_type; + using output_type = typename TestFixture::output_type; + using op_type = typename TestFixture::op_type; + using config = typename TestFixture::config; + + constexpr bool debug_synchronous = TestFixture::debug_synchronous; + op_type op{}; + + for(size_t seed_index = 0; seed_index < random_seeds_count + seed_size; seed_index++) + { + unsigned int seed_value + = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count]; + + for(const auto size : test_utils::get_sizes(seed_value)) + { + hipStream_t stream = 0; // default + size_t count = size / 2; + size_t temp_storage_size; + input_type h_value{1}; + std::vector h_input(size); + std::fill(h_input.begin(), h_input.begin() + (size - count), h_value); + std::fill(h_input.begin() + count, h_input.end(), 0); + output_type h_output; + test_utils::device_ptr d_input(h_input); + test_utils::device_ptr d_value(&h_value, 1); + test_utils::device_ptr d_output(1); + test_utils::device_ptr d_temp_storage; + + SCOPED_TRACE(testing::Message() << "with size = " << h_input.size()); + SCOPED_TRACE(testing::Message() << "with count = " << count); + SCOPED_TRACE(testing::Message() << "with value = " << h_value); + + if ROCPRIM_IF_CONSTEXPR(TestFixture::use_graphs) + { + // Default stream does not support hipGraph stream capture, so create one + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + } + // get size + HIP_CHECK(rocprim::search_n(0, + temp_storage_size, + d_input.get(), + d_output.get(), + h_input.size(), + count, + nullptr)); + + d_temp_storage.resize(temp_storage_size); + + test_utils::GraphHelper gHelper; + if(TestFixture::use_graphs) + { + gHelper.startStreamCapture(stream); + } + + HIP_CHECK(rocprim::search_n(d_temp_storage.get(), + temp_storage_size, + d_input.get(), + d_output.get(), + h_input.size(), + count, + d_value.get(), + op, + stream, + debug_synchronous)); + + if ROCPRIM_IF_CONSTEXPR(TestFixture::use_graphs) + { + gHelper.createAndLaunchGraph(stream); + } + + HIP_CHECK(hipGetLastError()); + HIP_CHECK(hipStreamSynchronize(stream)); + + const auto expected + = std::search_n(h_input.cbegin(), h_input.cend(), count, h_value, op) + - h_input.cbegin(); + + h_output = d_output.load()[0]; + + ASSERT_EQ(h_output, expected); + + if ROCPRIM_IF_CONSTEXPR(TestFixture::use_graphs) + { + gHelper.cleanupGraphHelper(); + HIP_CHECK(hipStreamDestroy(stream)); + } + } + } +} + +TYPED_TEST(RocprimDeviceSearchNTests, StartFromMiddle) +{ + int device_id = test_common_utils::obtain_device_from_ctest(); + + HIP_CHECK(hipSetDevice(device_id)); + + using input_type = typename TestFixture::input_type; + using output_type = typename TestFixture::output_type; + using op_type = typename TestFixture::op_type; + using config = typename TestFixture::config; + + constexpr bool debug_synchronous = TestFixture::debug_synchronous; + op_type op{}; + + for(size_t seed_index = 0; seed_index < random_seeds_count + seed_size; seed_index++) + { + unsigned int seed_value + = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count]; + + for(const auto size : test_utils::get_sizes(seed_value)) + { + hipStream_t stream = 0; // default + size_t count = size / 2; + size_t temp_storage_size; + input_type h_value{1}; + std::vector h_input(size); + std::fill(h_input.begin(), h_input.begin() + (size - count), 0); + std::fill(h_input.begin() + count, h_input.end(), h_value); + output_type h_output; + test_utils::device_ptr d_input(h_input); + test_utils::device_ptr d_value(&h_value, 1); + test_utils::device_ptr d_output(1); + test_utils::device_ptr d_temp_storage; + + SCOPED_TRACE(testing::Message() << "with size = " << h_input.size()); + SCOPED_TRACE(testing::Message() << "with count = " << count); + SCOPED_TRACE(testing::Message() << "with value = " << h_value); + + if ROCPRIM_IF_CONSTEXPR(TestFixture::use_graphs) + { + // Default stream does not support hipGraph stream capture, so create one + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + } + // get size + HIP_CHECK(rocprim::search_n(0, + temp_storage_size, + d_input.get(), + d_output.get(), + h_input.size(), + count, + nullptr)); + + d_temp_storage.resize(temp_storage_size); + + test_utils::GraphHelper gHelper; + if(TestFixture::use_graphs) + { + gHelper.startStreamCapture(stream); + } + + HIP_CHECK(rocprim::search_n(d_temp_storage.get(), + temp_storage_size, + d_input.get(), + d_output.get(), + h_input.size(), + count, + d_value.get(), + op, + stream, + debug_synchronous)); + + if ROCPRIM_IF_CONSTEXPR(TestFixture::use_graphs) + { + gHelper.createAndLaunchGraph(stream); + } + + HIP_CHECK(hipGetLastError()); + HIP_CHECK(hipStreamSynchronize(stream)); + + const auto expected + = std::search_n(h_input.cbegin(), h_input.cend(), count, h_value, op) + - h_input.cbegin(); + + h_output = d_output.load()[0]; + + ASSERT_EQ(h_output, expected); + + if ROCPRIM_IF_CONSTEXPR(TestFixture::use_graphs) + { + gHelper.cleanupGraphHelper(); + HIP_CHECK(hipStreamDestroy(stream)); + } + } + } +} + +TYPED_TEST(RocprimDeviceSearchNTests, StartFromEnd) +{ + int device_id = test_common_utils::obtain_device_from_ctest(); + + HIP_CHECK(hipSetDevice(device_id)); + + using input_type = typename TestFixture::input_type; + using output_type = typename TestFixture::output_type; + using op_type = typename TestFixture::op_type; + using config = typename TestFixture::config; + + constexpr bool debug_synchronous = TestFixture::debug_synchronous; + op_type op{}; + + for(size_t seed_index = 0; seed_index < random_seeds_count + seed_size; seed_index++) + { + unsigned int seed_value + = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count]; + + for(const auto size : test_utils::get_sizes(seed_value)) + { + hipStream_t stream = 0; // default + size_t count = test_utils::get_random_value(0, size, ++seed_value); + size_t temp_storage_size; + input_type h_value{1}; + std::vector h_input(size); + std::fill(h_input.begin(), h_input.begin() + (size - count), 0); + std::fill(h_input.begin() + (size - count), h_input.end(), h_value); + output_type h_output; + test_utils::device_ptr d_input(h_input); + test_utils::device_ptr d_value(&h_value, 1); + test_utils::device_ptr d_output(1); + test_utils::device_ptr d_temp_storage; + + SCOPED_TRACE(testing::Message() << "with size = " << h_input.size()); + SCOPED_TRACE(testing::Message() << "with count = " << count); + SCOPED_TRACE(testing::Message() << "with value = " << h_value); + + if ROCPRIM_IF_CONSTEXPR(TestFixture::use_graphs) + { + // Default stream does not support hipGraph stream capture, so create one + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + } + // get size + HIP_CHECK(rocprim::search_n(0, + temp_storage_size, + d_input.get(), + d_output.get(), + h_input.size(), + count, + nullptr)); + + d_temp_storage.resize(temp_storage_size); + + test_utils::GraphHelper gHelper; + if(TestFixture::use_graphs) + { + gHelper.startStreamCapture(stream); + } + + HIP_CHECK(rocprim::search_n(d_temp_storage.get(), + temp_storage_size, + d_input.get(), + d_output.get(), + h_input.size(), + count, + d_value.get(), + op, + stream, + debug_synchronous)); + + if ROCPRIM_IF_CONSTEXPR(TestFixture::use_graphs) + { + gHelper.createAndLaunchGraph(stream); + } + + HIP_CHECK(hipGetLastError()); + HIP_CHECK(hipStreamSynchronize(stream)); + + const auto expected + = std::search_n(h_input.cbegin(), h_input.cend(), count, h_value, op) + - h_input.cbegin(); + + h_output = d_output.load()[0]; + + ASSERT_EQ(h_output, expected); + + if ROCPRIM_IF_CONSTEXPR(TestFixture::use_graphs) + { + gHelper.cleanupGraphHelper(); + HIP_CHECK(hipStreamDestroy(stream)); + } + } + } +} + +TYPED_TEST(RocprimDeviceSearchNTests, StartFromEndButFail) +{ + int device_id = test_common_utils::obtain_device_from_ctest(); + + HIP_CHECK(hipSetDevice(device_id)); + + using input_type = typename TestFixture::input_type; + using output_type = typename TestFixture::output_type; + using op_type = typename TestFixture::op_type; + using config = typename TestFixture::config; + + constexpr bool debug_synchronous = TestFixture::debug_synchronous; + op_type op{}; + + for(size_t seed_index = 0; seed_index < random_seeds_count + seed_size; seed_index++) + { + unsigned int seed_value + = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count]; + + for(const auto size : test_utils::get_sizes(seed_value)) + { + hipStream_t stream = 0; // default + size_t count = test_utils::get_random_value(0, size, ++seed_value); + size_t temp_storage_size; + input_type h_value{1}; + std::vector h_input(size); + std::fill(h_input.begin(), h_input.begin() + (size - count), 0); + std::fill(h_input.begin() + (size - count), h_input.end(), h_value); + if(count + 2 <= size) + { + count += 2; + } + output_type h_output; + test_utils::device_ptr d_input(h_input); + test_utils::device_ptr d_value(&h_value, 1); + test_utils::device_ptr d_output(1); + test_utils::device_ptr d_temp_storage; + + SCOPED_TRACE(testing::Message() << "with size = " << h_input.size()); + SCOPED_TRACE(testing::Message() << "with count = " << count); + SCOPED_TRACE(testing::Message() << "with value = " << h_value); + + if ROCPRIM_IF_CONSTEXPR(TestFixture::use_graphs) + { + // Default stream does not support hipGraph stream capture, so create one + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + } + // get size + HIP_CHECK(rocprim::search_n(0, + temp_storage_size, + d_input.get(), + d_output.get(), + h_input.size(), + count, + nullptr)); + + d_temp_storage.resize(temp_storage_size); + + test_utils::GraphHelper gHelper; + if(TestFixture::use_graphs) + { + gHelper.startStreamCapture(stream); + } + + HIP_CHECK(rocprim::search_n(d_temp_storage.get(), + temp_storage_size, + d_input.get(), + d_output.get(), + h_input.size(), + count, + d_value.get(), + op, + stream, + debug_synchronous)); + + if ROCPRIM_IF_CONSTEXPR(TestFixture::use_graphs) + { + gHelper.createAndLaunchGraph(stream); + } + + HIP_CHECK(hipGetLastError()); + HIP_CHECK(hipStreamSynchronize(stream)); + + const auto expected + = std::search_n(h_input.cbegin(), h_input.cend(), count, h_value, op) + - h_input.cbegin(); + + h_output = d_output.load()[0]; + + ASSERT_EQ(h_output, expected); + + if ROCPRIM_IF_CONSTEXPR(TestFixture::use_graphs) + { + gHelper.cleanupGraphHelper(); + HIP_CHECK(hipStreamDestroy(stream)); + } + } + } +} + +TYPED_TEST(RocprimDeviceSearchNTests, NoiseTest_1block) +{ + int device_id = test_common_utils::obtain_device_from_ctest(); + + HIP_CHECK(hipSetDevice(device_id)); + + using input_type = typename TestFixture::input_type; + using output_type = typename TestFixture::output_type; + using op_type = typename TestFixture::op_type; + using config = typename TestFixture::config; + + constexpr bool debug_synchronous = TestFixture::debug_synchronous; + op_type op{}; + + for(size_t seed_index = 0; seed_index < random_seeds_count + seed_size; seed_index++) + { + unsigned int seed_value + = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count]; + + for(const auto size : test_utils::get_sizes(seed_value)) + { + using wrapped_config = rocprim::detail::wrapped_search_n_config; + size_t temp_storage_size; + hipStream_t stream = 0; // default + rocprim::detail::target_arch target_arch; + HIP_CHECK(rocprim::detail::host_target_arch(stream, target_arch)); + const auto params = rocprim::detail::dispatch_target_arch(target_arch); + const unsigned int block_size = params.kernel_config.block_size; + const unsigned int items_per_thread = params.kernel_config.items_per_thread; + const unsigned int items_per_block = block_size * items_per_thread; + + /// Will do test like this: + /// |----------------------------------- size ------------------------------------| + /// |----- Item/block ----|----- Item/block ----|----- Item/block ----|----- Item/| + /// |--------count--------| + /// |111111111111111111110|111111111111111111110|111111111111111111111|11111111111| + + size_t count = 0; + input_type h_value{1}; + input_type h_noise{0}; + std::vector h_input(size, h_value); + output_type h_output; + + if(size > items_per_block) + { + count = items_per_block; + size_t cur_tile = 0; + size_t last_tile = size / count - 1; + while(cur_tile != last_tile) + { + h_input[cur_tile * count + count - 1] = h_noise; + ++cur_tile; + } + } + + test_utils::device_ptr d_input(h_input); + test_utils::device_ptr d_value(&h_value, 1); + test_utils::device_ptr d_output(1); + test_utils::device_ptr d_temp_storage; + + SCOPED_TRACE(testing::Message() << "with size = " << h_input.size()); + SCOPED_TRACE(testing::Message() << "with count = " << count); + SCOPED_TRACE(testing::Message() << "with value = " << h_value); + + if ROCPRIM_IF_CONSTEXPR(TestFixture::use_graphs) + { + // Default stream does not support hipGraph stream capture, so create one + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + } + // get size + HIP_CHECK(rocprim::search_n(0, + temp_storage_size, + d_input.get(), + d_output.get(), + h_input.size(), + count, + nullptr)); + + d_temp_storage.resize(temp_storage_size); + + test_utils::GraphHelper gHelper; + if(TestFixture::use_graphs) + { + gHelper.startStreamCapture(stream); + } + + HIP_CHECK(rocprim::search_n(d_temp_storage.get(), + temp_storage_size, + d_input.get(), + d_output.get(), + h_input.size(), + count, + d_value.get(), + op, + stream, + debug_synchronous)); + + if ROCPRIM_IF_CONSTEXPR(TestFixture::use_graphs) + { + gHelper.createAndLaunchGraph(stream); + } + + HIP_CHECK(hipGetLastError()); + HIP_CHECK(hipStreamSynchronize(stream)); + + const auto expected + = std::search_n(h_input.cbegin(), h_input.cend(), count, h_value, op) + - h_input.cbegin(); + + h_output = d_output.load()[0]; + + ASSERT_EQ(h_output, expected); + + if ROCPRIM_IF_CONSTEXPR(TestFixture::use_graphs) + { + gHelper.cleanupGraphHelper(); + HIP_CHECK(hipStreamDestroy(stream)); + } + } + } +} + +TYPED_TEST(RocprimDeviceSearchNTests, NoiseTest_2block) +{ + int device_id = test_common_utils::obtain_device_from_ctest(); + + HIP_CHECK(hipSetDevice(device_id)); + + using input_type = typename TestFixture::input_type; + using output_type = typename TestFixture::output_type; + using op_type = typename TestFixture::op_type; + using config = typename TestFixture::config; + + constexpr bool debug_synchronous = TestFixture::debug_synchronous; + op_type op{}; + + for(size_t seed_index = 0; seed_index < random_seeds_count + seed_size; seed_index++) + { + unsigned int seed_value + = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count]; + + for(const auto size : test_utils::get_sizes(seed_value)) + { + using wrapped_config = rocprim::detail::wrapped_search_n_config; + size_t temp_storage_size; + hipStream_t stream = 0; // default + rocprim::detail::target_arch target_arch; + HIP_CHECK(rocprim::detail::host_target_arch(stream, target_arch)); + const auto params = rocprim::detail::dispatch_target_arch(target_arch); + const unsigned int block_size = params.kernel_config.block_size; + const unsigned int items_per_thread = params.kernel_config.items_per_thread; + const unsigned int items_per_block = block_size * items_per_thread; + + /// Will do test like this: + /// |----------------------------------- size ------------------------------------| + /// |----- Item/block ----|----- Item/block ----|----- Item/block ----|----- Item/| + /// |--------count------------------------------| + /// |1111111111111111111111111111111111111111110|111111111111111111111111111111111| + + size_t count = 0; + input_type h_value{1}; + input_type h_noise{0}; + std::vector h_input(size, h_value); + output_type h_output; + + if(size > items_per_block * 2) + { + count = items_per_block * 2; + size_t cur_tile = 0; + size_t last_tile = size / count - 1; + while(cur_tile != last_tile) + { + h_input[cur_tile * count + count - 1] = h_noise; + ++cur_tile; + } + } + + test_utils::device_ptr d_input(h_input); + test_utils::device_ptr d_value(&h_value, 1); + test_utils::device_ptr d_output(1); + test_utils::device_ptr d_temp_storage; + + SCOPED_TRACE(testing::Message() << "with size = " << h_input.size()); + SCOPED_TRACE(testing::Message() << "with count = " << count); + SCOPED_TRACE(testing::Message() << "with value = " << h_value); + + if ROCPRIM_IF_CONSTEXPR(TestFixture::use_graphs) + { + // Default stream does not support hipGraph stream capture, so create one + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + } + // get size + HIP_CHECK(rocprim::search_n(0, + temp_storage_size, + d_input.get(), + d_output.get(), + h_input.size(), + count, + nullptr)); + + d_temp_storage.resize(temp_storage_size); + + test_utils::GraphHelper gHelper; + if(TestFixture::use_graphs) + { + gHelper.startStreamCapture(stream); + } + + HIP_CHECK(rocprim::search_n(d_temp_storage.get(), + temp_storage_size, + d_input.get(), + d_output.get(), + h_input.size(), + count, + d_value.get(), + op, + stream, + debug_synchronous)); + + if ROCPRIM_IF_CONSTEXPR(TestFixture::use_graphs) + { + gHelper.createAndLaunchGraph(stream); + } + + HIP_CHECK(hipGetLastError()); + HIP_CHECK(hipStreamSynchronize(stream)); + + const auto expected + = std::search_n(h_input.cbegin(), h_input.cend(), count, h_value, op) + - h_input.cbegin(); + + h_output = d_output.load()[0]; + + ASSERT_EQ(h_output, expected); + + if ROCPRIM_IF_CONSTEXPR(TestFixture::use_graphs) + { + gHelper.cleanupGraphHelper(); + HIP_CHECK(hipStreamDestroy(stream)); + } + } + } +} + +TYPED_TEST(RocprimDeviceSearchNTests, NoiseTest_3block) +{ + int device_id = test_common_utils::obtain_device_from_ctest(); + + HIP_CHECK(hipSetDevice(device_id)); + + using input_type = typename TestFixture::input_type; + using output_type = typename TestFixture::output_type; + using op_type = typename TestFixture::op_type; + using config = typename TestFixture::config; + + constexpr bool debug_synchronous = TestFixture::debug_synchronous; + op_type op{}; + + for(size_t seed_index = 0; seed_index < random_seeds_count + seed_size; seed_index++) + { + unsigned int seed_value + = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count]; + + for(const auto size : test_utils::get_sizes(seed_value)) + { + using wrapped_config = rocprim::detail::wrapped_search_n_config; + size_t temp_storage_size; + hipStream_t stream = 0; // default + rocprim::detail::target_arch target_arch; + HIP_CHECK(rocprim::detail::host_target_arch(stream, target_arch)); + const auto params = rocprim::detail::dispatch_target_arch(target_arch); + const unsigned int block_size = params.kernel_config.block_size; + const unsigned int items_per_thread = params.kernel_config.items_per_thread; + const unsigned int items_per_block = block_size * items_per_thread; + + /// Will do test like this: + /// |----------------------------------- size ----------------------------------------------------------------------------------------------------------------| + /// |----- Item/block ----|----- Item/block ----|----- Item/block ----|----- Item/block ----|----- Item/block ----|----- Item/block ----|----- Item/block ----| + /// |------------------------------count------------------------------|------------------------------count------------------------------| + /// |11111111111111111111111111111111111111111111111111111111111111110|11111111111111111111111111111111111111111111111111111111111111111|111111111111111111111| + + size_t count = 0; + input_type h_value{1}; + input_type h_noise{0}; + std::vector h_input(size, h_value); + output_type h_output; + + if(size > items_per_block * 3) + { + count = items_per_block * 3; + size_t cur_tile = 0; + size_t last_tile = size / count - 1; + while(cur_tile != last_tile) + { + h_input[cur_tile * count + count - 1] = h_noise; + ++cur_tile; + } + } + + test_utils::device_ptr d_input(h_input); + test_utils::device_ptr d_value(&h_value, 1); + test_utils::device_ptr d_output(1); + test_utils::device_ptr d_temp_storage; + + SCOPED_TRACE(testing::Message() << "with size = " << h_input.size()); + SCOPED_TRACE(testing::Message() << "with count = " << count); + SCOPED_TRACE(testing::Message() << "with value = " << h_value); + + if ROCPRIM_IF_CONSTEXPR(TestFixture::use_graphs) + { + // Default stream does not support hipGraph stream capture, so create one + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + } + // get size + HIP_CHECK(rocprim::search_n(0, + temp_storage_size, + d_input.get(), + d_output.get(), + h_input.size(), + count, + nullptr)); + + d_temp_storage.resize(temp_storage_size); + + test_utils::GraphHelper gHelper; + if(TestFixture::use_graphs) + { + gHelper.startStreamCapture(stream); + } + + HIP_CHECK(rocprim::search_n(d_temp_storage.get(), + temp_storage_size, + d_input.get(), + d_output.get(), + h_input.size(), + count, + d_value.get(), + op, + stream, + debug_synchronous)); + + if ROCPRIM_IF_CONSTEXPR(TestFixture::use_graphs) + { + gHelper.createAndLaunchGraph(stream); + } + + HIP_CHECK(hipGetLastError()); + HIP_CHECK(hipStreamSynchronize(stream)); + + const auto expected + = std::search_n(h_input.cbegin(), h_input.cend(), count, h_value, op) + - h_input.cbegin(); + + h_output = d_output.load()[0]; + + ASSERT_EQ(h_output, expected); + + if ROCPRIM_IF_CONSTEXPR(TestFixture::use_graphs) + { + gHelper.cleanupGraphHelper(); + HIP_CHECK(hipStreamDestroy(stream)); + } + } + } +} + +TYPED_TEST(RocprimDeviceSearchNTests, MultiResult1) +{ + int device_id = test_common_utils::obtain_device_from_ctest(); + + HIP_CHECK(hipSetDevice(device_id)); + + using input_type = typename TestFixture::input_type; + using output_type = typename TestFixture::output_type; + using op_type = typename TestFixture::op_type; + using config = typename TestFixture::config; + + constexpr bool debug_synchronous = TestFixture::debug_synchronous; + op_type op{}; + + for(size_t seed_index = 0; seed_index < random_seeds_count + seed_size; seed_index++) + { + unsigned int seed_value + = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count]; + + for(const auto size : test_utils::get_sizes(seed_value)) + { + using wrapped_config = rocprim::detail::wrapped_search_n_config; + size_t temp_storage_size; + hipStream_t stream = 0; // default + rocprim::detail::target_arch target_arch; + HIP_CHECK(rocprim::detail::host_target_arch(stream, target_arch)); + const auto params = rocprim::detail::dispatch_target_arch(target_arch); + const unsigned int block_size = params.kernel_config.block_size; + const unsigned int items_per_thread = params.kernel_config.items_per_thread; + const unsigned int items_per_block = block_size * items_per_thread; + + /// Will do test like this: + /// |----------------------------------- size ------------------------------------------------------------------------------------------------------------------------------------------------------------------... + /// |----- Item/block ----|----- Item/block ----|----- Item/block ----|----- Item/block ----|----- Item/block ----|----- Item/block ----|----- Item/block ----|----- Item/block ----|----- Item/block ----| + /// |------------------------------count----------------------------| |------------------------------count----------------------------| |------------------------------count----------------------------| + /// |01111111111111111111111111111111111111111111111111111111111111110|11111111111111111111111111111111111111111111111111111111111111110|11111111111111111111111111111111111111111111111111111111111111111|11111... + + size_t count = 0; + input_type h_value{1}; + input_type h_noise{0}; + std::vector h_input(size, h_value); + output_type h_output; + + if(size > items_per_block * 3) + { + count = items_per_block * 3; + size_t cur_tile = 0; + size_t last_tile = size / count - 1; + while(cur_tile != last_tile) + { + h_input[cur_tile * count + count - 1] = h_noise; + ++cur_tile; + } + count -= 1; + h_input[0] = h_noise; + } + + test_utils::device_ptr d_input(h_input); + test_utils::device_ptr d_value(&h_value, 1); + test_utils::device_ptr d_output(1); + test_utils::device_ptr d_temp_storage; + + SCOPED_TRACE(testing::Message() << "with size = " << h_input.size()); + SCOPED_TRACE(testing::Message() << "with count = " << count); + SCOPED_TRACE(testing::Message() << "with value = " << h_value); + + if ROCPRIM_IF_CONSTEXPR(TestFixture::use_graphs) + { + // Default stream does not support hipGraph stream capture, so create one + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + } + // get size + HIP_CHECK(rocprim::search_n(0, + temp_storage_size, + d_input.get(), + d_output.get(), + h_input.size(), + count, + nullptr)); + + d_temp_storage.resize(temp_storage_size); + + test_utils::GraphHelper gHelper; + if(TestFixture::use_graphs) + { + gHelper.startStreamCapture(stream); + } + + HIP_CHECK(rocprim::search_n(d_temp_storage.get(), + temp_storage_size, + d_input.get(), + d_output.get(), + h_input.size(), + count, + d_value.get(), + op, + stream, + debug_synchronous)); + + if ROCPRIM_IF_CONSTEXPR(TestFixture::use_graphs) + { + gHelper.createAndLaunchGraph(stream); + } + + HIP_CHECK(hipGetLastError()); + HIP_CHECK(hipStreamSynchronize(stream)); + + const auto expected + = std::search_n(h_input.cbegin(), h_input.cend(), count, h_value, op) + - h_input.cbegin(); + + h_output = d_output.load()[0]; + + ASSERT_EQ(h_output, expected); + + if ROCPRIM_IF_CONSTEXPR(TestFixture::use_graphs) + { + gHelper.cleanupGraphHelper(); + HIP_CHECK(hipStreamDestroy(stream)); + } + } + } +} + +TYPED_TEST(RocprimDeviceSearchNTests, MultiResult2) +{ + int device_id = test_common_utils::obtain_device_from_ctest(); + + HIP_CHECK(hipSetDevice(device_id)); + + using input_type = typename TestFixture::input_type; + using output_type = typename TestFixture::output_type; + using op_type = typename TestFixture::op_type; + using config = typename TestFixture::config; + + constexpr bool debug_synchronous = TestFixture::debug_synchronous; + op_type op{}; + + for(size_t seed_index = 0; seed_index < random_seeds_count + seed_size; seed_index++) + { + unsigned int seed_value + = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count]; + + for(const auto size : test_utils::get_sizes(seed_value)) + { + using wrapped_config = rocprim::detail::wrapped_search_n_config; + size_t temp_storage_size; + hipStream_t stream = 0; // default + rocprim::detail::target_arch target_arch; + HIP_CHECK(rocprim::detail::host_target_arch(stream, target_arch)); + const auto params = rocprim::detail::dispatch_target_arch(target_arch); + const unsigned int block_size = params.kernel_config.block_size; + const unsigned int items_per_thread = params.kernel_config.items_per_thread; + const unsigned int items_per_block = block_size * items_per_thread; + + size_t count = 0; + input_type h_value{1}; + input_type h_noise{0}; + std::vector h_input(size); + output_type h_output; + + if(size > items_per_block) + { + count = items_per_block; + size_t start = size - 1 - count; + std::fill(h_input.begin() + start, h_input.end(), h_value); + for(size_t i = 0; i < start; i++) + { + if(!(i % 3)) + { + h_input[i] = h_noise; + } + } + } + + test_utils::device_ptr d_input(h_input); + test_utils::device_ptr d_value(&h_value, 1); + test_utils::device_ptr d_output(1); + test_utils::device_ptr d_temp_storage; + + SCOPED_TRACE(testing::Message() << "with size = " << h_input.size()); + SCOPED_TRACE(testing::Message() << "with count = " << count); + SCOPED_TRACE(testing::Message() << "with value = " << h_value); + + if ROCPRIM_IF_CONSTEXPR(TestFixture::use_graphs) + { + // Default stream does not support hipGraph stream capture, so create one + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + } + // get size + HIP_CHECK(rocprim::search_n(0, + temp_storage_size, + d_input.get(), + d_output.get(), + h_input.size(), + count, + nullptr)); + + d_temp_storage.resize(temp_storage_size); + + test_utils::GraphHelper gHelper; + if(TestFixture::use_graphs) + { + gHelper.startStreamCapture(stream); + } + + HIP_CHECK(rocprim::search_n(d_temp_storage.get(), + temp_storage_size, + d_input.get(), + d_output.get(), + h_input.size(), + count, + d_value.get(), + op, + stream, + debug_synchronous)); + + if ROCPRIM_IF_CONSTEXPR(TestFixture::use_graphs) + { + gHelper.createAndLaunchGraph(stream); + } + + HIP_CHECK(hipGetLastError()); + HIP_CHECK(hipStreamSynchronize(stream)); + + const auto expected + = std::search_n(h_input.cbegin(), h_input.cend(), count, h_value, op) + - h_input.cbegin(); + + h_output = d_output.load()[0]; + + ASSERT_EQ(h_output, expected); + + if ROCPRIM_IF_CONSTEXPR(TestFixture::use_graphs) + { + gHelper.cleanupGraphHelper(); + HIP_CHECK(hipStreamDestroy(stream)); + } + } + } +} diff --git a/test/rocprim/test_device_segmented_radix_sort.cpp.in b/test/rocprim/test_device_segmented_radix_sort.cpp.in index 7b5217298..d212492a8 100644 --- a/test/rocprim/test_device_segmented_radix_sort.cpp.in +++ b/test/rocprim/test_device_segmented_radix_sort.cpp.in @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2022 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2022-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -41,6 +41,18 @@ #elif ROCPRIM_TEST_SUITE_SLICE == 3 TYPED_TEST_P(SUITE, SortPairsDoubleBuffer ) { sort_pairs_double_buffer(); } REGISTER_TYPED_TEST_SUITE_P(SUITE, SortPairsDoubleBuffer); +#elif HIPCUB_TEST_SUITE_SLICE == 4 + TYPED_TEST_P(SUITE, SortKeysEmptyData ) { sort_keys_empty_data(); } + REGISTER_TYPED_TEST_SUITE_P(SUITE, SortKeysEmptyData); +#elif HIPCUB_TEST_SUITE_SLICE == 5 + TYPED_TEST_P(SUITE, SortKeysLargeSegments ) { sort_keys_large_segments(); } + REGISTER_TYPED_TEST_SUITE_P(SUITE, SortKeysLargeSegments); +#elif HIPCUB_TEST_SUITE_SLICE == 6 + TYPED_TEST_P(SUITE, SortKeysUnspecifiedRanges ) { sort_keys_unspecified_ranges(); } + REGISTER_TYPED_TEST_SUITE_P(SUITE, SortKeysUnspecifiedRanges); +#elif HIPCUB_TEST_SUITE_SLICE == 7 + TYPED_TEST_P(SUITE, SortPairsUnspecifiedRanges ) { sort_pairs_unspecified_ranges(); } + REGISTER_TYPED_TEST_SUITE_P(SUITE, SortPairsUnspecifiedRanges); #endif #if ROCPRIM_TEST_TYPE_SLICE == 0 diff --git a/test/rocprim/test_device_segmented_radix_sort.hpp b/test/rocprim/test_device_segmented_radix_sort.hpp index 52ae07121..3daa8f99e 100644 --- a/test/rocprim/test_device_segmented_radix_sort.hpp +++ b/test/rocprim/test_device_segmented_radix_sort.hpp @@ -139,22 +139,11 @@ inline void sort_keys() SCOPED_TRACE(testing::Message() << "with size = " << size); // Generate data - std::vector keys_input; - if(rocprim::is_floating_point::value) - { - keys_input = test_utils::get_random_data(size, - static_cast(-1000), - static_cast(+1000), - seed_value); - } - else - { - keys_input - = test_utils::get_random_data(size, - std::numeric_limits::min(), - std::numeric_limits::max(), - seed_value); - } + std::vector keys_input = test_utils::get_random_data( + size, + test_utils::generate_limits::min(), + test_utils::generate_limits::max(), + seed_value); std::vector offsets; unsigned int segments_count = 0; @@ -262,14 +251,130 @@ inline void sort_keys() } template -inline void sort_pairs() +inline void sort_keys_empty_data() { int device_id = test_common_utils::obtain_device_from_ctest(); SCOPED_TRACE(testing::Message() << "with device_id = " << device_id); HIP_CHECK(hipSetDevice(device_id)); + using key_type = typename TestFixture::params::key_type; + using config = typename TestFixture::params::config; + static constexpr bool descending = TestFixture::params::descending; + static constexpr unsigned int start_bit = TestFixture::params::start_bit; + static constexpr unsigned int end_bit = TestFixture::params::end_bit; + + using offset_type = unsigned int; + + hipStream_t stream = 0; + + const std::vector sizes = {0, 1024}; + for(size_t size : sizes) + { + SCOPED_TRACE(testing::Message() << "with size = " << size); + const std::vector segments_counts = {0, 1}; + for(size_t segments_count : segments_counts) + { + unsigned int seed_value = seeds[0]; + SCOPED_TRACE(testing::Message() << "with segments_count = " << segments_count); + + // Generate data + std::vector keys_input = test_utils::get_random_data( + size, + test_utils::generate_limits::min(), + test_utils::generate_limits::max(), + seed_value); + + std::vector offsets(2); + offsets[0] = 0; + offsets[1] = 0; + + key_type* d_keys; + HIP_CHECK(test_common_utils::hipMallocHelper(&d_keys, size * sizeof(key_type))); + HIP_CHECK(hipMemcpy(d_keys, + keys_input.data(), + size * sizeof(key_type), + hipMemcpyHostToDevice)); + + offset_type* d_offsets; + HIP_CHECK( + test_common_utils::hipMallocHelper(&d_offsets, + (segments_count + 1) * sizeof(offset_type))); + HIP_CHECK(hipMemcpy(d_offsets, + offsets.data(), + (segments_count + 1) * sizeof(offset_type), + hipMemcpyHostToDevice)); + + size_t temporary_storage_bytes = 0; + HIP_CHECK(rocprim::segmented_radix_sort_keys(nullptr, + temporary_storage_bytes, + d_keys, + d_keys, + size, + segments_count, + d_offsets, + d_offsets + 1, + start_bit, + end_bit)); + + ASSERT_GT(temporary_storage_bytes, 0U); + + void* d_temporary_storage; + HIP_CHECK( + test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_bytes)); + + if(descending) + { + HIP_CHECK(rocprim::segmented_radix_sort_pairs_desc(d_temporary_storage, + temporary_storage_bytes, + d_keys, + d_keys, + size, + segments_count, + d_offsets, + d_offsets + 1, + start_bit, + end_bit, + stream)); + } + else + { + HIP_CHECK(rocprim::segmented_radix_sort_pairs(d_temporary_storage, + temporary_storage_bytes, + d_keys, + d_keys, + size, + segments_count, + d_offsets, + d_offsets + 1, + start_bit, + end_bit, + stream)); + } + + std::vector keys_output(size); + HIP_CHECK(hipMemcpy(keys_output.data(), + d_keys, + size * sizeof(key_type), + hipMemcpyDeviceToHost)); + + HIP_CHECK(hipFree(d_temporary_storage)); + HIP_CHECK(hipFree(d_keys)); + HIP_CHECK(hipFree(d_offsets)); + + // Output should not have changed + ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(keys_output, keys_input)); + } + } +} + +template +inline void sort_keys_large_segments() +{ + int device_id = test_common_utils::obtain_device_from_ctest(); + SCOPED_TRACE(testing::Message() << "with device_id= " << device_id); + HIP_CHECK(hipSetDevice(device_id)); + using key_type = typename TestFixture::params::key_type; - using value_type = typename TestFixture::params::value_type; using config = typename TestFixture::params::config; constexpr bool descending = TestFixture::params::descending; constexpr unsigned int start_bit = TestFixture::params::start_bit; @@ -279,7 +384,132 @@ inline void sort_pairs() hipStream_t stream = 0; - const bool debug_synchronous = false; + size_t size = 1 << 20; + size_t segments_count = 2; + + for(size_t seed_index = 0; seed_index < random_seeds_count + seed_size; seed_index++) + { + unsigned int seed_value + = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count]; + SCOPED_TRACE(testing::Message() << "with seed = " << seed_value); + + // Generate data + std::vector keys_input + = test_utils::get_random_data(size, + test_utils::generate_limits::min(), + test_utils::generate_limits::max(), + seed_value); + + std::vector offsets(3); + offsets[0] = 0; + offsets[1] = static_cast(size / 2); + offsets[2] = static_cast(size); + + key_type* d_keys_input; + key_type* d_keys_output; + HIP_CHECK(test_common_utils::hipMallocHelper(&d_keys_input, size * sizeof(key_type))); + HIP_CHECK(test_common_utils::hipMallocHelper(&d_keys_output, size * sizeof(key_type))); + HIP_CHECK(hipMemcpy(d_keys_input, + keys_input.data(), + size * sizeof(key_type), + hipMemcpyHostToDevice)); + + offset_type* d_offsets; + HIP_CHECK(test_common_utils::hipMallocHelper(&d_offsets, + (segments_count + 1) * sizeof(offset_type))); + HIP_CHECK(hipMemcpy(d_offsets, + offsets.data(), + (segments_count + 1) * sizeof(offset_type), + hipMemcpyHostToDevice)); + + // Calculate expected results on host + std::vector expected(keys_input); + for(size_t i = 0; i < segments_count; i++) + { + std::stable_sort( + expected.begin() + offsets[i], + expected.begin() + offsets[i + 1], + test_utils::key_comparator()); + } + + size_t temporary_storage_bytes = 0; + HIP_CHECK(rocprim::segmented_radix_sort_keys(nullptr, + temporary_storage_bytes, + d_keys_input, + d_keys_output, + size, + segments_count, + d_offsets, + d_offsets + 1, + start_bit, + end_bit)); + + ASSERT_GT(temporary_storage_bytes, 0U); + + void* d_temporary_storage; + HIP_CHECK( + test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_bytes)); + + if(descending) + { + HIP_CHECK(rocprim::segmented_radix_sort_keys_desc(d_temporary_storage, + temporary_storage_bytes, + d_keys_input, + d_keys_output, + size, + segments_count, + d_offsets, + d_offsets + 1, + start_bit, + end_bit, + stream)); + } + else + { + HIP_CHECK(rocprim::segmented_radix_sort_keys(d_temporary_storage, + temporary_storage_bytes, + d_keys_input, + d_keys_output, + size, + segments_count, + d_offsets, + d_offsets + 1, + start_bit, + end_bit, + stream)); + } + + std::vector keys_output(size); + HIP_CHECK(hipMemcpy(keys_output.data(), + d_keys_output, + size * sizeof(key_type), + hipMemcpyDeviceToHost)); + + HIP_CHECK(hipFree(d_temporary_storage)); + HIP_CHECK(hipFree(d_keys_input)); + HIP_CHECK(hipFree(d_keys_output)); + HIP_CHECK(hipFree(d_offsets)); + + ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(keys_output, expected)); + } +} + +template +inline void sort_keys_unspecified_ranges() +{ + int device_id = test_common_utils::obtain_device_from_ctest(); + SCOPED_TRACE(testing::Message() << "with device_id = " << device_id); + HIP_CHECK(hipSetDevice(device_id)); + + using key_type = typename TestFixture::params::key_type; + using config = typename TestFixture::params::config; + constexpr bool descending = TestFixture::params::descending; + constexpr unsigned int start_bit = TestFixture::params::start_bit; + constexpr unsigned int end_bit = TestFixture::params::end_bit; + + using offset_type = unsigned int; + + hipStream_t stream = 0; std::random_device rd; std::default_random_engine gen(rd()); @@ -299,23 +529,187 @@ inline void sort_pairs() SCOPED_TRACE(testing::Message() << "with size = " << size); // Generate data - std::vector keys_input; - if(rocprim::is_floating_point::value) + std::vector keys_input = test_utils::get_random_data( + size, + test_utils::generate_limits::min(), + test_utils::generate_limits::max(), + seed_value); + + std::vector begin_offsets; + unsigned int segments_count = 0; + size_t offset = 0; + while(offset < size) + { + const size_t segment_length = segment_length_dis(gen); + begin_offsets.push_back(offset); + segments_count++; + offset += segment_length; + } + begin_offsets.push_back(size); + std::vector end_offsets(begin_offsets.cbegin() + 1, begin_offsets.cend()); + begin_offsets.pop_back(); + + size_t empty_segments = rocprim::max(segments_count / 16, 1u); + std::vector is_empty_segment(segments_count, false); + std::fill(is_empty_segment.begin(), is_empty_segment.begin() + empty_segments, true); + std::shuffle(is_empty_segment.begin(), is_empty_segment.end(), gen); + + for(size_t i = 0; i < segments_count; i++) + { + if(is_empty_segment[i]) + { + begin_offsets[i] = 0; + end_offsets[i] = 0; + } + } + + key_type* d_keys_input; + key_type* d_keys_output; + HIP_CHECK(test_common_utils::hipMallocHelper(&d_keys_input, size * sizeof(key_type))); + HIP_CHECK(test_common_utils::hipMallocHelper(&d_keys_output, size * sizeof(key_type))); + HIP_CHECK(hipMemcpy(d_keys_input, + keys_input.data(), + size * sizeof(key_type), + hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_keys_output, + keys_input.data(), + size * sizeof(key_type), + hipMemcpyHostToDevice)); + + offset_type* d_offsets_begin; + offset_type* d_offsets_end; + HIP_CHECK(test_common_utils::hipMallocHelper(&d_offsets_begin, + segments_count * sizeof(offset_type))); + HIP_CHECK(test_common_utils::hipMallocHelper(&d_offsets_end, + segments_count * sizeof(offset_type))); + HIP_CHECK(hipMemcpy(d_offsets_begin, + begin_offsets.data(), + segments_count * sizeof(offset_type), + hipMemcpyHostToDevice)); + + HIP_CHECK(hipMemcpy(d_offsets_end, + end_offsets.data(), + segments_count * sizeof(offset_type), + hipMemcpyHostToDevice)); + + // Calculate expected results on host + std::vector expected(keys_input); + for(size_t i = 0; i < segments_count; i++) + { + std::stable_sort( + expected.begin() + begin_offsets[i], + expected.begin() + end_offsets[i], + test_utils::key_comparator()); + } + + size_t temporary_storage_bytes = 0; + HIP_CHECK(rocprim::segmented_radix_sort_keys(nullptr, + temporary_storage_bytes, + d_keys_input, + d_keys_output, + size, + segments_count, + d_offsets_begin, + d_offsets_end, + start_bit, + end_bit)); + + ASSERT_GT(temporary_storage_bytes, 0U); + + void* d_temporary_storage; + HIP_CHECK( + test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_bytes)); + + if(descending) { - keys_input = test_utils::get_random_data(size, - static_cast(-1000), - static_cast(+1000), - seed_value); + HIP_CHECK(rocprim::segmented_radix_sort_keys_desc(d_temporary_storage, + temporary_storage_bytes, + d_keys_input, + d_keys_output, + size, + segments_count, + d_offsets_begin, + d_offsets_end, + start_bit, + end_bit, + stream)); } else { - keys_input - = test_utils::get_random_data(size, - std::numeric_limits::min(), - std::numeric_limits::max(), - seed_value); + HIP_CHECK(rocprim::segmented_radix_sort_keys(d_temporary_storage, + temporary_storage_bytes, + d_keys_input, + d_keys_output, + size, + segments_count, + d_offsets_begin, + d_offsets_end, + start_bit, + end_bit, + stream)); } + std::vector keys_output(size); + HIP_CHECK(hipMemcpy(keys_output.data(), + d_keys_output, + size * sizeof(key_type), + hipMemcpyDeviceToHost)); + + HIP_CHECK(hipFree(d_temporary_storage)); + HIP_CHECK(hipFree(d_keys_input)); + HIP_CHECK(hipFree(d_keys_output)); + HIP_CHECK(hipFree(d_offsets_begin)); + HIP_CHECK(hipFree(d_offsets_end)); + + ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(keys_output, expected)); + } + } +} + +template +inline void sort_pairs() +{ + int device_id = test_common_utils::obtain_device_from_ctest(); + SCOPED_TRACE(testing::Message() << "with device_id = " << device_id); + HIP_CHECK(hipSetDevice(device_id)); + + using key_type = typename TestFixture::params::key_type; + using value_type = typename TestFixture::params::value_type; + using config = typename TestFixture::params::config; + constexpr bool descending = TestFixture::params::descending; + constexpr unsigned int start_bit = TestFixture::params::start_bit; + constexpr unsigned int end_bit = TestFixture::params::end_bit; + + using offset_type = unsigned int; + + hipStream_t stream = 0; + + const bool debug_synchronous = false; + + std::random_device rd; + std::default_random_engine gen(rd()); + + std::uniform_int_distribution segment_length_dis( + TestFixture::params::min_segment_length, + TestFixture::params::max_segment_length); + + for(size_t seed_index = 0; seed_index < random_seeds_count + seed_size; seed_index++) + { + unsigned int seed_value + = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count]; + SCOPED_TRACE(testing::Message() << "with seed = " << seed_value); + + for(size_t size : test_utils::get_sizes(seed_value)) + { + SCOPED_TRACE(testing::Message() << "with size = " << size); + + // Generate data + std::vector keys_input = test_utils::get_random_data( + size, + test_utils::generate_limits::min(), + test_utils::generate_limits::max(), + seed_value); + std::vector offsets; unsigned int segments_count = 0; size_t offset = 0; @@ -467,13 +861,14 @@ inline void sort_pairs() } template -inline void sort_keys_double_buffer() +inline void sort_pairs_unspecified_ranges() { int device_id = test_common_utils::obtain_device_from_ctest(); SCOPED_TRACE(testing::Message() << "with device_id = " << device_id); HIP_CHECK(hipSetDevice(device_id)); using key_type = typename TestFixture::params::key_type; + using value_type = typename TestFixture::params::value_type; using config = typename TestFixture::params::config; constexpr bool descending = TestFixture::params::descending; constexpr unsigned int start_bit = TestFixture::params::start_bit; @@ -483,8 +878,6 @@ inline void sort_keys_double_buffer() hipStream_t stream = 0; - const bool debug_synchronous = false; - std::random_device rd; std::default_random_engine gen(rd()); @@ -503,22 +896,229 @@ inline void sort_keys_double_buffer() SCOPED_TRACE(testing::Message() << "with size = " << size); // Generate data - std::vector keys_input; - if(rocprim::is_floating_point::value) + std::vector keys_input = test_utils::get_random_data( + size, + test_utils::generate_limits::min(), + test_utils::generate_limits::max(), + seed_value); + + std::vector values_input(size); + std::iota(values_input.begin(), values_input.end(), 0); + + std::vector begin_offsets; + unsigned int segments_count = 0; + size_t offset = 0; + while(offset < size) + { + const size_t segment_length = segment_length_dis(gen); + begin_offsets.push_back(offset); + segments_count++; + offset += segment_length; + } + begin_offsets.push_back(size); + std::vector end_offsets(begin_offsets.cbegin() + 1, begin_offsets.cend()); + begin_offsets.pop_back(); + + size_t empty_segments = rocprim::max(segments_count / 16, 1u); + std::vector is_empty_segment(segments_count, false); + std::fill(is_empty_segment.begin(), is_empty_segment.begin() + empty_segments, true); + std::shuffle(is_empty_segment.begin(), is_empty_segment.end(), gen); + + for(size_t i = 0; i < segments_count; i++) + { + if(is_empty_segment[i]) + { + begin_offsets[i] = 0; + end_offsets[i] = 0; + } + } + + key_type* d_keys_input; + key_type* d_keys_output; + HIP_CHECK(test_common_utils::hipMallocHelper(&d_keys_input, size * sizeof(key_type))); + HIP_CHECK(test_common_utils::hipMallocHelper(&d_keys_output, size * sizeof(key_type))); + HIP_CHECK(hipMemcpy(d_keys_input, + keys_input.data(), + size * sizeof(key_type), + hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_keys_output, + keys_input.data(), + size * sizeof(key_type), + hipMemcpyHostToDevice)); + + value_type* d_values_input; + value_type* d_values_output; + HIP_CHECK( + test_common_utils::hipMallocHelper(&d_values_input, size * sizeof(value_type))); + HIP_CHECK( + test_common_utils::hipMallocHelper(&d_values_output, size * sizeof(value_type))); + HIP_CHECK(hipMemcpy(d_values_input, + values_input.data(), + size * sizeof(value_type), + hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_values_output, + values_input.data(), + size * sizeof(value_type), + hipMemcpyHostToDevice)); + + offset_type* d_offsets_begin; + offset_type* d_offsets_end; + HIP_CHECK(test_common_utils::hipMallocHelper(&d_offsets_begin, + segments_count * sizeof(offset_type))); + HIP_CHECK(test_common_utils::hipMallocHelper(&d_offsets_end, + segments_count * sizeof(offset_type))); + HIP_CHECK(hipMemcpy(d_offsets_begin, + begin_offsets.data(), + segments_count * sizeof(offset_type), + hipMemcpyHostToDevice)); + + HIP_CHECK(hipMemcpy(d_offsets_end, + end_offsets.data(), + segments_count * sizeof(offset_type), + hipMemcpyHostToDevice)); + using key_value = std::pair; + + // Calculate expected results on host + std::vector expected(size); + for(size_t i = 0; i < size; i++) + { + expected[i] = key_value(keys_input[i], values_input[i]); + } + for(size_t i = 0; i < segments_count; i++) + { + std::stable_sort(expected.begin() + begin_offsets[i], + expected.begin() + end_offsets[i], + test_utils::key_value_comparator()); + } + + size_t temporary_storage_bytes = 0; + HIP_CHECK(rocprim::segmented_radix_sort_pairs(nullptr, + temporary_storage_bytes, + d_keys_input, + d_keys_output, + d_values_input, + d_values_output, + size, + segments_count, + d_offsets_begin, + d_offsets_end, + start_bit, + end_bit)); + + ASSERT_GT(temporary_storage_bytes, 0U); + + void* d_temporary_storage; + HIP_CHECK( + test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_bytes)); + + if(descending) { - keys_input = test_utils::get_random_data(size, - static_cast(-1000), - static_cast(+1000), - seed_value); + HIP_CHECK(rocprim::segmented_radix_sort_pairs_desc(d_temporary_storage, + temporary_storage_bytes, + d_keys_input, + d_keys_output, + d_values_input, + d_values_output, + size, + segments_count, + d_offsets_begin, + d_offsets_end, + start_bit, + end_bit, + stream)); } else { - keys_input - = test_utils::get_random_data(size, - std::numeric_limits::min(), - std::numeric_limits::max(), - seed_value); + HIP_CHECK(rocprim::segmented_radix_sort_pairs(d_temporary_storage, + temporary_storage_bytes, + d_keys_input, + d_keys_output, + d_values_input, + d_values_output, + size, + segments_count, + d_offsets_begin, + d_offsets_end, + start_bit, + end_bit, + stream)); + } + + std::vector keys_output(size); + HIP_CHECK(hipMemcpy(keys_output.data(), + d_keys_output, + size * sizeof(key_type), + hipMemcpyDeviceToHost)); + + std::vector values_output(size); + HIP_CHECK(hipMemcpy(values_output.data(), + d_values_output, + size * sizeof(value_type), + hipMemcpyDeviceToHost)); + + HIP_CHECK(hipFree(d_temporary_storage)); + HIP_CHECK(hipFree(d_keys_input)); + HIP_CHECK(hipFree(d_values_input)); + HIP_CHECK(hipFree(d_keys_output)); + HIP_CHECK(hipFree(d_values_output)); + HIP_CHECK(hipFree(d_offsets_begin)); + HIP_CHECK(hipFree(d_offsets_end)); + + for(size_t i = 0; i < size; i++) + { + ASSERT_EQ(keys_output[i], expected[i].first); + ASSERT_EQ(values_output[i], expected[i].second); } + } + } +} + +template +inline void sort_keys_double_buffer() +{ + int device_id = test_common_utils::obtain_device_from_ctest(); + SCOPED_TRACE(testing::Message() << "with device_id = " << device_id); + HIP_CHECK(hipSetDevice(device_id)); + + using key_type = typename TestFixture::params::key_type; + using config = typename TestFixture::params::config; + constexpr bool descending = TestFixture::params::descending; + constexpr unsigned int start_bit = TestFixture::params::start_bit; + constexpr unsigned int end_bit = TestFixture::params::end_bit; + + using offset_type = unsigned int; + + hipStream_t stream = 0; + + const bool debug_synchronous = false; + + std::random_device rd; + std::default_random_engine gen(rd()); + + std::uniform_int_distribution segment_length_dis( + TestFixture::params::min_segment_length, + TestFixture::params::max_segment_length); + + for(size_t seed_index = 0; seed_index < random_seeds_count + seed_size; seed_index++) + { + unsigned int seed_value + = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count]; + SCOPED_TRACE(testing::Message() << "with seed = " << seed_value); + + for(size_t size : test_utils::get_sizes(seed_value)) + { + SCOPED_TRACE(testing::Message() << "with size = " << size); + + // Generate data + std::vector keys_input = test_utils::get_random_data( + size, + test_utils::generate_limits::min(), + test_utils::generate_limits::max(), + seed_value); std::vector offsets; unsigned int segments_count = 0; @@ -662,22 +1262,11 @@ inline void sort_pairs_double_buffer() SCOPED_TRACE(testing::Message() << "with size = " << size); // Generate data - std::vector keys_input; - if(rocprim::is_floating_point::value) - { - keys_input = test_utils::get_random_data(size, - static_cast(-1000), - static_cast(+1000), - seed_value); - } - else - { - keys_input - = test_utils::get_random_data(size, - std::numeric_limits::min(), - std::numeric_limits::max(), - seed_value); - } + std::vector keys_input = test_utils::get_random_data( + size, + test_utils::generate_limits::min(), + test_utils::generate_limits::max(), + seed_value); std::vector offsets; unsigned int segments_count = 0; diff --git a/test/rocprim/test_device_segmented_reduce.cpp b/test/rocprim/test_device_segmented_reduce.cpp index cf00db384..ad3ebdc4a 100644 --- a/test/rocprim/test_device_segmented_reduce.cpp +++ b/test/rocprim/test_device_segmented_reduce.cpp @@ -452,9 +452,9 @@ void testLargeIndices() SCOPED_TRACE(testing::Message() << "with seed = " << seed); ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(aggregates_output, aggregates_expected)); - hipFree(d_offsets); - hipFree(d_temp_storage); - hipFree(d_aggregates_output); + HIP_CHECK(hipFree(d_offsets)); + HIP_CHECK(hipFree(d_temp_storage)); + HIP_CHECK(hipFree(d_aggregates_output)); if(use_graphs) { diff --git a/test/rocprim/test_device_select.cpp b/test/rocprim/test_device_select.cpp index 79f31eef2..ffe841f0a 100644 --- a/test/rocprim/test_device_select.cpp +++ b/test/rocprim/test_device_select.cpp @@ -221,11 +221,11 @@ TYPED_TEST(RocprimDeviceSelectTests, Flagged) HIP_CHECK(hipDeviceSynchronize()); ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(output, expected, expected.size())); - hipFree(d_input); - hipFree(d_flags); - hipFree(d_output); - hipFree(d_selected_count_output); - hipFree(d_temp_storage); + HIP_CHECK(hipFree(d_input)); + HIP_CHECK(hipFree(d_flags)); + HIP_CHECK(hipFree(d_output)); + HIP_CHECK(hipFree(d_selected_count_output)); + HIP_CHECK(hipFree(d_temp_storage)); if(TestFixture::use_graphs) { @@ -384,10 +384,152 @@ TYPED_TEST(RocprimDeviceSelectTests, SelectOp) HIP_CHECK(hipDeviceSynchronize()); ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(output, expected, expected.size())); - hipFree(d_input); - hipFree(d_output); - hipFree(d_selected_count_output); - hipFree(d_temp_storage); + HIP_CHECK(hipFree(d_input)); + HIP_CHECK(hipFree(d_output)); + HIP_CHECK(hipFree(d_selected_count_output)); + HIP_CHECK(hipFree(d_temp_storage)); + + if(TestFixture::use_graphs) + { + gHelper.cleanupGraphHelper(); + } + } + } + + if(TestFixture::use_graphs) + { + HIP_CHECK(hipStreamDestroy(stream)); + } +} + +TYPED_TEST(RocprimDeviceSelectTests, SelectFlagged) +{ + int device_id = test_common_utils::obtain_device_from_ctest(); + SCOPED_TRACE(testing::Message() << "with device_id = " << device_id); + HIP_CHECK(hipSetDevice(device_id)); + + using T = typename TestFixture::input_type; + using U = typename TestFixture::output_type; + using F = typename TestFixture::flag_type; + static constexpr bool use_identity_iterator = TestFixture::use_identity_iterator; + + hipStream_t stream = 0; // default stream + if(TestFixture::use_graphs) + { + // Default stream does not support hipGraph stream capture, so create one + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + } + + for(size_t seed_index = 0; seed_index < random_seeds_count + seed_size; seed_index++) + { + unsigned int seed_value + = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count]; + SCOPED_TRACE(testing::Message() << "with seed = " << seed_value); + + for(auto size : test_utils::get_sizes(seed_value)) + { + SCOPED_TRACE(testing::Message() << "with size = " << size); + + // Generate data + std::vector input = test_utils::get_random_data(size, 1, 100, seed_value); + std::vector flags = test_utils::get_random_data(size, 0, 1, seed_value); + + T* d_input; + F* d_flags; + U* d_output; + unsigned int* d_selected_count_output; + HIP_CHECK(test_common_utils::hipMallocHelper(&d_input, input.size() * sizeof(T))); + HIP_CHECK(test_common_utils::hipMallocHelper(&d_flags, flags.size() * sizeof(F))); + HIP_CHECK(test_common_utils::hipMallocHelper(&d_output, input.size() * sizeof(U))); + HIP_CHECK( + test_common_utils::hipMallocHelper(&d_selected_count_output, sizeof(unsigned int))); + HIP_CHECK( + hipMemcpy(d_input, input.data(), input.size() * sizeof(T), hipMemcpyHostToDevice)); + HIP_CHECK( + hipMemcpy(d_flags, flags.data(), flags.size() * sizeof(F), hipMemcpyHostToDevice)); + + // Calculate expected results on host + std::vector expected; + expected.reserve(input.size()); + for(size_t i = 0; i < input.size(); i++) + { + if(select_op()(flags[i]) != 0) + { + expected.push_back(input[i]); + } + } + + // temp storage + size_t temp_storage_size_bytes; + // Get size of d_temp_storage + HIP_CHECK(rocprim::select( + nullptr, + temp_storage_size_bytes, + d_input, + d_flags, + test_utils::wrap_in_identity_iterator(d_output), + d_selected_count_output, + input.size(), + select_op(), + stream, + TestFixture::debug_synchronous)); + + HIP_CHECK(hipDeviceSynchronize()); + + // temp_storage_size_bytes must be >0 + ASSERT_GT(temp_storage_size_bytes, 0); + + // allocate temporary storage + void* d_temp_storage = nullptr; + HIP_CHECK(test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size_bytes)); + + test_utils::GraphHelper gHelper; + if(TestFixture::use_graphs) + { + gHelper.startStreamCapture(stream); + } + + // Run + HIP_CHECK(rocprim::select( + d_temp_storage, + temp_storage_size_bytes, + d_input, + d_flags, + test_utils::wrap_in_identity_iterator(d_output), + d_selected_count_output, + input.size(), + select_op(), + stream, + TestFixture::debug_synchronous)); + + if(TestFixture::use_graphs) + { + gHelper.createAndLaunchGraph(stream); + } + + HIP_CHECK(hipDeviceSynchronize()); + + // Check if number of selected value is as expected + unsigned int selected_count_output = 0; + HIP_CHECK(hipMemcpy(&selected_count_output, + d_selected_count_output, + sizeof(unsigned int), + hipMemcpyDeviceToHost)); + ASSERT_EQ(selected_count_output, expected.size()); + + // Check if output values are as expected + std::vector output(input.size()); + HIP_CHECK(hipMemcpy(output.data(), + d_output, + output.size() * sizeof(U), + hipMemcpyDeviceToHost)); + ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(output, expected, expected.size())); + + HIP_CHECK(hipFree(d_input)); + HIP_CHECK(hipFree(d_flags)); + HIP_CHECK(hipFree(d_output)); + HIP_CHECK(hipFree(d_selected_count_output)); + HIP_CHECK(hipFree(d_temp_storage)); if(TestFixture::use_graphs) { @@ -562,10 +704,10 @@ TYPED_TEST(RocprimDeviceSelectTests, Unique) HIP_CHECK(hipDeviceSynchronize()); ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(output, expected, expected.size())); - hipFree(d_input); - hipFree(d_output); - hipFree(d_selected_count_output); - hipFree(d_temp_storage); + HIP_CHECK(hipFree(d_input)); + HIP_CHECK(hipFree(d_output)); + HIP_CHECK(hipFree(d_selected_count_output)); + HIP_CHECK(hipFree(d_temp_storage)); if(TestFixture::use_graphs) { @@ -770,11 +912,11 @@ void testUniqueGuardedOperator() HIP_CHECK(hipDeviceSynchronize()); ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(output, expected, expected.size())); - hipFree(d_input); - hipFree(d_flag); - hipFree(d_output); - hipFree(d_selected_count_output); - hipFree(d_temp_storage); + HIP_CHECK(hipFree(d_input)); + HIP_CHECK(hipFree(d_flag)); + HIP_CHECK(hipFree(d_output)); + HIP_CHECK(hipFree(d_selected_count_output)); + HIP_CHECK(hipFree(d_temp_storage)); if(UseGraphs) { @@ -1033,12 +1175,12 @@ TYPED_TEST(RocprimDeviceUniqueByKeyTests, UniqueByKey) ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(output_keys, expected_keys, expected_keys.size())); ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(output_values, expected_values, expected_values.size())); - hipFree(d_keys_input); - hipFree(d_values_input); - hipFree(d_keys_output); - hipFree(d_values_output); - hipFree(d_selected_count_output); - hipFree(d_temp_storage); + HIP_CHECK(hipFree(d_keys_input)); + HIP_CHECK(hipFree(d_values_input)); + HIP_CHECK(hipFree(d_keys_output)); + HIP_CHECK(hipFree(d_values_output)); + HIP_CHECK(hipFree(d_selected_count_output)); + HIP_CHECK(hipFree(d_temp_storage)); if(TestFixture::use_graphs) { @@ -1229,10 +1371,10 @@ TYPED_TEST(RocprimDeviceUniqueByKeyTests, UniqueByKeyAlias) ASSERT_NO_FATAL_FAILURE( test_utils::assert_eq(output_values, expected_values, expected_values.size())); - hipFree(d_keys_input); - hipFree(d_values_input); - hipFree(d_selected_count_output); - hipFree(d_temp_storage); + HIP_CHECK(hipFree(d_keys_input)); + HIP_CHECK(hipFree(d_values_input)); + HIP_CHECK(hipFree(d_selected_count_output)); + HIP_CHECK(hipFree(d_temp_storage)); if(TestFixture::use_graphs) { @@ -1397,9 +1539,9 @@ TEST_P(RocprimDeviceSelectLargeInputTests, LargeInputFlagged) ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(output, expected_output, expected_output.size())); - hipFree(d_output); - hipFree(d_selected_count_output); - hipFree(d_temp_storage); + HIP_CHECK(hipFree(d_output)); + HIP_CHECK(hipFree(d_selected_count_output)); + HIP_CHECK(hipFree(d_temp_storage)); if(use_graphs) { @@ -1411,6 +1553,275 @@ TEST_P(RocprimDeviceSelectLargeInputTests, LargeInputFlagged) HIP_CHECK(hipStreamDestroy(stream)); } +template +struct large_select_op +{ + T max_value; + __device__ __host__ + inline bool + operator()(const T& value) const + { + return rocprim::less()(value, T(max_value)); + } +}; + +TEST_P(RocprimDeviceSelectLargeInputTests, LargeInputSelectOp) +{ + int device_id = test_common_utils::obtain_device_from_ctest(); + SCOPED_TRACE(testing::Message() << "with device_id = " << device_id); + HIP_CHECK(hipSetDevice(device_id)); + + auto param = GetParam(); + const bool use_graphs = std::get<1>(param); + + const bool debug_synchronous = RocprimDeviceSelectLargeInputTests::debug_synchronous; + + hipStream_t stream = 0; // default stream + if(use_graphs) + { + // Default stream does not support hipGraph stream capture, so create one + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + } + + for(auto size : test_utils::get_large_sizes(0)) + { + const size_t selected_input = std::get<0>(param); + auto select_op = large_select_op{selected_input}; + + // otherwise test is too long + if(size > (size_t{1} << 35)) + break; + SCOPED_TRACE(testing::Message() << "with size = " << size); + + // Generate data + auto input_iota = rocprim::make_counting_iterator(std::size_t{0}); + + size_t selected_count_output = 0; + size_t* d_selected_count_output; + + size_t expected_output_size = selected_input; + + size_t* d_output; + std::vector output(expected_output_size); + + // Calculate expected results on host + std::vector expected_output(expected_output_size); + std::iota(expected_output.begin(), expected_output.end(), 0); + + HIP_CHECK(test_common_utils::hipMallocHelper(&d_output, + sizeof(d_output[0]) * expected_output_size)); + HIP_CHECK(test_common_utils::hipMallocHelper(&d_selected_count_output, + sizeof(d_selected_count_output[0]))); + + // temp storage + size_t temp_storage_size_bytes; + void* d_temp_storage = nullptr; + + // Get size of d_temp_storage + HIP_CHECK(rocprim::select(d_temp_storage, + temp_storage_size_bytes, + input_iota, + d_output, + d_selected_count_output, + size, + select_op, + stream, + debug_synchronous)); + + HIP_CHECK(hipDeviceSynchronize()); + + // temp_storage_size_bytes must be >0 + ASSERT_GT(temp_storage_size_bytes, 0); + + // allocate temporary storage + HIP_CHECK(test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size_bytes)); + + test_utils::GraphHelper gHelper; + if(use_graphs) + { + gHelper.startStreamCapture(stream); + } + + // Run + HIP_CHECK(rocprim::select(d_temp_storage, + temp_storage_size_bytes, + input_iota, + d_output, + d_selected_count_output, + size, + select_op, + stream, + debug_synchronous)); + + if(use_graphs) + { + gHelper.createAndLaunchGraph(stream); + } + + HIP_CHECK(hipDeviceSynchronize()); + + // Check if number of selected value is as expected + HIP_CHECK(hipMemcpy(&selected_count_output, + d_selected_count_output, + sizeof(size_t), + hipMemcpyDeviceToHost)); + ASSERT_EQ(selected_count_output, expected_output_size); + + // Check if output values are as expected + HIP_CHECK(hipMemcpy(output.data(), + d_output, + sizeof(output[0]) * expected_output_size, + hipMemcpyDeviceToHost)); + + ASSERT_NO_FATAL_FAILURE( + test_utils::assert_eq(output, expected_output, expected_output.size())); + + HIP_CHECK(hipFree(d_output)); + HIP_CHECK(hipFree(d_selected_count_output)); + HIP_CHECK(hipFree(d_temp_storage)); + + if(use_graphs) + { + gHelper.cleanupGraphHelper(); + } + } + + if(use_graphs) + HIP_CHECK(hipStreamDestroy(stream)); +} + +TEST_P(RocprimDeviceSelectLargeInputTests, LargeInputSelectFlagged) +{ + int device_id = test_common_utils::obtain_device_from_ctest(); + SCOPED_TRACE(testing::Message() << "with device_id = " << device_id); + HIP_CHECK(hipSetDevice(device_id)); + + auto param = GetParam(); + const bool use_graphs = std::get<1>(param); + + using InputIterator = typename rocprim::counting_iterator; + + const bool debug_synchronous = RocprimDeviceSelectLargeInputTests::debug_synchronous; + + hipStream_t stream = 0; // default stream + if(use_graphs) + { + // Default stream does not support hipGraph stream capture, so create one + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + } + + for(auto size : test_utils::get_large_sizes(0)) + { + // otherwise test is too long + if(size > (size_t{1} << 35)) + break; + SCOPED_TRACE(testing::Message() << "with size = " << size); + + const size_t selected_flags = std::get<0>(param); + auto select_op = large_select_op{selected_flags}; + + // Generate data + size_t initial_value = 0; + InputIterator input_begin(initial_value); + + auto flags_it = rocprim::make_counting_iterator(size_t(0)); + + size_t selected_count_output = 0; + size_t* d_selected_count_output; + + size_t expected_output_size = selected_flags; + + size_t* d_output; + std::vector output(expected_output_size); + + // Calculate expected results on host + std::vector expected_output(expected_output_size); + std::iota(expected_output.begin(), expected_output.end(), 0); + + HIP_CHECK(test_common_utils::hipMallocHelper(&d_output, + sizeof(d_output[0]) * expected_output_size)); + HIP_CHECK(test_common_utils::hipMallocHelper(&d_selected_count_output, + sizeof(d_selected_count_output[0]))); + + // temp storage + size_t temp_storage_size_bytes; + void* d_temp_storage = nullptr; + + // Get size of d_temp_storage + HIP_CHECK(rocprim::select(d_temp_storage, + temp_storage_size_bytes, + input_begin, + flags_it, + d_output, + d_selected_count_output, + size, + select_op, + stream, + debug_synchronous)); + + HIP_CHECK(hipDeviceSynchronize()); + + // temp_storage_size_bytes must be >0 + ASSERT_GT(temp_storage_size_bytes, 0); + + // allocate temporary storage + HIP_CHECK(test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size_bytes)); + + test_utils::GraphHelper gHelper; + if(use_graphs) + { + gHelper.startStreamCapture(stream); + } + + // Run + HIP_CHECK(rocprim::select(d_temp_storage, + temp_storage_size_bytes, + input_begin, + flags_it, + d_output, + d_selected_count_output, + size, + select_op, + stream, + debug_synchronous)); + + if(use_graphs) + { + gHelper.createAndLaunchGraph(stream); + } + + HIP_CHECK(hipDeviceSynchronize()); + + // Check if number of selected value is as expected + HIP_CHECK(hipMemcpy(&selected_count_output, + d_selected_count_output, + sizeof(size_t), + hipMemcpyDeviceToHost)); + ASSERT_EQ(selected_count_output, expected_output_size); + + // Check if output values are as expected + HIP_CHECK(hipMemcpy(output.data(), + d_output, + sizeof(output[0]) * expected_output_size, + hipMemcpyDeviceToHost)); + + ASSERT_NO_FATAL_FAILURE( + test_utils::assert_eq(output, expected_output, expected_output.size())); + + HIP_CHECK(hipFree(d_output)); + HIP_CHECK(hipFree(d_selected_count_output)); + HIP_CHECK(hipFree(d_temp_storage)); + + if(use_graphs) + { + gHelper.cleanupGraphHelper(); + } + } + + if(use_graphs) + HIP_CHECK(hipStreamDestroy(stream)); +} + TEST_P(RocprimDeviceSelectLargeInputTests, LargeInputUnique) { static constexpr bool debug_synchronous = false; diff --git a/test/rocprim/test_device_transform.cpp b/test/rocprim/test_device_transform.cpp index e9b60aac1..6efa48c6e 100644 --- a/test/rocprim/test_device_transform.cpp +++ b/test/rocprim/test_device_transform.cpp @@ -196,8 +196,8 @@ TYPED_TEST(RocprimDeviceTransformTests, Transform) ASSERT_NO_FATAL_FAILURE( test_utils::assert_near(output, expected, test_utils::precision)); - hipFree(d_input); - hipFree(d_output); + HIP_CHECK(hipFree(d_input)); + HIP_CHECK(hipFree(d_output)); if (TestFixture::use_graphs) { @@ -319,10 +319,10 @@ TYPED_TEST(RocprimDeviceTransformTests, BinaryTransform) ASSERT_NO_FATAL_FAILURE( test_utils::assert_near(output, expected, test_utils::precision)); - hipFree(d_input1); - hipFree(d_input2); - hipFree(d_output); - + HIP_CHECK(hipFree(d_input1)); + HIP_CHECK(hipFree(d_input2)); + HIP_CHECK(hipFree(d_output)); + if (TestFixture::use_graphs) { gHelper.cleanupGraphHelper(); @@ -332,6 +332,28 @@ TYPED_TEST(RocprimDeviceTransformTests, BinaryTransform) } } +template +struct flag_expected_op_t +{ + bool* d_flag; + T expected; + T expected_above_limit; + + __device__ + auto operator()(const T& value) -> int + { + if(value == expected) + { + d_flag[0] = true; + } + if(value == expected_above_limit) + { + d_flag[1] = true; + } + return 0; + } +}; + template void testLargeIndices() { @@ -378,17 +400,8 @@ void testLargeIndices() SCOPED_TRACE(testing::Message() << "expected = " << expected); SCOPED_TRACE(testing::Message() << "expected_above_limit = " << expected_above_limit); - const auto flag_expected = [=] __device__ (T value) -> int { - if(value == expected) - { - d_flag[0] = true; - } - if(value == expected_above_limit) - { - d_flag[1] = true; - } - return 0; - }; + const auto flag_expected + = flag_expected_op_t{d_flag, expected, expected_above_limit}; test_utils::GraphHelper gHelper; if(UseGraphs) diff --git a/test/rocprim/test_intrinsics.cpp b/test/rocprim/test_intrinsics.cpp index e2818612d..d2f4c45cc 100644 --- a/test/rocprim/test_intrinsics.cpp +++ b/test/rocprim/test_intrinsics.cpp @@ -372,7 +372,7 @@ void test_shuffle() } } - hipFree(d_data); + HIP_CHECK(hipFree(d_data)); } TYPED_TEST(RocprimIntrinsicsTests, ShuffleUp) @@ -501,8 +501,8 @@ TYPED_TEST(RocprimIntrinsicsTests, ShuffleIndex) ASSERT_EQ(output[j], expected[j]) << "where index = " << j; } } - hipFree(device_data); - hipFree(device_src_lanes); + HIP_CHECK(hipFree(device_data)); + HIP_CHECK(hipFree(device_src_lanes)); } } @@ -555,7 +555,7 @@ TEST(RocprimIntrinsicsTests, LaneId) } } - hipFree(d_output); + HIP_CHECK(hipFree(d_output)); } __global__ void masked_bit_count_kernel(unsigned int* out, @@ -669,8 +669,8 @@ TEST(RocprimIntrinsicsTests, MaskedBitCount) } } - hipFree(d_input); - hipFree(d_output); + HIP_CHECK(hipFree(d_input)); + HIP_CHECK(hipFree(d_output)); } enum class warp_any_all_test_type @@ -795,8 +795,8 @@ void warp_any_all_test() } } - hipFree(d_input); - hipFree(d_output); + HIP_CHECK(hipFree(d_input)); + HIP_CHECK(hipFree(d_output)); } TEST(RocprimIntrinsicsTests, WarpAny) @@ -949,9 +949,9 @@ TYPED_TEST(RocprimIntrinsicsTests, WarpPermute) } } - hipFree(d_input); - hipFree(d_output); - hipFree(d_indices); + HIP_CHECK(hipFree(d_input)); + HIP_CHECK(hipFree(d_output)); + HIP_CHECK(hipFree(d_indices)); } template @@ -1074,8 +1074,8 @@ TEST(RocprimIntrinsicsTests, MatchAny) } } - hipFree(d_input); - hipFree(d_output); + HIP_CHECK(hipFree(d_input)); + HIP_CHECK(hipFree(d_output)); } __global__ void @@ -1172,8 +1172,8 @@ TEST(RocprimIntrinsicsTests, Ballot) } } - hipFree(d_input); - hipFree(d_output); + HIP_CHECK(hipFree(d_input)); + HIP_CHECK(hipFree(d_output)); } __global__ void group_elect_kernel(max_lane_mask_type* output, @@ -1294,6 +1294,6 @@ TEST(RocprimIntrinsicsTests, GroupElect) } } - hipFree(d_input); - hipFree(d_output); + HIP_CHECK(hipFree(d_input)); + HIP_CHECK(hipFree(d_output)); } diff --git a/test/rocprim/test_radix_key_codec.cpp b/test/rocprim/test_radix_key_codec.cpp index 58b4b4c54..8189ccd61 100644 --- a/test/rocprim/test_radix_key_codec.cpp +++ b/test/rocprim/test_radix_key_codec.cpp @@ -427,8 +427,8 @@ TYPED_TEST(TypedRadixKeyCodecTest, EncodeDecodeExtract) const size_t size = (1 << 20) + 123; std::vector input_keys = test_utils::get_random_data(size, - test_utils::numeric_limits::min(), - test_utils::numeric_limits::max(), + test_utils::generate_limits::min(), + test_utils::generate_limits::max(), seed_value); for(size_t i = 0; i < size; ++i) diff --git a/test/rocprim/test_temporary_storage_partitioning.cpp b/test/rocprim/test_temporary_storage_partitioning.cpp index 92c07042a..b93f247f2 100644 --- a/test/rocprim/test_temporary_storage_partitioning.cpp +++ b/test/rocprim/test_temporary_storage_partitioning.cpp @@ -66,7 +66,7 @@ TEST(RocprimTemporaryStoragePartitioningTests, Basic) ASSERT_EQ(storage_size, size); ASSERT_EQ(test_allocation, temporary_storage); - hipFree(temporary_storage); + HIP_CHECK(hipFree(temporary_storage)); } TEST(RocprimTemporaryStoragePartitioningTests, ZeroSizePartition) diff --git a/test/rocprim/test_texture_cache_iterator.cpp b/test/rocprim/test_texture_cache_iterator.cpp index 8b0e5b24a..a8899d20c 100644 --- a/test/rocprim/test_texture_cache_iterator.cpp +++ b/test/rocprim/test_texture_cache_iterator.cpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -116,7 +116,7 @@ TYPED_TEST(RocprimTextureCacheIteratorTests, Transform) HIP_CHECK(hipDeviceSynchronize()); Iterator x; - x.bind_texture(d_input, sizeof(T) * input.size()); + HIP_CHECK(x.bind_texture(d_input, sizeof(T) * input.size())); // Calculate expected results on host std::vector expected(size); @@ -153,8 +153,8 @@ TYPED_TEST(RocprimTextureCacheIteratorTests, Transform) ASSERT_EQ(output[i], expected[i]) << "where index = " << i; } - x.unbind_texture(); - hipFree(d_input); - hipFree(d_output); + HIP_CHECK(x.unbind_texture()); + HIP_CHECK(hipFree(d_input)); + HIP_CHECK(hipFree(d_output)); } } diff --git a/test/rocprim/test_transform_iterator.cpp b/test/rocprim/test_transform_iterator.cpp index c3d71c281..b8b51a457 100644 --- a/test/rocprim/test_transform_iterator.cpp +++ b/test/rocprim/test_transform_iterator.cpp @@ -1,6 +1,6 @@ // MIT License // -// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -182,9 +182,9 @@ TYPED_TEST(RocprimTransformIteratorTests, TransformReduce) // Check if output values are as expected test_utils::assert_near(output[0], expected, test_utils::precision * size); - hipFree(d_input); - hipFree(d_output); - hipFree(d_temp_storage); + HIP_CHECK(hipFree(d_input)); + HIP_CHECK(hipFree(d_output)); + HIP_CHECK(hipFree(d_temp_storage)); } } diff --git a/test/rocprim/test_utils_assertions.hpp b/test/rocprim/test_utils_assertions.hpp index 1d72eeb78..d8d3c6a3d 100644 --- a/test/rocprim/test_utils_assertions.hpp +++ b/test/rocprim/test_utils_assertions.hpp @@ -294,14 +294,15 @@ auto assert_near(const custom_test_type& result, const custom_test_type& e #if ROCPRIM_HAS_INT128_SUPPORT template auto operator<<(std::ostream& os, const T& value) - -> std::enable_if_t::value || std::is_same::value, + -> std::enable_if_t::value + || std::is_same::value, std::ostream&> { static const char* charmap = "0123456789"; std::string result; result.reserve(41); // max. 40 digits possible ( uint64_t has 20) plus sign - __uint128_t helper = (value < 0) ? -value : value; + rocprim::uint128_t helper = (value < 0) ? -value : value; do { diff --git a/test/rocprim/test_utils_data_generation.hpp b/test/rocprim/test_utils_data_generation.hpp index 6358dffbe..df9df44a6 100644 --- a/test/rocprim/test_utils_data_generation.hpp +++ b/test/rocprim/test_utils_data_generation.hpp @@ -58,7 +58,7 @@ struct is_valid_for_int_distribution : namespace detail { template -struct numeric_limits_custom_test_type : public std::numeric_limits +struct numeric_limits_custom_test_type : public rocprim::numeric_limits {}; } // namespace detail @@ -67,7 +67,7 @@ template struct numeric_limits : public std::conditional::value || is_custom_test_array_type::value, detail::numeric_limits_custom_test_type, - std::numeric_limits>::type + rocprim::numeric_limits>::type {}; template<> struct numeric_limits : public std::numeric_limits { @@ -118,6 +118,48 @@ template<> class numeric_limits : public std::numeric_limi }; // End of extended numeric_limits +template +struct generate_limits +{ + static inline T min() + { + return rocprim::numeric_limits::min(); + } + static inline T max() + { + return rocprim::numeric_limits::max(); + } +}; + +template +struct generate_limits< + T, + std::enable_if_t::value || is_custom_test_type::value>> +{ + using Type = typename T::value_type; + static inline Type min() + { + return generate_limits::min(); + } + static inline Type max() + { + return generate_limits::max(); + } +}; + +template +struct generate_limits::value>> +{ + static inline T min() + { + return T(-1000); + } + static inline T max() + { + return T(1000); + } +}; + // Converts possible device side types to their relevant host side native types inline rocprim::native_half convert_to_native(const rocprim::half& value) { @@ -229,13 +271,13 @@ template constexpr Res saturate_cast(T x) noexcept { // Handle overflow - if(test_utils::cmp_less(x, std::numeric_limits::min())) + if(test_utils::cmp_less(x, numeric_limits::min())) { - return std::numeric_limits::min(); + return numeric_limits::min(); } - if(test_utils::cmp_greater(x, std::numeric_limits::max())) + if(test_utils::cmp_greater(x, numeric_limits::max())) { - return std::numeric_limits::max(); + return numeric_limits::max(); } // No overflow return static_cast(x); @@ -460,7 +502,7 @@ std::vector get_sizes(T seed_value) return sizes; } -template +template std::vector get_large_sizes(T seed_value) { std::vector sizes = { @@ -470,12 +512,12 @@ std::vector get_large_sizes(T seed_value) (size_t{1} << 33) + (size_t{1} << 32) - 876543, (size_t{1} << 34) - 12346, (size_t{1} << 35) + 1, - (size_t{1} << 37) - 1, + (size_t{1} << MaxPow2) - 1, }; const std::vector random_sizes = test_utils::get_random_data(2, (size_t{1} << 30) + 1, - (size_t{1} << 37) - 2, + (size_t{1} << MaxPow2) - 2, seed_value); sizes.insert(sizes.end(), random_sizes.begin(), random_sizes.end()); std::sort(sizes.begin(), sizes.end()); diff --git a/test/rocprim/test_utils_device_ptr.hpp b/test/rocprim/test_utils_device_ptr.hpp new file mode 100644 index 000000000..badad886e --- /dev/null +++ b/test/rocprim/test_utils_device_ptr.hpp @@ -0,0 +1,242 @@ +// Copyright (c) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_TEST_UTILS_DEVICE_PTR_HPP +#define ROCPRIM_TEST_UTILS_DEVICE_PTR_HPP + +#include +#include + +#include "../common_test_header.hpp" + +namespace test_utils +{ + +/// \brief An RAII friendly class to manage the memory allocated on device. +/// +/// \tparam A Template type used by the class. +template +class device_ptr +{ +public: + using decay_type = std::decay_t; + using size_type = std::size_t; + using value_type = typename std:: + conditional_t::value, unsigned char, PointerType>; + + device_ptr() : device_raw_ptr_(nullptr), number_of_ele_(0) {}; + + /// \brief Construct with a pre-allocated memory space. + device_ptr(size_type pre_alloc_number_of_ele) + : device_raw_ptr_(nullptr), number_of_ele_(pre_alloc_number_of_ele) + { + size_type storage_size = number_of_ele_ * sizeof(value_type); + HIP_CHECK(test_common_utils::hipMallocHelper(&device_raw_ptr_, storage_size)); + }; + + device_ptr(device_ptr const&) = delete; + + device_ptr(device_ptr&& other) noexcept + : device_raw_ptr_(other.device_raw_ptr_), number_of_ele_(other.number_of_ele_) + { + other.leak(); + }; + + /// \brief Construct by host vectors with the same sized value_type + template + explicit device_ptr(std::vector const& data) + : device_raw_ptr_(nullptr), number_of_ele_(data.size()) + { + static_assert( + sizeof(InVecValueType) == sizeof(value_type), + "value_type of input vector must have the same size with device_ptr::value_type"); + + size_type storage_size = number_of_ele_ * sizeof(value_type); + HIP_CHECK(test_common_utils::hipMallocHelper(&device_raw_ptr_, storage_size)); + HIP_CHECK(hipMemcpy(device_raw_ptr_, data.data(), storage_size, hipMemcpyHostToDevice)); + } + + /// \brief Construct with a copy of this `host_buffer` + /// + /// \param _number_of_ele be aware, this is NOT the sizeof `host_buffer`, this is the `number of elements` in the `host_buffer` + device_ptr(const void* host_buffer, size_type _number_of_ele) + : device_raw_ptr_(nullptr), number_of_ele_(_number_of_ele) + { + size_type storage_size = number_of_ele_ * sizeof(value_type); + HIP_CHECK(test_common_utils::hipMallocHelper(&device_raw_ptr_, storage_size)); + HIP_CHECK(hipMemcpy(device_raw_ptr_, host_buffer, storage_size, hipMemcpyHostToDevice)); + }; + + ~device_ptr() + { + free_manually(); + }; + + device_ptr& operator=(device_ptr const&) = delete; + + device_ptr& operator=(device_ptr&& other) noexcept + { + free_manually(); + device_raw_ptr_ = other.device_raw_ptr_; + number_of_ele_ = other.number_of_ele_; + other.leak(); + }; + + /// \brief Do copy on the device. + /// + /// \return A new `device_ptr` rvalue. + device_ptr duplicate() const + { + device_ptr ret; + ret.number_of_ele_ = number_of_ele_; + size_type storage_size = number_of_ele_ * sizeof(value_type); + HIP_CHECK(test_common_utils::hipMallocHelper(&ret.device_raw_ptr_, storage_size)); + HIP_CHECK( + hipMemcpy(ret.device_raw_ptr_, device_raw_ptr_, storage_size, hipMemcpyDeviceToDevice)); + return ret; + } + + /// \brief Do type cast and move the ownership to the new `device_ptr`. + /// + /// \return A new `device_ptr` rvalue. + template + device_ptr move_cast() noexcept + { + using target_value_t = typename device_ptr::value_type; + + auto ret_deivce_raw_ptr_ + = static_cast(static_cast(device_raw_ptr_)); + auto ret_number_of_ele_ = sizeof(value_type) * number_of_ele_ / sizeof(target_value_t); + leak(); + return {ret_deivce_raw_ptr_, ret_number_of_ele_}; + } + + /// \brief Get the device raw pointer + value_type* get() const noexcept + { + return device_raw_ptr_; + } + + /// \brief Clean every thing on this instance, which could lead to memory leak. Should call `get()` and free the raw pointer manually + void leak() noexcept + { + device_raw_ptr_ = nullptr; + number_of_ele_ = 0; + } + + /// \brief Call this function to garbage the memory in advance + void free_manually() + { + if(device_raw_ptr_) + { + HIP_CHECK(hipFree(device_raw_ptr_)); + } + leak(); + } + + void resize(size_type _new_number_of_ele) + { + if(_new_number_of_ele == 0) + { + free_manually(); + } + else + { + value_type* device_temp_ptr = nullptr; + HIP_CHECK(test_common_utils::hipMallocHelper(&device_temp_ptr, + _new_number_of_ele * sizeof(value_type))); + HIP_CHECK(hipMemcpy(device_temp_ptr, + device_raw_ptr_, + std::min(_new_number_of_ele, number_of_ele_) * sizeof(value_type), + hipMemcpyDeviceToDevice)); + free_manually(); + device_raw_ptr_ = device_temp_ptr; + number_of_ele_ = _new_number_of_ele; + } + } + + /// \brief Get the size of this memory space + size_type msize() const noexcept + { + return number_of_ele_ * sizeof(value_type); + } + + /// \brief Get the number of elements + size_type size() const noexcept + { + return number_of_ele_; + } + + /// \brief Copy from host to device + template + void store(std::vector const& host_vec, size_type offset = 0) + { + static_assert( + sizeof(InVecValueType) == sizeof(value_type), + "value_type of input vector must have the same size with device_ptr::value_type"); + + if(host_vec.size() + offset > number_of_ele_) + { + resize(host_vec.size() + offset); + } + + HIP_CHECK(hipMemcpy(device_raw_ptr_ + offset, + host_vec.data(), + host_vec.size() * sizeof(value_type), + hipMemcpyHostToDevice)); + } + + /// \brief Copy from host to device + template + void store(device_ptr const& device_ptr, size_type offset = 0) + { + static_assert(sizeof(InPtrValueType) == sizeof(value_type), + "sizeof(InPtrValueType) must equal to sizeof(value_type)"); + + if(device_ptr.number_of_ele_ + offset > number_of_ele_) + { + resize(device_ptr.number_of_ele_ + offset); + } + + HIP_CHECK(hipMemcpy(device_raw_ptr_ + offset, + device_ptr.device_raw_ptr_, + device_ptr.number_of_ele_ * sizeof(value_type), + hipMemcpyDeviceToDevice)); + } + + /// \brief Copy from device to host + std::vector load() const + { + std::vector ret(number_of_ele_); + HIP_CHECK(hipMemcpy(ret.data(), + device_raw_ptr_, + number_of_ele_ * sizeof(value_type), + hipMemcpyDeviceToHost)); + return ret; + } + +private: + value_type* device_raw_ptr_; + size_type number_of_ele_; +}; + +} // namespace test_utils + +#endif diff --git a/test/rocprim/test_utils_sort_comparator.hpp b/test/rocprim/test_utils_sort_comparator.hpp index bd4d6db00..5f26eb00e 100644 --- a/test/rocprim/test_utils_sort_comparator.hpp +++ b/test/rocprim/test_utils_sort_comparator.hpp @@ -42,25 +42,38 @@ template::value && !std::is_same::value) - || std::is_same::value - || std::is_same::value, + || std::is_same::value + || std::is_same::value, int> = 0> -Key to_bits(const Key key) +auto to_bits(const Key key) -> typename rocprim::get_unsigned_bits_type::unsigned_type { - static constexpr Key radix_mask_upper - = EndBit == 8 * sizeof(Key) ? ~Key(0) : static_cast((Key(1) << EndBit) - 1); - static constexpr Key radix_mask_bottom = static_cast((Key(1) << StartBit) - 1); - static constexpr Key radix_mask = radix_mask_upper ^ radix_mask_bottom; + using unsigned_bits_type = typename rocprim::get_unsigned_bits_type::unsigned_type; + + static constexpr unsigned_bits_type radix_mask_upper + = EndBit == 8 * sizeof(Key) + ? ~unsigned_bits_type(0) + : static_cast((unsigned_bits_type(1) << EndBit) - 1); + static constexpr unsigned_bits_type radix_mask_bottom + = static_cast((unsigned_bits_type(1) << StartBit) - 1); + static constexpr unsigned_bits_type radix_mask = radix_mask_upper ^ radix_mask_bottom; + + auto bit_key = static_cast(key); + // Flip sign bit to properly order signed types + if(::rocprim::is_signed::value) + { + constexpr auto sign_bit = static_cast(1) << (sizeof(Key) * 8 - 1); + bit_key ^= sign_bit; + } - return key & radix_mask; + return bit_key & radix_mask; } template::value, int> = 0> -Key to_bits(const Key key) +auto to_bits(const Key key) -> typename rocprim::get_unsigned_bits_type::unsigned_type { using unsigned_bits_type = typename rocprim::get_unsigned_bits_type::unsigned_type; unsigned_bits_type bit_key; @@ -79,7 +92,7 @@ template::value, int> = 0> -auto to_bits(const Key key) +auto to_bits(const Key key) -> typename rocprim::get_unsigned_bits_type::unsigned_type { using unsigned_bits_type = typename rocprim::get_unsigned_bits_type::unsigned_type; @@ -117,14 +130,14 @@ template::value, int> = 0> -auto to_bits(const Key& key) +auto to_bits(const Key& key) -> typename rocprim::get_unsigned_bits_type::unsigned_type { using inner_t = typename inner_type::type; using unsigned_bits_type = typename ::rocprim::get_unsigned_bits_type::unsigned_type; // For two doubles, we need uint128, but that is not part of rocprim::get_unsigned_bits_type using result_bits_type = std::conditional_t< sizeof(inner_t) == 8, - __uint128_t, + rocprim::uint128_t, typename rocprim::get_unsigned_bits_type(8), sizeof(inner_t) * 2)>::unsigned_type>; @@ -132,14 +145,6 @@ auto to_bits(const Key& key) auto bit_key_upper = static_cast(to_bits<0, sizeof(key.x) * 8>(key.x)); auto bit_key_lower = static_cast(to_bits<0, sizeof(key.y) * 8>(key.y)); - // Flip sign bit to properly order signed types - if(::rocprim::is_signed::value) - { - constexpr auto sign_bit = static_cast(1) << (sizeof(inner_t) * 8 - 1); - bit_key_upper ^= sign_bit; - bit_key_lower ^= sign_bit; - } - // Create the result containing both parts const auto bit_key = (static_cast(bit_key_upper) << (8 * sizeof(unsigned_bits_type))) @@ -153,7 +158,7 @@ template::value, int> = 0> -auto to_bits(const Key key) +auto to_bits(const Key key) -> typename rocprim::get_unsigned_bits_type::unsigned_type { return to_bits(key.x); } @@ -193,7 +198,8 @@ struct custom_test_type_decomposer "custom_test_type_decomposer can only be used with custom_test_type"); using inner_t = typename inner_type::type; - __host__ __device__ auto operator()(CustomTestType& key) const + __host__ __device__ + auto operator()(CustomTestType& key) const { return ::rocprim::tuple{key.x, key.y}; } diff --git a/test/rocprim/test_utils_types.hpp b/test/rocprim/test_utils_types.hpp index 3a04ed224..73c952667 100644 --- a/test/rocprim/test_utils_types.hpp +++ b/test/rocprim/test_utils_types.hpp @@ -164,8 +164,8 @@ typedef ::testing::Types block_param_type(bool, rocprim::half) #if ROCPRIM_HAS_INT128_SUPPORT , - block_param_type(__uint128_t, short), - block_param_type(__int128_t, float) + block_param_type(rocprim::uint128_t, short), + block_param_type(rocprim::int128_t, float) #endif > BlockParamsIntegralExtended; diff --git a/test/rocprim/test_warp_load.cpp b/test/rocprim/test_warp_load.cpp index 2adc6ffa9..b4fac714a 100644 --- a/test/rocprim/test_warp_load.cpp +++ b/test/rocprim/test_warp_load.cpp @@ -279,7 +279,8 @@ TYPED_TEST(WarpLoadTest, WarpLoadGuarded) constexpr unsigned int block_size = 1024; constexpr unsigned int items_count = items_per_thread * block_size; constexpr unsigned int valid_items = warp_size / 4; - constexpr T oob_default = std::numeric_limits::max(); + + const T oob_default = test_utils::numeric_limits::max(); int device_id = test_common_utils::obtain_device_from_ctest(); SKIP_IF_UNSUPPORTED_WARP_SIZE(warp_size, device_id); diff --git a/test/rocprim/test_zip_iterator.cpp b/test/rocprim/test_zip_iterator.cpp index 83fe7af91..5e105184e 100644 --- a/test/rocprim/test_zip_iterator.cpp +++ b/test/rocprim/test_zip_iterator.cpp @@ -252,10 +252,10 @@ TEST(RocprimZipIteratorTests, Transform) expected, std::max(test_utils::precision, test_utils::precision * 2)); - hipFree(d_input1); - hipFree(d_input2); - hipFree(d_input3); - hipFree(d_output); + HIP_CHECK(hipFree(d_input1)); + HIP_CHECK(hipFree(d_input2)); + HIP_CHECK(hipFree(d_input3)); + HIP_CHECK(hipFree(d_output)); } } @@ -441,12 +441,12 @@ TEST(RocprimZipIteratorTests, TransformReduce) (std::max(test_utils::precision, test_utils::precision) + test_utils::precision)*size); - hipFree(d_input1); - hipFree(d_input2); - hipFree(d_input3); - hipFree(d_output1); - hipFree(d_output2); - hipFree(d_temp_storage); + HIP_CHECK(hipFree(d_input1)); + HIP_CHECK(hipFree(d_input2)); + HIP_CHECK(hipFree(d_input3)); + HIP_CHECK(hipFree(d_output1)); + HIP_CHECK(hipFree(d_output2)); + HIP_CHECK(hipFree(d_temp_storage)); } }