From d0fca7590955c740d004738bd6acd56af9b8ffea Mon Sep 17 00:00:00 2001 From: Bernhard Manfred Gruber Date: Wed, 22 Jan 2025 09:45:10 +0100 Subject: [PATCH] Specialize __is_extended_floating_point_v for FP8 types --- .../is_extended_floating_point.h | 20 +++++++++++++++++++ .../meta.unary.cat/is_floating_point.pass.cpp | 5 +++++ 2 files changed, 25 insertions(+) diff --git a/libcudacxx/include/cuda/std/__type_traits/is_extended_floating_point.h b/libcudacxx/include/cuda/std/__type_traits/is_extended_floating_point.h index bb1afa4225b..f1e468039c3 100644 --- a/libcudacxx/include/cuda/std/__type_traits/is_extended_floating_point.h +++ b/libcudacxx/include/cuda/std/__type_traits/is_extended_floating_point.h @@ -33,6 +33,10 @@ _CCCL_DIAG_SUPPRESS_CLANG("-Wunused-function") _CCCL_DIAG_POP #endif // _LIBCUDACXX_HAS_NVBF16 +#if defined(_CCCL_HAS_NVFP8) +# include +#endif // _CCCL_HAS_NVFP8 + _LIBCUDACXX_BEGIN_NAMESPACE_STD template @@ -71,6 +75,22 @@ _CCCL_INLINE_VAR constexpr bool __is_extended_floating_point_v<__nv_bfloat16> = # endif // !_CCCL_NO_INLINE_VARIABLES #endif // _LIBCUDACXX_HAS_NVBF16 +#if defined(_CCCL_HAS_NVFP8) +template <> +struct __is_extended_floating_point<__nv_fp8_e4m3> : true_type +{}; +template <> +struct __is_extended_floating_point<__nv_fp8_e5m2> : true_type +{}; + +# ifndef _CCCL_NO_INLINE_VARIABLES +template <> +_CCCL_INLINE_VAR constexpr bool __is_extended_floating_point_v<__nv_fp8_e4m3> = true; +template <> +_CCCL_INLINE_VAR constexpr bool __is_extended_floating_point_v<__nv_fp8_e5m2> = true; +# endif // !_CCCL_NO_INLINE_VARIABLES +#endif // _CCCL_HAS_NVFP8 + _LIBCUDACXX_END_NAMESPACE_STD #endif // _LIBCUDACXX___TYPE_TRAITS_IS_EXTENDED_FLOATING_POINT_H diff --git a/libcudacxx/test/libcudacxx/libcxx/utilities/meta/meta.unary/meta.unary.cat/is_floating_point.pass.cpp b/libcudacxx/test/libcudacxx/libcxx/utilities/meta/meta.unary/meta.unary.cat/is_floating_point.pass.cpp index 13bb443314a..2235a8e79a4 100644 --- a/libcudacxx/test/libcudacxx/libcxx/utilities/meta/meta.unary/meta.unary.cat/is_floating_point.pass.cpp +++ b/libcudacxx/test/libcudacxx/libcxx/utilities/meta/meta.unary/meta.unary.cat/is_floating_point.pass.cpp @@ -86,7 +86,12 @@ int main(int, char**) #ifdef _LIBCUDACXX_HAS_NVBF16 test_is_floating_point<__nv_bfloat16>(); #endif // _LIBCUDACXX_HAS_NVBF16 +#ifdef _CCCL_HAS_NVFP8 + test_is_floating_point<__nv_fp8_e4m3>(); + test_is_floating_point<__nv_fp8_e5m2>(); +#endif // _CCCL_HAS_NVFP8 + test_is_not_floating_point(); test_is_not_floating_point(); test_is_not_floating_point(); test_is_not_floating_point();