diff --git a/cub/test/catch2_test_device_find_if.cu b/cub/test/catch2_test_device_find_if.cu index b92030c9eaa..2f7a3175049 100644 --- a/cub/test/catch2_test_device_find_if.cu +++ b/cub/test/catch2_test_device_find_if.cu @@ -36,10 +36,14 @@ #include #include +#include + #include "c2h/custom_type.cuh" #include "catch2_test_device_reduce.cuh" #include "catch2_test_helper.h" #include "catch2_test_launch_helper.h" +#include "thrust/detail/raw_pointer_cast.h" +#include #include // %PARAM% TEST_LAUNCH lid 0:1 @@ -72,15 +76,20 @@ void compute_find_if_reference(InputIt first, InputIt last, OutputIt& result, Bi } template -struct equals_2 +struct equals { + T val; + equals(T _val) + : val(_val) + {} + __device__ __host__ bool operator()(T i) { - return i == 2; + return i == val; } }; -CUB_TEST("Device find if works", "[device]", full_type_list) +CUB_TEST("Device find_if works", "[device]", full_type_list) { using params = params_t; using input_t = typename params::item_t; @@ -115,14 +124,15 @@ CUB_TEST("Device find if works", "[device]", full_type_list) } auto d_in_it = thrust::raw_pointer_cast(in_items.data()); - SECTION("find if") + SECTION("Generic find if case") { - using op_t = equals_2; + using op_t = equals; + input_t val_to_find{2}; // Prepare verification data c2h::host_vector host_items(in_items); c2h::host_vector expected_result(1); - compute_find_if_reference(host_items.begin(), host_items.end(), expected_result[0], op_t{}); + compute_find_if_reference(host_items.begin(), host_items.end(), expected_result[0], op_t{val_to_find}); void* d_temp_storage = nullptr; size_t temp_storage_bytes{}; @@ -132,13 +142,116 @@ CUB_TEST("Device find if works", "[device]", full_type_list) auto d_out_it = thrust::raw_pointer_cast(out_result.data()); cub::DeviceFind::FindIf( - d_temp_storage, temp_storage_bytes, unwrap_it(d_in_it), unwrap_it(d_out_it), op_t{}, num_items); + d_temp_storage, temp_storage_bytes, unwrap_it(d_in_it), unwrap_it(d_out_it), op_t{val_to_find}, num_items); + + thrust::device_vector temp_storage(temp_storage_bytes); + d_temp_storage = thrust::raw_pointer_cast(temp_storage.data()); + + cub::DeviceFind::FindIf( + d_temp_storage, temp_storage_bytes, unwrap_it(d_in_it), unwrap_it(d_out_it), op_t{val_to_find}, num_items); + + // Verify result + REQUIRE(expected_result == out_result); + } + + SECTION("find_if works with non raw pointers - .begin() iterator") + { + using op_t = equals; + input_t val_to_find{2}; + + // Prepare verification data + c2h::host_vector host_items(in_items); + c2h::host_vector expected_result(1); + compute_find_if_reference(host_items.begin(), host_items.end(), expected_result[0], op_t{val_to_find}); + + void* d_temp_storage = nullptr; + size_t temp_storage_bytes{}; + + // Run test + c2h::device_vector out_result(1); + + cub::DeviceFind::FindIf( + d_temp_storage, temp_storage_bytes, in_items.begin(), out_result.begin(), op_t{val_to_find}, num_items); thrust::device_vector temp_storage(temp_storage_bytes); d_temp_storage = thrust::raw_pointer_cast(temp_storage.data()); cub::DeviceFind::FindIf( - d_temp_storage, temp_storage_bytes, unwrap_it(d_in_it), unwrap_it(d_out_it), op_t{}, num_items); + d_temp_storage, temp_storage_bytes, in_items.begin(), out_result.begin(), op_t{val_to_find}, num_items); + + // Verify result + REQUIRE(expected_result == out_result); + } + + SECTION("find_if works for unaligned input") + { + for (int offset = 1; offset < 4; ++offset) + { + if (num_items - offset > 0) + { + using op_t = equals; + input_t val_to_find{2}; + + // Prepare verification data + c2h::host_vector host_items(in_items); + c2h::host_vector expected_result(1); + compute_find_if_reference(host_items.begin() + offset, host_items.end(), expected_result[0], op_t{val_to_find}); + + void* d_temp_storage = nullptr; + size_t temp_storage_bytes{}; + + // Run test + c2h::device_vector out_result(1); + auto d_out_it = thrust::raw_pointer_cast(out_result.data()); + + cub::DeviceFind::FindIf( + d_temp_storage, + temp_storage_bytes, + unwrap_it(d_in_it + offset), + unwrap_it(d_out_it), + op_t{val_to_find}, + num_items - offset); + + thrust::device_vector temp_storage(temp_storage_bytes); + d_temp_storage = thrust::raw_pointer_cast(temp_storage.data()); + + cub::DeviceFind::FindIf( + d_temp_storage, + temp_storage_bytes, + unwrap_it(d_in_it + offset), + unwrap_it(d_out_it), + op_t{val_to_find}, + num_items - offset); + + // Verify result + REQUIRE(expected_result == out_result); + } + } + } + + SECTION("find_if works with non primitive iterator") + { + using op_t = equals; + input_t val_to_find{2}; + + // Prepare verification data + auto it = thrust::make_counting_iterator(0); // non-primitive iterator + c2h::host_vector expected_result(1); + compute_find_if_reference(it, it + num_items, expected_result[0], op_t{val_to_find}); + + void* d_temp_storage = nullptr; + size_t temp_storage_bytes{}; + + // Run test + c2h::device_vector out_result(1); + auto d_out_it = thrust::raw_pointer_cast(out_result.data()); + + cub::DeviceFind::FindIf(d_temp_storage, temp_storage_bytes, it, unwrap_it(d_out_it), op_t{val_to_find}, num_items); + + thrust::device_vector temp_storage(temp_storage_bytes); + d_temp_storage = thrust::raw_pointer_cast(temp_storage.data()); + + cub::DeviceFind::FindIf(d_temp_storage, temp_storage_bytes, it, unwrap_it(d_out_it), op_t{val_to_find}, num_items); // Verify result REQUIRE(expected_result == out_result);