Skip to content

Commit

Permalink
removes superfluous copy of predecessor keys
Browse files Browse the repository at this point in the history
  • Loading branch information
elstehle committed Sep 28, 2024
1 parent 17ea266 commit 2b7dc2a
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 30 deletions.
5 changes: 1 addition & 4 deletions cub/cub/agent/agent_scan_by_key.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,6 @@ struct AgentScanByKey

TempStorage_& storage;
WrappedKeysInputIteratorT d_keys_in;
KeyT* d_keys_prev_in;
WrappedValuesInputIteratorT d_values_in;
ValuesOutputIteratorT d_values_out;
InequalityWrapper<EqualityOp> inequality_op;
Expand Down Expand Up @@ -373,7 +372,7 @@ struct AgentScanByKey
}
else
{
KeyT tile_pred_key = (threadIdx.x == 0) ? d_keys_prev_in[tile_idx] : KeyT();
KeyT tile_pred_key = d_keys_in[tile_base - 1];

BlockDiscontinuityKeysT(storage.scan_storage.discontinuity)
.FlagHeads(segment_flags, keys, inequality_op, tile_pred_key);
Expand Down Expand Up @@ -412,15 +411,13 @@ struct AgentScanByKey
_CCCL_DEVICE _CCCL_FORCEINLINE AgentScanByKey(
TempStorage& storage,
KeysInputIteratorT d_keys_in,
KeyT* d_keys_prev_in,
ValuesInputIteratorT d_values_in,
ValuesOutputIteratorT d_values_out,
EqualityOp equality_op,
ScanOpT scan_op,
InitValueT init_value)
: storage(storage.Alias())
, d_keys_in(d_keys_in)
, d_keys_prev_in(d_keys_prev_in)
, d_values_in(d_values_in)
, d_values_out(d_values_out)
, inequality_op(equality_op)
Expand Down
33 changes: 7 additions & 26 deletions cub/cub/device/dispatch/dispatch_scan_by_key.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,6 @@ template <typename ChainedPolicyT,
__launch_bounds__(int(ChainedPolicyT::ActivePolicy::ScanByKeyPolicyT::BLOCK_THREADS))
CUB_DETAIL_KERNEL_ATTRIBUTES void DeviceScanByKeyKernel(
KeysInputIteratorT d_keys_in,
KeyT* d_keys_prev_in,
ValuesInputIteratorT d_values_in,
ValuesOutputIteratorT d_values_out,
ScanByKeyTileStateT tile_state,
Expand Down Expand Up @@ -166,28 +165,15 @@ __launch_bounds__(int(ChainedPolicyT::ActivePolicy::ScanByKeyPolicyT::BLOCK_THRE
__shared__ typename AgentScanByKeyT::TempStorage temp_storage;

// Process tiles
AgentScanByKeyT(temp_storage, d_keys_in, d_keys_prev_in, d_values_in, d_values_out, equality_op, scan_op, init_value)
AgentScanByKeyT(temp_storage, d_keys_in, d_values_in, d_values_out, equality_op, scan_op, init_value)
.ConsumeRange(num_items, tile_state, start_tile);
}

template <typename ScanTileStateT, typename KeysInputIteratorT, typename OffsetT>
CUB_DETAIL_KERNEL_ATTRIBUTES void DeviceScanByKeyInitKernel(
ScanTileStateT tile_state,
KeysInputIteratorT d_keys_in,
cub::detail::value_t<KeysInputIteratorT>* d_keys_prev_in,
OffsetT items_per_tile,
int num_tiles)
template <typename ScanTileStateT>
CUB_DETAIL_KERNEL_ATTRIBUTES void DeviceScanByKeyInitKernel(ScanTileStateT tile_state, int num_tiles)
{
// Initialize tile status
tile_state.InitializeStatus(num_tiles);

const unsigned tid = threadIdx.x + blockDim.x * blockIdx.x;
const OffsetT tile_base = static_cast<OffsetT>(tid) * items_per_tile;

if (tid > 0 && tid < num_tiles)
{
d_keys_prev_in[tid] = d_keys_in[tile_base - 1];
}
}

/******************************************************************************
Expand Down Expand Up @@ -394,18 +380,16 @@ struct DispatchScanByKey : SelectedPolicy
int num_tiles = static_cast<int>(::cuda::ceil_div(num_items, tile_size));

// Specify temporary storage allocation requirements
size_t allocation_sizes[2];
size_t allocation_sizes[1];
error = CubDebug(ScanByKeyTileStateT::AllocationSize(num_tiles, allocation_sizes[0]));
if (cudaSuccess != error)
{
break; // bytes needed for tile status descriptors
}

allocation_sizes[1] = sizeof(KeyT) * (num_tiles + 1);

// Compute allocation pointers into the single storage blob (or compute
// the necessary size of the blob)
void* allocations[2] = {};
void* allocations[1] = {};

error = CubDebug(AliasTemporaries(d_temp_storage, temp_storage_bytes, allocations, allocation_sizes));
if (cudaSuccess != error)
Expand All @@ -426,8 +410,6 @@ struct DispatchScanByKey : SelectedPolicy
break;
}

KeyT* d_keys_prev_in = reinterpret_cast<KeyT*>(allocations[1]);

// Construct the tile status interface
ScanByKeyTileStateT tile_state;
error = CubDebug(tile_state.Init(num_tiles, allocations[0], allocation_sizes[0]));
Expand All @@ -444,7 +426,7 @@ struct DispatchScanByKey : SelectedPolicy

// Invoke init_kernel to initialize tile descriptors
THRUST_NS_QUALIFIER::cuda_cub::launcher::triple_chevron(init_grid_size, INIT_KERNEL_THREADS, 0, stream)
.doit(init_kernel, tile_state, d_keys_in, d_keys_prev_in, static_cast<OffsetT>(tile_size), num_tiles);
.doit(init_kernel, tile_state, num_tiles);

// Check for failure to launch
error = CubDebug(cudaPeekAtLastError());
Expand Down Expand Up @@ -487,7 +469,6 @@ struct DispatchScanByKey : SelectedPolicy
THRUST_NS_QUALIFIER::cuda_cub::launcher::triple_chevron(scan_grid_size, Policy::BLOCK_THREADS, 0, stream)
.doit(scan_kernel,
d_keys_in,
d_keys_prev_in,
d_values_in,
d_values_out,
tile_state,
Expand Down Expand Up @@ -523,7 +504,7 @@ struct DispatchScanByKey : SelectedPolicy

// Ensure kernels are instantiated.
return Invoke<ActivePolicyT>(
DeviceScanByKeyInitKernel<ScanByKeyTileStateT, KeysInputIteratorT, OffsetT>,
DeviceScanByKeyInitKernel<ScanByKeyTileStateT>,
DeviceScanByKeyKernel<MaxPolicyT,
KeysInputIteratorT,
ValuesInputIteratorT,
Expand Down

0 comments on commit 2b7dc2a

Please sign in to comment.