Skip to content

Commit

Permalink
Simplify code around castIfHalfPointer
Browse files Browse the repository at this point in the history
  • Loading branch information
bernhardmgruber committed May 7, 2024
1 parent 8d8f5e4 commit 5b25b78
Showing 1 changed file with 7 additions and 19 deletions.
26 changes: 7 additions & 19 deletions cub/test/test_device_histogram.cu
Original file line number Diff line number Diff line change
Expand Up @@ -75,23 +75,11 @@ auto castIfHalfPointer(It it) -> It
return it;
}

template <typename T>
auto castIfArrayOfHalfPointers(T** p) -> T**
{
return p;
}

#if TEST_HALF_T
auto castIfHalfPointer(half_t* p) -> __half*
{
return reinterpret_cast<__half*>(p);
}

template <std::size_t N>
auto castIfArrayOfHalfPointers(half_t** a) -> __half**
{
return reinterpret_cast<__half**>(a);
}
#endif

//---------------------------------------------------------------------
Expand Down Expand Up @@ -135,7 +123,7 @@ struct Dispatch
castIfHalfPointer(d_samples),
d_histogram,
num_levels,
castIfArrayOfHalfPointers(d_levels),
castIfHalfPointer(d_levels),
num_row_pixels,
num_rows,
row_stride_bytes);
Expand Down Expand Up @@ -211,9 +199,9 @@ struct Dispatch<1, 1>
int* num_levels, ///< [in] The number of boundaries (levels) for delineating histogram samples in each active
///< channel. Implies that the number of bins for channel<sub><em>i</em></sub> is
///< <tt>num_levels[i]</tt> - 1.
LevelT* d_levels, ///< [in] The pointers to the arrays of boundaries (levels), one for each active channel. Bin
///< ranges are defined by consecutive boundary pairings: lower sample value boundaries are
///< inclusive and upper sample value boundaries are exclusive.
LevelT** d_levels, ///< [in] The pointers to the arrays of boundaries (levels), one for each active channel. Bin
///< ranges are defined by consecutive boundary pairings: lower sample value boundaries are
///< inclusive and upper sample value boundaries are exclusive.
OffsetT num_row_pixels, ///< [in] The number of multi-channel pixels per row in the region of interest
OffsetT num_rows, ///< [in] The number of rows in the region of interest
OffsetT row_stride_bytes) ///< [in] The number of bytes between starts of consecutive rows in the region of interest
Expand All @@ -226,7 +214,7 @@ struct Dispatch<1, 1>
castIfHalfPointer(d_samples),
d_histogram[0],
num_levels[0],
*castIfHalfPointer(&d_levels[0]),
castIfHalfPointer(d_levels)[0],
num_row_pixels,
num_rows,
row_stride_bytes);
Expand Down Expand Up @@ -270,8 +258,8 @@ struct Dispatch<1, 1>
castIfHalfPointer(d_samples),
d_histogram[0],
num_levels[0],
*castIfHalfPointer(&lower_level[0]),
*castIfHalfPointer(&upper_level[0]),
castIfHalfPointer(lower_level)[0],
castIfHalfPointer(upper_level)[0],
num_row_pixels,
num_rows,
row_stride_bytes);
Expand Down

0 comments on commit 5b25b78

Please sign in to comment.