Skip to content

Commit

Permalink
Support FP16 traits on CTK 12.0
Browse files Browse the repository at this point in the history
  • Loading branch information
bernhardmgruber committed Jan 27, 2025
1 parent abfb7b4 commit 8653d48
Show file tree
Hide file tree
Showing 37 changed files with 152 additions and 152 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ _CCCL_INLINE_VAR constexpr bool __is_extended_floating_point_v
# endif // !_CCCL_NO_INLINE_VARIABLES
#endif // !_CCCL_NO_VARIABLE_TEMPLATES

#if defined(_LIBCUDACXX_HAS_NVFP16)
#if defined(_CCCL_HAS_FP16)
template <>
struct __is_extended_floating_point<__half> : true_type
{};
Expand All @@ -62,9 +62,9 @@ struct __is_extended_floating_point<__half> : true_type
template <>
_CCCL_INLINE_VAR constexpr bool __is_extended_floating_point_v<__half> = true;
# endif // !_CCCL_NO_INLINE_VARIABLES
#endif // _LIBCUDACXX_HAS_NVFP16
#endif // _CCCL_HAS_FP16

#if defined(_LIBCUDACXX_HAS_NVBF16)
#if defined(_CCCL_HAS_BF16)
template <>
struct __is_extended_floating_point<__nv_bfloat16> : true_type
{};
Expand All @@ -73,7 +73,7 @@ struct __is_extended_floating_point<__nv_bfloat16> : true_type
template <>
_CCCL_INLINE_VAR constexpr bool __is_extended_floating_point_v<__nv_bfloat16> = true;
# endif // !_CCCL_NO_INLINE_VARIABLES
#endif // _LIBCUDACXX_HAS_NVBF16
#endif // _CCCL_HAS_BF16

#if _CCCL_HAS_NVFP8()
template <>
Expand Down
8 changes: 4 additions & 4 deletions libcudacxx/include/cuda/std/limits
Original file line number Diff line number Diff line change
Expand Up @@ -608,7 +608,7 @@ public:
#endif // !_LIBCUDACXX_HAS_NO_LONG_DOUBLE
};

#if defined(_LIBCUDACXX_HAS_NVFP16)
#if defined(_CCCL_HAS_FP16)
template <>
class __numeric_limits_impl<__half, __numeric_limits_type::__floating_point>
{
Expand Down Expand Up @@ -681,9 +681,9 @@ public:
static constexpr bool tinyness_before = false;
static constexpr float_round_style round_style = round_to_nearest;
};
#endif // _LIBCUDACXX_HAS_NVFP16
#endif // _CCCL_HAS_FP16

#if defined(_LIBCUDACXX_HAS_NVBF16)
#if defined(_CCCL_HAS_BF16)
template <>
class __numeric_limits_impl<__nv_bfloat16, __numeric_limits_type::__floating_point>
{
Expand Down Expand Up @@ -756,7 +756,7 @@ public:
static constexpr bool tinyness_before = false;
static constexpr float_round_style round_style = round_to_nearest;
};
#endif // _LIBCUDACXX_HAS_NVBF16
#endif // _CCCL_HAS_BF16

#if _CCCL_HAS_NVFP8()
# if defined(_CCCL_BUILTIN_BIT_CAST) || _CCCL_STD_VER >= 2014
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,12 @@ int main(int, char**)
test_is_floating_point<float>();
test_is_floating_point<double>();
test_is_floating_point<long double>();
#ifdef _LIBCUDACXX_HAS_NVFP16
#ifdef _CCCL_HAS_FP16
test_is_floating_point<__half>();
#endif // _LIBCUDACXX_HAS_NVFP16
#ifdef _LIBCUDACXX_HAS_NVBF16
#endif // _CCCL_HAS_FP16
#ifdef _CCCL_HAS_BF16
test_is_floating_point<__nv_bfloat16>();
#endif // _LIBCUDACXX_HAS_NVBF16
#endif // _CCCL_HAS_BF16
#if _CCCL_HAS_NVFP8()
test_is_floating_point<__nv_fp8_e4m3>();
test_is_floating_point<__nv_fp8_e5m2>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,12 @@ int main(int, char**)
#ifndef _LIBCUDACXX_HAS_NO_LONG_DOUBLE
test<long double>();
#endif
#if defined(_LIBCUDACXX_HAS_NVFP16)
#if defined(_CCCL_HAS_FP16)
test<__half>();
#endif // _LIBCUDACXX_HAS_NVFP16
#if defined(_LIBCUDACXX_HAS_NVBF16)
#endif // _CCCL_HAS_FP16
#if defined(_CCCL_HAS_BF16)
test<__nv_bfloat16>();
#endif // _LIBCUDACXX_HAS_NVBF16
#endif // _CCCL_HAS_BF16

static_assert(!cuda::std::numeric_limits<cuda::std::complex<double>>::is_specialized,
"!cuda::std::numeric_limits<cuda::std::complex<double> >::is_specialized");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,18 +51,18 @@ __host__ __device__ inline bool float_eq(__nv_fp8_e5m2 x, __nv_fp8_e5m2 y)
}
#endif // _CCCL_HAS_NVFP8

#if defined(_LIBCUDACXX_HAS_NVFP16)
#if defined(_CCCL_HAS_FP16)
__host__ __device__ inline bool float_eq(__half x, __half y)
{
return __heq(x, y);
}
#endif // _LIBCUDACXX_HAS_NVFP16
#endif // _CCCL_HAS_FP16

#if defined(_LIBCUDACXX_HAS_NVBF16)
#if defined(_CCCL_HAS_BF16)
__host__ __device__ inline bool float_eq(__nv_bfloat16 x, __nv_bfloat16 y)
{
return __heq(x, y);
}
#endif // _LIBCUDACXX_HAS_NVBF16
#endif // _CCCL_HAS_BF16

#endif // NUMERIC_LIMITS_MEMBERS_COMMON_H
Original file line number Diff line number Diff line change
Expand Up @@ -110,12 +110,12 @@ int main(int, char**)
#ifndef _LIBCUDACXX_HAS_NO_LONG_DOUBLE
test_type<long double>();
#endif // _LIBCUDACXX_HAS_NO_LONG_DOUBLE
#if defined(_LIBCUDACXX_HAS_NVFP16)
#if defined(_CCCL_HAS_FP16)
test_type<__half>();
#endif // _LIBCUDACXX_HAS_NVFP16
#if defined(_LIBCUDACXX_HAS_NVBF16)
#endif // _CCCL_HAS_FP16
#if defined(_CCCL_HAS_BF16)
test_type<__nv_bfloat16>();
#endif // _LIBCUDACXX_HAS_NVBF16
#endif // _CCCL_HAS_BF16
#if _CCCL_HAS_NVFP8()
test_type<__nv_fp8_e4m3>();
test_type<__nv_fp8_e5m2>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,12 @@ int main(int, char**)
test<long double>(LDBL_TRUE_MIN);
# endif
#endif
#if defined(_LIBCUDACXX_HAS_NVFP16)
#if defined(_CCCL_HAS_FP16)
test<__half>(__double2half(5.9604644775390625e-08));
#endif // _LIBCUDACXX_HAS_NVFP16
#if defined(_LIBCUDACXX_HAS_NVBF16)
#endif // _CCCL_HAS_FP16
#if defined(_CCCL_HAS_BF16)
test<__nv_bfloat16>(__double2bfloat16(9.18354961579912115600575419705e-41));
#endif // _LIBCUDACXX_HAS_NVBF16
#endif // _CCCL_HAS_BF16
#if _CCCL_HAS_NVFP8()
test<__nv_fp8_e4m3>(make_fp8_e4m3(0.001953125));
test<__nv_fp8_e5m2>(make_fp8_e5m2(0.0000152587890625));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,12 @@ int main(int, char**)
#ifndef _LIBCUDACXX_HAS_NO_LONG_DOUBLE
test<long double, LDBL_MANT_DIG>();
#endif
#if defined(_LIBCUDACXX_HAS_NVFP16)
#if defined(_CCCL_HAS_FP16)
test<__half, 11>();
#endif // _LIBCUDACXX_HAS_NVFP16
#if defined(_LIBCUDACXX_HAS_NVBF16)
#endif // _CCCL_HAS_FP16
#if defined(_CCCL_HAS_BF16)
test<__nv_bfloat16, 8>();
#endif // _LIBCUDACXX_HAS_NVBF16
#endif // _CCCL_HAS_BF16
#if _CCCL_HAS_NVFP8()
test<__nv_fp8_e4m3, 3>();
test<__nv_fp8_e5m2, 2>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,12 @@ int main(int, char**)
#ifndef _LIBCUDACXX_HAS_NO_LONG_DOUBLE
test<long double>();
#endif
#if defined(_LIBCUDACXX_HAS_NVFP16)
#if defined(_CCCL_HAS_FP16)
test<__half>();
#endif // _LIBCUDACXX_HAS_NVFP16
#if defined(_LIBCUDACXX_HAS_NVBF16)
#endif // _CCCL_HAS_FP16
#if defined(_CCCL_HAS_BF16)
test<__nv_bfloat16>();
#endif // _LIBCUDACXX_HAS_NVBF16
#endif // _CCCL_HAS_BF16
#if _CCCL_HAS_NVFP8()
test<__nv_fp8_e4m3>();
test<__nv_fp8_e5m2>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,12 @@ int main(int, char**)
#ifndef _LIBCUDACXX_HAS_NO_LONG_DOUBLE
test<long double>(LDBL_EPSILON);
#endif
#if defined(_LIBCUDACXX_HAS_NVFP16)
#if defined(_CCCL_HAS_FP16)
test<__half>(__double2half(0.0009765625));
#endif // _LIBCUDACXX_HAS_NVFP16
#if defined(_LIBCUDACXX_HAS_NVBF16)
#endif // _CCCL_HAS_FP16
#if defined(_CCCL_HAS_BF16)
test<__nv_bfloat16>(__double2bfloat16(0.0078125));
#endif // _LIBCUDACXX_HAS_NVBF16
#endif // _CCCL_HAS_BF16
#if _CCCL_HAS_NVFP8()
test<__nv_fp8_e4m3>(make_fp8_e4m3(0.125));
test<__nv_fp8_e5m2>(make_fp8_e5m2(0.25));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,12 @@ int main(int, char**)
#ifndef _LIBCUDACXX_HAS_NO_LONG_DOUBLE
test<long double, cuda::std::denorm_present>();
#endif
#if defined(_LIBCUDACXX_HAS_NVFP16)
#if defined(_CCCL_HAS_FP16)
test<__half, cuda::std::denorm_present>();
#endif // _LIBCUDACXX_HAS_NVFP16
#if defined(_LIBCUDACXX_HAS_NVBF16)
#endif // _CCCL_HAS_FP16
#if defined(_CCCL_HAS_BF16)
test<__nv_bfloat16, cuda::std::denorm_present>();
#endif // _LIBCUDACXX_HAS_NVBF16
#endif // _CCCL_HAS_BF16
#if _CCCL_HAS_NVFP8()
test<__nv_fp8_e4m3, cuda::std::denorm_present>();
test<__nv_fp8_e5m2, cuda::std::denorm_present>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,12 @@ int main(int, char**)
#ifndef _LIBCUDACXX_HAS_NO_LONG_DOUBLE
test<long double, false>();
#endif
#if defined(_LIBCUDACXX_HAS_NVFP16)
#if defined(_CCCL_HAS_FP16)
test<__half, false>();
#endif // _LIBCUDACXX_HAS_NVFP16
#if defined(_LIBCUDACXX_HAS_NVBF16)
#endif // _CCCL_HAS_FP16
#if defined(_CCCL_HAS_BF16)
test<__nv_bfloat16, false>();
#endif // _LIBCUDACXX_HAS_NVBF16
#endif // _CCCL_HAS_BF16
#if _CCCL_HAS_NVFP8()
test<__nv_fp8_e4m3, false>();
test<__nv_fp8_e5m2, false>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,12 @@ int main(int, char**)
#ifndef _LIBCUDACXX_HAS_NO_LONG_DOUBLE
test<long double, true>();
#endif
#if defined(_LIBCUDACXX_HAS_NVFP16)
#if defined(_CCCL_HAS_FP16)
test<__half, true>();
#endif // _LIBCUDACXX_HAS_NVFP16
#if defined(_LIBCUDACXX_HAS_NVBF16)
#endif // _CCCL_HAS_FP16
#if defined(_CCCL_HAS_BF16)
test<__nv_bfloat16, true>();
#endif // _LIBCUDACXX_HAS_NVBF16
#endif // _CCCL_HAS_BF16
#if _CCCL_HAS_NVFP8()
test<__nv_fp8_e4m3, false>();
test<__nv_fp8_e5m2, true>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,12 @@ int main(int, char**)
#ifndef _LIBCUDACXX_HAS_NO_LONG_DOUBLE
test<long double, true>();
#endif
#if defined(_LIBCUDACXX_HAS_NVFP16)
#if defined(_CCCL_HAS_FP16)
test<__half, true>();
#endif // _LIBCUDACXX_HAS_NVFP16
#if defined(_LIBCUDACXX_HAS_NVBF16)
#endif // _CCCL_HAS_FP16
#if defined(_CCCL_HAS_BF16)
test<__nv_bfloat16, true>();
#endif // _LIBCUDACXX_HAS_NVBF16
#endif // _CCCL_HAS_BF16
#if _CCCL_HAS_NVFP8()
test<__nv_fp8_e4m3, true>();
test<__nv_fp8_e5m2, true>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,12 @@ int main(int, char**)
#ifndef _LIBCUDACXX_HAS_NO_LONG_DOUBLE
test<long double, true>();
#endif
#if defined(_LIBCUDACXX_HAS_NVFP16)
#if defined(_CCCL_HAS_FP16)
test<__half, true>();
#endif // _LIBCUDACXX_HAS_NVFP16
#if defined(_LIBCUDACXX_HAS_NVBF16)
#endif // _CCCL_HAS_FP16
#if defined(_CCCL_HAS_BF16)
test<__nv_bfloat16, true>();
#endif // _LIBCUDACXX_HAS_NVBF16
#endif // _CCCL_HAS_BF16
#if _CCCL_HAS_NVFP8()
test<__nv_fp8_e4m3, false>();
test<__nv_fp8_e5m2, true>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,12 @@ int main(int, char**)
# ifndef _LIBCUDACXX_HAS_NO_LONG_DOUBLE
test<long double>(1. / 0.);
# endif
# if defined(_LIBCUDACXX_HAS_NVFP16)
# if defined(_CCCL_HAS_FP16)
test<__half>(__double2half(1.0 / 0.0));
# endif // _LIBCUDACXX_HAS_NVFP16
# if defined(_LIBCUDACXX_HAS_NVBF16)
# endif // _CCCL_HAS_FP16
# if defined(_CCCL_HAS_BF16)
test<__nv_bfloat16>(__double2bfloat16(1.0 / 0.0));
# endif // _LIBCUDACXX_HAS_NVBF16
# endif // _CCCL_HAS_BF16
# if _CCCL_HAS_NVFP8()
test<__nv_fp8_e4m3>(__nv_fp8_e4m3{});
test<__nv_fp8_e5m2>(make_fp8_e5m2(1.0 / 0.0));
Expand All @@ -81,12 +81,12 @@ int main(int, char**)
# ifndef _LIBCUDACXX_HAS_NO_LONG_DOUBLE
test<long double>(INFINITY);
# endif
# if defined(_LIBCUDACXX_HAS_NVFP16)
# if defined(_CCCL_HAS_FP16)
test<__half>(__double2half(INFINITY));
# endif // _LIBCUDACXX_HAS_NVFP16
# if defined(_LIBCUDACXX_HAS_NVBF16)
# endif // _CCCL_HAS_FP16
# if defined(_CCCL_HAS_BF16)
test<__nv_bfloat16>(__double2bfloat16(INFINITY));
# endif // _LIBCUDACXX_HAS_NVBF16
# endif // _CCCL_HAS_BF16
# if _CCCL_HAS_NVFP8()
test<__nv_fp8_e4m3>(__nv_fp8_e4m3{});
test<__nv_fp8_e5m2>(make_fp8_e5m2(INFINITY));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,12 @@ int main(int, char**)
#ifndef _LIBCUDACXX_HAS_NO_LONG_DOUBLE
test<long double, true>();
#endif
#if defined(_LIBCUDACXX_HAS_NVFP16)
#if defined(_CCCL_HAS_FP16)
test<__half, true>();
#endif // _LIBCUDACXX_HAS_NVFP16
#if defined(_LIBCUDACXX_HAS_NVBF16)
#endif // _CCCL_HAS_FP16
#if defined(_CCCL_HAS_BF16)
test<__nv_bfloat16, true>();
#endif // _LIBCUDACXX_HAS_NVBF16
#endif // _CCCL_HAS_BF16
#if _CCCL_HAS_NVFP8()
test<__nv_fp8_e4m3, true>();
test<__nv_fp8_e5m2, true>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,12 @@ int main(int, char**)
#ifndef _LIBCUDACXX_HAS_NO_LONG_DOUBLE
test<long double, false>();
#endif
#if defined(_LIBCUDACXX_HAS_NVFP16)
#if defined(_CCCL_HAS_FP16)
test<__half, false>();
#endif // _LIBCUDACXX_HAS_NVFP16
#if defined(_LIBCUDACXX_HAS_NVBF16)
#endif // _CCCL_HAS_FP16
#if defined(_CCCL_HAS_BF16)
test<__nv_bfloat16, false>();
#endif // _LIBCUDACXX_HAS_NVBF16
#endif // _CCCL_HAS_BF16
#if _CCCL_HAS_NVFP8()
test<__nv_fp8_e4m3, false>();
test<__nv_fp8_e5m2, false>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,12 @@ int main(int, char**)
#ifndef _LIBCUDACXX_HAS_NO_LONG_DOUBLE
test<long double, true>();
#endif // _LIBCUDACXX_HAS_NO_LONG_DOUBLE
#if defined(_LIBCUDACXX_HAS_NVFP16)
#if defined(_CCCL_HAS_FP16)
test<__half, true>();
#endif // _LIBCUDACXX_HAS_NVFP16
#if defined(_LIBCUDACXX_HAS_NVBF16)
#endif // _CCCL_HAS_FP16
#if defined(_CCCL_HAS_BF16)
test<__nv_bfloat16, true>();
#endif // _LIBCUDACXX_HAS_NVBF16
#endif // _CCCL_HAS_BF16
#if _CCCL_HAS_NVFP8()
test<__nv_fp8_e4m3, false>();
test<__nv_fp8_e5m2, false>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,12 @@ int main(int, char**)
#ifndef _LIBCUDACXX_HAS_NO_LONG_DOUBLE
test<long double, false>();
#endif
#if defined(_LIBCUDACXX_HAS_NVFP16)
#if defined(_CCCL_HAS_FP16)
test<__half, false>();
#endif // _LIBCUDACXX_HAS_NVFP16
#if defined(_LIBCUDACXX_HAS_NVBF16)
#endif // _CCCL_HAS_FP16
#if defined(_CCCL_HAS_BF16)
test<__nv_bfloat16, false>();
#endif // _LIBCUDACXX_HAS_NVBF16
#endif // _CCCL_HAS_BF16
#if _CCCL_HAS_NVFP8()
test<__nv_fp8_e4m3, false>();
test<__nv_fp8_e5m2, false>();
Expand Down
Loading

0 comments on commit 8653d48

Please sign in to comment.