Skip to content

Commit

Permalink
Fix compilation error with ALICE (#646)
Browse files Browse the repository at this point in the history
The ALICE hpc app was experiencing a compile-time error related to the device_find_first_of kernels.
This change fixes the problem by wrapping the kernels in a struct and making them static.
  • Loading branch information
umfranzw authored Nov 15, 2024
1 parent 16212a6 commit c6b1468
Showing 1 changed file with 98 additions and 91 deletions.
189 changes: 98 additions & 91 deletions rocprim/include/rocprim/device/device_find_first_of.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,125 +58,128 @@ namespace detail
} \
while(0)

template<class T>
ROCPRIM_KERNEL
void init_find_first_of_kernel(T* output, T size, ordered_block_id<T> ordered_bid)
template<class Config, class InputIterator1, class InputIterator2, class BinaryFunction>
struct find_first_of_impl_kernels
{
*output = size;
ordered_bid.reset();
}
template<class T>
static ROCPRIM_KERNEL
void init_find_first_of_kernel(T* output, T size, ordered_block_id<T> ordered_bid)
{
*output = size;
ordered_bid.reset();
}

template<class Config, class InputIterator1, class InputIterator2, class BinaryFunction>
ROCPRIM_KERNEL
static ROCPRIM_KERNEL
#ifndef DOXYGEN_DOCUMENTATION_BUILD
__launch_bounds__(device_params<Config>().kernel_config.block_size)
__launch_bounds__(device_params<Config>().kernel_config.block_size)
#endif
void find_first_of_kernel(InputIterator1 input,
InputIterator2 keys,
size_t* output,
size_t size,
size_t keys_size,
ordered_block_id<size_t> ordered_bid,
BinaryFunction compare_function)
{
constexpr find_first_of_config_params params = device_params<Config>();
void find_first_of_kernel(InputIterator1 input,
InputIterator2 keys,
size_t* output,
size_t size,
size_t keys_size,
ordered_block_id<size_t> ordered_bid,
BinaryFunction compare_function)
{
constexpr find_first_of_config_params params = device_params<Config>();

constexpr unsigned int block_size = params.kernel_config.block_size;
constexpr unsigned int items_per_thread = params.kernel_config.items_per_thread;
constexpr unsigned int items_per_block = block_size * items_per_thread;
constexpr unsigned int identity = std::numeric_limits<unsigned int>::max();
constexpr unsigned int block_size = params.kernel_config.block_size;
constexpr unsigned int items_per_thread = params.kernel_config.items_per_thread;
constexpr unsigned int items_per_block = block_size * items_per_thread;
constexpr unsigned int identity = std::numeric_limits<unsigned int>::max();

using type = typename std::iterator_traits<InputIterator1>::value_type;
using key_type = typename std::iterator_traits<InputIterator2>::value_type;
using type = typename std::iterator_traits<InputIterator1>::value_type;
using key_type = typename std::iterator_traits<InputIterator2>::value_type;

const unsigned int thread_id = ::rocprim::detail::block_thread_id<0>();
const unsigned int thread_id = ::rocprim::detail::block_thread_id<0>();

ROCPRIM_SHARED_MEMORY struct
{
unsigned int block_first_index;
size_t global_first_index;
ROCPRIM_SHARED_MEMORY struct
{
unsigned int block_first_index;
size_t global_first_index;

typename decltype(ordered_bid)::storage_type ordered_bid;
} storage;
typename decltype(ordered_bid)::storage_type ordered_bid;
} storage;

if(thread_id == 0)
{
storage.block_first_index = identity;
}
syncthreads();

while(true)
{
if(thread_id == 0)
{
storage.global_first_index = atomic_load(output);
storage.block_first_index = identity;
}
const size_t block_id = ordered_bid.get(thread_id, storage.ordered_bid);
const size_t block_offset = block_id * items_per_block;
// ordered_bid.get() calls syncthreads(), it is safe to read global_first_index
syncthreads();

// Exit if all input has been processed or one of previous blocks has found a match
if(block_offset >= storage.global_first_index)
while(true)
{
break;
}
if(thread_id == 0)
{
storage.global_first_index = atomic_load(output);
}
const size_t block_id = ordered_bid.get(thread_id, storage.ordered_bid);
const size_t block_offset = block_id * items_per_block;
// ordered_bid.get() calls syncthreads(), it is safe to read global_first_index

unsigned int thread_first_index = identity;
// Exit if all input has been processed or one of previous blocks has found a match
if(block_offset >= storage.global_first_index)
{
break;
}

if(block_offset + items_per_block <= size)
{
type items[items_per_thread];
block_load_direct_striped<block_size>(thread_id, input + block_offset, items);
for(size_t key_index = 0; key_index < keys_size; ++key_index)
unsigned int thread_first_index = identity;

if(block_offset + items_per_block <= size)
{
const key_type key = keys[key_index];
ROCPRIM_UNROLL
for(unsigned int i = 0; i < items_per_thread; ++i)
type items[items_per_thread];
block_load_direct_striped<block_size>(thread_id, input + block_offset, items);
for(size_t key_index = 0; key_index < keys_size; ++key_index)
{
if(compare_function(key, items[i]))
{
thread_first_index = min(thread_first_index, i);
}
const key_type key = keys[key_index];
ROCPRIM_UNROLL
for(unsigned int i = 0; i < items_per_thread; ++i)
{
if(compare_function(key, items[i]))
{
thread_first_index = min(thread_first_index, i);
}
}
}
}
}
else
{
const unsigned int valid = size - block_offset;

type items[items_per_thread];
block_load_direct_striped<block_size>(thread_id, input + block_offset, items, valid);
for(size_t key_index = 0; key_index < keys_size; ++key_index)
else
{
const key_type key = keys[key_index];
ROCPRIM_UNROLL
for(unsigned int i = 0; i < items_per_thread; ++i)
const unsigned int valid = size - block_offset;

type items[items_per_thread];
block_load_direct_striped<block_size>(thread_id, input + block_offset, items, valid);
for(size_t key_index = 0; key_index < keys_size; ++key_index)
{
if(i * block_size + thread_id < valid && compare_function(key, items[i]))
{
thread_first_index = min(thread_first_index, i);
}
const key_type key = keys[key_index];
ROCPRIM_UNROLL
for(unsigned int i = 0; i < items_per_thread; ++i)
{
if(i * block_size + thread_id < valid && compare_function(key, items[i]))
{
thread_first_index = min(thread_first_index, i);
}
}
}
}
}

if(thread_first_index != identity)
{
// This happens to some blocks rarely so it is not beneficial to avoid atomic conflicts
// with block_reduce which needs to be computed even if no threads have a match.
atomic_min(&storage.block_first_index, thread_first_index * block_size + thread_id);
}
syncthreads();
if(storage.block_first_index != identity)
{
if(thread_id == 0)
if(thread_first_index != identity)
{
atomic_min(output, block_offset + storage.block_first_index);
// This happens to some blocks rarely so it is not beneficial to avoid atomic conflicts
// with block_reduce which needs to be computed even if no threads have a match.
atomic_min(&storage.block_first_index, thread_first_index * block_size + thread_id);
}
syncthreads();
if(storage.block_first_index != identity)
{
if(thread_id == 0)
{
atomic_min(output, block_offset + storage.block_first_index);
}
break;
}
break;
}
}
}
};

template<class Config,
class InputIterator1,
Expand All @@ -197,6 +200,8 @@ hipError_t find_first_of_impl(void* temporary_storage,
{
using type = typename std::iterator_traits<InputIterator1>::value_type;
using config = wrapped_find_first_of_config<Config, type>;
using find_first_of_kernels
= find_first_of_impl_kernels<config, InputIterator1, InputIterator2, BinaryFunction>;

target_arch target_arch;
hipError_t result = host_target_arch(stream, target_arch);
Expand Down Expand Up @@ -238,12 +243,14 @@ hipError_t find_first_of_impl(void* temporary_storage,
{
start = std::chrono::steady_clock::now();
}
init_find_first_of_kernel<<<1, 1, 0, stream>>>(tmp_output, size, ordered_bid);
find_first_of_kernels::init_find_first_of_kernel<<<1, 1, 0, stream>>>(tmp_output,
size,
ordered_bid);
ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("init_find_first_of_kernel", 1, start);

if(size > 0 && keys_size > 0)
{
auto kernel = find_first_of_kernel<config, InputIterator1, InputIterator2, BinaryFunction>;
auto kernel = find_first_of_kernels::find_first_of_kernel;

const size_t shared_memory_size = 0;

Expand Down

0 comments on commit c6b1468

Please sign in to comment.