Skip to content

Commit

Permalink
Specialize __is_extended_floating_point_v for FP8 types
Browse files Browse the repository at this point in the history
  • Loading branch information
bernhardmgruber committed Jan 22, 2025
1 parent d2857b1 commit 387e032
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ _CCCL_DIAG_SUPPRESS_CLANG("-Wunused-function")
_CCCL_DIAG_POP
#endif // _LIBCUDACXX_HAS_NVBF16

#if defined(_CCCL_HAS_NVFP8)
# include <cuda_fp8.h>
#endif // _CCCL_HAS_NVFP8

_LIBCUDACXX_BEGIN_NAMESPACE_STD

template <class _Tp>
Expand Down Expand Up @@ -71,6 +75,22 @@ _CCCL_INLINE_VAR constexpr bool __is_extended_floating_point_v<__nv_bfloat16> =
# endif // !_CCCL_NO_INLINE_VARIABLES
#endif // _LIBCUDACXX_HAS_NVBF16

#if defined(_CCCL_HAS_NVFP8)
template <>
struct __is_extended_floating_point<__nv_fp8_e4m3> : true_type
{};
template <>
struct __is_extended_floating_point<__nv_fp8_e5m2> : true_type
{};

# ifndef _CCCL_NO_INLINE_VARIABLES
template <>
_CCCL_INLINE_VAR constexpr bool __is_extended_floating_point_v<__nv_fp8_e4m3> = true;
template <>
_CCCL_INLINE_VAR constexpr bool __is_extended_floating_point_v<__nv_fp8_e5m2> = true;
# endif // !_CCCL_NO_INLINE_VARIABLES
#endif // _CCCL_HAS_NVFP8

_LIBCUDACXX_END_NAMESPACE_STD

#endif // _LIBCUDACXX___TYPE_TRAITS_IS_EXTENDED_FLOATING_POINT_H
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ int main(int, char**)
#ifdef _LIBCUDACXX_HAS_NVBF16
test_is_floating_point<__nv_bfloat16>();
#endif // _LIBCUDACXX_HAS_NVBF16
#ifdef _CCCL_HAS_NVFP8
test_is_floating_point<__nv_fp8_e4m3>();
test_is_floating_point<__nv_fp8_e5m2>();
#endif // _CCCL_HAS_NVFP8

test_is_not_floating_point<short>();
test_is_not_floating_point<unsigned short>();
Expand Down

0 comments on commit 387e032

Please sign in to comment.