Skip to content

Commit

Permalink
Add elaborate unit testing
Browse files Browse the repository at this point in the history
  • Loading branch information
gonidelis committed Oct 1, 2024
1 parent f1c3580 commit b3e99bc
Showing 1 changed file with 121 additions and 8 deletions.
129 changes: 121 additions & 8 deletions cub/test/catch2_test_device_find_if.cu
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,14 @@
#include <thrust/device_vector.h>
#include <thrust/iterator/constant_iterator.h>

#include <iostream>

#include "c2h/custom_type.cuh"
#include "catch2_test_device_reduce.cuh"
#include "catch2_test_helper.h"
#include "catch2_test_launch_helper.h"
#include "thrust/detail/raw_pointer_cast.h"
#include <catch2/catch.hpp>
#include <nv/target>

// %PARAM% TEST_LAUNCH lid 0:1
Expand Down Expand Up @@ -72,15 +76,20 @@ void compute_find_if_reference(InputIt first, InputIt last, OutputIt& result, Bi
}

template <typename T>
struct equals_2
struct equals
{
T val;
equals(T _val)
: val(_val)
{}

__device__ __host__ bool operator()(T i)
{
return i == 2;
return i == val;
}
};

CUB_TEST("Device find if works", "[device]", full_type_list)
CUB_TEST("Device find_if works", "[device]", full_type_list)
{
using params = params_t<TestType>;
using input_t = typename params::item_t;
Expand Down Expand Up @@ -115,14 +124,15 @@ CUB_TEST("Device find if works", "[device]", full_type_list)
}
auto d_in_it = thrust::raw_pointer_cast(in_items.data());

SECTION("find if")
SECTION("Generic find if case")
{
using op_t = equals_2<std::int32_t>;
using op_t = equals<input_t>;
input_t val_to_find{2};

// Prepare verification data
c2h::host_vector<input_t> host_items(in_items);
c2h::host_vector<output_t> expected_result(1);
compute_find_if_reference(host_items.begin(), host_items.end(), expected_result[0], op_t{});
compute_find_if_reference(host_items.begin(), host_items.end(), expected_result[0], op_t{val_to_find});

void* d_temp_storage = nullptr;
size_t temp_storage_bytes{};
Expand All @@ -132,13 +142,116 @@ CUB_TEST("Device find if works", "[device]", full_type_list)
auto d_out_it = thrust::raw_pointer_cast(out_result.data());

cub::DeviceFind::FindIf(
d_temp_storage, temp_storage_bytes, unwrap_it(d_in_it), unwrap_it(d_out_it), op_t{}, num_items);
d_temp_storage, temp_storage_bytes, unwrap_it(d_in_it), unwrap_it(d_out_it), op_t{val_to_find}, num_items);

thrust::device_vector<uint8_t> temp_storage(temp_storage_bytes);
d_temp_storage = thrust::raw_pointer_cast(temp_storage.data());

cub::DeviceFind::FindIf(
d_temp_storage, temp_storage_bytes, unwrap_it(d_in_it), unwrap_it(d_out_it), op_t{val_to_find}, num_items);

// Verify result
REQUIRE(expected_result == out_result);
}

SECTION("find_if works with non raw pointers - .begin() iterator")
{
using op_t = equals<input_t>;
input_t val_to_find{2};

// Prepare verification data
c2h::host_vector<input_t> host_items(in_items);
c2h::host_vector<output_t> expected_result(1);
compute_find_if_reference(host_items.begin(), host_items.end(), expected_result[0], op_t{val_to_find});

void* d_temp_storage = nullptr;
size_t temp_storage_bytes{};

// Run test
c2h::device_vector<output_t> out_result(1);

cub::DeviceFind::FindIf(
d_temp_storage, temp_storage_bytes, in_items.begin(), out_result.begin(), op_t{val_to_find}, num_items);

thrust::device_vector<uint8_t> temp_storage(temp_storage_bytes);
d_temp_storage = thrust::raw_pointer_cast(temp_storage.data());

cub::DeviceFind::FindIf(
d_temp_storage, temp_storage_bytes, unwrap_it(d_in_it), unwrap_it(d_out_it), op_t{}, num_items);
d_temp_storage, temp_storage_bytes, in_items.begin(), out_result.begin(), op_t{val_to_find}, num_items);

// Verify result
REQUIRE(expected_result == out_result);
}

SECTION("find_if works for unaligned input")
{
for (int offset = 1; offset < 4; ++offset)
{
if (num_items - offset > 0)
{
using op_t = equals<input_t>;
input_t val_to_find{2};

// Prepare verification data
c2h::host_vector<input_t> host_items(in_items);
c2h::host_vector<output_t> expected_result(1);
compute_find_if_reference(host_items.begin() + offset, host_items.end(), expected_result[0], op_t{val_to_find});

void* d_temp_storage = nullptr;
size_t temp_storage_bytes{};

// Run test
c2h::device_vector<output_t> out_result(1);
auto d_out_it = thrust::raw_pointer_cast(out_result.data());

cub::DeviceFind::FindIf(
d_temp_storage,
temp_storage_bytes,
unwrap_it(d_in_it + offset),
unwrap_it(d_out_it),
op_t{val_to_find},
num_items - offset);

thrust::device_vector<uint8_t> temp_storage(temp_storage_bytes);
d_temp_storage = thrust::raw_pointer_cast(temp_storage.data());

cub::DeviceFind::FindIf(
d_temp_storage,
temp_storage_bytes,
unwrap_it(d_in_it + offset),
unwrap_it(d_out_it),
op_t{val_to_find},
num_items - offset);

// Verify result
REQUIRE(expected_result == out_result);
}
}
}

SECTION("find_if works with non primitive iterator")
{
using op_t = equals<input_t>;
input_t val_to_find{2};

// Prepare verification data
auto it = thrust::make_counting_iterator(0); // non-primitive iterator
c2h::host_vector<output_t> expected_result(1);
compute_find_if_reference(it, it + num_items, expected_result[0], op_t{val_to_find});

void* d_temp_storage = nullptr;
size_t temp_storage_bytes{};

// Run test
c2h::device_vector<output_t> out_result(1);
auto d_out_it = thrust::raw_pointer_cast(out_result.data());

cub::DeviceFind::FindIf(d_temp_storage, temp_storage_bytes, it, unwrap_it(d_out_it), op_t{val_to_find}, num_items);

thrust::device_vector<uint8_t> temp_storage(temp_storage_bytes);
d_temp_storage = thrust::raw_pointer_cast(temp_storage.data());

cub::DeviceFind::FindIf(d_temp_storage, temp_storage_bytes, it, unwrap_it(d_out_it), op_t{val_to_find}, num_items);

// Verify result
REQUIRE(expected_result == out_result);
Expand Down

0 comments on commit b3e99bc

Please sign in to comment.