Skip to content

Commit

Permalink
Fix compilation error with ALICE (#645)
Browse files Browse the repository at this point in the history
* Fix compilation error with ALICE

* Fix device_adjacent_find tor ALICE too
  • Loading branch information
NB4444 authored Nov 21, 2024
1 parent 3557be9 commit d120db6
Show file tree
Hide file tree
Showing 5 changed files with 454 additions and 401 deletions.
192 changes: 99 additions & 93 deletions rocprim/include/rocprim/device/detail/device_adjacent_find.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,123 +32,129 @@ BEGIN_ROCPRIM_NAMESPACE

namespace detail
{
namespace adjacent_find
{
template<class OutputT, class IdT>
ROCPRIM_KERNEL __launch_bounds__(1)
void init_adjacent_find(OutputT* reduce_output,
ordered_block_id<IdT> ordered_tile_id,
const size_t size)
{
// Reset output value.
*reduce_output = size;

// Reset ordered_block_id.
ordered_tile_id.reset();
}

template<typename Config,
typename TransformedInputIterator,
typename ReduceIndexIterator,
typename BinaryPred,
typename OrderedTileIdType>
struct adjacent_find_impl_kernels
{
template<class OutputT, class IdT>
static
ROCPRIM_KERNEL __launch_bounds__(1)
void init_adjacent_find(OutputT* reduce_output,
ordered_block_id<IdT> ordered_tile_id,
const size_t size)
{
// Reset output value.
*reduce_output = size;

// Reset ordered_block_id.
ordered_tile_id.reset();
}

static
ROCPRIM_KERNEL
#ifndef DOXYGEN_DOCUMENTATION_BUILD
__launch_bounds__(device_params<Config>().kernel_config.block_size)
#endif
void block_reduce_kernel(TransformedInputIterator transformed_input,
ReduceIndexIterator reduce_output,
const std::size_t size,
BinaryPred op,
OrderedTileIdType ordered_tile_id)
{
static constexpr adjacent_find_config_params params = device_params<Config>();
static constexpr unsigned int block_size = params.kernel_config.block_size;
static constexpr unsigned int items_per_thread = params.kernel_config.items_per_thread;
static constexpr unsigned int items_per_tile = block_size * items_per_thread;

using transformed_input_type =
typename std::iterator_traits<TransformedInputIterator>::value_type;
using block_reduce_type = ::rocprim::block_reduce<
transformed_input_type,
block_size,
block_reduce_algorithm::raking_reduce>;

ROCPRIM_SHARED_MEMORY union
void block_reduce_kernel(TransformedInputIterator transformed_input,
ReduceIndexIterator reduce_output,
const std::size_t size,
BinaryPred op,
OrderedTileIdType ordered_tile_id)
{
typename decltype(ordered_tile_id)::storage_type tile_id;
std::size_t global_reduce_output;
} storage;
static constexpr adjacent_find_config_params params = device_params<Config>();
static constexpr unsigned int block_size = params.kernel_config.block_size;
static constexpr unsigned int items_per_thread = params.kernel_config.items_per_thread;
static constexpr unsigned int items_per_tile = block_size * items_per_thread;

using transformed_input_type =
typename std::iterator_traits<TransformedInputIterator>::value_type;
using block_reduce_type = ::rocprim::block_reduce<
transformed_input_type,
block_size,
block_reduce_algorithm::raking_reduce>; // TODO?: params.block_reduce_method>;

ROCPRIM_SHARED_MEMORY union
{
typename decltype(ordered_tile_id)::storage_type tile_id;
std::size_t global_reduce_output;
} storage;

// Get initial tile id
const unsigned int thread_id = threadIdx.x;
std::size_t tile_offset = ordered_tile_id.get(threadIdx.x, storage.tile_id) * items_per_tile;
// Get initial tile id
const unsigned int thread_id = threadIdx.x;
std::size_t tile_offset
= ordered_tile_id.get(threadIdx.x, storage.tile_id) * items_per_tile;

while(tile_offset < size)
{
// First thread of each block loads the latest global adjacent index found
if(thread_id == 0)
while(tile_offset < size)
{
storage.global_reduce_output = atomic_load(reduce_output);
}
syncthreads();
// First thread of each block loads the latest global adjacent index found
if(thread_id == 0)
{
storage.global_reduce_output = atomic_load(reduce_output);
}
syncthreads();

// Early exit if a previous block or tile found an adjacent pair
if(storage.global_reduce_output < tile_offset)
{
return;
}
// Early exit if a previous block or tile found an adjacent pair
if(storage.global_reduce_output < tile_offset)
{
return;
}

// Do block reduction
transformed_input_type transformed_input_values[items_per_thread];
transformed_input_type output_value;
// Do block reduction
transformed_input_type transformed_input_values[items_per_thread];
transformed_input_type output_value;

if(tile_offset + items_per_tile > size_t{size - 1}) /* Last incomplete processing */
{
const std::size_t valid_in_last_iteration = size - 1 - tile_offset;
block_load_direct_striped<block_size>(thread_id,
transformed_input + tile_offset,
transformed_input_values,
valid_in_last_iteration);

// Thread reductions with boundary check
output_value = transformed_input_values[0];
ROCPRIM_UNROLL
for(unsigned int i = 1; i < items_per_thread; i++)
if(tile_offset + items_per_tile > size_t{size - 1}) /* Last incomplete processing */
{
if(thread_id + i * block_size < valid_in_last_iteration)
const std::size_t valid_in_last_iteration = size - 1 - tile_offset;
block_load_direct_striped<block_size>(thread_id,
transformed_input + tile_offset,
transformed_input_values,
valid_in_last_iteration);

// Thread reductions with boundary check
output_value = transformed_input_values[0];
ROCPRIM_UNROLL
for(unsigned int i = 1; i < items_per_thread; i++)
{
output_value = op(output_value, transformed_input_values[i]);
if(thread_id + i * block_size < valid_in_last_iteration)
{
output_value = op(output_value, transformed_input_values[i]);
}
}
// Reduce thread reductions
block_reduce_type().reduce(
output_value, // input
output_value, // output
std::min(valid_in_last_iteration, std::size_t{block_size}),
op);
}
else /* Complete processings */
{
block_load_direct_striped<block_size>(thread_id,
transformed_input + tile_offset,
transformed_input_values);
block_reduce_type().reduce(transformed_input_values, // input
output_value, // output
op);
}
// Reduce thread reductions
block_reduce_type().reduce(output_value, // input
output_value, // output
std::min(valid_in_last_iteration, std::size_t{block_size}),
op);
}
else /* Complete processings */
{
block_load_direct_striped<block_size>(thread_id,
transformed_input + tile_offset,
transformed_input_values);
block_reduce_type().reduce(transformed_input_values, // input
output_value, // output
op);
}

// Save reduction's index into output if an adjacent pair is found
if(thread_id == 0 && output_value < size)
{
// Store global minimum
atomic_min(reduce_output, output_value);
}
// Save reduction's index into output if an adjacent pair is found
if(thread_id == 0 && output_value < size)
{
// Store global minimum
atomic_min(reduce_output, output_value);
}

// Get next tile's id
tile_offset = ordered_tile_id.get(threadIdx.x, storage.tile_id) * items_per_tile;
// Get next tile's id
tile_offset = ordered_tile_id.get(threadIdx.x, storage.tile_id) * items_per_tile;
}
}
}
} // namespace adjacent_find
};

} // namespace detail

END_ROCPRIM_NAMESPACE
Expand Down
Loading

0 comments on commit d120db6

Please sign in to comment.