Skip to content

Commit

Permalink
add a __type_switch utility and use it the ptx generator (#2946)
Browse files Browse the repository at this point in the history
  • Loading branch information
ericniebler authored Nov 27, 2024
1 parent 27d8c87 commit 83aca35
Show file tree
Hide file tree
Showing 5 changed files with 168 additions and 70 deletions.
5 changes: 4 additions & 1 deletion docs/repo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ doxygen_predefined = [
"_CCCL_FORCEINLINE",
"_CCCL_STD_VER",
"_CCCL_NODISCARD",
"_CCCL_NTTP_AUTO=auto",
"_CCCL_VISIBILITY_HIDDEN",
"_CCCL_SUPPRESS_DEPRECATED_PUSH",
"_CCCL_SUPPRESS_DEPRECATED_POP",
Expand Down Expand Up @@ -261,6 +262,7 @@ doxygen_predefined = [
"_CCCL_HOST=",
"_CCCL_HOST_DEVICE=",
"_CCCL_NODISCARD=[[nodiscard]]",
"_CCCL_NTTP_AUTO=auto",
"_CCCL_STD_VER",
"_CCCL_SUPPRESS_DEPRECATED_PUSH",
"_CCCL_SUPPRESS_DEPRECATED_POP",
Expand Down Expand Up @@ -408,6 +410,7 @@ doxygen_predefined = [
"_CCCL_CUDACC_AT_LEAST(x, y)=1",
"_CCCL_CUDACC_BELOW(x, y)=0",
"_CCCL_DEVICE=",
"_CCCL_DOXYGEN_INVOKED",
"_CCCL_EAT_REST(x)=",
"_CCCL_EXEC_CHECK_DISABLE=",
"_CCCL_FORCEINLINE=",
Expand All @@ -419,6 +422,7 @@ doxygen_predefined = [
"_CCCL_INLINE_VAR=inline",
"_CCCL_NODISCARD=[[nodiscard]]",
"_CCCL_NODISCARD_FRIEND=",
"_CCCL_NTTP_AUTO=auto",
"_CCCL_STD_VER=2020",
"_CCCL_TRAIT(x, y)=x<y>::value",
"_CUDA_VMR=cuda::mr",
Expand All @@ -443,7 +447,6 @@ doxygen_predefined = [
"_CUDAX_TRIVIAL_DEVICE_API",
"_CUDAX_PUBLIC_API",
"LIBCUDACXX_ENABLE_EXPERIMENTAL_MEMORY_RESOURCE=",
"_CCCL_DOXYGEN_INVOKED",
]

# make sure to use ./fetch_imgs.sh
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <cuda/std/__type_traits/is_floating_point.h>
#include <cuda/std/__type_traits/is_scalar.h>
#include <cuda/std/__type_traits/is_signed.h>
#include <cuda/std/__type_traits/type_list.h>
#include <cuda/std/cstddef>
#include <cuda/std/cstdint>

Expand Down Expand Up @@ -110,61 +111,50 @@ struct __atomic_longlong2

template <class _Type>
using __atomic_cuda_deduce_bitwise =
_If<sizeof(_Type) == 1,
__atomic_cuda_operand_deduction<uint8_t, __atomic_cuda_operand_b8>,
_If<sizeof(_Type) == 2,
__atomic_cuda_operand_deduction<uint16_t, __atomic_cuda_operand_b16>,
_If<sizeof(_Type) == 4,
__atomic_cuda_operand_deduction<uint32_t, __atomic_cuda_operand_b32>,
_If<sizeof(_Type) == 8,
__atomic_cuda_operand_deduction<uint64_t, __atomic_cuda_operand_b64>,
__atomic_cuda_operand_deduction<__atomic_longlong2, __atomic_cuda_operand_b128>>>>>;
__type_switch<sizeof(_Type),
__type_case<1, __atomic_cuda_operand_deduction<uint8_t, __atomic_cuda_operand_b8>>,
__type_case<2, __atomic_cuda_operand_deduction<uint16_t, __atomic_cuda_operand_b16>>,
__type_case<4, __atomic_cuda_operand_deduction<uint32_t, __atomic_cuda_operand_b32>>,
__type_case<8, __atomic_cuda_operand_deduction<uint64_t, __atomic_cuda_operand_b64>>,
__type_default<__atomic_cuda_operand_deduction<__atomic_longlong2, __atomic_cuda_operand_b128>>>;

template <class _Type>
using __atomic_cuda_deduce_arithmetic =
_If<_CCCL_TRAIT(is_floating_point, _Type),
_If<sizeof(_Type) == 4,
__atomic_cuda_operand_deduction<float, __atomic_cuda_operand_f32>,
__atomic_cuda_operand_deduction<double, __atomic_cuda_operand_f64>>,
_If<_CCCL_TRAIT(is_signed, _Type),
_If<sizeof(_Type) == 1,
__atomic_cuda_operand_deduction<int8_t, __atomic_cuda_operand_s8>,
_If<sizeof(_Type) == 2,
__atomic_cuda_operand_deduction<int16_t, __atomic_cuda_operand_s16>,
_If<sizeof(_Type) == 4,
__atomic_cuda_operand_deduction<int32_t, __atomic_cuda_operand_s32>,
__atomic_cuda_operand_deduction<int64_t, __atomic_cuda_operand_u64>>>>, // There is no
// atom.add.s64
_If<sizeof(_Type) == 1,
__atomic_cuda_operand_deduction<uint8_t, __atomic_cuda_operand_u8>,
_If<sizeof(_Type) == 2,
__atomic_cuda_operand_deduction<uint16_t, __atomic_cuda_operand_u16>,
_If<sizeof(_Type) == 4,
__atomic_cuda_operand_deduction<uint32_t, __atomic_cuda_operand_u32>,
__atomic_cuda_operand_deduction<uint64_t, __atomic_cuda_operand_u64>>>>>>;
using __atomic_cuda_deduce_arithmetic = _If<
_CCCL_TRAIT(is_floating_point, _Type),
_If<sizeof(_Type) == 4,
__atomic_cuda_operand_deduction<float, __atomic_cuda_operand_f32>,
__atomic_cuda_operand_deduction<double, __atomic_cuda_operand_f64>>,
_If<_CCCL_TRAIT(is_signed, _Type),
__type_switch<sizeof(_Type),
__type_case<1, __atomic_cuda_operand_deduction<int8_t, __atomic_cuda_operand_s8>>,
__type_case<2, __atomic_cuda_operand_deduction<int16_t, __atomic_cuda_operand_s16>>,
__type_case<4, __atomic_cuda_operand_deduction<int32_t, __atomic_cuda_operand_s32>>,
__type_default<__atomic_cuda_operand_deduction<int64_t, __atomic_cuda_operand_u64>>>, // There is no
// atom.add.s64
__type_switch<sizeof(_Type),
__type_case<1, __atomic_cuda_operand_deduction<uint8_t, __atomic_cuda_operand_u8>>,
__type_case<2, __atomic_cuda_operand_deduction<uint16_t, __atomic_cuda_operand_u16>>,
__type_case<4, __atomic_cuda_operand_deduction<uint32_t, __atomic_cuda_operand_u32>>,
__type_default<__atomic_cuda_operand_deduction<uint64_t, __atomic_cuda_operand_u64>>>>>;

template <class _Type>
using __atomic_cuda_deduce_minmax =
_If<_CCCL_TRAIT(is_floating_point, _Type),
_If<sizeof(_Type) == 4,
__atomic_cuda_operand_deduction<float, __atomic_cuda_operand_f32>,
__atomic_cuda_operand_deduction<double, __atomic_cuda_operand_f64>>,
_If<_CCCL_TRAIT(is_signed, _Type),
_If<sizeof(_Type) == 1,
__atomic_cuda_operand_deduction<int8_t, __atomic_cuda_operand_s8>,
_If<sizeof(_Type) == 2,
__atomic_cuda_operand_deduction<int16_t, __atomic_cuda_operand_s16>,
_If<sizeof(_Type) == 4,
__atomic_cuda_operand_deduction<int32_t, __atomic_cuda_operand_s32>,
__atomic_cuda_operand_deduction<int64_t, __atomic_cuda_operand_s64>>>>, // atom.min|max.s64
// supported
_If<sizeof(_Type) == 1,
__atomic_cuda_operand_deduction<uint8_t, __atomic_cuda_operand_u8>,
_If<sizeof(_Type) == 2,
__atomic_cuda_operand_deduction<uint16_t, __atomic_cuda_operand_u16>,
_If<sizeof(_Type) == 4,
__atomic_cuda_operand_deduction<uint32_t, __atomic_cuda_operand_u32>,
__atomic_cuda_operand_deduction<uint64_t, __atomic_cuda_operand_u64>>>>>>;
using __atomic_cuda_deduce_minmax = _If<
_CCCL_TRAIT(is_floating_point, _Type),
_If<sizeof(_Type) == 4,
__atomic_cuda_operand_deduction<float, __atomic_cuda_operand_f32>,
__atomic_cuda_operand_deduction<double, __atomic_cuda_operand_f64>>,
_If<_CCCL_TRAIT(is_signed, _Type),
__type_switch<sizeof(_Type),
__type_case<1, __atomic_cuda_operand_deduction<int8_t, __atomic_cuda_operand_s8>>,
__type_case<2, __atomic_cuda_operand_deduction<int16_t, __atomic_cuda_operand_s16>>,
__type_case<4, __atomic_cuda_operand_deduction<int32_t, __atomic_cuda_operand_s32>>,
__type_default<__atomic_cuda_operand_deduction<int64_t, __atomic_cuda_operand_s64>>>, // atom.min|max.s64
// supported
__type_switch<sizeof(_Type),
__type_case<1, __atomic_cuda_operand_deduction<uint8_t, __atomic_cuda_operand_u8>>,
__type_case<2, __atomic_cuda_operand_deduction<uint16_t, __atomic_cuda_operand_u16>>,
__type_case<4, __atomic_cuda_operand_deduction<uint32_t, __atomic_cuda_operand_u32>>,
__type_default<__atomic_cuda_operand_deduction<uint64_t, __atomic_cuda_operand_u64>>>>>;

template <class _Type>
using __atomic_enable_if_native_bitwise = bool;
Expand Down
9 changes: 9 additions & 0 deletions libcudacxx/include/cuda/std/__cccl/dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,15 @@
# define _CCCL_NO_VARIABLE_TEMPLATES
#endif // _CCCL_STD_VER <= 2011

// Declaring a non-type template parameters with auto is only available from C++17 onwards
#if _CCCL_STD_VER >= 2017 && defined(__cpp_nontype_template_parameter_auto) \
&& (__cpp_nontype_template_parameter_auto >= 201606L)
# define _CCCL_NTTP_AUTO auto
#else // ^^^ C++17 ^^^ / vvv C++14 vvv
# define _CCCL_NO_NONTYPE_TEMPLATE_PARAMETER_AUTO
# define _CCCL_NTTP_AUTO unsigned long long int
#endif // _CCCL_STD_VER <= 2014

// concepts are only available from C++20 onwards
#if _CCCL_STD_VER <= 2017 || !defined(__cpp_concepts) || (__cpp_concepts < 201907L)
# define _CCCL_NO_CONCEPTS
Expand Down
112 changes: 94 additions & 18 deletions libcudacxx/include/cuda/std/__type_traits/type_list.h
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,100 @@ using __type_front = __type_at_c<0, _List>;
template <class _List>
using __type_back = __type_at_c<_List::__size - 1, _List>;

//! \brief A pair of types
template <class _First, class _Second>
struct _CCCL_TYPE_VISIBILITY_DEFAULT __type_pair
{
using __first _CCCL_NODEBUG_ALIAS = _First;
using __second _CCCL_NODEBUG_ALIAS = _Second;
};

//! \brief Retrieve the first of a pair of types
//! \pre \c _Pair is a specialization of \c __type_pair
template <class _Pair>
using __type_pair_first _CCCL_NODEBUG_ALIAS = typename _Pair::__first;

//! \brief Retrieve the second of a pair of types
//! \pre \c _Pair is a specialization of \c __type_pair
template <class _Pair>
using __type_pair_second _CCCL_NODEBUG_ALIAS = typename _Pair::__second;

//! \see __type_switch
template <class _Value>
struct _CCCL_TYPE_VISIBILITY_DEFAULT __type_default
{
template <class>
using __rebind _CCCL_NODEBUG_ALIAS = __type_default;

using type _CCCL_NODEBUG_ALIAS = _Value;
};

# if _CCCL_CUDACC_AT_LEAST(12, 0) || defined(_CCCL_DOXYGEN_INVOKED)

//! \see __type_switch
template <_CCCL_NTTP_AUTO _Label, class _Value>
struct _CCCL_TYPE_VISIBILITY_DEFAULT __type_case
{
template <class _OtherInt>
using __rebind _CCCL_NODEBUG_ALIAS = __type_case<static_cast<_OtherInt>(_Label), _Value>;

using type = _Value;
};

# else // ^^^ CUDACC >= 12.0 || DOXYGEN ^^^ / vvv CUDACC < 12.0 && !DOXYGEN vvv

template <class _Label, class _Value>
struct _CCCL_TYPE_VISIBILITY_DEFAULT __type_case_
{
template <class _OtherInt>
using __rebind _CCCL_NODEBUG_ALIAS = __type_case_<integral_constant<_OtherInt, _Label::value>, _Value>;

using type = _Value;
};

template <_CCCL_NTTP_AUTO _Label, class _Value>
using __type_case _CCCL_NODEBUG_ALIAS = __type_case_<integral_constant<decltype(_Label), _Label>, _Value>;

# endif // CUDACC < 12.0 && !DOXYGEN

namespace __detail
{
template <_CCCL_NTTP_AUTO _Label, class _Value>
_LIBCUDACXX_HIDE_FROM_ABI auto __type_switch_fn(__type_case<_Label, _Value>*, int) -> __type_case<_Label, _Value>;

template <_CCCL_NTTP_AUTO _Label, class _Value>
_LIBCUDACXX_HIDE_FROM_ABI auto __type_switch_fn(__type_default<_Value>*, long) -> __type_default<_Value>;
} // namespace __detail

//! \see __type_switch
template <class _Type, class... _Cases>
struct _CCCL_TYPE_VISIBILITY_DEFAULT _LIBCUDACXX_DECLSPEC_EMPTY_BASES __type_switch_fn
: _Cases::template __rebind<_Type>...
{
template <class _Label>
using __call _CCCL_NODEBUG_ALIAS =
__type<decltype(__detail::__type_switch_fn<_Label::value>(static_cast<__type_switch_fn*>(nullptr), 0))>;
};

//! \brief Given an integral constant \c _Label and a pack of "cases"
//! consisting of one or more specializations of \c __type_case and zero or
//! one specializations of \c __type_default, `__type_switch<_Label, _Cases...>`
//! returns the value associated with the first case whose label matches the
//! given label. If no such case exists, the value associated with the default
//! case is returned. If no default case exists, the type is ill-formed.
//!
//! \p Example:
//! \code
//! using result = __type_switch<2,
//! __type_case<1, char>,
//! __type_case<2, double>,
//! __type_default<float>>;
//! static_assert(is_same_v<result, double>);
//! \endcode
template <_CCCL_NTTP_AUTO _Label, class... _Cases>
using __type_switch _CCCL_NODEBUG_ALIAS =
__type_call<__type_switch_fn<decltype(_Label), _Cases...>, integral_constant<decltype(_Label), _Label>>;

namespace __detail
{
# if _CCCL_COMPILER(MSVC, <, 19, 38)
Expand Down Expand Up @@ -907,24 +1001,6 @@ struct _CCCL_TYPE_VISIBILITY_DEFAULT __type_sizeof
using __call _CCCL_NODEBUG_ALIAS = integral_constant<size_t, sizeof(_Ty)>;
};

//! \brief A pair of types
template <class _First, class _Second>
struct _CCCL_TYPE_VISIBILITY_DEFAULT __type_pair
{
using __first _CCCL_NODEBUG_ALIAS = _First;
using __second _CCCL_NODEBUG_ALIAS = _Second;
};

//! \brief Retreive the first of a pair of types
//! \pre \c _Pair is a specialization of \c __type_pair
template <class _Pair>
using __type_pair_first = typename _Pair::__first;

//! \brief Retreive the second of a pair of types
//! \pre \c _Pair is a specialization of \c __type_pair
template <class _Pair>
using __type_pair_second = typename _Pair::__second;

//! \brief A list of compile-time values, and a meta-callable that accepts a
//! meta-callable and evaluates it with the values, each value wrapped in an
//! integral constant wrapper.
Expand Down
20 changes: 20 additions & 0 deletions libcudacxx/test/libcudacxx/cuda/type_list.pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,26 @@ static_assert(
"");
#endif

// __type_switch
static_assert(::cuda::std::is_same<::cuda::std::__type_switch<0,
::cuda::std::__type_case<0, char>,
::cuda::std::__type_case<1, double>,
::cuda::std::__type_default<float>>,
char>::value,
"");
static_assert(::cuda::std::is_same<::cuda::std::__type_switch<1,
::cuda::std::__type_case<0, char>,
::cuda::std::__type_case<1, double>,
::cuda::std::__type_default<float>>,
double>::value,
"");
static_assert(::cuda::std::is_same<::cuda::std::__type_switch<2,
::cuda::std::__type_case<0, char>,
::cuda::std::__type_case<1, double>,
::cuda::std::__type_default<float>>,
float>::value,
"");

// __type_concat
static_assert(::cuda::std::is_same<::cuda::std::__type_concat<>, ::cuda::std::__type_list<>>::value, "");

Expand Down

0 comments on commit 83aca35

Please sign in to comment.