Skip to content

Commit 9475d95

Browse files
committed
restores empty problem behaviour
1 parent fe112da commit 9475d95

File tree

1 file changed

+10
-15
lines changed

1 file changed

+10
-15
lines changed

cub/cub/device/dispatch/dispatch_select_if.cuh

+10-15
Original file line numberDiff line numberDiff line change
@@ -580,19 +580,6 @@ struct DispatchSelectIf : SelectedPolicy
580580
EqualityOpT,
581581
per_partition_offset_t,
582582
streaming_context_t>;
583-
584-
// Return for empty problem (also needed to avoid division by zero)
585-
// TODO(elstehle): In this case d_num_selected_out will never be written. Maybe we want to write it despite?
586-
if (num_items == 0)
587-
{
588-
// If this was just to query temporary storage requirements, return non-empty bytes
589-
if (d_temp_storage == nullptr)
590-
{
591-
temp_storage_bytes = std::size_t{1};
592-
}
593-
return cudaSuccess;
594-
}
595-
596583
cudaError error = cudaSuccess;
597584

598585
constexpr auto block_threads = VsmemHelperT::agent_policy_t::BLOCK_THREADS;
@@ -607,8 +594,9 @@ struct DispatchSelectIf : SelectedPolicy
607594
? static_cast<OffsetT>(partition_size)
608595
: num_items;
609596

610-
// The number of partitions required to "iterate" over the total input
611-
auto const num_partitions = ::cuda::ceil_div(num_items, max_partition_size);
597+
// The number of partitions required to "iterate" over the total input (ternary to avoid div-by-zero)
598+
auto const num_partitions =
599+
(max_partition_size == 0) ? static_cast<OffsetT>(1) : ::cuda::ceil_div(num_items, max_partition_size);
612600

613601
// The maximum number of tiles for which we will ever invoke the kernel
614602
auto const max_num_tiles_per_invocation = static_cast<OffsetT>(::cuda::ceil_div(max_partition_size, tile_size));
@@ -704,6 +692,13 @@ struct DispatchSelectIf : SelectedPolicy
704692
return error;
705693
}
706694

695+
// No more items to process (note, we do not want to return early for num_items==0, because we need to make sure
696+
// that `scan_init_kernel` has written '0' to d_num_selected_out)
697+
if (current_num_items == 0)
698+
{
699+
return cudaSuccess;
700+
}
701+
707702
// Log select_if_kernel configuration
708703
#ifdef CUB_DETAIL_DEBUG_ENABLE_LOG
709704
{

0 commit comments

Comments
 (0)