Skip to content

Commit

Permalink
replace voting instructions
Browse files Browse the repository at this point in the history
  • Loading branch information
fbusato committed Jan 13, 2025
1 parent 8d44adb commit 96cf032
Show file tree
Hide file tree
Showing 15 changed files with 56 additions and 52 deletions.
6 changes: 3 additions & 3 deletions cub/cub/agent/agent_radix_sort_onesweep.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ struct AgentRadixSortOnesweep
} while (value_j == 0);

inc_sum += value_j & LOOKBACK_VALUE_MASK;
want_mask = WARP_BALLOT((value_j & LOOKBACK_GLOBAL_MASK) == 0, want_mask);
want_mask = __ballot_sync((value_j & LOOKBACK_GLOBAL_MASK) == 0, want_mask);
if (value_j & LOOKBACK_GLOBAL_MASK)
{
break;
Expand Down Expand Up @@ -484,7 +484,7 @@ struct AgentRadixSortOnesweep
{
d_keys_out[global_idx] = Twiddle::Out(key, decomposer);
}
WARP_SYNC(WARP_MASK);
__syncwarp(WARP_MASK);
}
}

Expand All @@ -502,7 +502,7 @@ struct AgentRadixSortOnesweep
{
d_values_out[global_idx] = value;
}
WARP_SYNC(WARP_MASK);
__syncwarp(WARP_MASK);
}
}

Expand Down
2 changes: 1 addition & 1 deletion cub/cub/agent/agent_rle.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -625,7 +625,7 @@ struct AgentRle
WarpExchangeOffsets(temp_storage.aliasable.scatter_aliasable.exchange_offsets[warp_id])
.ScatterToStriped(run_offsets, thread_num_runs_exclusive_in_warp);

WARP_SYNC(0xffffffff);
__syncwarp(0xffffffff);

WarpExchangeLengths(temp_storage.aliasable.scatter_aliasable.exchange_lengths[warp_id])
.ScatterToStriped(run_lengths, thread_num_runs_exclusive_in_warp);
Expand Down
8 changes: 4 additions & 4 deletions cub/cub/agent/agent_sub_warp_merge_sort.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -233,23 +233,23 @@ public:
KeyT oob_default = AgentSubWarpSort::get_oob_default(Int2Type<std::is_same<bool, KeyT>::value>{});

WarpLoadKeysT(storage.load_keys).Load(keys_input, keys, segment_size, oob_default);
WARP_SYNC(warp_merge_sort.get_member_mask());
__syncwarp(warp_merge_sort.get_member_mask());

if (!KEYS_ONLY)
{
WarpLoadItemsT(storage.load_items).Load(values_input, values, segment_size);

WARP_SYNC(warp_merge_sort.get_member_mask());
__syncwarp(warp_merge_sort.get_member_mask());
}

warp_merge_sort.Sort(keys, values, BinaryOpT{}, segment_size, oob_default);
WARP_SYNC(warp_merge_sort.get_member_mask());
__syncwarp(warp_merge_sort.get_member_mask());

WarpStoreKeysT(storage.store_keys).Store(keys_output, keys, segment_size);

if (!KEYS_ONLY)
{
WARP_SYNC(warp_merge_sort.get_member_mask());
__syncwarp(warp_merge_sort.get_member_mask());
WarpStoreItemsT(storage.store_items).Store(values_output, values, segment_size);
}
}
Expand Down
8 changes: 4 additions & 4 deletions cub/cub/agent/single_pass_scan_operators.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -733,7 +733,7 @@ public:
tile_descriptor = reinterpret_cast<TileDescriptor&>(alias);
}

while (WARP_ANY((tile_descriptor.status == SCAN_TILE_INVALID), 0xffffffff))
while (__any_sync((tile_descriptor.status == SCAN_TILE_INVALID), 0xffffffff))
{
delay_or_prevent_hoisting();
TxnWord alias = LoadStatus<Order>(d_tile_descriptors + TILE_STATUS_PADDING + tile_idx);
Expand Down Expand Up @@ -918,7 +918,7 @@ struct ScanTileState<T, false>
delay();
status = detail::load_relaxed(d_tile_status + TILE_STATUS_PADDING + tile_idx);
__threadfence();
} while (WARP_ANY((status == SCAN_TILE_INVALID), 0xffffffff));
} while (__any_sync((status == SCAN_TILE_INVALID), 0xffffffff));

if (status == StatusWord(SCAN_TILE_PARTIAL))
{
Expand Down Expand Up @@ -1145,7 +1145,7 @@ struct ReduceByKeyScanTileState<ValueT, KeyT, true>
TxnWord alias = detail::load_relaxed(d_tile_descriptors + TILE_STATUS_PADDING + tile_idx);
tile_descriptor = reinterpret_cast<TileDescriptor&>(alias);

} while (WARP_ANY((tile_descriptor.status == SCAN_TILE_INVALID), 0xffffffff));
} while (__any_sync((tile_descriptor.status == SCAN_TILE_INVALID), 0xffffffff));

status = tile_descriptor.status;
value.value = tile_descriptor.value;
Expand Down Expand Up @@ -1268,7 +1268,7 @@ struct TilePrefixCallbackOp
exclusive_prefix = window_aggregate;

// Keep sliding the window back until we come across a tile whose inclusive prefix is known
while (WARP_ALL((predecessor_status != StatusWord(SCAN_TILE_INCLUSIVE)), 0xffffffff))
while (__all_sync((predecessor_status != StatusWord(SCAN_TILE_INCLUSIVE)), 0xffffffff))
{
predecessor_idx -= CUB_PTX_WARP_THREADS;

Expand Down
10 changes: 5 additions & 5 deletions cub/cub/block/block_exchange.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ private:
detail::uninitialized_copy_single(temp_storage.buff + item_offset, input_items[i]);
}

WARP_SYNC(0xffffffff);
__syncwarp(0xffffffff);

#pragma unroll
for (int i = 0; i < ITEMS_PER_THREAD; i++)
Expand Down Expand Up @@ -363,7 +363,7 @@ private:
detail::uninitialized_copy_single(temp_storage.buff + item_offset, input_items[i]);
}

WARP_SYNC(0xffffffff);
__syncwarp(0xffffffff);

#pragma unroll
for (int i = 0; i < ITEMS_PER_THREAD; i++)
Expand Down Expand Up @@ -395,7 +395,7 @@ private:
detail::uninitialized_copy_single(temp_storage.buff + item_offset, input_items[i]);
}

WARP_SYNC(0xffffffff);
__syncwarp(0xffffffff);

#pragma unroll
for (int i = 0; i < ITEMS_PER_THREAD; i++)
Expand Down Expand Up @@ -545,7 +545,7 @@ private:
detail::uninitialized_copy_single(temp_storage.buff + item_offset, input_items[i]);
}

WARP_SYNC(0xffffffff);
__syncwarp(0xffffffff);

#pragma unroll
for (int i = 0; i < ITEMS_PER_THREAD; i++)
Expand Down Expand Up @@ -589,7 +589,7 @@ private:
detail::uninitialized_copy_single(temp_storage.buff + item_offset, input_items[i]);
}

WARP_SYNC(0xffffffff);
__syncwarp(0xffffffff);

#pragma unroll
for (int i = 0; i < ITEMS_PER_THREAD; i++)
Expand Down
12 changes: 6 additions & 6 deletions cub/cub/block/block_radix_rank.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -741,7 +741,7 @@ public:
DigitCounterT warp_digit_prefix = *digit_counters[ITEM];

// Warp-sync
WARP_SYNC(0xFFFFFFFF);
__syncwarp(0xFFFFFFFF);

// Number of peers having same digit as me
int32_t digit_count = __popc(peer_mask);
Expand All @@ -756,7 +756,7 @@ public:
}

// Warp-sync
WARP_SYNC(0xFFFFFFFF);
__syncwarp(0xFFFFFFFF);

// Number of prior keys having same digit
ranks[ITEM] = warp_digit_prefix + DigitCounterT(peer_digit_prefix);
Expand Down Expand Up @@ -978,7 +978,7 @@ struct BlockRadixRankMatchEarlyCounts
match_masks[bin] = 0;
}
}
WARP_SYNC(WARP_MASK);
__syncwarp(WARP_MASK);

// compute private per-part histograms
int part = lane % NUM_PARTS;
Expand All @@ -992,7 +992,7 @@ struct BlockRadixRankMatchEarlyCounts
// no extra work is necessary if NUM_PARTS == 1
if (NUM_PARTS > 1)
{
WARP_SYNC(WARP_MASK);
__syncwarp(WARP_MASK);
// TODO: handle RADIX_DIGITS % WARP_THREADS != 0 if it becomes necessary
constexpr int WARP_BINS_PER_THREAD = RADIX_DIGITS / WARP_THREADS;
int bins[WARP_BINS_PER_THREAD];
Expand Down Expand Up @@ -1067,7 +1067,7 @@ struct BlockRadixRankMatchEarlyCounts
::cuda::std::uint32_t bin = Digit(keys[u]);
int* p_match_mask = &match_masks[bin];
atomicOr(p_match_mask, lane_mask);
WARP_SYNC(WARP_MASK);
__syncwarp(WARP_MASK);
int bin_mask = *p_match_mask;
int leader = (WARP_THREADS - 1) - __clz(bin_mask);
int warp_offset = 0;
Expand All @@ -1082,7 +1082,7 @@ struct BlockRadixRankMatchEarlyCounts
{
*p_match_mask = 0;
}
WARP_SYNC(WARP_MASK);
__syncwarp(WARP_MASK);
ranks[u] = warp_offset + popc - 1;
}
}
Expand Down
2 changes: 1 addition & 1 deletion cub/cub/block/specializations/block_reduce_raking.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ struct BlockReduceRaking
// sync before re-using shmem (warp_storage/raking_grid are aliased)
static_assert(RAKING_THREADS <= CUB_PTX_WARP_THREADS, "RAKING_THREADS must be <= warp size.");
unsigned int mask = static_cast<unsigned int>((1ull << RAKING_THREADS) - 1);
WARP_SYNC(mask);
__syncwarp(mask);

partial = WarpReduce(temp_storage.warp_storage)
.template Reduce<(IS_FULL_TILE && RAKING_UNGUARDED)>(partial, valid_raking_threads, reduction_op);
Expand Down
4 changes: 4 additions & 0 deletions cub/cub/util_ptx.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ _CCCL_DEVICE _CCCL_FORCEINLINE int CTA_SYNC_OR(int p)
/**
* Warp barrier
*/
CCCL_DEPRECATED_BECAUSE("use __syncwarp() instead")
_CCCL_DEVICE _CCCL_FORCEINLINE void WARP_SYNC(unsigned int member_mask)
{
__syncwarp(member_mask);
Expand All @@ -225,6 +226,7 @@ _CCCL_DEVICE _CCCL_FORCEINLINE void WARP_SYNC(unsigned int member_mask)
/**
* Warp any
*/
CCCL_DEPRECATED_BECAUSE("use __any_sync() instead")
_CCCL_DEVICE _CCCL_FORCEINLINE int WARP_ANY(int predicate, unsigned int member_mask)
{
return __any_sync(member_mask, predicate);
Expand All @@ -233,6 +235,7 @@ _CCCL_DEVICE _CCCL_FORCEINLINE int WARP_ANY(int predicate, unsigned int member_m
/**
* Warp any
*/
CCCL_DEPRECATED_BECAUSE("use __all_sync() instead")
_CCCL_DEVICE _CCCL_FORCEINLINE int WARP_ALL(int predicate, unsigned int member_mask)
{
return __all_sync(member_mask, predicate);
Expand All @@ -241,6 +244,7 @@ _CCCL_DEVICE _CCCL_FORCEINLINE int WARP_ALL(int predicate, unsigned int member_m
/**
* Warp ballot
*/
CCCL_DEPRECATED_BECAUSE("use __ballot_sync() instead")
_CCCL_DEVICE _CCCL_FORCEINLINE int WARP_BALLOT(int predicate, unsigned int member_mask)
{
return __ballot_sync(member_mask, predicate);
Expand Down
6 changes: 3 additions & 3 deletions cub/cub/warp/specializations/warp_exchange_smem.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ public:
const int idx = ITEMS_PER_THREAD * lane_id + item;
temp_storage.items_shared[idx] = input_items[item];
}
WARP_SYNC(member_mask);
__syncwarp(member_mask);

for (int item = 0; item < ITEMS_PER_THREAD; item++)
{
Expand All @@ -122,7 +122,7 @@ public:
const int idx = LOGICAL_WARP_THREADS * item + lane_id;
temp_storage.items_shared[idx] = input_items[item];
}
WARP_SYNC(member_mask);
__syncwarp(member_mask);

for (int item = 0; item < ITEMS_PER_THREAD; item++)
{
Expand Down Expand Up @@ -155,7 +155,7 @@ public:
temp_storage.items_shared[ranks[ITEM]] = input_items[ITEM];
}

WARP_SYNC(member_mask);
__syncwarp(member_mask);

#pragma unroll
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ITEM++)
Expand Down
2 changes: 1 addition & 1 deletion cub/cub/warp/specializations/warp_reduce_shfl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -700,7 +700,7 @@ struct WarpReduceShfl
_CCCL_DEVICE _CCCL_FORCEINLINE T SegmentedReduce(T input, FlagT flag, ReductionOp reduction_op)
{
// Get the start flags for each thread in the warp.
int warp_flags = WARP_BALLOT(flag, member_mask);
int warp_flags = __ballot_sync(flag, member_mask);

// Convert to tail-segmented
if (HEAD_SEGMENTED)
Expand Down
14 changes: 7 additions & 7 deletions cub/cub/warp/specializations/warp_reduce_smem.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ struct WarpReduceSmem
// Share input through buffer
ThreadStore<STORE_VOLATILE>(&temp_storage.reduce[lane_id], input);

WARP_SYNC(member_mask);
__syncwarp(member_mask);

// Update input if peer_addend is in range
if ((ALL_LANES_VALID && IS_POW_OF_TWO) || ((lane_id + OFFSET) < valid_items))
Expand All @@ -170,7 +170,7 @@ struct WarpReduceSmem
input = reduction_op(input, peer_addend);
}

WARP_SYNC(member_mask);
__syncwarp(member_mask);

return ReduceStep<ALL_LANES_VALID>(input, valid_items, reduction_op, Int2Type<STEP + 1>());
}
Expand Down Expand Up @@ -224,7 +224,7 @@ struct WarpReduceSmem
SegmentedReduce(T input, FlagT flag, ReductionOp reduction_op, Int2Type<true> /*has_ballot*/)
{
// Get the start flags for each thread in the warp.
int warp_flags = WARP_BALLOT(flag, member_mask);
int warp_flags = __ballot_sync(flag, member_mask);

if (!HEAD_SEGMENTED)
{
Expand Down Expand Up @@ -257,7 +257,7 @@ struct WarpReduceSmem
// Share input into buffer
ThreadStore<STORE_VOLATILE>(&temp_storage.reduce[lane_id], input);

WARP_SYNC(member_mask);
__syncwarp(member_mask);

// Update input if peer_addend is in range
if (OFFSET + lane_id < next_flag)
Expand All @@ -266,7 +266,7 @@ struct WarpReduceSmem
input = reduction_op(input, peer_addend);
}

WARP_SYNC(member_mask);
__syncwarp(member_mask);
}

return input;
Expand Down Expand Up @@ -313,12 +313,12 @@ struct WarpReduceSmem
// Share input through buffer
ThreadStore<STORE_VOLATILE>(&temp_storage.reduce[lane_id], input);

WARP_SYNC(member_mask);
__syncwarp(member_mask);

// Get peer from buffer
T peer_addend = ThreadLoad<LOAD_VOLATILE>(&temp_storage.reduce[lane_id + OFFSET]);

WARP_SYNC(member_mask);
__syncwarp(member_mask);

// Share flag through buffer
flag_storage[lane_id] = flag_status;
Expand Down
2 changes: 1 addition & 1 deletion cub/cub/warp/specializations/warp_scan_shfl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,7 @@ struct WarpScanShfl

KeyT pred_key = ShuffleUp<LOGICAL_WARP_THREADS>(inclusive_output.key, 1, 0, member_mask);

unsigned int ballot = WARP_BALLOT((pred_key != inclusive_output.key), member_mask);
unsigned int ballot = __ballot_sync((pred_key != inclusive_output.key), member_mask);

// Mask away all lanes greater than ours
ballot = ballot & ::cuda::ptx::get_sreg_lanemask_le();
Expand Down
Loading

0 comments on commit 96cf032

Please sign in to comment.