Skip to content

Commit

Permalink
Add an opt-out from including bf16, and respect CUB's opt-out.
Browse files Browse the repository at this point in the history
  • Loading branch information
griwes committed Jan 27, 2024
1 parent 7389b25 commit c2d87c2
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 0 deletions.
6 changes: 6 additions & 0 deletions libcudacxx/include/cuda/std/detail/libcxx/include/__config
Original file line number Diff line number Diff line change
Expand Up @@ -1152,6 +1152,12 @@ typedef __char32_t char32_t;
#endif
#endif // _LIBCUDACXX_HAS_NO_LONG_DOUBLE

#ifndef _LIBCUDACXX_HAS_NO_NVBF16
#if defined(CUB_DISABLE_BF16_SUPPORT)
# define _LIBCUDACXX_HAS_NO_NVBF16
#endif
#endif // _LIBCUDACXX_HAS_NO_NVBF16

#ifndef _LIBCUDACXX_HAS_NO_ATTRIBUTE_NO_UNIQUE_ADDRESS
#if __has_cpp_attribute(msvc::no_unique_address)
// MSVC implements [[no_unique_address]] as a silent no-op currently.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@

#if defined(__cuda_std__) && defined(_LIBCUDACXX_CUDACC)
#include <cuda_fp16.h>
#ifndef _LIBCUDACXX_HAS_NO_NVBF16
#include <cuda_bf16.h>
#endif
#endif

#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
# pragma GCC system_header
Expand All @@ -40,7 +42,9 @@ struct __numeric_type
_LIBCUDACXX_INLINE_VISIBILITY static void __test(...);
#if defined(__cuda_std__) && defined(_LIBCUDACXX_CUDACC)
_LIBCUDACXX_INLINE_VISIBILITY static __half __test(__half);
#ifndef _LIBCUDACXX_HAS_NO_NVBF16
_LIBCUDACXX_INLINE_VISIBILITY static __nv_bfloat16 __test(__nv_bfloat16);
#endif
#endif
_LIBCUDACXX_INLINE_VISIBILITY static float __test(float);
_LIBCUDACXX_INLINE_VISIBILITY static double __test(char);
Expand Down
38 changes: 38 additions & 0 deletions libcudacxx/include/cuda/std/detail/libcxx/include/cmath
Original file line number Diff line number Diff line change
Expand Up @@ -324,8 +324,10 @@ long double truncl(long double x);

#ifdef __cuda_std__
#include <cuda_fp16.h>
#ifndef _LIBCUDACXX_HAS_NO_NVBF16
#include <cuda_bf16.h>
#endif
#endif

#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
# pragma GCC system_header
Expand Down Expand Up @@ -652,6 +654,7 @@ __half sin(__half __v)
)
}

#ifndef _LIBCUDACXX_HAS_NO_NVBF16
inline _LIBCUDACXX_INLINE_VISIBILITY
__nv_bfloat16 sin(__nv_bfloat16 __v)
{
Expand All @@ -660,18 +663,21 @@ __nv_bfloat16 sin(__nv_bfloat16 __v)
(return __nv_bfloat16(_CUDA_VSTD::sin(float(__v)));)
)
}
#endif

inline _LIBCUDACXX_INLINE_VISIBILITY
__half sinh(__half __v)
{
return __half(_CUDA_VSTD::sinh(float(__v)));
}

#ifndef _LIBCUDACXX_HAS_NO_NVBF16
inline _LIBCUDACXX_INLINE_VISIBILITY
__nv_bfloat16 sinh(__nv_bfloat16 __v)
{
return __nv_bfloat16(_CUDA_VSTD::sinh(float(__v)));
}
#endif

inline _LIBCUDACXX_INLINE_VISIBILITY
__half cos(__half __v)
Expand Down Expand Up @@ -699,6 +705,7 @@ __half cos(__half __v)
)
}

#ifndef _LIBCUDACXX_HAS_NO_NVBF16
inline _LIBCUDACXX_INLINE_VISIBILITY
__nv_bfloat16 cos(__nv_bfloat16 __v)
{
Expand All @@ -707,18 +714,21 @@ __nv_bfloat16 cos(__nv_bfloat16 __v)
(return __nv_bfloat16(_CUDA_VSTD::cos(float(__v)));)
)
}
#endif

inline _LIBCUDACXX_INLINE_VISIBILITY
__half cosh(__half __v)
{
return __half(_CUDA_VSTD::cosh(float(__v)));
}

#ifndef _LIBCUDACXX_HAS_NO_NVBF16
inline _LIBCUDACXX_INLINE_VISIBILITY
__nv_bfloat16 cosh(__nv_bfloat16 __v)
{
return __nv_bfloat16(_CUDA_VSTD::cosh(float(__v)));
}
#endif

inline _LIBCUDACXX_INLINE_VISIBILITY
__half exp(__half __v)
Expand Down Expand Up @@ -746,6 +756,7 @@ __half exp(__half __v)
)
}

#ifndef _LIBCUDACXX_HAS_NO_NVBF16
inline _LIBCUDACXX_INLINE_VISIBILITY
__nv_bfloat16 exp(__nv_bfloat16 __v)
{
Expand All @@ -754,30 +765,35 @@ __nv_bfloat16 exp(__nv_bfloat16 __v)
(return __nv_bfloat16(_CUDA_VSTD::exp(float(__v)));)
)
}
#endif

inline _LIBCUDACXX_INLINE_VISIBILITY
bool signbit(__half __v)
{
return ::signbit(__half2float(__v));
}

#ifndef _LIBCUDACXX_HAS_NO_NVBF16
inline _LIBCUDACXX_INLINE_VISIBILITY
bool signbit(__nv_bfloat16 __v)
{
return ::signbit(__bfloat162float(__v));
}
#endif

inline _LIBCUDACXX_INLINE_VISIBILITY
__half atan2(__half __x, __half __y)
{
return __half(_CUDA_VSTD::atan2(float(__x), float(__y)));
}

#ifndef _LIBCUDACXX_HAS_NO_NVBF16
inline _LIBCUDACXX_INLINE_VISIBILITY
__nv_bfloat16 atan2(__nv_bfloat16 __x, __nv_bfloat16 __y)
{
return __nv_bfloat16(_CUDA_VSTD::atan2(float(__x), float(__y)));
}
#endif

inline _LIBCUDACXX_INLINE_VISIBILITY
__half log(__half __x)
Expand All @@ -804,6 +820,7 @@ __half log(__half __x)
)
}

#ifndef _LIBCUDACXX_HAS_NO_NVBF16
inline _LIBCUDACXX_INLINE_VISIBILITY
__nv_bfloat16 log(__nv_bfloat16 __x)
{
Expand All @@ -812,6 +829,7 @@ __nv_bfloat16 log(__nv_bfloat16 __x)
(return __nv_bfloat16(_CUDA_VSTD::log(float(__x)));)
)
}
#endif

inline _LIBCUDACXX_INLINE_VISIBILITY
__half sqrt(__half __x)
Expand All @@ -822,6 +840,7 @@ __half sqrt(__half __x)
)
}

#ifndef _LIBCUDACXX_HAS_NO_NVBF16
inline _LIBCUDACXX_INLINE_VISIBILITY
__nv_bfloat16 sqrt(__nv_bfloat16 __x)
{
Expand All @@ -831,6 +850,7 @@ __nv_bfloat16 sqrt(__nv_bfloat16 __x)
)
}
#endif
#endif

template <class _A1>
_LIBCUDACXX_INLINE_VISIBILITY
Expand Down Expand Up @@ -867,18 +887,22 @@ bool isnan(__half __v)
return __constexpr_isnan(__v);
}

#ifndef _LIBCUDACXX_HAS_NO_NVBF16
inline _LIBCUDACXX_INLINE_VISIBILITY
bool __constexpr_isnan(__nv_bfloat16 __x) noexcept
{
return __hisnan(__x);
}
#endif

#ifndef _LIBCUDACXX_HAS_NO_NVBF16
inline _LIBCUDACXX_INLINE_VISIBILITY
bool isnan(__nv_bfloat16 __v)
{
return __constexpr_isnan(__v);
}
#endif
#endif

template <class _A1>
_LIBCUDACXX_INLINE_VISIBILITY
Expand Down Expand Up @@ -914,6 +938,7 @@ bool __constexpr_isinf(__half __x) noexcept {
#endif
}

#ifndef _LIBCUDACXX_HAS_NO_NVBF16
inline _LIBCUDACXX_INLINE_VISIBILITY
bool __constexpr_isinf(__nv_bfloat16 __x) noexcept {
#if _LIBCUDACXX_STD_VER >= 20
Expand All @@ -924,31 +949,36 @@ bool __constexpr_isinf(__nv_bfloat16 __x) noexcept {
return __hisinf(__x) != 0;
#endif
}
#endif

inline _LIBCUDACXX_INLINE_VISIBILITY
bool isinf(__half __v)
{
return __constexpr_isinf(__v);
}

#ifndef _LIBCUDACXX_HAS_NO_NVBF16
inline _LIBCUDACXX_INLINE_VISIBILITY
bool isinf(__nv_bfloat16 __v)
{
return __constexpr_isinf(__v);
}
#endif

inline _LIBCUDACXX_INLINE_VISIBILITY
__half hypot(__half __x, __half __y)
{
return __half(_CUDA_VSTD::hypot(float(__x), float(__y)));
}

#ifndef _LIBCUDACXX_HAS_NO_NVBF16
inline _LIBCUDACXX_INLINE_VISIBILITY
__nv_bfloat16 hypot(__nv_bfloat16 __x, __nv_bfloat16 __y)
{
return __nv_bfloat16(_CUDA_VSTD::hypot(float(__x), float(__y)));
}
#endif
#endif

template <class _A1>
_LIBCUDACXX_INLINE_VISIBILITY
Expand Down Expand Up @@ -984,6 +1014,7 @@ bool isfinite(__half __v)
return __constexpr_isfinite(__v);
}

#ifndef _LIBCUDACXX_HAS_NO_NVBF16
inline _LIBCUDACXX_INLINE_VISIBILITY
bool __constexpr_isfinite(__nv_bfloat16 __x) noexcept {
return !__constexpr_isnan(__x) && !__constexpr_isinf(__x);
Expand All @@ -995,6 +1026,7 @@ bool isfinite(__nv_bfloat16 __v)
return __constexpr_isfinite(__v);
}
#endif
#endif

#if defined(_MSC_VER) || defined(__CUDACC_RTC__) || defined(_LIBCUDACXX_COMPILER_CLANG_CUDA)
template <class _A1>
Expand Down Expand Up @@ -1043,6 +1075,7 @@ __half copysign(__half __x, __half __y)
return __constexpr_copysign(__x, __y);
}

#ifndef _LIBCUDACXX_HAS_NO_NVBF16
inline _LIBCUDACXX_INLINE_VISIBILITY
__nv_bfloat16 __constexpr_copysign(__nv_bfloat16 __x, __nv_bfloat16 __y) noexcept
{
Expand All @@ -1055,6 +1088,7 @@ __nv_bfloat16 copysign(__nv_bfloat16 __x, __nv_bfloat16 __y)
return __constexpr_copysign(__x, __y);
}
#endif
#endif

#if defined(_MSC_VER) || defined(__CUDACC_RTC__) || defined(_LIBCUDACXX_COMPILER_CLANG_CUDA)
template <class _A1>
Expand Down Expand Up @@ -1105,6 +1139,7 @@ __half abs(__half __x)
return __constexpr_fabs(__x);
}

#ifndef _LIBCUDACXX_HAS_NO_NVBF16
inline _LIBCUDACXX_INLINE_VISIBILITY
__nv_bfloat16 __constexpr_fabs(__nv_bfloat16 __x) noexcept
{
Expand All @@ -1123,6 +1158,7 @@ __nv_bfloat16 abs(__nv_bfloat16 __x)
return __constexpr_fabs(__x);
}
#endif
#endif

#if defined(_MSC_VER) || defined(__CUDACC_RTC__) || defined(_LIBCUDACXX_COMPILER_CLANG_CUDA)
template <class _A1>
Expand Down Expand Up @@ -1190,12 +1226,14 @@ __half __constexpr_fmax(__half __x, __half __y) noexcept
return __hmax(__x, __y);
}

#ifndef _LIBCUDACXX_HAS_NO_NVBF16
inline _LIBCUDACXX_INLINE_VISIBILITY
__nv_bfloat16 __constexpr_fmax(__nv_bfloat16 __x, __nv_bfloat16 __y) noexcept
{
return __hmax(__x, __y);
}
#endif
#endif

#if defined(_MSC_VER) || defined(__CUDACC_RTC__) || defined(_LIBCUDACXX_COMPILER_CLANG_CUDA)
template <class _A1>
Expand Down
Loading

0 comments on commit c2d87c2

Please sign in to comment.