Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support FP16 traits on CTK 12.0 #3535

Merged
merged 3 commits into from
Jan 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,16 @@

#include <cuda/std/__type_traits/integral_constant.h>

#if defined(_LIBCUDACXX_HAS_NVFP16)
#if defined(_CCCL_HAS_NVFP16)
# include <cuda_fp16.h>
#endif // _LIBCUDACXX_HAS_NVFP16
#endif // _CCCL_HAS_NVFP16

#if defined(_LIBCUDACXX_HAS_NVBF16)
#if defined(_CCCL_HAS_NVBF16)
_CCCL_DIAG_PUSH
_CCCL_DIAG_SUPPRESS_CLANG("-Wunused-function")
# include <cuda_bf16.h>
_CCCL_DIAG_POP
#endif // _LIBCUDACXX_HAS_NVBF16
#endif // _CCCL_HAS_NVBF16

#if _CCCL_HAS_NVFP8()
# include <cuda_fp8.h>
Expand All @@ -53,7 +53,7 @@ _CCCL_INLINE_VAR constexpr bool __is_extended_floating_point_v
# endif // !_CCCL_NO_INLINE_VARIABLES
#endif // !_CCCL_NO_VARIABLE_TEMPLATES

#if defined(_LIBCUDACXX_HAS_NVFP16)
#if defined(_CCCL_HAS_NVFP16)
template <>
struct __is_extended_floating_point<__half> : true_type
{};
Expand All @@ -62,9 +62,9 @@ struct __is_extended_floating_point<__half> : true_type
template <>
_CCCL_INLINE_VAR constexpr bool __is_extended_floating_point_v<__half> = true;
# endif // !_CCCL_NO_INLINE_VARIABLES
#endif // _LIBCUDACXX_HAS_NVFP16
#endif // _CCCL_HAS_NVFP16

#if defined(_LIBCUDACXX_HAS_NVBF16)
#if defined(_CCCL_HAS_NVBF16)
template <>
struct __is_extended_floating_point<__nv_bfloat16> : true_type
{};
Expand All @@ -73,7 +73,7 @@ struct __is_extended_floating_point<__nv_bfloat16> : true_type
template <>
_CCCL_INLINE_VAR constexpr bool __is_extended_floating_point_v<__nv_bfloat16> = true;
# endif // !_CCCL_NO_INLINE_VARIABLES
#endif // _LIBCUDACXX_HAS_NVBF16
#endif // _CCCL_HAS_NVBF16

#if _CCCL_HAS_NVFP8()
template <>
Expand Down
58 changes: 36 additions & 22 deletions libcudacxx/include/cuda/std/limits
Original file line number Diff line number Diff line change
Expand Up @@ -608,7 +608,13 @@ public:
#endif // !_LIBCUDACXX_HAS_NO_LONG_DOUBLE
};

#if defined(_LIBCUDACXX_HAS_NVFP16)
#if defined(_CCCL_HAS_NVFP16)
# ifdef _LIBCUDACXX_HAS_NVFP16
# define _LIBCUDACXX_FP16_CONSTEXPR constexpr
# else //_LIBCUDACXX_HAS_NVFP16
# define _LIBCUDACXX_FP16_CONSTEXPR
# endif //_LIBCUDACXX_HAS_NVFP16

template <>
class __numeric_limits_impl<__half, __numeric_limits_type::__floating_point>
{
Expand All @@ -621,27 +627,27 @@ public:
static constexpr int digits = 11;
static constexpr int digits10 = 3;
static constexpr int max_digits10 = 5;
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type min() noexcept
_LIBCUDACXX_HIDE_FROM_ABI static _LIBCUDACXX_FP16_CONSTEXPR type min() noexcept
{
return type(__half_raw{0x0400u});
}
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type max() noexcept
_LIBCUDACXX_HIDE_FROM_ABI static _LIBCUDACXX_FP16_CONSTEXPR type max() noexcept
{
return type(__half_raw{0x7bffu});
}
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type lowest() noexcept
_LIBCUDACXX_HIDE_FROM_ABI static _LIBCUDACXX_FP16_CONSTEXPR type lowest() noexcept
{
return type(__half_raw{0xfbffu});
}

static constexpr bool is_integer = false;
static constexpr bool is_exact = false;
static constexpr int radix = __FLT_RADIX__;
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type epsilon() noexcept
_LIBCUDACXX_HIDE_FROM_ABI static _LIBCUDACXX_FP16_CONSTEXPR type epsilon() noexcept
{
return type(__half_raw{0x1400u});
}
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type round_error() noexcept
_LIBCUDACXX_HIDE_FROM_ABI static _LIBCUDACXX_FP16_CONSTEXPR type round_error() noexcept
{
return type(__half_raw{0x3800u});
}
Expand All @@ -656,19 +662,19 @@ public:
static constexpr bool has_signaling_NaN = true;
static constexpr float_denorm_style has_denorm = denorm_present;
static constexpr bool has_denorm_loss = false;
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type infinity() noexcept
_LIBCUDACXX_HIDE_FROM_ABI static _LIBCUDACXX_FP16_CONSTEXPR type infinity() noexcept
{
return type(__half_raw{0x7c00u});
}
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type quiet_NaN() noexcept
_LIBCUDACXX_HIDE_FROM_ABI static _LIBCUDACXX_FP16_CONSTEXPR type quiet_NaN() noexcept
{
return type(__half_raw{0x7e00u});
}
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type signaling_NaN() noexcept
_LIBCUDACXX_HIDE_FROM_ABI static _LIBCUDACXX_FP16_CONSTEXPR type signaling_NaN() noexcept
{
return type(__half_raw{0x7d00u});
}
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type denorm_min() noexcept
_LIBCUDACXX_HIDE_FROM_ABI static _LIBCUDACXX_FP16_CONSTEXPR type denorm_min() noexcept
{
return type(__half_raw{0x0001u});
}
Expand All @@ -681,9 +687,16 @@ public:
static constexpr bool tinyness_before = false;
static constexpr float_round_style round_style = round_to_nearest;
};
#endif // _LIBCUDACXX_HAS_NVFP16
# undef _LIBCUDACXX_FP16_CONSTEXPR
#endif // _CCCL_HAS_NVFP16

#if defined(_CCCL_HAS_NVBF16)
# ifdef _LIBCUDACXX_HAS_NVBF16
# define _LIBCUDACXX_BF16_CONSTEXPR constexpr
# else //_LIBCUDACXX_HAS_NVBF16
# define _LIBCUDACXX_BF16_CONSTEXPR
# endif //_LIBCUDACXX_HAS_NVBF16

#if defined(_LIBCUDACXX_HAS_NVBF16)
template <>
class __numeric_limits_impl<__nv_bfloat16, __numeric_limits_type::__floating_point>
{
Expand All @@ -696,27 +709,27 @@ public:
static constexpr int digits = 8;
static constexpr int digits10 = 2;
static constexpr int max_digits10 = 4;
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type min() noexcept
_LIBCUDACXX_HIDE_FROM_ABI static _LIBCUDACXX_BF16_CONSTEXPR type min() noexcept
{
return type(__nv_bfloat16_raw{0x0080u});
}
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type max() noexcept
_LIBCUDACXX_HIDE_FROM_ABI static _LIBCUDACXX_BF16_CONSTEXPR type max() noexcept
{
return type(__nv_bfloat16_raw{0x7f7fu});
}
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type lowest() noexcept
_LIBCUDACXX_HIDE_FROM_ABI static _LIBCUDACXX_BF16_CONSTEXPR type lowest() noexcept
{
return type(__nv_bfloat16_raw{0xff7fu});
}

static constexpr bool is_integer = false;
static constexpr bool is_exact = false;
static constexpr int radix = __FLT_RADIX__;
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type epsilon() noexcept
_LIBCUDACXX_HIDE_FROM_ABI static _LIBCUDACXX_BF16_CONSTEXPR type epsilon() noexcept
{
return type(__nv_bfloat16_raw{0x3c00u});
}
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type round_error() noexcept
_LIBCUDACXX_HIDE_FROM_ABI static _LIBCUDACXX_BF16_CONSTEXPR type round_error() noexcept
{
return type(__nv_bfloat16_raw{0x3f00u});
}
Expand All @@ -731,19 +744,19 @@ public:
static constexpr bool has_signaling_NaN = true;
static constexpr float_denorm_style has_denorm = denorm_present;
static constexpr bool has_denorm_loss = false;
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type infinity() noexcept
_LIBCUDACXX_HIDE_FROM_ABI static _LIBCUDACXX_BF16_CONSTEXPR type infinity() noexcept
{
return type(__nv_bfloat16_raw{0x7f80u});
}
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type quiet_NaN() noexcept
_LIBCUDACXX_HIDE_FROM_ABI static _LIBCUDACXX_BF16_CONSTEXPR type quiet_NaN() noexcept
{
return type(__nv_bfloat16_raw{0x7fc0u});
}
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type signaling_NaN() noexcept
_LIBCUDACXX_HIDE_FROM_ABI static _LIBCUDACXX_BF16_CONSTEXPR type signaling_NaN() noexcept
{
return type(__nv_bfloat16_raw{0x7fa0u});
}
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type denorm_min() noexcept
_LIBCUDACXX_HIDE_FROM_ABI static _LIBCUDACXX_BF16_CONSTEXPR type denorm_min() noexcept
{
return type(__nv_bfloat16_raw{0x0001u});
}
Expand All @@ -756,7 +769,8 @@ public:
static constexpr bool tinyness_before = false;
static constexpr float_round_style round_style = round_to_nearest;
};
#endif // _LIBCUDACXX_HAS_NVBF16
# undef _LIBCUDACXX_BF16_CONSTEXPR
#endif // _CCCL_HAS_NVBF16

#if _CCCL_HAS_NVFP8()
# if defined(_CCCL_BUILTIN_BIT_CAST) || _CCCL_STD_VER >= 2014
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,12 @@ int main(int, char**)
test_is_floating_point<float>();
test_is_floating_point<double>();
test_is_floating_point<long double>();
#ifdef _LIBCUDACXX_HAS_NVFP16
#ifdef _CCCL_HAS_NVFP16
test_is_floating_point<__half>();
#endif // _LIBCUDACXX_HAS_NVFP16
#ifdef _LIBCUDACXX_HAS_NVBF16
#endif // _CCCL_HAS_NVFP16
#ifdef _CCCL_HAS_NVBF16
test_is_floating_point<__nv_bfloat16>();
#endif // _LIBCUDACXX_HAS_NVBF16
#endif // _CCCL_HAS_NVBF16
#if _CCCL_HAS_NVFP8()
test_is_floating_point<__nv_fp8_e4m3>();
test_is_floating_point<__nv_fp8_e5m2>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,12 @@ int main(int, char**)
#ifndef _LIBCUDACXX_HAS_NO_LONG_DOUBLE
test<long double>();
#endif
#if defined(_LIBCUDACXX_HAS_NVFP16)
#if defined(_CCCL_HAS_NVFP16)
test<__half>();
#endif // _LIBCUDACXX_HAS_NVFP16
#if defined(_LIBCUDACXX_HAS_NVBF16)
#endif // _CCCL_HAS_NVFP16
#if defined(_CCCL_HAS_NVBF16)
test<__nv_bfloat16>();
#endif // _LIBCUDACXX_HAS_NVBF16
#endif // _CCCL_HAS_NVBF16

static_assert(!cuda::std::numeric_limits<cuda::std::complex<double>>::is_specialized,
"!cuda::std::numeric_limits<cuda::std::complex<double> >::is_specialized");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#define __CUDA_NO_BFLOAT16_CONVERSIONS__ 1
#define __CUDA_NO_BFLOAT16_OPERATORS__ 1

#include <cuda/std/__bit/bit_cast.h>
#include <cuda/std/limits>

template <class T>
Expand All @@ -42,27 +43,43 @@ __host__ __device__ inline __nv_fp8_e5m2 make_fp8_e5m2(double x, __nv_saturation

__host__ __device__ inline bool float_eq(__nv_fp8_e4m3 x, __nv_fp8_e4m3 y)
{
# if _CCCL_CUDACC_AT_LEAST(12, 2)
return float_eq(__half{__nv_cvt_fp8_to_halfraw(x.__x, __NV_E4M3)}, __half{__nv_cvt_fp8_to_halfraw(y.__x, __NV_E4M3)});
# else
return ::cuda::std::bit_cast<unsigned char>(x) == ::cuda::std::bit_cast<unsigned char>(y);
# endif
}

__host__ __device__ inline bool float_eq(__nv_fp8_e5m2 x, __nv_fp8_e5m2 y)
{
# if _CCCL_CUDACC_AT_LEAST(12, 2)
return float_eq(__half{__nv_cvt_fp8_to_halfraw(x.__x, __NV_E5M2)}, __half{__nv_cvt_fp8_to_halfraw(y.__x, __NV_E5M2)});
# else
return ::cuda::std::bit_cast<unsigned char>(x) == ::cuda::std::bit_cast<unsigned char>(y);
# endif
}
#endif // _CCCL_HAS_NVFP8

#if defined(_LIBCUDACXX_HAS_NVFP16)
#if defined(_CCCL_HAS_NVFP16)
__host__ __device__ inline bool float_eq(__half x, __half y)
{
# if _CCCL_CUDACC_AT_LEAST(12, 2)
return __heq(x, y);
# else
return __half2float(x) == __half2float(y);
# endif
}
#endif // _LIBCUDACXX_HAS_NVFP16
#endif // _CCCL_HAS_NVFP16

#if defined(_LIBCUDACXX_HAS_NVBF16)
#if defined(_CCCL_HAS_NVBF16)
__host__ __device__ inline bool float_eq(__nv_bfloat16 x, __nv_bfloat16 y)
{
# if _CCCL_CUDACC_AT_LEAST(12, 2)
return __heq(x, y);
# else
return __bfloat162float(x) == __bfloat162float(y);
# endif
}
#endif // _LIBCUDACXX_HAS_NVBF16
#endif // _CCCL_HAS_NVBF16

#endif // NUMERIC_LIMITS_MEMBERS_COMMON_H
Original file line number Diff line number Diff line change
Expand Up @@ -110,12 +110,12 @@ int main(int, char**)
#ifndef _LIBCUDACXX_HAS_NO_LONG_DOUBLE
test_type<long double>();
#endif // _LIBCUDACXX_HAS_NO_LONG_DOUBLE
#if defined(_LIBCUDACXX_HAS_NVFP16)
#if defined(_CCCL_HAS_NVFP16)
test_type<__half>();
#endif // _LIBCUDACXX_HAS_NVFP16
#if defined(_LIBCUDACXX_HAS_NVBF16)
#endif // _CCCL_HAS_NVFP16
#if defined(_CCCL_HAS_NVBF16)
test_type<__nv_bfloat16>();
#endif // _LIBCUDACXX_HAS_NVBF16
#endif // _CCCL_HAS_NVBF16
#if _CCCL_HAS_NVFP8()
test_type<__nv_fp8_e4m3>();
test_type<__nv_fp8_e5m2>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,12 @@ int main(int, char**)
test<long double>(LDBL_TRUE_MIN);
# endif
#endif
#if defined(_LIBCUDACXX_HAS_NVFP16)
#if defined(_CCCL_HAS_NVFP16)
test<__half>(__double2half(5.9604644775390625e-08));
#endif // _LIBCUDACXX_HAS_NVFP16
#if defined(_LIBCUDACXX_HAS_NVBF16)
#endif // _CCCL_HAS_NVFP16
#if defined(_CCCL_HAS_NVBF16)
test<__nv_bfloat16>(__double2bfloat16(9.18354961579912115600575419705e-41));
#endif // _LIBCUDACXX_HAS_NVBF16
#endif // _CCCL_HAS_NVBF16
#if _CCCL_HAS_NVFP8()
test<__nv_fp8_e4m3>(make_fp8_e4m3(0.001953125));
test<__nv_fp8_e5m2>(make_fp8_e5m2(0.0000152587890625));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,12 @@ int main(int, char**)
#ifndef _LIBCUDACXX_HAS_NO_LONG_DOUBLE
test<long double, LDBL_MANT_DIG>();
#endif
#if defined(_LIBCUDACXX_HAS_NVFP16)
#if defined(_CCCL_HAS_NVFP16)
test<__half, 11>();
#endif // _LIBCUDACXX_HAS_NVFP16
#if defined(_LIBCUDACXX_HAS_NVBF16)
#endif // _CCCL_HAS_NVFP16
#if defined(_CCCL_HAS_NVBF16)
test<__nv_bfloat16, 8>();
#endif // _LIBCUDACXX_HAS_NVBF16
#endif // _CCCL_HAS_NVBF16
#if _CCCL_HAS_NVFP8()
test<__nv_fp8_e4m3, 3>();
test<__nv_fp8_e5m2, 2>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,12 @@ int main(int, char**)
#ifndef _LIBCUDACXX_HAS_NO_LONG_DOUBLE
test<long double>();
#endif
#if defined(_LIBCUDACXX_HAS_NVFP16)
#if defined(_CCCL_HAS_NVFP16)
test<__half>();
#endif // _LIBCUDACXX_HAS_NVFP16
#if defined(_LIBCUDACXX_HAS_NVBF16)
#endif // _CCCL_HAS_NVFP16
#if defined(_CCCL_HAS_NVBF16)
test<__nv_bfloat16>();
#endif // _LIBCUDACXX_HAS_NVBF16
#endif // _CCCL_HAS_NVBF16
#if _CCCL_HAS_NVFP8()
test<__nv_fp8_e4m3>();
test<__nv_fp8_e5m2>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,12 @@ int main(int, char**)
#ifndef _LIBCUDACXX_HAS_NO_LONG_DOUBLE
test<long double>(LDBL_EPSILON);
#endif
#if defined(_LIBCUDACXX_HAS_NVFP16)
#if defined(_CCCL_HAS_NVFP16)
test<__half>(__double2half(0.0009765625));
#endif // _LIBCUDACXX_HAS_NVFP16
#if defined(_LIBCUDACXX_HAS_NVBF16)
#endif // _CCCL_HAS_NVFP16
#if defined(_CCCL_HAS_NVBF16)
test<__nv_bfloat16>(__double2bfloat16(0.0078125));
#endif // _LIBCUDACXX_HAS_NVBF16
#endif // _CCCL_HAS_NVBF16
#if _CCCL_HAS_NVFP8()
test<__nv_fp8_e4m3>(make_fp8_e4m3(0.125));
test<__nv_fp8_e5m2>(make_fp8_e5m2(0.25));
Expand Down
Loading
Loading