Skip to content

Commit

Permalink
Fix <complex> failing upon inclusion when bad macros are defined
Browse files Browse the repository at this point in the history
  • Loading branch information
wmaxey committed Aug 13, 2024
1 parent 57cc989 commit 6173274
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 14 deletions.
23 changes: 16 additions & 7 deletions libcudacxx/include/cuda/std/__complex/nvbf16.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,15 @@ class _LIBCUDACXX_TEMPLATE_VIS _CCCL_ALIGNAS(alignof(__nv_bfloat162)) complex<__
: __repr_(static_cast<value_type>(__c.real()), static_cast<value_type>(__c.imag()))
{}

template <class _Up,
__enable_if_t<!__complex_can_implicitly_construct<value_type, _Up>::value, int> = 0,
__enable_if_t<!_CCCL_TRAIT(is_constructible, value_type, _Up)
&& (_CCCL_TRAIT(is_same, _Up, float) || _CCCL_TRAIT(is_same, _Up, double)),
int> = 0>
_LIBCUDACXX_INLINE_VISIBILITY explicit complex(const complex<_Up>& __c)
: __repr_(__float2bfloat16(__c.real()), __float2bfloat16(__c.imag()))
{}

_LIBCUDACXX_INLINE_VISIBILITY complex& operator=(const value_type& __re)
{
__repr_.x = __re;
Expand Down Expand Up @@ -155,24 +164,24 @@ class _LIBCUDACXX_TEMPLATE_VIS _CCCL_ALIGNAS(alignof(__nv_bfloat162)) complex<__

_LIBCUDACXX_INLINE_VISIBILITY complex& operator+=(const value_type& __re)
{
__repr_.x += __re;
__repr_.x = __hadd(__repr_.x, __re);
return *this;
}
_LIBCUDACXX_INLINE_VISIBILITY complex& operator-=(const value_type& __re)
{
__repr_.x -= __re;
__repr_.x = __hsub(__repr_.x, __re);
return *this;
}
_LIBCUDACXX_INLINE_VISIBILITY complex& operator*=(const value_type& __re)
{
__repr_.x *= __re;
__repr_.y *= __re;
__repr_.x = __hmul(__repr_.x, __re);
__repr_.y = __hmul(__repr_.y, __re);
return *this;
}
_LIBCUDACXX_INLINE_VISIBILITY complex& operator/=(const value_type& __re)
{
__repr_.x /= __re;
__repr_.y /= __re;
__repr_.x = __hdiv(__repr_.x, __re);
__repr_.y = __hdiv(__repr_.y, __re);
return *this;
}

Expand All @@ -197,7 +206,7 @@ class _LIBCUDACXX_TEMPLATE_VIS _CCCL_ALIGNAS(alignof(__nv_bfloat162)) complex<__

inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 arg(__nv_bfloat16 __re)
{
return _CUDA_VSTD::atan2f(__nv_bfloat16(0), __re);
return _CUDA_VSTD::atan2(__int2bfloat16_rn(0), __re);
}

// We have performance issues with some trigonometric functions with __nv_bfloat16
Expand Down
23 changes: 16 additions & 7 deletions libcudacxx/include/cuda/std/__complex/nvfp16.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,15 @@ class _LIBCUDACXX_TEMPLATE_VIS _CCCL_ALIGNAS(alignof(__half2)) complex<__half>
: __repr_(static_cast<value_type>(__c.real()), static_cast<value_type>(__c.imag()))
{}

template <class _Up,
__enable_if_t<!__complex_can_implicitly_construct<value_type, _Up>::value, int> = 0,
__enable_if_t<!_CCCL_TRAIT(is_constructible, value_type, _Up)
&& (_CCCL_TRAIT(is_same, _Up, float) || _CCCL_TRAIT(is_same, _Up, double)),
int> = 0>
_LIBCUDACXX_INLINE_VISIBILITY explicit complex(const complex<_Up>& __c)
: __repr_(__float2half(__c.real()), __float2half(__c.imag()))
{}

_LIBCUDACXX_INLINE_VISIBILITY complex& operator=(const value_type& __re)
{
__repr_.x = __re;
Expand Down Expand Up @@ -152,24 +161,24 @@ class _LIBCUDACXX_TEMPLATE_VIS _CCCL_ALIGNAS(alignof(__half2)) complex<__half>

_LIBCUDACXX_INLINE_VISIBILITY complex& operator+=(const value_type& __re)
{
__repr_.x += __re;
__repr_.x = __hadd(__repr_.x, __re);
return *this;
}
_LIBCUDACXX_INLINE_VISIBILITY complex& operator-=(const value_type& __re)
{
__repr_.x -= __re;
__repr_.x = __hsub(__repr_.x, __re);
return *this;
}
_LIBCUDACXX_INLINE_VISIBILITY complex& operator*=(const value_type& __re)
{
__repr_.x *= __re;
__repr_.y *= __re;
__repr_.x = __hmul(__repr_.x, __re);
__repr_.y = __hmul(__repr_.y, __re);
return *this;
}
_LIBCUDACXX_INLINE_VISIBILITY complex& operator/=(const value_type& __re)
{
__repr_.x /= __re;
__repr_.y /= __re;
__repr_.x = __hdiv(__repr_.x, __re);
__repr_.y = __hdiv(__repr_.y, __re);
return *this;
}

Expand All @@ -194,7 +203,7 @@ class _LIBCUDACXX_TEMPLATE_VIS _CCCL_ALIGNAS(alignof(__half2)) complex<__half>

inline _LIBCUDACXX_INLINE_VISIBILITY __half arg(__half __re)
{
return _CUDA_VSTD::atan2f(__half(0), __re);
return _CUDA_VSTD::atan2(__int2half_rn(0), __re);
}

// We have performance issues with some trigonometric functions with __half
Expand Down
18 changes: 18 additions & 0 deletions libcudacxx/include/cuda/std/detail/libcxx/include/complex
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,24 @@ public:
, __im_(static_cast<_Tp>(__c.imag()))
{}

#ifdef _LIBCUDACXX_HAS_NVFP16
template <class _Up,
__enable_if_t<!_CCCL_TRAIT(is_constructible, _Tp, _Up) && _CCCL_TRAIT(is_same, _Up, __half), int> = 0>
_LIBCUDACXX_INLINE_VISIBILITY explicit constexpr complex(const complex<_Up>& __c)
: __re_(__half2float(__c.real()))
, __im_(__half2float(__c.imag()))
{}
#endif // _LIBCUDACXX_HAS_NVFP16

#ifdef _LIBCUDACXX_HAS_NVBF16
template <class _Up,
__enable_if_t<!_CCCL_TRAIT(is_constructible, _Tp, _Up) && _CCCL_TRAIT(is_same, _Up, __nv_bfloat16), int> = 0>
_LIBCUDACXX_INLINE_VISIBILITY explicit constexpr complex(const complex<_Up>& __c)
: __re_(__bfloat162float(__c.real()))
, __im_(__bfloat162float(__c.imag()))
{}
#endif // _LIBCUDACXX_HAS_NVBF16

_LIBCUDACXX_INLINE_VISIBILITY _CCCL_CONSTEXPR_CXX14 complex& operator=(const value_type& __re)
{
__re_ = __re;
Expand Down

0 comments on commit 6173274

Please sign in to comment.