diff --git a/rocprim/include/rocprim/device/detail/device_merge.hpp b/rocprim/include/rocprim/device/detail/device_merge.hpp index f10ac8f80..3126f05c9 100644 --- a/rocprim/include/rocprim/device/detail/device_merge.hpp +++ b/rocprim/include/rocprim/device/detail/device_merge.hpp @@ -1,4 +1,4 @@ -// Copyright (c) 2017-2022 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2017-2024 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -68,24 +68,23 @@ void partition_kernel_impl(IndexIterator indices, const unsigned int spacing, BinaryFunction compare_function) { - const unsigned int flat_id = ::rocprim::detail::block_thread_id<0>(); - const unsigned int flat_block_id = ::rocprim::detail::block_id<0>(); + const unsigned int flat_id = ::rocprim::detail::block_thread_id<0>(); + const unsigned int flat_block_id = ::rocprim::detail::block_id<0>(); const unsigned int flat_block_size = ::rocprim::detail::block_size<0>(); + const unsigned int input_size = input1_size + input2_size; + const unsigned int id = flat_block_id * flat_block_size + flat_id; + const unsigned int partition_id = id * spacing; + const unsigned int partitions = (input_size + spacing - 1) / spacing; - unsigned int id = flat_block_id * flat_block_size + flat_id; + if(id > partitions) + { + return; + } - unsigned int partition_id = id * spacing; size_t diag = min(static_cast(partition_id), input1_size + input2_size); - unsigned int begin = - merge_path( - keys_input1, - keys_input2, - input1_size, - input2_size, - diag, - compare_function - ); + unsigned int begin + = merge_path(keys_input1, keys_input2, input1_size, input2_size, diag, compare_function); indices[id] = begin; } @@ -310,8 +309,10 @@ void merge_kernel_impl(IndexIterator indices, const unsigned int valid_in_last_block = count - block_offset; const bool is_incomplete_block = valid_in_last_block < items_per_block; - const unsigned int p1 = indices[flat_block_id]; - const unsigned int p2 = indices[flat_block_id + 1]; + const unsigned int partitions = (count + items_per_block - 1) / items_per_block; + + const unsigned int p1 = indices[rocprim::min(flat_block_id, partitions)]; + const unsigned int p2 = indices[rocprim::min(flat_block_id + 1, partitions)]; range_t range = compute_range(