diff --git a/libcudacxx/include/cuda/std/__cccl/extended_floating_point.h b/libcudacxx/include/cuda/std/__cccl/extended_floating_point.h index fe83fbb03c1..dee553633d8 100644 --- a/libcudacxx/include/cuda/std/__cccl/extended_floating_point.h +++ b/libcudacxx/include/cuda/std/__cccl/extended_floating_point.h @@ -25,16 +25,6 @@ #include #include -#if !defined(_CCCL_DISABLE_NVFP8_SUPPORT) -# if _CCCL_HAS_INCLUDE() -# define _CCCL_HAS_NVFP8() 1 -# else -# define _CCCL_HAS_NVFP8() 0 -# endif // _CCCL_CUDACC_AT_LEAST(11, 8) && _CCCL_HAS_INCLUDE() -#else -# define _CCCL_HAS_NVFP8() 0 -#endif // !defined(_CCCL_DISABLE_NVFP8_SUPPORT) - #if !defined(_CCCL_HAS_NVFP16) # if _CCCL_HAS_INCLUDE() && (_CCCL_HAS_CUDA_COMPILER || defined(LIBCUDACXX_ENABLE_HOST_NVFP16)) \ && !defined(CCCL_DISABLE_FP16_SUPPORT) @@ -49,4 +39,14 @@ # endif #endif // !_CCCL_HAS_NVBF16 +#if !defined(_CCCL_DISABLE_NVFP8_SUPPORT) +# if _CCCL_HAS_INCLUDE() && defined(_CCCL_HAS_NVFP16) && defined(_CCCL_HAS_NVBF16) +# define _CCCL_HAS_NVFP8() 1 +# else +# define _CCCL_HAS_NVFP8() 0 +# endif // _CCCL_HAS_INCLUDE() +#else +# define _CCCL_HAS_NVFP8() 0 +#endif // !defined(_CCCL_DISABLE_NVFP8_SUPPORT) + #endif // __CCCL_EXTENDED_FLOATING_POINT_H diff --git a/libcudacxx/include/cuda/std/__type_traits/is_extended_floating_point.h b/libcudacxx/include/cuda/std/__type_traits/is_extended_floating_point.h index bb1afa4225b..b9700a87066 100644 --- a/libcudacxx/include/cuda/std/__type_traits/is_extended_floating_point.h +++ b/libcudacxx/include/cuda/std/__type_traits/is_extended_floating_point.h @@ -33,6 +33,10 @@ _CCCL_DIAG_SUPPRESS_CLANG("-Wunused-function") _CCCL_DIAG_POP #endif // _LIBCUDACXX_HAS_NVBF16 +#if _CCCL_HAS_NVFP8() +# include +#endif // _CCCL_HAS_NVFP8() + _LIBCUDACXX_BEGIN_NAMESPACE_STD template @@ -71,6 +75,22 @@ _CCCL_INLINE_VAR constexpr bool __is_extended_floating_point_v<__nv_bfloat16> = # endif // !_CCCL_NO_INLINE_VARIABLES #endif // _LIBCUDACXX_HAS_NVBF16 +#if _CCCL_HAS_NVFP8() +template <> +struct __is_extended_floating_point<__nv_fp8_e4m3> : true_type +{}; +template <> +struct __is_extended_floating_point<__nv_fp8_e5m2> : true_type +{}; + +# ifndef _CCCL_NO_INLINE_VARIABLES +template <> +_CCCL_INLINE_VAR constexpr bool __is_extended_floating_point_v<__nv_fp8_e4m3> = true; +template <> +_CCCL_INLINE_VAR constexpr bool __is_extended_floating_point_v<__nv_fp8_e5m2> = true; +# endif // !_CCCL_NO_INLINE_VARIABLES +#endif // _CCCL_HAS_NVFP8() + _LIBCUDACXX_END_NAMESPACE_STD #endif // _LIBCUDACXX___TYPE_TRAITS_IS_EXTENDED_FLOATING_POINT_H diff --git a/libcudacxx/test/libcudacxx/libcxx/utilities/meta/meta.unary/meta.unary.cat/is_floating_point.pass.cpp b/libcudacxx/test/libcudacxx/libcxx/utilities/meta/meta.unary/meta.unary.cat/is_floating_point.pass.cpp index 13bb443314a..b0b7a3f3b69 100644 --- a/libcudacxx/test/libcudacxx/libcxx/utilities/meta/meta.unary/meta.unary.cat/is_floating_point.pass.cpp +++ b/libcudacxx/test/libcudacxx/libcxx/utilities/meta/meta.unary/meta.unary.cat/is_floating_point.pass.cpp @@ -86,6 +86,10 @@ int main(int, char**) #ifdef _LIBCUDACXX_HAS_NVBF16 test_is_floating_point<__nv_bfloat16>(); #endif // _LIBCUDACXX_HAS_NVBF16 +#if _CCCL_HAS_NVFP8() + test_is_floating_point<__nv_fp8_e4m3>(); + test_is_floating_point<__nv_fp8_e5m2>(); +#endif // ()) test_is_not_floating_point(); test_is_not_floating_point();