Skip to content

Commit

Permalink
use template parameter for input data
Browse files Browse the repository at this point in the history
  • Loading branch information
fbusato committed Jan 7, 2025
1 parent eb6df1d commit c0be178
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 11 deletions.
23 changes: 13 additions & 10 deletions libcudacxx/include/cuda/__ptx/instructions/shfl_sync.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

# include <cuda/__ptx/instructions/get_sreg.h>
# include <cuda/__ptx/ptx_dot_variants.h>
# include <cuda/std/__bit/bit_cast.h>
# include <cuda/std/__type_traits/is_integral.h>
# include <cuda/std/__type_traits/is_signed.h>
# include <cuda/std/cstdint>
Expand Down Expand Up @@ -74,65 +75,67 @@ struct shfl_return_values
bool pred;
};

template <dot_shfl_mode _ShuffleMode>
template <typename _Tp, dot_shfl_mode _ShuffleMode>
_CCCL_NODISCARD _CCCL_DEVICE static inline shfl_return_values shfl_sync(
shfl_mode_t<_ShuffleMode> __shfl_mode,
_CUDA_VSTD::uint32_t __data,
_Tp __data,
_CUDA_VSTD::uint32_t __lane_idx_offset,
_CUDA_VSTD::uint32_t __clamp_segmask,
_CUDA_VSTD::uint32_t __lane_mask) noexcept
{
static_assert(sizeof(_Tp) == 4, "shfl.sync only accepts 4-byte data types");
_CCCL_ASSERT(__lane_idx_offset < 32, "the lane index or offset must be less than the warp size");
_CCCL_ASSERT((__clamp_segmask | 0b1111100011111) == 0b1111100011111,
"clamp value + segmentation mask must be less or equal than 12 bits");
_CCCL_ASSERT((__lane_mask & __activemask()) == __lane_mask, "lane mask must be a subset of the active mask");
_CCCL_ASSERT(__shfl_sync_dst_lane(__shfl_mode, __lane_idx_offset, __clamp_segmask, __lane_mask) & __lane_mask,
"the destination lane must be a member of the lane mask");
auto __data1 = _CUDA_VSTD::bit_cast<_CUDA_VSTD::uint32_t>(__data);
_CUDA_VSTD::int32_t __pred;
_CUDA_VSTD::uint32_t __ret;
if constexpr (__shfl_mode == shfl_mode_idx)
{
asm volatile(
"{ \n\t\t"
".reg .pred p; \n\t\t"
"shfl_sync.sync.idx.b32 %0|p, %2, %3, %4, %5; \n\t\t"
"shfl.sync.sync.idx.b32 %0|p, %2, %3, %4, %5; \n\t\t"
"selp.s32 %1, 1, 0, p; \n\t"
"}"
: "=r"(__ret), "=r"(__pred)
: "r"(__data), "r"(__lane_idx_offset), "r"(__clamp_segmask), "r"(__lane_mask));
: "r"(__data1), "r"(__lane_idx_offset), "r"(__clamp_segmask), "r"(__lane_mask));
}
else if constexpr (__shfl_mode == shfl_mode_up)
{
asm volatile(
"{ \n\t\t"
".reg .pred p; \n\t\t"
"shfl_sync.sync.up.b32 %0|p, %2, %3, %4, %5; \n\t\t"
"shfl.sync.sync.up.b32 %0|p, %2, %3, %4, %5; \n\t\t"
"selp.s32 %1, 1, 0, p; \n\t"
"}"
: "=r"(__ret), "=r"(__pred)
: "r"(__data), "r"(__lane_idx_offset), "r"(__clamp_segmask), "r"(__lane_mask));
: "r"(__data1), "r"(__lane_idx_offset), "r"(__clamp_segmask), "r"(__lane_mask));
}
else if constexpr (__shfl_mode == shfl_mode_down)
{
asm volatile(
"{ \n\t\t"
".reg .pred p; \n\t\t"
"shfl_sync.sync.down.b32 %0|p, %2, %3, %4, %5; \n\t\t"
"shfl.sync.sync.down.b32 %0|p, %2, %3, %4, %5; \n\t\t"
"selp.s32 %1, 1, 0, p; \n\t"
"}"
: "=r"(__ret), "=r"(__pred)
: "r"(__data), "r"(__lane_idx_offset), "r"(__clamp_segmask), "r"(__lane_mask));
: "r"(__data1), "r"(__lane_idx_offset), "r"(__clamp_segmask), "r"(__lane_mask));
}
else
{
asm volatile(
"{ \n\t\t"
".reg .pred p; \n\t\t"
"shfl_sync.sync.bfly.b32 %0|p, %2, %3, %4, %5; \n\t\t"
"shfl.sync.sync.bfly.b32 %0|p, %2, %3, %4, %5; \n\t\t"
"selp.s32 %1, 1, 0, p; \n\t"
"}"
: "=r"(__ret), "=r"(__pred)
: "r"(__data), "r"(__lane_idx_offset), "r"(__clamp_segmask), "r"(__lane_mask));
: "r"(__data1), "r"(__lane_idx_offset), "r"(__clamp_segmask), "r"(__lane_mask));
}
return shfl_return_values{__ret, static_cast<bool>(__pred)};
}
Expand Down
2 changes: 1 addition & 1 deletion libcudacxx/include/cuda/ptx
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@
#include <cuda/__ptx/instructions/mbarrier_init.h>
#include <cuda/__ptx/instructions/mbarrier_wait.h>
#include <cuda/__ptx/instructions/red_async.h>
#include <cuda/__ptx/instructions/shfl.h>
#include <cuda/__ptx/instructions/shfl_sync.h>
#include <cuda/__ptx/instructions/st_async.h>
#include <cuda/__ptx/instructions/tensormap_cp_fenceproxy.h>
#include <cuda/__ptx/instructions/tensormap_replace.h>
Expand Down

0 comments on commit c0be178

Please sign in to comment.