Skip to content

Commit

Permalink
Specialize __is_extended_floating_point for FP8 types (NVIDIA#3470)
Browse files Browse the repository at this point in the history
Also ensure that we actually can enable FP8 due to FP16 and BF16 requirements

Co-authored-by: Michael Schellenberger Costa <[email protected]>
  • Loading branch information
bernhardmgruber and miscco committed Jan 22, 2025
1 parent 432a060 commit 43e80f7
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 10 deletions.
20 changes: 10 additions & 10 deletions libcudacxx/include/cuda/std/__cccl/extended_floating_point.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,6 @@
#include <cuda/std/__cccl/diagnostic.h>
#include <cuda/std/__cccl/preprocessor.h>

#if !defined(_CCCL_DISABLE_NVFP8_SUPPORT)
# if _CCCL_HAS_INCLUDE(<cuda_fp8.h>)
# define _CCCL_HAS_NVFP8() 1
# else
# define _CCCL_HAS_NVFP8() 0
# endif // _CCCL_CUDACC_AT_LEAST(11, 8) && _CCCL_HAS_INCLUDE(<cuda_fp8.h>)
#else
# define _CCCL_HAS_NVFP8() 0
#endif // !defined(_CCCL_DISABLE_NVFP8_SUPPORT)

#if !defined(_CCCL_HAS_NVFP16)
# if _CCCL_HAS_INCLUDE(<cuda_fp16.h>) && (_CCCL_HAS_CUDA_COMPILER || defined(LIBCUDACXX_ENABLE_HOST_NVFP16)) \
&& !defined(CCCL_DISABLE_FP16_SUPPORT)
Expand All @@ -49,4 +39,14 @@
# endif
#endif // !_CCCL_HAS_NVBF16

#if !defined(_CCCL_DISABLE_NVFP8_SUPPORT)
# if _CCCL_HAS_INCLUDE(<cuda_fp8.h>) && defined(_CCCL_HAS_NVFP16) && defined(_CCCL_HAS_NVBF16)
# define _CCCL_HAS_NVFP8() 1
# else
# define _CCCL_HAS_NVFP8() 0
# endif // _CCCL_HAS_INCLUDE(<cuda_fp8.h>)
#else
# define _CCCL_HAS_NVFP8() 0
#endif // !defined(_CCCL_DISABLE_NVFP8_SUPPORT)

#endif // __CCCL_EXTENDED_FLOATING_POINT_H
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ _CCCL_DIAG_SUPPRESS_CLANG("-Wunused-function")
_CCCL_DIAG_POP
#endif // _LIBCUDACXX_HAS_NVBF16

#if _CCCL_HAS_NVFP8()
# include <cuda_fp8.h>
#endif // _CCCL_HAS_NVFP8()

_LIBCUDACXX_BEGIN_NAMESPACE_STD

template <class _Tp>
Expand Down Expand Up @@ -71,6 +75,22 @@ _CCCL_INLINE_VAR constexpr bool __is_extended_floating_point_v<__nv_bfloat16> =
# endif // !_CCCL_NO_INLINE_VARIABLES
#endif // _LIBCUDACXX_HAS_NVBF16

#if _CCCL_HAS_NVFP8()
template <>
struct __is_extended_floating_point<__nv_fp8_e4m3> : true_type
{};
template <>
struct __is_extended_floating_point<__nv_fp8_e5m2> : true_type
{};

# ifndef _CCCL_NO_INLINE_VARIABLES
template <>
_CCCL_INLINE_VAR constexpr bool __is_extended_floating_point_v<__nv_fp8_e4m3> = true;
template <>
_CCCL_INLINE_VAR constexpr bool __is_extended_floating_point_v<__nv_fp8_e5m2> = true;
# endif // !_CCCL_NO_INLINE_VARIABLES
#endif // _CCCL_HAS_NVFP8()

_LIBCUDACXX_END_NAMESPACE_STD

#endif // _LIBCUDACXX___TYPE_TRAITS_IS_EXTENDED_FLOATING_POINT_H
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ int main(int, char**)
#ifdef _LIBCUDACXX_HAS_NVBF16
test_is_floating_point<__nv_bfloat16>();
#endif // _LIBCUDACXX_HAS_NVBF16
#if _CCCL_HAS_NVFP8()
test_is_floating_point<__nv_fp8_e4m3>();
test_is_floating_point<__nv_fp8_e5m2>();
#endif // ())

test_is_not_floating_point<short>();
test_is_not_floating_point<unsigned short>();
Expand Down

0 comments on commit 43e80f7

Please sign in to comment.