Skip to content

Commit

Permalink
fix return type
Browse files Browse the repository at this point in the history
  • Loading branch information
fbusato committed Jan 7, 2025
1 parent c0be178 commit f0caaa9
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions libcudacxx/include/cuda/__ptx/instructions/shfl_sync.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,15 @@ _CCCL_DEVICE static inline _CUDA_VSTD::uint32_t __shfl_sync_dst_lane(
return (1 << __dst);
}

template <typename _Tp>
struct shfl_return_values
{
_CUDA_VSTD::uint32_t data;
_Tp data;
bool pred;
};

template <typename _Tp, dot_shfl_mode _ShuffleMode>
_CCCL_NODISCARD _CCCL_DEVICE static inline shfl_return_values shfl_sync(
_CCCL_NODISCARD _CCCL_DEVICE static inline shfl_return_values<_Tp> shfl_sync(
shfl_mode_t<_ShuffleMode> __shfl_mode,
_Tp __data,
_CUDA_VSTD::uint32_t __lane_idx_offset,
Expand All @@ -98,7 +99,7 @@ _CCCL_NODISCARD _CCCL_DEVICE static inline shfl_return_values shfl_sync(
asm volatile(
"{ \n\t\t"
".reg .pred p; \n\t\t"
"shfl.sync.sync.idx.b32 %0|p, %2, %3, %4, %5; \n\t\t"
"shfl.sync.sync.idx.b32 %0|p, %2, %3, %4, %5; \n\t\t"
"selp.s32 %1, 1, 0, p; \n\t"
"}"
: "=r"(__ret), "=r"(__pred)
Expand All @@ -109,7 +110,7 @@ _CCCL_NODISCARD _CCCL_DEVICE static inline shfl_return_values shfl_sync(
asm volatile(
"{ \n\t\t"
".reg .pred p; \n\t\t"
"shfl.sync.sync.up.b32 %0|p, %2, %3, %4, %5; \n\t\t"
"shfl.sync.sync.up.b32 %0|p, %2, %3, %4, %5; \n\t\t"
"selp.s32 %1, 1, 0, p; \n\t"
"}"
: "=r"(__ret), "=r"(__pred)
Expand All @@ -120,7 +121,7 @@ _CCCL_NODISCARD _CCCL_DEVICE static inline shfl_return_values shfl_sync(
asm volatile(
"{ \n\t\t"
".reg .pred p; \n\t\t"
"shfl.sync.sync.down.b32 %0|p, %2, %3, %4, %5; \n\t\t"
"shfl.sync.sync.down.b32 %0|p, %2, %3, %4, %5; \n\t\t"
"selp.s32 %1, 1, 0, p; \n\t"
"}"
: "=r"(__ret), "=r"(__pred)
Expand All @@ -131,13 +132,14 @@ _CCCL_NODISCARD _CCCL_DEVICE static inline shfl_return_values shfl_sync(
asm volatile(
"{ \n\t\t"
".reg .pred p; \n\t\t"
"shfl.sync.sync.bfly.b32 %0|p, %2, %3, %4, %5; \n\t\t"
"shfl.sync.sync.bfly.b32 %0|p, %2, %3, %4, %5; \n\t\t"
"selp.s32 %1, 1, 0, p; \n\t"
"}"
: "=r"(__ret), "=r"(__pred)
: "r"(__data1), "r"(__lane_idx_offset), "r"(__clamp_segmask), "r"(__lane_mask));
}
return shfl_return_values{__ret, static_cast<bool>(__pred)};
auto __ret1 = _CUDA_VSTD::bit_cast<_CUDA_VSTD::uint32_t>(__ret);
return shfl_return_values<_Tp>{__ret1, static_cast<bool>(__pred)};
}

# endif // __cccl_ptx_isa >= 600
Expand Down

0 comments on commit f0caaa9

Please sign in to comment.