diff --git a/libcudacxx/include/cuda/std/detail/libcxx/include/__config b/libcudacxx/include/cuda/std/detail/libcxx/include/__config index d2c2be623a3..b2babea4617 100644 --- a/libcudacxx/include/cuda/std/detail/libcxx/include/__config +++ b/libcudacxx/include/cuda/std/detail/libcxx/include/__config @@ -1152,6 +1152,12 @@ typedef __char32_t char32_t; #endif #endif // _LIBCUDACXX_HAS_NO_LONG_DOUBLE +#ifndef _LIBCUDACXX_HAS_NO_NVBF16 +#if defined(CUB_DISABLE_BF16_SUPPORT) +# define _LIBCUDACXX_HAS_NO_NVBF16 +#endif +#endif // _LIBCUDACXX_HAS_NO_NVBF16 + #ifndef _LIBCUDACXX_HAS_NO_ATTRIBUTE_NO_UNIQUE_ADDRESS #if __has_cpp_attribute(msvc::no_unique_address) // MSVC implements [[no_unique_address]] as a silent no-op currently. diff --git a/libcudacxx/include/cuda/std/detail/libcxx/include/__type_traits/promote.h b/libcudacxx/include/cuda/std/detail/libcxx/include/__type_traits/promote.h index fc2ea6429fe..40825c79674 100644 --- a/libcudacxx/include/cuda/std/detail/libcxx/include/__type_traits/promote.h +++ b/libcudacxx/include/cuda/std/detail/libcxx/include/__type_traits/promote.h @@ -21,8 +21,10 @@ #if defined(__cuda_std__) && defined(_LIBCUDACXX_CUDACC) #include +#ifndef _LIBCUDACXX_HAS_NO_NVBF16 #include #endif +#endif #if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC) # pragma GCC system_header @@ -40,7 +42,9 @@ struct __numeric_type _LIBCUDACXX_INLINE_VISIBILITY static void __test(...); #if defined(__cuda_std__) && defined(_LIBCUDACXX_CUDACC) _LIBCUDACXX_INLINE_VISIBILITY static __half __test(__half); +#ifndef _LIBCUDACXX_HAS_NO_NVBF16 _LIBCUDACXX_INLINE_VISIBILITY static __nv_bfloat16 __test(__nv_bfloat16); +#endif #endif _LIBCUDACXX_INLINE_VISIBILITY static float __test(float); _LIBCUDACXX_INLINE_VISIBILITY static double __test(char); diff --git a/libcudacxx/include/cuda/std/detail/libcxx/include/cmath b/libcudacxx/include/cuda/std/detail/libcxx/include/cmath index f1e2c44f4bd..973379c5215 100644 --- a/libcudacxx/include/cuda/std/detail/libcxx/include/cmath +++ b/libcudacxx/include/cuda/std/detail/libcxx/include/cmath @@ -324,8 +324,10 @@ long double truncl(long double x); #ifdef __cuda_std__ #include +#ifndef _LIBCUDACXX_HAS_NO_NVBF16 #include #endif +#endif #if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC) # pragma GCC system_header @@ -652,6 +654,7 @@ __half sin(__half __v) ) } +#ifndef _LIBCUDACXX_HAS_NO_NVBF16 inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 sin(__nv_bfloat16 __v) { @@ -660,6 +663,7 @@ __nv_bfloat16 sin(__nv_bfloat16 __v) (return __nv_bfloat16(_CUDA_VSTD::sin(float(__v)));) ) } +#endif inline _LIBCUDACXX_INLINE_VISIBILITY __half sinh(__half __v) @@ -667,11 +671,13 @@ __half sinh(__half __v) return __half(_CUDA_VSTD::sinh(float(__v))); } +#ifndef _LIBCUDACXX_HAS_NO_NVBF16 inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 sinh(__nv_bfloat16 __v) { return __nv_bfloat16(_CUDA_VSTD::sinh(float(__v))); } +#endif inline _LIBCUDACXX_INLINE_VISIBILITY __half cos(__half __v) @@ -699,6 +705,7 @@ __half cos(__half __v) ) } +#ifndef _LIBCUDACXX_HAS_NO_NVBF16 inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 cos(__nv_bfloat16 __v) { @@ -707,6 +714,7 @@ __nv_bfloat16 cos(__nv_bfloat16 __v) (return __nv_bfloat16(_CUDA_VSTD::cos(float(__v)));) ) } +#endif inline _LIBCUDACXX_INLINE_VISIBILITY __half cosh(__half __v) @@ -714,11 +722,13 @@ __half cosh(__half __v) return __half(_CUDA_VSTD::cosh(float(__v))); } +#ifndef _LIBCUDACXX_HAS_NO_NVBF16 inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 cosh(__nv_bfloat16 __v) { return __nv_bfloat16(_CUDA_VSTD::cosh(float(__v))); } +#endif inline _LIBCUDACXX_INLINE_VISIBILITY __half exp(__half __v) @@ -746,6 +756,7 @@ __half exp(__half __v) ) } +#ifndef _LIBCUDACXX_HAS_NO_NVBF16 inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 exp(__nv_bfloat16 __v) { @@ -754,6 +765,7 @@ __nv_bfloat16 exp(__nv_bfloat16 __v) (return __nv_bfloat16(_CUDA_VSTD::exp(float(__v)));) ) } +#endif inline _LIBCUDACXX_INLINE_VISIBILITY bool signbit(__half __v) @@ -761,11 +773,13 @@ bool signbit(__half __v) return ::signbit(__half2float(__v)); } +#ifndef _LIBCUDACXX_HAS_NO_NVBF16 inline _LIBCUDACXX_INLINE_VISIBILITY bool signbit(__nv_bfloat16 __v) { return ::signbit(__bfloat162float(__v)); } +#endif inline _LIBCUDACXX_INLINE_VISIBILITY __half atan2(__half __x, __half __y) @@ -773,11 +787,13 @@ __half atan2(__half __x, __half __y) return __half(_CUDA_VSTD::atan2(float(__x), float(__y))); } +#ifndef _LIBCUDACXX_HAS_NO_NVBF16 inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 atan2(__nv_bfloat16 __x, __nv_bfloat16 __y) { return __nv_bfloat16(_CUDA_VSTD::atan2(float(__x), float(__y))); } +#endif inline _LIBCUDACXX_INLINE_VISIBILITY __half log(__half __x) @@ -804,6 +820,7 @@ __half log(__half __x) ) } +#ifndef _LIBCUDACXX_HAS_NO_NVBF16 inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 log(__nv_bfloat16 __x) { @@ -812,6 +829,7 @@ __nv_bfloat16 log(__nv_bfloat16 __x) (return __nv_bfloat16(_CUDA_VSTD::log(float(__x)));) ) } +#endif inline _LIBCUDACXX_INLINE_VISIBILITY __half sqrt(__half __x) @@ -822,6 +840,7 @@ __half sqrt(__half __x) ) } +#ifndef _LIBCUDACXX_HAS_NO_NVBF16 inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 sqrt(__nv_bfloat16 __x) { @@ -831,6 +850,7 @@ __nv_bfloat16 sqrt(__nv_bfloat16 __x) ) } #endif +#endif template _LIBCUDACXX_INLINE_VISIBILITY @@ -867,18 +887,22 @@ bool isnan(__half __v) return __constexpr_isnan(__v); } +#ifndef _LIBCUDACXX_HAS_NO_NVBF16 inline _LIBCUDACXX_INLINE_VISIBILITY bool __constexpr_isnan(__nv_bfloat16 __x) noexcept { return __hisnan(__x); } +#endif +#ifndef _LIBCUDACXX_HAS_NO_NVBF16 inline _LIBCUDACXX_INLINE_VISIBILITY bool isnan(__nv_bfloat16 __v) { return __constexpr_isnan(__v); } #endif +#endif template _LIBCUDACXX_INLINE_VISIBILITY @@ -914,6 +938,7 @@ bool __constexpr_isinf(__half __x) noexcept { #endif } +#ifndef _LIBCUDACXX_HAS_NO_NVBF16 inline _LIBCUDACXX_INLINE_VISIBILITY bool __constexpr_isinf(__nv_bfloat16 __x) noexcept { #if _LIBCUDACXX_STD_VER >= 20 @@ -924,6 +949,7 @@ bool __constexpr_isinf(__nv_bfloat16 __x) noexcept { return __hisinf(__x) != 0; #endif } +#endif inline _LIBCUDACXX_INLINE_VISIBILITY bool isinf(__half __v) @@ -931,11 +957,13 @@ bool isinf(__half __v) return __constexpr_isinf(__v); } +#ifndef _LIBCUDACXX_HAS_NO_NVBF16 inline _LIBCUDACXX_INLINE_VISIBILITY bool isinf(__nv_bfloat16 __v) { return __constexpr_isinf(__v); } +#endif inline _LIBCUDACXX_INLINE_VISIBILITY __half hypot(__half __x, __half __y) @@ -943,12 +971,14 @@ __half hypot(__half __x, __half __y) return __half(_CUDA_VSTD::hypot(float(__x), float(__y))); } +#ifndef _LIBCUDACXX_HAS_NO_NVBF16 inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 hypot(__nv_bfloat16 __x, __nv_bfloat16 __y) { return __nv_bfloat16(_CUDA_VSTD::hypot(float(__x), float(__y))); } #endif +#endif template _LIBCUDACXX_INLINE_VISIBILITY @@ -984,6 +1014,7 @@ bool isfinite(__half __v) return __constexpr_isfinite(__v); } +#ifndef _LIBCUDACXX_HAS_NO_NVBF16 inline _LIBCUDACXX_INLINE_VISIBILITY bool __constexpr_isfinite(__nv_bfloat16 __x) noexcept { return !__constexpr_isnan(__x) && !__constexpr_isinf(__x); @@ -995,6 +1026,7 @@ bool isfinite(__nv_bfloat16 __v) return __constexpr_isfinite(__v); } #endif +#endif #if defined(_MSC_VER) || defined(__CUDACC_RTC__) || defined(_LIBCUDACXX_COMPILER_CLANG_CUDA) template @@ -1043,6 +1075,7 @@ __half copysign(__half __x, __half __y) return __constexpr_copysign(__x, __y); } +#ifndef _LIBCUDACXX_HAS_NO_NVBF16 inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 __constexpr_copysign(__nv_bfloat16 __x, __nv_bfloat16 __y) noexcept { @@ -1055,6 +1088,7 @@ __nv_bfloat16 copysign(__nv_bfloat16 __x, __nv_bfloat16 __y) return __constexpr_copysign(__x, __y); } #endif +#endif #if defined(_MSC_VER) || defined(__CUDACC_RTC__) || defined(_LIBCUDACXX_COMPILER_CLANG_CUDA) template @@ -1105,6 +1139,7 @@ __half abs(__half __x) return __constexpr_fabs(__x); } +#ifndef _LIBCUDACXX_HAS_NO_NVBF16 inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 __constexpr_fabs(__nv_bfloat16 __x) noexcept { @@ -1123,6 +1158,7 @@ __nv_bfloat16 abs(__nv_bfloat16 __x) return __constexpr_fabs(__x); } #endif +#endif #if defined(_MSC_VER) || defined(__CUDACC_RTC__) || defined(_LIBCUDACXX_COMPILER_CLANG_CUDA) template @@ -1190,12 +1226,14 @@ __half __constexpr_fmax(__half __x, __half __y) noexcept return __hmax(__x, __y); } +#ifndef _LIBCUDACXX_HAS_NO_NVBF16 inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 __constexpr_fmax(__nv_bfloat16 __x, __nv_bfloat16 __y) noexcept { return __hmax(__x, __y); } #endif +#endif #if defined(_MSC_VER) || defined(__CUDACC_RTC__) || defined(_LIBCUDACXX_COMPILER_CLANG_CUDA) template diff --git a/libcudacxx/include/cuda/std/detail/libcxx/include/complex b/libcudacxx/include/cuda/std/detail/libcxx/include/complex index fc1dea38e5f..74e497dc91f 100644 --- a/libcudacxx/include/cuda/std/detail/libcxx/include/complex +++ b/libcudacxx/include/cuda/std/detail/libcxx/include/complex @@ -297,7 +297,9 @@ struct __is_complex_float { static constexpr auto value = is_floating_point<_Tp>::value #ifdef __cuda_std__ || is_same<_Tp, __half>::value +#ifndef _LIBCUDACXX_HAS_NO_NVBF16 || is_same<_Tp, __nv_bfloat16>::value +#endif #endif ; }; @@ -517,6 +519,7 @@ public: } }; +#ifndef _LIBCUDACXX_HAS_NO_NVBF16 template<> class _LIBCUDACXX_TEMPLATE_VIS _LIBCUDACXX_COMPLEX_ALIGNAS(alignof(__nv_bfloat162)) complex<__nv_bfloat16> { @@ -617,6 +620,7 @@ public: } }; #endif +#endif template<> class _LIBCUDACXX_TEMPLATE_VIS _LIBCUDACXX_COMPLEX_ALIGNAS(2*sizeof(float)) complex @@ -631,8 +635,10 @@ public: #ifdef __cuda_std__ _LIBCUDACXX_INLINE_VISIBILITY complex(const complex<__half> & __c) : __re_(__c.real()), __im_(__c.imag()) {} +#ifndef _LIBCUDACXX_HAS_NO_NVBF16 _LIBCUDACXX_INLINE_VISIBILITY complex(const complex<__nv_bfloat16> & __c) : __re_(__c.real()), __im_(__c.imag()) {} +#endif #endif _LIBCUDACXX_INLINE_VISIBILITY explicit constexpr complex(const complex& __c); @@ -734,8 +740,10 @@ public: #ifdef __cuda_std__ _LIBCUDACXX_INLINE_VISIBILITY complex(const complex<__half> & __c) : __re_(__c.real()), __im_(__c.imag()) {} +#ifndef _LIBCUDACXX_HAS_NO_NVBF16 _LIBCUDACXX_INLINE_VISIBILITY complex(const complex<__nv_bfloat16> & __c) : __re_(__c.real()), __im_(__c.imag()) {} +#endif #endif _LIBCUDACXX_INLINE_VISIBILITY constexpr complex(const complex& __c); @@ -1072,8 +1080,10 @@ __complex_piecewise_mul(_Tp __x1, _Tp __y1, _Tp __x2, _Tp __y2) #ifdef __cuda_std__ template<> struct __type_to_vector<__half> { using __type = __half2; }; +#ifndef _LIBCUDACXX_HAS_NO_NVBF16 template<> struct __type_to_vector<__nv_bfloat16> { using __type = __nv_bfloat162; }; +#endif template _LIBCUDACXX_INLINE_VISIBILITY @@ -1477,6 +1487,7 @@ struct __libcpp_complex_overload_traits<__half, false, false> typedef complex<__half> _ComplexType; }; +#ifndef _LIBCUDACXX_HAS_NO_NVBF16 template <> struct __libcpp_complex_overload_traits<__nv_bfloat16, false, false> { @@ -1484,6 +1495,7 @@ struct __libcpp_complex_overload_traits<__nv_bfloat16, false, false> typedef complex<__nv_bfloat16> _ComplexType; }; #endif +#endif // real @@ -1590,6 +1602,7 @@ arg(_Tp __re) return _CUDA_VSTD::atan2f(__half(0), __re); } +#ifndef _LIBCUDACXX_HAS_NO_NVBF16 template inline _LIBCUDACXX_INLINE_VISIBILITY __enable_if_t< @@ -1601,6 +1614,7 @@ arg(_Tp __re) return _CUDA_VSTD::atan2f(__nv_bfloat16(0), __re); } #endif +#endif // norm