Skip to content

Commit

Permalink
Fix including <complex> when bad CUDA bfloat/half macros are used. (#…
Browse files Browse the repository at this point in the history
…2226)

* Add <complex> test for bad macros being defined

* Fix <complex> failing upon inclusion when bad macros are defined

* Rather use explicit specializations and some evil hackery to get the complex interop to work

* Fix typos

* Inline everything

* Move workarounds together

* Use conversion functions instead of explicit specializations

* Drop unneeded conversions

---------

Co-authored-by: Michael Schellenberger Costa <[email protected]>
  • Loading branch information
wmaxey and miscco committed Aug 13, 2024
1 parent bd3d726 commit 886a290
Show file tree
Hide file tree
Showing 3 changed files with 203 additions and 22 deletions.
87 changes: 76 additions & 11 deletions libcudacxx/include/cuda/std/__complex/nvbf16.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,39 @@ struct __libcpp_complex_overload_traits<__nv_bfloat16, false, false>
typedef complex<__nv_bfloat16> _ComplexType;
};

// This is a workaround against the user defining macros __CUDA_NO_BFLOAT16_CONVERSIONS__ __CUDA_NO_BFLOAT16_OPERATORS__
template <>
struct __complex_can_implicitly_construct<__nv_bfloat16, float> : true_type
{};

template <>
struct __complex_can_implicitly_construct<__nv_bfloat16, double> : true_type
{};

template <>
struct __complex_can_implicitly_construct<float, __nv_bfloat16> : true_type
{};

template <>
struct __complex_can_implicitly_construct<double, __nv_bfloat16> : true_type
{};

template <class _Tp>
inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 __convert_to_bfloat16(const _Tp& __value) noexcept
{
return __value;
}

inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 __convert_to_bfloat16(const float& __value) noexcept
{
return __float2bfloat16(__value);
}

inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 __convert_to_bfloat16(const double& __value) noexcept
{
return __double2bfloat16(__value);
}

template <>
class _LIBCUDACXX_TEMPLATE_VIS _CCCL_ALIGNAS(alignof(__nv_bfloat162)) complex<__nv_bfloat16>
{
Expand All @@ -80,14 +113,14 @@ class _LIBCUDACXX_TEMPLATE_VIS _CCCL_ALIGNAS(alignof(__nv_bfloat162)) complex<__

template <class _Up, __enable_if_t<__complex_can_implicitly_construct<value_type, _Up>::value, int> = 0>
_LIBCUDACXX_INLINE_VISIBILITY complex(const complex<_Up>& __c)
: __repr_(static_cast<value_type>(__c.real()), static_cast<value_type>(__c.imag()))
: __repr_(__convert_to_bfloat16(__c.real()), __convert_to_bfloat16(__c.imag()))
{}

template <class _Up,
__enable_if_t<!__complex_can_implicitly_construct<value_type, _Up>::value, int> = 0,
__enable_if_t<_CCCL_TRAIT(is_constructible, value_type, _Up), int> = 0>
_LIBCUDACXX_INLINE_VISIBILITY explicit complex(const complex<_Up>& __c)
: __repr_(static_cast<value_type>(__c.real()), static_cast<value_type>(__c.imag()))
: __repr_(__convert_to_bfloat16(__c.real()), __convert_to_bfloat16(__c.imag()))
{}

_LIBCUDACXX_INLINE_VISIBILITY complex& operator=(const value_type& __re)
Expand All @@ -100,8 +133,8 @@ class _LIBCUDACXX_TEMPLATE_VIS _CCCL_ALIGNAS(alignof(__nv_bfloat162)) complex<__
template <class _Up>
_LIBCUDACXX_INLINE_VISIBILITY complex& operator=(const complex<_Up>& __c)
{
__repr_.x = __c.real();
__repr_.y = __c.imag();
__repr_.x = __convert_to_bfloat16(__c.real());
__repr_.y = __convert_to_bfloat16(__c.imag());
return *this;
}

Expand Down Expand Up @@ -155,24 +188,24 @@ class _LIBCUDACXX_TEMPLATE_VIS _CCCL_ALIGNAS(alignof(__nv_bfloat162)) complex<__

_LIBCUDACXX_INLINE_VISIBILITY complex& operator+=(const value_type& __re)
{
__repr_.x += __re;
__repr_.x = __hadd(__repr_.x, __re);
return *this;
}
_LIBCUDACXX_INLINE_VISIBILITY complex& operator-=(const value_type& __re)
{
__repr_.x -= __re;
__repr_.x = __hsub(__repr_.x, __re);
return *this;
}
_LIBCUDACXX_INLINE_VISIBILITY complex& operator*=(const value_type& __re)
{
__repr_.x *= __re;
__repr_.y *= __re;
__repr_.x = __hmul(__repr_.x, __re);
__repr_.y = __hmul(__repr_.y, __re);
return *this;
}
_LIBCUDACXX_INLINE_VISIBILITY complex& operator/=(const value_type& __re)
{
__repr_.x /= __re;
__repr_.y /= __re;
__repr_.x = __hdiv(__repr_.x, __re);
__repr_.y = __hdiv(__repr_.y, __re);
return *this;
}

Expand All @@ -195,9 +228,41 @@ class _LIBCUDACXX_TEMPLATE_VIS _CCCL_ALIGNAS(alignof(__nv_bfloat162)) complex<__
}
};

template <> // complex<float>
template <> // complex<__half>
inline _LIBCUDACXX_INLINE_VISIBILITY complex<float>::complex(const complex<__nv_bfloat16>& __c)
: __re_(__bfloat162float(__c.real()))
, __im_(__bfloat162float(__c.imag()))
{}

template <> // complex<double>
template <> // complex<__half>
inline _LIBCUDACXX_INLINE_VISIBILITY complex<double>::complex(const complex<__nv_bfloat16>& __c)
: __re_(__bfloat162float(__c.real()))
, __im_(__bfloat162float(__c.imag()))
{}

template <> // complex<float>
template <> // complex<__nv_bfloat16>
inline _LIBCUDACXX_INLINE_VISIBILITY complex<float>& complex<float>::operator=(const complex<__nv_bfloat16>& __c)
{
__re_ = __bfloat162float(__c.real());
__im_ = __bfloat162float(__c.imag());
return *this;
}

template <> // complex<double>
template <> // complex<__nv_bfloat16>
inline _LIBCUDACXX_INLINE_VISIBILITY complex<double>& complex<double>::operator=(const complex<__nv_bfloat16>& __c)
{
__re_ = __bfloat162float(__c.real());
__im_ = __bfloat162float(__c.imag());
return *this;
}

inline _LIBCUDACXX_INLINE_VISIBILITY __nv_bfloat16 arg(__nv_bfloat16 __re)
{
return _CUDA_VSTD::atan2f(__nv_bfloat16(0), __re);
return _CUDA_VSTD::atan2(__int2bfloat16_rn(0), __re);
}

// We have performance issues with some trigonometric functions with __nv_bfloat16
Expand Down
87 changes: 76 additions & 11 deletions libcudacxx/include/cuda/std/__complex/nvfp16.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,39 @@ struct __libcpp_complex_overload_traits<__half, false, false>
typedef complex<__half> _ComplexType;
};

// This is a workaround against the user defining macros __CUDA_NO_HALF_CONVERSIONS__ __CUDA_NO_HALF_OPERATORS__
template <>
struct __complex_can_implicitly_construct<__half, float> : true_type
{};

template <>
struct __complex_can_implicitly_construct<__half, double> : true_type
{};

template <>
struct __complex_can_implicitly_construct<float, __half> : true_type
{};

template <>
struct __complex_can_implicitly_construct<double, __half> : true_type
{};

template <class _Tp>
inline _LIBCUDACXX_INLINE_VISIBILITY __half __convert_to_half(const _Tp& __value) noexcept
{
return __value;
}

inline _LIBCUDACXX_INLINE_VISIBILITY __half __convert_to_half(const float& __value) noexcept
{
return __float2half(__value);
}

inline _LIBCUDACXX_INLINE_VISIBILITY __half __convert_to_half(const double& __value) noexcept
{
return __double2half(__value);
}

template <>
class _LIBCUDACXX_TEMPLATE_VIS _CCCL_ALIGNAS(alignof(__half2)) complex<__half>
{
Expand All @@ -77,14 +110,14 @@ class _LIBCUDACXX_TEMPLATE_VIS _CCCL_ALIGNAS(alignof(__half2)) complex<__half>

template <class _Up, __enable_if_t<__complex_can_implicitly_construct<value_type, _Up>::value, int> = 0>
_LIBCUDACXX_INLINE_VISIBILITY complex(const complex<_Up>& __c)
: __repr_(static_cast<value_type>(__c.real()), static_cast<value_type>(__c.imag()))
: __repr_(__convert_to_half(__c.real()), __convert_to_half(__c.imag()))
{}

template <class _Up,
__enable_if_t<!__complex_can_implicitly_construct<value_type, _Up>::value, int> = 0,
__enable_if_t<_CCCL_TRAIT(is_constructible, value_type, _Up), int> = 0>
_LIBCUDACXX_INLINE_VISIBILITY explicit complex(const complex<_Up>& __c)
: __repr_(static_cast<value_type>(__c.real()), static_cast<value_type>(__c.imag()))
: __repr_(__convert_to_half(__c.real()), __convert_to_half(__c.imag()))
{}

_LIBCUDACXX_INLINE_VISIBILITY complex& operator=(const value_type& __re)
Expand All @@ -97,8 +130,8 @@ class _LIBCUDACXX_TEMPLATE_VIS _CCCL_ALIGNAS(alignof(__half2)) complex<__half>
template <class _Up>
_LIBCUDACXX_INLINE_VISIBILITY complex& operator=(const complex<_Up>& __c)
{
__repr_.x = __c.real();
__repr_.y = __c.imag();
__repr_.x = __convert_to_half(__c.real());
__repr_.y = __convert_to_half(__c.imag());
return *this;
}

Expand Down Expand Up @@ -152,24 +185,24 @@ class _LIBCUDACXX_TEMPLATE_VIS _CCCL_ALIGNAS(alignof(__half2)) complex<__half>

_LIBCUDACXX_INLINE_VISIBILITY complex& operator+=(const value_type& __re)
{
__repr_.x += __re;
__repr_.x = __hadd(__repr_.x, __re);
return *this;
}
_LIBCUDACXX_INLINE_VISIBILITY complex& operator-=(const value_type& __re)
{
__repr_.x -= __re;
__repr_.x = __hsub(__repr_.x, __re);
return *this;
}
_LIBCUDACXX_INLINE_VISIBILITY complex& operator*=(const value_type& __re)
{
__repr_.x *= __re;
__repr_.y *= __re;
__repr_.x = __hmul(__repr_.x, __re);
__repr_.y = __hmul(__repr_.y, __re);
return *this;
}
_LIBCUDACXX_INLINE_VISIBILITY complex& operator/=(const value_type& __re)
{
__repr_.x /= __re;
__repr_.y /= __re;
__repr_.x = __hdiv(__repr_.x, __re);
__repr_.y = __hdiv(__repr_.y, __re);
return *this;
}

Expand All @@ -192,9 +225,41 @@ class _LIBCUDACXX_TEMPLATE_VIS _CCCL_ALIGNAS(alignof(__half2)) complex<__half>
}
};

template <> // complex<float>
template <> // complex<__half>
inline _LIBCUDACXX_INLINE_VISIBILITY complex<float>::complex(const complex<__half>& __c)
: __re_(__half2float(__c.real()))
, __im_(__half2float(__c.imag()))
{}

template <> // complex<double>
template <> // complex<__half>
inline _LIBCUDACXX_INLINE_VISIBILITY complex<double>::complex(const complex<__half>& __c)
: __re_(__half2float(__c.real()))
, __im_(__half2float(__c.imag()))
{}

template <> // complex<float>
template <> // complex<__half>
inline _LIBCUDACXX_INLINE_VISIBILITY complex<float>& complex<float>::operator=(const complex<__half>& __c)
{
__re_ = __half2float(__c.real());
__im_ = __half2float(__c.imag());
return *this;
}

template <> // complex<double>
template <> // complex<__half>
inline _LIBCUDACXX_INLINE_VISIBILITY complex<double>& complex<double>::operator=(const complex<__half>& __c)
{
__re_ = __half2float(__c.real());
__im_ = __half2float(__c.imag());
return *this;
}

inline _LIBCUDACXX_INLINE_VISIBILITY __half arg(__half __re)
{
return _CUDA_VSTD::atan2f(__half(0), __re);
return _CUDA_VSTD::atan2(__int2half_rn(0), __re);
}

// We have performance issues with some trigonometric functions with __half
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
//===----------------------------------------------------------------------===//
//
// Part of the libcu++ Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES.
//
//===----------------------------------------------------------------------===//

#define __CUDA_NO_HALF_CONVERSIONS__ 1
#define __CUDA_NO_HALF_OPERATORS__ 1
#define __CUDA_NO_BFLOAT16_CONVERSIONS__ 1
#define __CUDA_NO_BFLOAT16_OPERATORS__ 1
#define __CUDA_NO_HALF2_OPERATORS__ 1
#define __CUDA_NO_BFLOAT162_OPERATORS__ 1

#include <cuda/std/cassert>
#include <cuda/std/complex>

#include "test_macros.h"

template <class T, class U>
__host__ __device__ void test_assignment(cuda::std::complex<U> v = {})
{
cuda::std::complex<T> converting(v);

cuda::std::complex<T> assigning{};
assigning = v;
}

__host__ __device__ void test()
{
#ifdef _LIBCUDACXX_HAS_NVFP16
test_assignment<__half, float>();
test_assignment<__half, double>();
test_assignment<float, __half>();
test_assignment<double, __half>();
#endif // _LIBCUDACXX_HAS_NVFP16
#ifdef _LIBCUDACXX_HAS_NVBF16
test_assignment<__nv_bfloat16, float>();
test_assignment<__nv_bfloat16, double>();
test_assignment<float, __nv_bfloat16>();
test_assignment<double, __nv_bfloat16>();
#endif // _LIBCUDACXX_HAS_NVBF16
}

int main(int arg, char** argv)
{
test();
return 0;
}

0 comments on commit 886a290

Please sign in to comment.