Skip to content

Commit

Permalink
Support FP16 traits on CTK 12.0 (#3535)
Browse files Browse the repository at this point in the history
* Support FP16 traits on CTK 12.0
* Only enable constexpr limits when supported
* Support float_eq on CTK < 12.2
  • Loading branch information
bernhardmgruber authored Jan 30, 2025
1 parent c02e845 commit a654bc6
Show file tree
Hide file tree
Showing 37 changed files with 205 additions and 174 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,16 @@

#include <cuda/std/__type_traits/integral_constant.h>

#if defined(_LIBCUDACXX_HAS_NVFP16)
#if defined(_CCCL_HAS_NVFP16)
# include <cuda_fp16.h>
#endif // _LIBCUDACXX_HAS_NVFP16
#endif // _CCCL_HAS_NVFP16

#if defined(_LIBCUDACXX_HAS_NVBF16)
#if defined(_CCCL_HAS_NVBF16)
_CCCL_DIAG_PUSH
_CCCL_DIAG_SUPPRESS_CLANG("-Wunused-function")
# include <cuda_bf16.h>
_CCCL_DIAG_POP
#endif // _LIBCUDACXX_HAS_NVBF16
#endif // _CCCL_HAS_NVBF16

#if _CCCL_HAS_NVFP8()
# include <cuda_fp8.h>
Expand All @@ -53,7 +53,7 @@ _CCCL_INLINE_VAR constexpr bool __is_extended_floating_point_v
# endif // !_CCCL_NO_INLINE_VARIABLES
#endif // !_CCCL_NO_VARIABLE_TEMPLATES

#if defined(_LIBCUDACXX_HAS_NVFP16)
#if defined(_CCCL_HAS_NVFP16)
template <>
struct __is_extended_floating_point<__half> : true_type
{};
Expand All @@ -62,9 +62,9 @@ struct __is_extended_floating_point<__half> : true_type
template <>
_CCCL_INLINE_VAR constexpr bool __is_extended_floating_point_v<__half> = true;
# endif // !_CCCL_NO_INLINE_VARIABLES
#endif // _LIBCUDACXX_HAS_NVFP16
#endif // _CCCL_HAS_NVFP16

#if defined(_LIBCUDACXX_HAS_NVBF16)
#if defined(_CCCL_HAS_NVBF16)
template <>
struct __is_extended_floating_point<__nv_bfloat16> : true_type
{};
Expand All @@ -73,7 +73,7 @@ struct __is_extended_floating_point<__nv_bfloat16> : true_type
template <>
_CCCL_INLINE_VAR constexpr bool __is_extended_floating_point_v<__nv_bfloat16> = true;
# endif // !_CCCL_NO_INLINE_VARIABLES
#endif // _LIBCUDACXX_HAS_NVBF16
#endif // _CCCL_HAS_NVBF16

#if _CCCL_HAS_NVFP8()
template <>
Expand Down
58 changes: 36 additions & 22 deletions libcudacxx/include/cuda/std/limits
Original file line number Diff line number Diff line change
Expand Up @@ -608,7 +608,13 @@ public:
#endif // !_LIBCUDACXX_HAS_NO_LONG_DOUBLE
};

#if defined(_LIBCUDACXX_HAS_NVFP16)
#if defined(_CCCL_HAS_NVFP16)
# ifdef _LIBCUDACXX_HAS_NVFP16
# define _LIBCUDACXX_FP16_CONSTEXPR constexpr
# else //_LIBCUDACXX_HAS_NVFP16
# define _LIBCUDACXX_FP16_CONSTEXPR
# endif //_LIBCUDACXX_HAS_NVFP16

template <>
class __numeric_limits_impl<__half, __numeric_limits_type::__floating_point>
{
Expand All @@ -621,27 +627,27 @@ public:
static constexpr int digits = 11;
static constexpr int digits10 = 3;
static constexpr int max_digits10 = 5;
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type min() noexcept
_LIBCUDACXX_HIDE_FROM_ABI static _LIBCUDACXX_FP16_CONSTEXPR type min() noexcept
{
return type(__half_raw{0x0400u});
}
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type max() noexcept
_LIBCUDACXX_HIDE_FROM_ABI static _LIBCUDACXX_FP16_CONSTEXPR type max() noexcept
{
return type(__half_raw{0x7bffu});
}
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type lowest() noexcept
_LIBCUDACXX_HIDE_FROM_ABI static _LIBCUDACXX_FP16_CONSTEXPR type lowest() noexcept
{
return type(__half_raw{0xfbffu});
}

static constexpr bool is_integer = false;
static constexpr bool is_exact = false;
static constexpr int radix = __FLT_RADIX__;
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type epsilon() noexcept
_LIBCUDACXX_HIDE_FROM_ABI static _LIBCUDACXX_FP16_CONSTEXPR type epsilon() noexcept
{
return type(__half_raw{0x1400u});
}
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type round_error() noexcept
_LIBCUDACXX_HIDE_FROM_ABI static _LIBCUDACXX_FP16_CONSTEXPR type round_error() noexcept
{
return type(__half_raw{0x3800u});
}
Expand All @@ -656,19 +662,19 @@ public:
static constexpr bool has_signaling_NaN = true;
static constexpr float_denorm_style has_denorm = denorm_present;
static constexpr bool has_denorm_loss = false;
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type infinity() noexcept
_LIBCUDACXX_HIDE_FROM_ABI static _LIBCUDACXX_FP16_CONSTEXPR type infinity() noexcept
{
return type(__half_raw{0x7c00u});
}
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type quiet_NaN() noexcept
_LIBCUDACXX_HIDE_FROM_ABI static _LIBCUDACXX_FP16_CONSTEXPR type quiet_NaN() noexcept
{
return type(__half_raw{0x7e00u});
}
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type signaling_NaN() noexcept
_LIBCUDACXX_HIDE_FROM_ABI static _LIBCUDACXX_FP16_CONSTEXPR type signaling_NaN() noexcept
{
return type(__half_raw{0x7d00u});
}
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type denorm_min() noexcept
_LIBCUDACXX_HIDE_FROM_ABI static _LIBCUDACXX_FP16_CONSTEXPR type denorm_min() noexcept
{
return type(__half_raw{0x0001u});
}
Expand All @@ -681,9 +687,16 @@ public:
static constexpr bool tinyness_before = false;
static constexpr float_round_style round_style = round_to_nearest;
};
#endif // _LIBCUDACXX_HAS_NVFP16
# undef _LIBCUDACXX_FP16_CONSTEXPR
#endif // _CCCL_HAS_NVFP16

#if defined(_CCCL_HAS_NVBF16)
# ifdef _LIBCUDACXX_HAS_NVBF16
# define _LIBCUDACXX_BF16_CONSTEXPR constexpr
# else //_LIBCUDACXX_HAS_NVBF16
# define _LIBCUDACXX_BF16_CONSTEXPR
# endif //_LIBCUDACXX_HAS_NVBF16

#if defined(_LIBCUDACXX_HAS_NVBF16)
template <>
class __numeric_limits_impl<__nv_bfloat16, __numeric_limits_type::__floating_point>
{
Expand All @@ -696,27 +709,27 @@ public:
static constexpr int digits = 8;
static constexpr int digits10 = 2;
static constexpr int max_digits10 = 4;
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type min() noexcept
_LIBCUDACXX_HIDE_FROM_ABI static _LIBCUDACXX_BF16_CONSTEXPR type min() noexcept
{
return type(__nv_bfloat16_raw{0x0080u});
}
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type max() noexcept
_LIBCUDACXX_HIDE_FROM_ABI static _LIBCUDACXX_BF16_CONSTEXPR type max() noexcept
{
return type(__nv_bfloat16_raw{0x7f7fu});
}
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type lowest() noexcept
_LIBCUDACXX_HIDE_FROM_ABI static _LIBCUDACXX_BF16_CONSTEXPR type lowest() noexcept
{
return type(__nv_bfloat16_raw{0xff7fu});
}

static constexpr bool is_integer = false;
static constexpr bool is_exact = false;
static constexpr int radix = __FLT_RADIX__;
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type epsilon() noexcept
_LIBCUDACXX_HIDE_FROM_ABI static _LIBCUDACXX_BF16_CONSTEXPR type epsilon() noexcept
{
return type(__nv_bfloat16_raw{0x3c00u});
}
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type round_error() noexcept
_LIBCUDACXX_HIDE_FROM_ABI static _LIBCUDACXX_BF16_CONSTEXPR type round_error() noexcept
{
return type(__nv_bfloat16_raw{0x3f00u});
}
Expand All @@ -731,19 +744,19 @@ public:
static constexpr bool has_signaling_NaN = true;
static constexpr float_denorm_style has_denorm = denorm_present;
static constexpr bool has_denorm_loss = false;
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type infinity() noexcept
_LIBCUDACXX_HIDE_FROM_ABI static _LIBCUDACXX_BF16_CONSTEXPR type infinity() noexcept
{
return type(__nv_bfloat16_raw{0x7f80u});
}
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type quiet_NaN() noexcept
_LIBCUDACXX_HIDE_FROM_ABI static _LIBCUDACXX_BF16_CONSTEXPR type quiet_NaN() noexcept
{
return type(__nv_bfloat16_raw{0x7fc0u});
}
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type signaling_NaN() noexcept
_LIBCUDACXX_HIDE_FROM_ABI static _LIBCUDACXX_BF16_CONSTEXPR type signaling_NaN() noexcept
{
return type(__nv_bfloat16_raw{0x7fa0u});
}
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type denorm_min() noexcept
_LIBCUDACXX_HIDE_FROM_ABI static _LIBCUDACXX_BF16_CONSTEXPR type denorm_min() noexcept
{
return type(__nv_bfloat16_raw{0x0001u});
}
Expand All @@ -756,7 +769,8 @@ public:
static constexpr bool tinyness_before = false;
static constexpr float_round_style round_style = round_to_nearest;
};
#endif // _LIBCUDACXX_HAS_NVBF16
# undef _LIBCUDACXX_BF16_CONSTEXPR
#endif // _CCCL_HAS_NVBF16

#if _CCCL_HAS_NVFP8()
# if defined(_CCCL_BUILTIN_BIT_CAST) || _CCCL_STD_VER >= 2014
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,12 @@ int main(int, char**)
test_is_floating_point<float>();
test_is_floating_point<double>();
test_is_floating_point<long double>();
#ifdef _LIBCUDACXX_HAS_NVFP16
#ifdef _CCCL_HAS_NVFP16
test_is_floating_point<__half>();
#endif // _LIBCUDACXX_HAS_NVFP16
#ifdef _LIBCUDACXX_HAS_NVBF16
#endif // _CCCL_HAS_NVFP16
#ifdef _CCCL_HAS_NVBF16
test_is_floating_point<__nv_bfloat16>();
#endif // _LIBCUDACXX_HAS_NVBF16
#endif // _CCCL_HAS_NVBF16
#if _CCCL_HAS_NVFP8()
test_is_floating_point<__nv_fp8_e4m3>();
test_is_floating_point<__nv_fp8_e5m2>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,12 @@ int main(int, char**)
#ifndef _LIBCUDACXX_HAS_NO_LONG_DOUBLE
test<long double>();
#endif
#if defined(_LIBCUDACXX_HAS_NVFP16)
#if defined(_CCCL_HAS_NVFP16)
test<__half>();
#endif // _LIBCUDACXX_HAS_NVFP16
#if defined(_LIBCUDACXX_HAS_NVBF16)
#endif // _CCCL_HAS_NVFP16
#if defined(_CCCL_HAS_NVBF16)
test<__nv_bfloat16>();
#endif // _LIBCUDACXX_HAS_NVBF16
#endif // _CCCL_HAS_NVBF16

static_assert(!cuda::std::numeric_limits<cuda::std::complex<double>>::is_specialized,
"!cuda::std::numeric_limits<cuda::std::complex<double> >::is_specialized");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#define __CUDA_NO_BFLOAT16_CONVERSIONS__ 1
#define __CUDA_NO_BFLOAT16_OPERATORS__ 1

#include <cuda/std/__bit/bit_cast.h>
#include <cuda/std/limits>

template <class T>
Expand All @@ -42,27 +43,43 @@ __host__ __device__ inline __nv_fp8_e5m2 make_fp8_e5m2(double x, __nv_saturation

__host__ __device__ inline bool float_eq(__nv_fp8_e4m3 x, __nv_fp8_e4m3 y)
{
# if _CCCL_CUDACC_AT_LEAST(12, 2)
return float_eq(__half{__nv_cvt_fp8_to_halfraw(x.__x, __NV_E4M3)}, __half{__nv_cvt_fp8_to_halfraw(y.__x, __NV_E4M3)});
# else
return ::cuda::std::bit_cast<unsigned char>(x) == ::cuda::std::bit_cast<unsigned char>(y);
# endif
}

__host__ __device__ inline bool float_eq(__nv_fp8_e5m2 x, __nv_fp8_e5m2 y)
{
# if _CCCL_CUDACC_AT_LEAST(12, 2)
return float_eq(__half{__nv_cvt_fp8_to_halfraw(x.__x, __NV_E5M2)}, __half{__nv_cvt_fp8_to_halfraw(y.__x, __NV_E5M2)});
# else
return ::cuda::std::bit_cast<unsigned char>(x) == ::cuda::std::bit_cast<unsigned char>(y);
# endif
}
#endif // _CCCL_HAS_NVFP8

#if defined(_LIBCUDACXX_HAS_NVFP16)
#if defined(_CCCL_HAS_NVFP16)
__host__ __device__ inline bool float_eq(__half x, __half y)
{
# if _CCCL_CUDACC_AT_LEAST(12, 2)
return __heq(x, y);
# else
return __half2float(x) == __half2float(y);
# endif
}
#endif // _LIBCUDACXX_HAS_NVFP16
#endif // _CCCL_HAS_NVFP16

#if defined(_LIBCUDACXX_HAS_NVBF16)
#if defined(_CCCL_HAS_NVBF16)
__host__ __device__ inline bool float_eq(__nv_bfloat16 x, __nv_bfloat16 y)
{
# if _CCCL_CUDACC_AT_LEAST(12, 2)
return __heq(x, y);
# else
return __bfloat162float(x) == __bfloat162float(y);
# endif
}
#endif // _LIBCUDACXX_HAS_NVBF16
#endif // _CCCL_HAS_NVBF16

#endif // NUMERIC_LIMITS_MEMBERS_COMMON_H
Original file line number Diff line number Diff line change
Expand Up @@ -110,12 +110,12 @@ int main(int, char**)
#ifndef _LIBCUDACXX_HAS_NO_LONG_DOUBLE
test_type<long double>();
#endif // _LIBCUDACXX_HAS_NO_LONG_DOUBLE
#if defined(_LIBCUDACXX_HAS_NVFP16)
#if defined(_CCCL_HAS_NVFP16)
test_type<__half>();
#endif // _LIBCUDACXX_HAS_NVFP16
#if defined(_LIBCUDACXX_HAS_NVBF16)
#endif // _CCCL_HAS_NVFP16
#if defined(_CCCL_HAS_NVBF16)
test_type<__nv_bfloat16>();
#endif // _LIBCUDACXX_HAS_NVBF16
#endif // _CCCL_HAS_NVBF16
#if _CCCL_HAS_NVFP8()
test_type<__nv_fp8_e4m3>();
test_type<__nv_fp8_e5m2>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,12 @@ int main(int, char**)
test<long double>(LDBL_TRUE_MIN);
# endif
#endif
#if defined(_LIBCUDACXX_HAS_NVFP16)
#if defined(_CCCL_HAS_NVFP16)
test<__half>(__double2half(5.9604644775390625e-08));
#endif // _LIBCUDACXX_HAS_NVFP16
#if defined(_LIBCUDACXX_HAS_NVBF16)
#endif // _CCCL_HAS_NVFP16
#if defined(_CCCL_HAS_NVBF16)
test<__nv_bfloat16>(__double2bfloat16(9.18354961579912115600575419705e-41));
#endif // _LIBCUDACXX_HAS_NVBF16
#endif // _CCCL_HAS_NVBF16
#if _CCCL_HAS_NVFP8()
test<__nv_fp8_e4m3>(make_fp8_e4m3(0.001953125));
test<__nv_fp8_e5m2>(make_fp8_e5m2(0.0000152587890625));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,12 @@ int main(int, char**)
#ifndef _LIBCUDACXX_HAS_NO_LONG_DOUBLE
test<long double, LDBL_MANT_DIG>();
#endif
#if defined(_LIBCUDACXX_HAS_NVFP16)
#if defined(_CCCL_HAS_NVFP16)
test<__half, 11>();
#endif // _LIBCUDACXX_HAS_NVFP16
#if defined(_LIBCUDACXX_HAS_NVBF16)
#endif // _CCCL_HAS_NVFP16
#if defined(_CCCL_HAS_NVBF16)
test<__nv_bfloat16, 8>();
#endif // _LIBCUDACXX_HAS_NVBF16
#endif // _CCCL_HAS_NVBF16
#if _CCCL_HAS_NVFP8()
test<__nv_fp8_e4m3, 3>();
test<__nv_fp8_e5m2, 2>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,12 @@ int main(int, char**)
#ifndef _LIBCUDACXX_HAS_NO_LONG_DOUBLE
test<long double>();
#endif
#if defined(_LIBCUDACXX_HAS_NVFP16)
#if defined(_CCCL_HAS_NVFP16)
test<__half>();
#endif // _LIBCUDACXX_HAS_NVFP16
#if defined(_LIBCUDACXX_HAS_NVBF16)
#endif // _CCCL_HAS_NVFP16
#if defined(_CCCL_HAS_NVBF16)
test<__nv_bfloat16>();
#endif // _LIBCUDACXX_HAS_NVBF16
#endif // _CCCL_HAS_NVBF16
#if _CCCL_HAS_NVFP8()
test<__nv_fp8_e4m3>();
test<__nv_fp8_e5m2>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,12 @@ int main(int, char**)
#ifndef _LIBCUDACXX_HAS_NO_LONG_DOUBLE
test<long double>(LDBL_EPSILON);
#endif
#if defined(_LIBCUDACXX_HAS_NVFP16)
#if defined(_CCCL_HAS_NVFP16)
test<__half>(__double2half(0.0009765625));
#endif // _LIBCUDACXX_HAS_NVFP16
#if defined(_LIBCUDACXX_HAS_NVBF16)
#endif // _CCCL_HAS_NVFP16
#if defined(_CCCL_HAS_NVBF16)
test<__nv_bfloat16>(__double2bfloat16(0.0078125));
#endif // _LIBCUDACXX_HAS_NVBF16
#endif // _CCCL_HAS_NVBF16
#if _CCCL_HAS_NVFP8()
test<__nv_fp8_e4m3>(make_fp8_e4m3(0.125));
test<__nv_fp8_e5m2>(make_fp8_e5m2(0.25));
Expand Down
Loading

0 comments on commit a654bc6

Please sign in to comment.