Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tuple protocol to cuda::std::complex from C++26 #2882

Merged
merged 14 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions libcudacxx/include/cuda/std/__complex/nvbf16.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ _CCCL_DIAG_POP

# include <cuda/std/__complex/vector_support.h>
# include <cuda/std/__cuda/cmath_nvbf16.h>
# include <cuda/std/__fwd/get.h>
# include <cuda/std/__type_traits/enable_if.h>
# include <cuda/std/__type_traits/integral_constant.h>
# include <cuda/std/__type_traits/is_constructible.h>
Expand Down Expand Up @@ -112,6 +113,9 @@ class _CCCL_TYPE_VISIBILITY_DEFAULT _CCCL_ALIGNAS(alignof(__nv_bfloat162)) compl
template <class _Up>
friend class complex;

template <class _Up>
friend struct __get_complex_impl;

public:
using value_type = __nv_bfloat16;

Expand Down Expand Up @@ -295,6 +299,34 @@ _LIBCUDACXX_HIDE_FROM_ABI complex<__nv_bfloat16> acos(const complex<__nv_bfloat1
return complex<__nv_bfloat16>{_CUDA_VSTD::acos(complex<float>{__x})};
}

template <>
struct __get_complex_impl<__nv_bfloat16>
{
template <size_t _Index>
miscco marked this conversation as resolved.
Show resolved Hide resolved
static _LIBCUDACXX_HIDE_FROM_ABI constexpr __nv_bfloat16& get(complex<__nv_bfloat16>& __z) noexcept
{
return (_Index == 0) ? __z.__repr_.x : __z.__repr_.y;
}

template <size_t _Index>
static _LIBCUDACXX_HIDE_FROM_ABI constexpr __nv_bfloat16&& get(complex<__nv_bfloat16>&& __z) noexcept
{
return _CUDA_VSTD::move((_Index == 0) ? __z.__repr_.x : __z.__repr_.y);
}

template <size_t _Index>
static _LIBCUDACXX_HIDE_FROM_ABI constexpr const __nv_bfloat16& get(const complex<__nv_bfloat16>& __z) noexcept
{
return (_Index == 0) ? __z.__repr_.x : __z.__repr_.y;
}

template <size_t _Index>
static _LIBCUDACXX_HIDE_FROM_ABI constexpr const __nv_bfloat16&& get(const complex<__nv_bfloat16>&& __z) noexcept
{
return _CUDA_VSTD::move((_Index == 0) ? __z.__repr_.x : __z.__repr_.y);
}
};

# if !_CCCL_COMPILER(NVRTC)
template <class _CharT, class _Traits>
::std::basic_istream<_CharT, _Traits>&
Expand Down
32 changes: 32 additions & 0 deletions libcudacxx/include/cuda/std/__complex/nvfp16.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

# include <cuda/std/__complex/vector_support.h>
# include <cuda/std/__cuda/cmath_nvfp16.h>
# include <cuda/std/__fwd/get.h>
# include <cuda/std/__type_traits/enable_if.h>
# include <cuda/std/__type_traits/integral_constant.h>
# include <cuda/std/__type_traits/is_constructible.h>
Expand Down Expand Up @@ -109,6 +110,9 @@ class _CCCL_TYPE_VISIBILITY_DEFAULT _CCCL_ALIGNAS(alignof(__half2)) complex<__ha
template <class _Up>
friend class complex;

template <class _Up>
friend struct __get_complex_impl;

public:
using value_type = __half;

Expand Down Expand Up @@ -292,6 +296,34 @@ _LIBCUDACXX_HIDE_FROM_ABI complex<__half> acos(const complex<__half>& __x)
return complex<__half>{_CUDA_VSTD::acos(complex<float>{__x})};
}

template <>
struct __get_complex_impl<__half>
{
template <size_t _Index>
static _LIBCUDACXX_HIDE_FROM_ABI constexpr __half& get(complex<__half>& __z) noexcept
{
return (_Index == 0) ? __z.__repr_.x : __z.__repr_.y;
}

template <size_t _Index>
static _LIBCUDACXX_HIDE_FROM_ABI constexpr __half&& get(complex<__half>&& __z) noexcept
{
return _CUDA_VSTD::move((_Index == 0) ? __z.__repr_.x : __z.__repr_.y);
}

template <size_t _Index>
static _LIBCUDACXX_HIDE_FROM_ABI constexpr const __half& get(const complex<__half>& __z) noexcept
{
return (_Index == 0) ? __z.__repr_.x : __z.__repr_.y;
}

template <size_t _Index>
static _LIBCUDACXX_HIDE_FROM_ABI constexpr const __half&& get(const complex<__half>&& __z) noexcept
{
return _CUDA_VSTD::move((_Index == 0) ? __z.__repr_.x : __z.__repr_.y);
}
};

# if !defined(_LIBCUDACXX_HAS_NO_LOCALIZATION) && !_CCCL_COMPILER(NVRTC)
template <class _CharT, class _Traits>
::std::basic_istream<_CharT, _Traits>& operator>>(::std::basic_istream<_CharT, _Traits>& __is, complex<__half>& __x)
Expand Down
30 changes: 30 additions & 0 deletions libcudacxx/include/cuda/std/__fwd/complex.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright (c) 2023-24 NVIDIA CORPORATION & AFFILIATES.
//
//===----------------------------------------------------------------------===//

#ifndef _LIBCUDACXX___FWD_COMPLEX_H
#define _LIBCUDACXX___FWD_COMPLEX_H

#include <cuda/std/detail/__config>

#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
# pragma GCC system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
# pragma clang system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
# pragma system_header
#endif // no system header

_LIBCUDACXX_BEGIN_NAMESPACE_STD

template <class _Tp>
class _CCCL_TYPE_VISIBILITY_DEFAULT complex;

_LIBCUDACXX_END_NAMESPACE_STD

#endif // _LIBCUDACXX___FWD_COMPLEX_H
13 changes: 13 additions & 0 deletions libcudacxx/include/cuda/std/__fwd/get.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

#include <cuda/std/__concepts/copyable.h>
#include <cuda/std/__fwd/array.h>
#include <cuda/std/__fwd/complex.h>
#include <cuda/std/__fwd/pair.h>
#include <cuda/std/__fwd/subrange.h>
#include <cuda/std/__fwd/tuple.h>
Expand Down Expand Up @@ -70,6 +71,18 @@ _LIBCUDACXX_HIDE_FROM_ABI _CCCL_CONSTEXPR_CXX14 _Tp&& get(array<_Tp, _Size>&&) n
template <size_t _Ip, class _Tp, size_t _Size>
_LIBCUDACXX_HIDE_FROM_ABI _CCCL_CONSTEXPR_CXX14 const _Tp&& get(const array<_Tp, _Size>&&) noexcept;

template <size_t _Ip, class _Tp>
_LIBCUDACXX_HIDE_FROM_ABI constexpr _Tp& get(complex<_Tp>&) noexcept;

template <size_t _Ip, class _Tp>
_LIBCUDACXX_HIDE_FROM_ABI constexpr _Tp&& get(complex<_Tp>&&) noexcept;

template <size_t _Ip, class _Tp>
_LIBCUDACXX_HIDE_FROM_ABI constexpr const _Tp& get(const complex<_Tp>&) noexcept;

template <size_t _Ip, class _Tp>
_LIBCUDACXX_HIDE_FROM_ABI constexpr const _Tp&& get(const complex<_Tp>&&) noexcept;

_LIBCUDACXX_END_NAMESPACE_STD

#if _CCCL_STD_VER >= 2017 && !_CCCL_COMPILER(MSVC2017)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ _CCCL_DIAG_SUPPRESS_CLANG("-Wmismatched-tags")
#endif // !_CCCL_COMPILER(NVRTC)

#include <cuda/std/__fwd/array.h>
#include <cuda/std/__fwd/complex.h>
#include <cuda/std/__fwd/pair.h>
#include <cuda/std/__fwd/subrange.h>
#include <cuda/std/__fwd/tuple.h>
Expand Down Expand Up @@ -87,6 +88,14 @@ struct tuple_element<_Ip, const volatile _CUDA_VSTD::array<_Tp, _Size>>
: _CUDA_VSTD::tuple_element<_Ip, const volatile _CUDA_VSTD::array<_Tp, _Size>>
{};

template <class _Tp>
struct tuple_size<_CUDA_VSTD::complex<_Tp>> : _CUDA_VSTD::tuple_size<_CUDA_VSTD::complex<_Tp>>
{};

template <size_t _Ip, class _Tp>
struct tuple_element<_Ip, _CUDA_VSTD::complex<_Tp>> : _CUDA_VSTD::tuple_element<_Ip, _CUDA_VSTD::complex<_Tp>>
{};

template <class _Tp, class _Up>
struct tuple_size<_CUDA_VSTD::pair<_Tp, _Up>> : _CUDA_VSTD::tuple_size<_CUDA_VSTD::pair<_Tp, _Up>>
{};
Expand Down
5 changes: 5 additions & 0 deletions libcudacxx/include/cuda/std/__tuple_dir/tuple_like.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#endif // no system header

#include <cuda/std/__fwd/array.h>
#include <cuda/std/__fwd/complex.h>
#include <cuda/std/__fwd/pair.h>
#include <cuda/std/__fwd/subrange.h>
#include <cuda/std/__fwd/tuple.h>
Expand Down Expand Up @@ -56,6 +57,10 @@ template <class _Tp, size_t _Size>
struct __tuple_like<array<_Tp, _Size>> : true_type
{};

template <class _Tp>
struct __tuple_like<complex<_Tp>> : true_type
{};

#if _CCCL_STD_VER >= 2017 && !_CCCL_COMPILER(MSVC2017)
template <class _Ip, class _Sp, _CUDA_VRANGES::subrange_kind _Kp>
struct __tuple_like<_CUDA_VRANGES::subrange<_Ip, _Sp, _Kp>> : true_type
Expand Down
5 changes: 5 additions & 0 deletions libcudacxx/include/cuda/std/__tuple_dir/tuple_like_ext.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#endif // no system header

#include <cuda/std/__fwd/array.h>
#include <cuda/std/__fwd/complex.h>
#include <cuda/std/__fwd/pair.h>
#include <cuda/std/__fwd/tuple.h>
#include <cuda/std/__tuple_dir/tuple_types.h>
Expand Down Expand Up @@ -55,6 +56,10 @@ template <class _Tp, size_t _Size>
struct __tuple_like_ext<array<_Tp, _Size>> : true_type
{};

template <class _Tp>
struct __tuple_like_ext<complex<_Tp>> : true_type
{};

template <class... _Tp>
struct __tuple_like_ext<__tuple_types<_Tp...>> : true_type
{};
Expand Down
74 changes: 74 additions & 0 deletions libcudacxx/include/cuda/std/detail/libcxx/include/complex
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,9 @@ template<class T> complex<T> tanh (const complex<T>&);
#endif // no system header

#include <cuda/std/__complex/vector_support.h>
#include <cuda/std/__fwd/get.h>
#include <cuda/std/__tuple_dir/tuple_element.h>
#include <cuda/std/__tuple_dir/tuple_size.h>
#include <cuda/std/__type_traits/enable_if.h>
#include <cuda/std/__type_traits/is_constructible.h>
#include <cuda/std/__type_traits/is_floating_point.h>
Expand Down Expand Up @@ -286,6 +289,9 @@ class _CCCL_TYPE_VISIBILITY_DEFAULT _LIBCUDACXX_COMPLEX_ALIGNAS complex
template <class _Up>
friend class complex;

template <class _Up>
friend struct __get_complex_impl;

public:
using value_type = _Tp;

Expand Down Expand Up @@ -1418,6 +1424,74 @@ _LIBCUDACXX_HIDE_FROM_ABI complex<_Tp> tan(const complex<_Tp>& __x)
return complex<_Tp>(__z.imag(), -__z.real());
}

template <class _Tp>
struct tuple_size<complex<_Tp>> : _CUDA_VSTD::integral_constant<size_t, 2>
{};

template <size_t _Index, class _Tp>
struct tuple_element<_Index, complex<_Tp>> : _CUDA_VSTD::enable_if < _Index<2, _Tp>
{};

template <class _Tp>
struct __get_complex_impl
{
template <size_t _Index>
static _LIBCUDACXX_HIDE_FROM_ABI constexpr _Tp& get(complex<_Tp>& __z) noexcept
{
return (_Index == 0) ? __z.__re_ : __z.__im_;
}

template <size_t _Index>
static _LIBCUDACXX_HIDE_FROM_ABI constexpr _Tp&& get(complex<_Tp>&& __z) noexcept
{
return _CUDA_VSTD::move((_Index == 0) ? __z.__re_ : __z.__im_);
}

template <size_t _Index>
static _LIBCUDACXX_HIDE_FROM_ABI constexpr const _Tp& get(const complex<_Tp>& __z) noexcept
{
return (_Index == 0) ? __z.__re_ : __z.__im_;
}

template <size_t _Index>
static _LIBCUDACXX_HIDE_FROM_ABI constexpr const _Tp&& get(const complex<_Tp>&& __z) noexcept
{
return _CUDA_VSTD::move((_Index == 0) ? __z.__re_ : __z.__im_);
}
};

template <size_t _Index, class _Tp>
_LIBCUDACXX_HIDE_FROM_ABI constexpr _Tp& get(complex<_Tp>& __z) noexcept
{
static_assert(_Index < 2, "Index value is out of range");

return __get_complex_impl<_Tp>::template get<_Index>(__z);
}

template <size_t _Index, class _Tp>
_LIBCUDACXX_HIDE_FROM_ABI constexpr _Tp&& get(complex<_Tp>&& __z) noexcept
{
static_assert(_Index < 2, "Index value is out of range");

return __get_complex_impl<_Tp>::template get<_Index>(_CUDA_VSTD::move(__z));
}

template <size_t _Index, class _Tp>
_LIBCUDACXX_HIDE_FROM_ABI constexpr const _Tp& get(const complex<_Tp>& __z) noexcept
{
static_assert(_Index < 2, "Index value is out of range");

return __get_complex_impl<_Tp>::template get<_Index>(__z);
}

template <size_t _Index, class _Tp>
_LIBCUDACXX_HIDE_FROM_ABI constexpr const _Tp&& get(const complex<_Tp>&& __z) noexcept
{
static_assert(_Index < 2, "Index value is out of range");

return __get_complex_impl<_Tp>::template get<_Index>(_CUDA_VSTD::move(__z));
}

#if !_CCCL_COMPILER(NVRTC)
template <class _Tp, class _CharT, class _Traits>
::std::basic_istream<_CharT, _Traits>& operator>>(::std::basic_istream<_CharT, _Traits>& __is, complex<_Tp>& __x)
Expand Down
1 change: 1 addition & 0 deletions libcudacxx/include/cuda/std/version
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#endif // !_CCCL_COMPILER(NVRTC)

#define __cccl_lib_to_underlying 202102L
// #define __cpp_lib_tuple_like 202311L // P2819R2 is implemented, but P2165R4 is not yet

#if _CCCL_STD_VER >= 2014
# define __cccl_lib_bit_cast 201806L
Expand Down
Loading
Loading