Skip to content

Commit

Permalink
Develop fix rocm530 device merge iterator mismatch (#381) (#388)
Browse files Browse the repository at this point in the history
* Add test for device_merge with mismatched iterators

Add a test case (failing) for using device_merge with different types for
`keys_input1` and `keys_input2`. This is supported and should work.

* Fix device_merge with mismatched input iterators

Co-authored-by: Gergely Meszaros <[email protected]>

Co-authored-by: Vincent van Heertum <[email protected]>
Co-authored-by: Gergely Meszaros <[email protected]>
  • Loading branch information
3 people authored Nov 18, 2022
1 parent 3529188 commit d8726e2
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 21 deletions.
16 changes: 11 additions & 5 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,22 @@

Full documentation for rocPRIM is available at [https://codedocs.xyz/ROCmSoftwarePlatform/rocPRIM/](https://codedocs.xyz/ROCmSoftwarePlatform/rocPRIM/)

## [Unreleased rocPRIM-2.12.0 for ROCm 5.4.0]
## Changed
## [rocPRIM-2.12.0 for ROCm 5.4.0]
### Changed
- `device_partition`, `device_unique`, and `device_reduce_by_key` now support problem
sizes larger than 2^32 items.
### Removed
- `block_sort::sort()` overload for keys and values with a dynamic size. This overload was documented but the
implementation is missing. To avoid further confusion the documentation is removed until a decision is made on
implementing the function.

## [Unreleased rocPRIM-2.11.0 for ROCm 5.3.0]
### Fixed
- Fixed the compilation failure in `device_merge` if the two key iterators don't match.

## [rocPRIM-2.11.0 for ROCm 5.3.2]
### Known Issue
- device_merge no longer correctly supports using different types for `keys_input1` and `keys_input2` (starting from the 5.3.0 release).

## [rocPRIM-2.11.0 for ROCm 5.3.0]
### Added
- New functions `subtract_left` and `subtract_right` in `block_adjacent_difference` to apply functions
on pairs of adjacent items distributed between threads in a block.
Expand All @@ -21,7 +27,7 @@ Full documentation for rocPRIM is available at [https://codedocs.xyz/ROCmSoftwar
- CMake functionality to improve build parallelism of the test suite that splits compilation units by
function or by parameters.
- Reverse iterator.
## Changed
### Changed
- Improved the performance of warp primitives using the swizzle operation on Navi
- Improved build parallelism of the test suite by splitting up large compilation units
- `device_select` now supports problem sizes larger than 2^32 items.
Expand Down
27 changes: 15 additions & 12 deletions rocprim/include/rocprim/detail/merge_path.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@

#include "../config.hpp"

#include <iterator>

BEGIN_ROCPRIM_NAMESPACE

namespace detail
Expand All @@ -46,25 +48,26 @@ struct range_t
}
};

template<class KeysInputIterator, class OffsetT, class BinaryFunction>
ROCPRIM_DEVICE ROCPRIM_INLINE OffsetT merge_path(KeysInputIterator keys_input1,
KeysInputIterator keys_input2,
const OffsetT input1_size,
const OffsetT input2_size,
const OffsetT diag,
BinaryFunction compare_function)
template<class KeysInputIterator1, class KeysInputIterator2, class OffsetT, class BinaryFunction>
ROCPRIM_DEVICE ROCPRIM_INLINE OffsetT merge_path(KeysInputIterator1 keys_input1,
KeysInputIterator2 keys_input2,
const OffsetT input1_size,
const OffsetT input2_size,
const OffsetT diag,
BinaryFunction compare_function)
{
using key_type = typename std::iterator_traits<KeysInputIterator>::value_type;
using key_type_1 = typename std::iterator_traits<KeysInputIterator1>::value_type;
using key_type_2 = typename std::iterator_traits<KeysInputIterator2>::value_type;

OffsetT begin = diag < input2_size ? 0u : diag - input2_size;
OffsetT end = min(diag, input1_size);

while(begin < end)
{
OffsetT a = (begin + end) / 2;
OffsetT b = diag - 1 - a;
key_type input_a = keys_input1[a];
key_type input_b = keys_input2[b];
OffsetT a = (begin + end) / 2;
OffsetT b = diag - 1 - a;
key_type_1 input_a = keys_input1[a];
key_type_2 input_b = keys_input2[b];
if(!compare_function(input_b, input_a))
{
begin = a + 1;
Expand Down
96 changes: 92 additions & 4 deletions test/rocprim/test_device_merge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,23 @@
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
// SOFTWARE.

#include "common_test_header.hpp"
// required test headers
#include "../common_test_header.hpp"
#include "test_utils_types.hpp"

// required rocprim headers
#include <rocprim/functional.hpp>
#include <rocprim/device/device_merge.hpp>
#include <rocprim/functional.hpp>
#include <rocprim/iterator/counting_iterator.hpp>
#include <rocprim/iterator/transform_iterator.hpp>

// required test headers
#include "test_utils_types.hpp"
#include <gtest/gtest.h>

#include <hip/hip_runtime.h>

#include <algorithm>
#include <numeric>
#include <vector>

// Params for tests
template<
Expand Down Expand Up @@ -433,3 +442,82 @@ TYPED_TEST(RocprimDeviceMergeTests, MergeKeyValue)

}
}

TEST(RocprimDeviceMergeTests, MergeMismatchedIteratorTypes)
{
const int device_id = test_common_utils::obtain_device_from_ctest();
SCOPED_TRACE(testing::Message() << "with device_id = " << device_id);
HIP_CHECK(hipSetDevice(device_id));

std::vector<int> keys_input1(1'024);
std::generate(keys_input1.begin(),
keys_input1.end(),
[n = 0]() mutable
{
const int temp = n;
n += 2;
return temp;
});

std::vector<int> expected_keys_output(2 * keys_input1.size());
std::iota(expected_keys_output.begin(), expected_keys_output.end(), 0);

int* d_keys_input1 = nullptr;
int* d_keys_output = nullptr;
HIP_CHECK(test_common_utils::hipMallocHelper(&d_keys_input1,
keys_input1.size() * sizeof(keys_input1[0])));
HIP_CHECK(
test_common_utils::hipMallocHelper(&d_keys_output,
expected_keys_output.size() * sizeof(keys_input1[0])));

HIP_CHECK(hipMemcpy(d_keys_input1,
keys_input1.data(),
keys_input1.size() * sizeof(keys_input1[0]),
hipMemcpyHostToDevice));

const auto d_keys_input2 = rocprim::make_transform_iterator(rocprim::make_counting_iterator(0),
[] __host__ __device__(int value)
{ return value * 2 + 1; });

static constexpr bool debug_synchronous = false;

size_t temp_storage_size_bytes = 0;
HIP_CHECK(rocprim::merge(nullptr,
temp_storage_size_bytes,
d_keys_input1,
d_keys_input2,
d_keys_output,
keys_input1.size(),
keys_input1.size(),
rocprim::less<int>{},
hipStreamDefault,
debug_synchronous));

ASSERT_GT(temp_storage_size_bytes, 0);

void* d_temp_storage = nullptr;
HIP_CHECK(test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size_bytes));

HIP_CHECK(rocprim::merge(d_temp_storage,
temp_storage_size_bytes,
d_keys_input1,
d_keys_input2,
d_keys_output,
keys_input1.size(),
keys_input1.size(),
rocprim::less<int>{},
hipStreamDefault,
debug_synchronous));

std::vector<int> keys_output(expected_keys_output.size());
HIP_CHECK(hipMemcpy(keys_output.data(),
d_keys_output,
keys_output.size() * sizeof(keys_output[0]),
hipMemcpyDeviceToHost));

ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(keys_output, expected_keys_output));

HIP_CHECK(hipFree(d_temp_storage));
HIP_CHECK(hipFree(d_keys_output));
HIP_CHECK(hipFree(d_keys_input1));
}

0 comments on commit d8726e2

Please sign in to comment.