diff --git a/libcudacxx/include/cuda/std/__type_traits/is_extended_floating_point.h b/libcudacxx/include/cuda/std/__type_traits/is_extended_floating_point.h index bb1afa4225b..f1e468039c3 100644 --- a/libcudacxx/include/cuda/std/__type_traits/is_extended_floating_point.h +++ b/libcudacxx/include/cuda/std/__type_traits/is_extended_floating_point.h @@ -33,6 +33,10 @@ _CCCL_DIAG_SUPPRESS_CLANG("-Wunused-function") _CCCL_DIAG_POP #endif // _LIBCUDACXX_HAS_NVBF16 +#if defined(_CCCL_HAS_NVFP8) +# include +#endif // _CCCL_HAS_NVFP8 + _LIBCUDACXX_BEGIN_NAMESPACE_STD template @@ -71,6 +75,22 @@ _CCCL_INLINE_VAR constexpr bool __is_extended_floating_point_v<__nv_bfloat16> = # endif // !_CCCL_NO_INLINE_VARIABLES #endif // _LIBCUDACXX_HAS_NVBF16 +#if defined(_CCCL_HAS_NVFP8) +template <> +struct __is_extended_floating_point<__nv_fp8_e4m3> : true_type +{}; +template <> +struct __is_extended_floating_point<__nv_fp8_e5m2> : true_type +{}; + +# ifndef _CCCL_NO_INLINE_VARIABLES +template <> +_CCCL_INLINE_VAR constexpr bool __is_extended_floating_point_v<__nv_fp8_e4m3> = true; +template <> +_CCCL_INLINE_VAR constexpr bool __is_extended_floating_point_v<__nv_fp8_e5m2> = true; +# endif // !_CCCL_NO_INLINE_VARIABLES +#endif // _CCCL_HAS_NVFP8 + _LIBCUDACXX_END_NAMESPACE_STD #endif // _LIBCUDACXX___TYPE_TRAITS_IS_EXTENDED_FLOATING_POINT_H