Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backport PR #2046 - Fixing FP16 conversions. #2222

Merged
merged 2 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 76 additions & 11 deletions libcudacxx/include/cuda/std/__complex/nvbf16.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,39 @@ struct __libcpp_complex_overload_traits<__nv_bfloat16, false, false>
typedef complex<__nv_bfloat16> _ComplexType;
};

// This is a workaround against the user defining macros __CUDA_NO_BFLOAT16_CONVERSIONS__ __CUDA_NO_BFLOAT16_OPERATORS__
template <>
struct __complex_can_implicitly_construct<__nv_bfloat16, float> : true_type
{};

template <>
struct __complex_can_implicitly_construct<__nv_bfloat16, double> : true_type
{};

template <>
struct __complex_can_implicitly_construct<float, __nv_bfloat16> : true_type
{};

template <>
struct __complex_can_implicitly_construct<double, __nv_bfloat16> : true_type
{};

template <class _Tp>
inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 __convert_to_bfloat16(const _Tp& __value) noexcept
{
return __value;
}

inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 __convert_to_bfloat16(const float& __value) noexcept
{
return __float2bfloat16(__value);
}

inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 __convert_to_bfloat16(const double& __value) noexcept
{
return __double2bfloat16(__value);
}

template <>
class _LIBCUDACXX_TEMPLATE_VIS _CCCL_ALIGNAS(alignof(__nv_bfloat162)) complex<__nv_bfloat16>
{
Expand All @@ -80,14 +113,14 @@ class _LIBCUDACXX_TEMPLATE_VIS _CCCL_ALIGNAS(alignof(__nv_bfloat162)) complex<__

template <class _Up, __enable_if_t<__complex_can_implicitly_construct<value_type, _Up>::value, int> = 0>
_LIBCUDACXX_INLINE_VISIBILITY complex(const complex<_Up>& __c)
: __repr_(static_cast<value_type>(__c.real()), static_cast<value_type>(__c.imag()))
: __repr_(__convert_to_bfloat16(__c.real()), __convert_to_bfloat16(__c.imag()))
{}

template <class _Up,
__enable_if_t<!__complex_can_implicitly_construct<value_type, _Up>::value, int> = 0,
__enable_if_t<_CCCL_TRAIT(is_constructible, value_type, _Up), int> = 0>
_LIBCUDACXX_INLINE_VISIBILITY explicit complex(const complex<_Up>& __c)
: __repr_(static_cast<value_type>(__c.real()), static_cast<value_type>(__c.imag()))
: __repr_(__convert_to_bfloat16(__c.real()), __convert_to_bfloat16(__c.imag()))
{}

_LIBCUDACXX_INLINE_VISIBILITY complex& operator=(const value_type& __re)
Expand All @@ -100,8 +133,8 @@ class _LIBCUDACXX_TEMPLATE_VIS _CCCL_ALIGNAS(alignof(__nv_bfloat162)) complex<__
template <class _Up>
_LIBCUDACXX_INLINE_VISIBILITY complex& operator=(const complex<_Up>& __c)
{
__repr_.x = __c.real();
__repr_.y = __c.imag();
__repr_.x = __convert_to_bfloat16(__c.real());
__repr_.y = __convert_to_bfloat16(__c.imag());
return *this;
}

Expand Down Expand Up @@ -155,24 +188,24 @@ class _LIBCUDACXX_TEMPLATE_VIS _CCCL_ALIGNAS(alignof(__nv_bfloat162)) complex<__

_LIBCUDACXX_INLINE_VISIBILITY complex& operator+=(const value_type& __re)
{
__repr_.x += __re;
__repr_.x = __hadd(__repr_.x, __re);
return *this;
}
_LIBCUDACXX_INLINE_VISIBILITY complex& operator-=(const value_type& __re)
{
__repr_.x -= __re;
__repr_.x = __hsub(__repr_.x, __re);
return *this;
}
_LIBCUDACXX_INLINE_VISIBILITY complex& operator*=(const value_type& __re)
{
__repr_.x *= __re;
__repr_.y *= __re;
__repr_.x = __hmul(__repr_.x, __re);
__repr_.y = __hmul(__repr_.y, __re);
return *this;
}
_LIBCUDACXX_INLINE_VISIBILITY complex& operator/=(const value_type& __re)
{
__repr_.x /= __re;
__repr_.y /= __re;
__repr_.x = __hdiv(__repr_.x, __re);
__repr_.y = __hdiv(__repr_.y, __re);
return *this;
}

Expand All @@ -195,9 +228,41 @@ class _LIBCUDACXX_TEMPLATE_VIS _CCCL_ALIGNAS(alignof(__nv_bfloat162)) complex<__
}
};

template <> // complex<float>
template <> // complex<__half>
inline _LIBCUDACXX_INLINE_VISIBILITY complex<float>::complex(const complex<__nv_bfloat16>& __c)
: __re_(__bfloat162float(__c.real()))
, __im_(__bfloat162float(__c.imag()))
{}

template <> // complex<double>
template <> // complex<__half>
inline _LIBCUDACXX_INLINE_VISIBILITY complex<double>::complex(const complex<__nv_bfloat16>& __c)
: __re_(__bfloat162float(__c.real()))
, __im_(__bfloat162float(__c.imag()))
{}

template <> // complex<float>
template <> // complex<__nv_bfloat16>
inline _LIBCUDACXX_INLINE_VISIBILITY complex<float>& complex<float>::operator=(const complex<__nv_bfloat16>& __c)
{
__re_ = __bfloat162float(__c.real());
__im_ = __bfloat162float(__c.imag());
return *this;
}

template <> // complex<double>
template <> // complex<__nv_bfloat16>
inline _LIBCUDACXX_INLINE_VISIBILITY complex<double>& complex<double>::operator=(const complex<__nv_bfloat16>& __c)
{
__re_ = __bfloat162float(__c.real());
__im_ = __bfloat162float(__c.imag());
return *this;
}

inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 arg(__nv_bfloat16 __re)
{
return _CUDA_VSTD::atan2f(__nv_bfloat16(0), __re);
return _CUDA_VSTD::atan2(__int2bfloat16_rn(0), __re);
}

// We have performance issues with some trigonometric functions with __nv_bfloat16
Expand Down
87 changes: 76 additions & 11 deletions libcudacxx/include/cuda/std/__complex/nvfp16.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,39 @@ struct __libcpp_complex_overload_traits<__half, false, false>
typedef complex<__half> _ComplexType;
};

// This is a workaround against the user defining macros __CUDA_NO_HALF_CONVERSIONS__ __CUDA_NO_HALF_OPERATORS__
template <>
struct __complex_can_implicitly_construct<__half, float> : true_type
{};

template <>
struct __complex_can_implicitly_construct<__half, double> : true_type
{};

template <>
struct __complex_can_implicitly_construct<float, __half> : true_type
{};

template <>
struct __complex_can_implicitly_construct<double, __half> : true_type
{};

template <class _Tp>
inline _LIBCUDACXX_INLINE_VISIBILITY __half __convert_to_half(const _Tp& __value) noexcept
{
return __value;
}

inline _LIBCUDACXX_INLINE_VISIBILITY __half __convert_to_half(const float& __value) noexcept
{
return __float2half(__value);
}

inline _LIBCUDACXX_INLINE_VISIBILITY __half __convert_to_half(const double& __value) noexcept
{
return __double2half(__value);
}

template <>
class _LIBCUDACXX_TEMPLATE_VIS _CCCL_ALIGNAS(alignof(__half2)) complex<__half>
{
Expand All @@ -77,14 +110,14 @@ class _LIBCUDACXX_TEMPLATE_VIS _CCCL_ALIGNAS(alignof(__half2)) complex<__half>

template <class _Up, __enable_if_t<__complex_can_implicitly_construct<value_type, _Up>::value, int> = 0>
_LIBCUDACXX_INLINE_VISIBILITY complex(const complex<_Up>& __c)
: __repr_(static_cast<value_type>(__c.real()), static_cast<value_type>(__c.imag()))
: __repr_(__convert_to_half(__c.real()), __convert_to_half(__c.imag()))
{}

template <class _Up,
__enable_if_t<!__complex_can_implicitly_construct<value_type, _Up>::value, int> = 0,
__enable_if_t<_CCCL_TRAIT(is_constructible, value_type, _Up), int> = 0>
_LIBCUDACXX_INLINE_VISIBILITY explicit complex(const complex<_Up>& __c)
: __repr_(static_cast<value_type>(__c.real()), static_cast<value_type>(__c.imag()))
: __repr_(__convert_to_half(__c.real()), __convert_to_half(__c.imag()))
{}

_LIBCUDACXX_INLINE_VISIBILITY complex& operator=(const value_type& __re)
Expand All @@ -97,8 +130,8 @@ class _LIBCUDACXX_TEMPLATE_VIS _CCCL_ALIGNAS(alignof(__half2)) complex<__half>
template <class _Up>
_LIBCUDACXX_INLINE_VISIBILITY complex& operator=(const complex<_Up>& __c)
{
__repr_.x = __c.real();
__repr_.y = __c.imag();
__repr_.x = __convert_to_half(__c.real());
__repr_.y = __convert_to_half(__c.imag());
return *this;
}

Expand Down Expand Up @@ -152,24 +185,24 @@ class _LIBCUDACXX_TEMPLATE_VIS _CCCL_ALIGNAS(alignof(__half2)) complex<__half>

_LIBCUDACXX_INLINE_VISIBILITY complex& operator+=(const value_type& __re)
{
__repr_.x += __re;
__repr_.x = __hadd(__repr_.x, __re);
return *this;
}
_LIBCUDACXX_INLINE_VISIBILITY complex& operator-=(const value_type& __re)
{
__repr_.x -= __re;
__repr_.x = __hsub(__repr_.x, __re);
return *this;
}
_LIBCUDACXX_INLINE_VISIBILITY complex& operator*=(const value_type& __re)
{
__repr_.x *= __re;
__repr_.y *= __re;
__repr_.x = __hmul(__repr_.x, __re);
__repr_.y = __hmul(__repr_.y, __re);
return *this;
}
_LIBCUDACXX_INLINE_VISIBILITY complex& operator/=(const value_type& __re)
{
__repr_.x /= __re;
__repr_.y /= __re;
__repr_.x = __hdiv(__repr_.x, __re);
__repr_.y = __hdiv(__repr_.y, __re);
return *this;
}

Expand All @@ -192,9 +225,41 @@ class _LIBCUDACXX_TEMPLATE_VIS _CCCL_ALIGNAS(alignof(__half2)) complex<__half>
}
};

template <> // complex<float>
template <> // complex<__half>
inline _LIBCUDACXX_INLINE_VISIBILITY complex<float>::complex(const complex<__half>& __c)
: __re_(__half2float(__c.real()))
, __im_(__half2float(__c.imag()))
{}

template <> // complex<double>
template <> // complex<__half>
inline _LIBCUDACXX_INLINE_VISIBILITY complex<double>::complex(const complex<__half>& __c)
: __re_(__half2float(__c.real()))
, __im_(__half2float(__c.imag()))
{}

template <> // complex<float>
template <> // complex<__half>
inline _LIBCUDACXX_INLINE_VISIBILITY complex<float>& complex<float>::operator=(const complex<__half>& __c)
{
__re_ = __half2float(__c.real());
__im_ = __half2float(__c.imag());
return *this;
}

template <> // complex<double>
template <> // complex<__half>
inline _LIBCUDACXX_INLINE_VISIBILITY complex<double>& complex<double>::operator=(const complex<__half>& __c)
{
__re_ = __half2float(__c.real());
__im_ = __half2float(__c.imag());
return *this;
}

inline _LIBCUDACXX_INLINE_VISIBILITY __half arg(__half __re)
{
return _CUDA_VSTD::atan2f(__half(0), __re);
return _CUDA_VSTD::atan2(__int2half_rn(0), __re);
}

// We have performance issues with some trigonometric functions with __half
Expand Down
20 changes: 10 additions & 10 deletions libcudacxx/include/cuda/std/__cuda/cmath_nvbf16.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,47 +37,47 @@ _LIBCUDACXX_BEGIN_NAMESPACE_STD
// trigonometric functions
inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 sin(__nv_bfloat16 __v)
{
NV_IF_ELSE_TARGET(NV_IS_DEVICE, (return ::hsin(__v);), (return __nv_bfloat16(::sin(float(__v)));))
NV_IF_ELSE_TARGET(NV_IS_DEVICE, (return ::hsin(__v);), (return __float2bfloat16(::sin(__bfloat162float(__v)));))
}

inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 sinh(__nv_bfloat16 __v)
{
return __nv_bfloat16(::sinh(float(__v)));
return __float2bfloat16(::sinh(__bfloat162float(__v)));
}

inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 cos(__nv_bfloat16 __v)
{
NV_IF_ELSE_TARGET(NV_IS_DEVICE, (return ::hcos(__v);), (return __nv_bfloat16(::cos(float(__v)));))
NV_IF_ELSE_TARGET(NV_IS_DEVICE, (return ::hcos(__v);), (return __float2bfloat16(::cos(__bfloat162float(__v)));))
}

inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 cosh(__nv_bfloat16 __v)
{
return __nv_bfloat16(::cosh(float(__v)));
return __float2bfloat16(::cosh(__bfloat162float(__v)));
}

inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 exp(__nv_bfloat16 __v)
{
NV_IF_ELSE_TARGET(NV_IS_DEVICE, (return ::hexp(__v);), (return __nv_bfloat16(::exp(float(__v)));))
NV_IF_ELSE_TARGET(NV_IS_DEVICE, (return ::hexp(__v);), (return __float2bfloat16(::exp(__bfloat162float(__v)));))
}

inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 hypot(__nv_bfloat16 __x, __nv_bfloat16 __y)
{
return __nv_bfloat16(::hypot(float(__x), float(__y)));
return __float2bfloat16(::hypot(__bfloat162float(__x), __bfloat162float(__y)));
}

inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 atan2(__nv_bfloat16 __x, __nv_bfloat16 __y)
{
return __nv_bfloat16(::atan2(float(__x), float(__y)));
return __float2bfloat16(::atan2(__bfloat162float(__x), __bfloat162float(__y)));
}

inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 log(__nv_bfloat16 __x)
{
NV_IF_ELSE_TARGET(NV_IS_DEVICE, (return ::hlog(__x);), (return __nv_bfloat16(::log(float(__x)));))
NV_IF_ELSE_TARGET(NV_IS_DEVICE, (return ::hlog(__x);), (return __float2bfloat16(::log(__bfloat162float(__x)));))
}

inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 sqrt(__nv_bfloat16 __x)
{
NV_IF_ELSE_TARGET(NV_IS_DEVICE, (return ::hsqrt(__x);), (return __nv_bfloat16(::sqrt(float(__x)));))
NV_IF_ELSE_TARGET(NV_IS_DEVICE, (return ::hsqrt(__x);), (return __float2bfloat16(::sqrt(__bfloat162float(__x)));))
}

// floating point helper
Expand Down Expand Up @@ -123,7 +123,7 @@ inline _LIBCUDACXX_INLINE_VISIBILITY bool isfinite(__nv_bfloat16 __v)

inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 __constexpr_copysign(__nv_bfloat16 __x, __nv_bfloat16 __y) noexcept
{
return __nv_bfloat16(::copysignf(float(__x), float(__y)));
return __float2bfloat16(::copysignf(__bfloat162float(__x), __bfloat162float(__y)));
}

inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 copysign(__nv_bfloat16 __x, __nv_bfloat16 __y)
Expand Down
Loading
Loading