Skip to content

Commit

Permalink
Implement tuple interface for cuda vector types
Browse files Browse the repository at this point in the history
This specializes the `std::tuple_size` and `std::tuple_element` traits so that they are usable with cuda vector types.

We also provide overloads for `std::get`, which together enables structured bindings support in C++17 onwards.

Fixes NVIDIA#1406
  • Loading branch information
miscco committed Feb 21, 2024
1 parent f6903bf commit b7ae902
Show file tree
Hide file tree
Showing 6 changed files with 887 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,318 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
//
//===----------------------------------------------------------------------===//

#ifndef _LIBCUDACXX___TUPLE_VECTOR_TYPES_H
#define _LIBCUDACXX___TUPLE_VECTOR_TYPES_H

#ifndef __cuda_std__
# include <__config>
#endif // __cuda_std__

#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
# pragma GCC system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
# pragma clang system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
# pragma system_header
#endif // no system header

#ifdef __cuda_std__
# if defined(_CCCL_CUDA_COMPILER)

_CCCL_DIAG_PUSH
_CCCL_DIAG_SUPPRESS_CLANG("-Wmismatched-tags")

# if !defined(_CCCL_COMPILER_NVRTC)
// Fetch utility to get primary template for ::std::tuple_size necessary for the specialization of
// ::std::tuple_size<cuda::std::tuple> to enable structured bindings.
// See https://github.com/NVIDIA/libcudacxx/issues/316
# include <utility>
# endif

# include "../__fwd/get.h"
# include "../__tuple_dir/structured_bindings.h"
# include "../__tuple_dir/tuple_element.h"
# include "../__tuple_dir/tuple_size.h"
# include "../__type_traits/integral_constant.h"
# include "../__type_traits/enable_if.h"
# include "../__utility/forward.h"
# include "../__utility/move.h"

# define _LIBCUDACXX_SPECIALIZE_TUPLE_INTERFACE(__name, __type, __size) \
template <> \
struct tuple_size<__name##__size> : _CUDA_VSTD::integral_constant<size_t, __size> \
{}; \
template <> \
struct tuple_size<const __name##__size> : _CUDA_VSTD::integral_constant<size_t, __size> \
{}; \
template <> \
struct tuple_size<volatile __name##__size> : _CUDA_VSTD::integral_constant<size_t, __size> \
{}; \
template <> \
struct tuple_size<const volatile __name##__size> : _CUDA_VSTD::integral_constant<size_t, __size> \
{}; \
\
template <size_t _Ip> \
struct tuple_element<_Ip, __name##__size> \
{ \
static_assert(_Ip < __size, "tuple_element index out of range"); \
using type = __type; \
}; \
template <size_t _Ip> \
struct tuple_element<_Ip, const __name##__size> \
{ \
static_assert(_Ip < __size, "tuple_element index out of range"); \
using type = const __type; \
}; \
template <size_t _Ip> \
struct tuple_element<_Ip, volatile __name##__size> \
{ \
static_assert(_Ip < __size, "tuple_element index out of range"); \
using type = volatile __type; \
}; \
template <size_t _Ip> \
struct tuple_element<_Ip, const volatile __name##__size> \
{ \
static_assert(_Ip < __size, "tuple_element index out of range"); \
using type = const volatile __type; \
};

# define _LIBCUDACXX_SPECIALIZE_TUPLE_INTERFACE_VECTOR(__name, __type) \
_LIBCUDACXX_SPECIALIZE_TUPLE_INTERFACE(__name, __type, 1) \
_LIBCUDACXX_SPECIALIZE_TUPLE_INTERFACE(__name, __type, 2) \
_LIBCUDACXX_SPECIALIZE_TUPLE_INTERFACE(__name, __type, 3) \
_LIBCUDACXX_SPECIALIZE_TUPLE_INTERFACE(__name, __type, 4)

# define _LIBCUDACXX_SPECIALIZE_GET(__name, __base_type) \
template <size_t _Ip> \
_LIBCUDACXX_INLINE_VISIBILITY _LIBCUDACXX_CONSTEXPR_AFTER_CXX11 __base_type& get(__name& __val) noexcept \
{ \
return _CUDA_VSTD::__get_element<_Ip>::template get<__name, __base_type>(__val); \
} \
template <size_t _Ip> \
_LIBCUDACXX_INLINE_VISIBILITY _LIBCUDACXX_CONSTEXPR_AFTER_CXX11 const __base_type& get( \
const __name& __val) noexcept \
{ \
return _CUDA_VSTD::__get_element<_Ip>::template get<__name, __base_type>(__val); \
} \
template <size_t _Ip> \
_LIBCUDACXX_INLINE_VISIBILITY _LIBCUDACXX_CONSTEXPR_AFTER_CXX11 __base_type&& get(__name&& __val) noexcept \
{ \
return _CUDA_VSTD::__get_element<_Ip>::template get<__name, __base_type>(static_cast<__name&&>(__val)); \
} \
template <size_t _Ip> \
_LIBCUDACXX_INLINE_VISIBILITY _LIBCUDACXX_CONSTEXPR_AFTER_CXX11 const __base_type&& get( \
const __name&& __val) noexcept \
{ \
return _CUDA_VSTD::__get_element<_Ip>::template get<__name, __base_type>(static_cast<const __name&&>(__val)); \
}

# define _LIBCUDACXX_SPECIALIZE_GET_VECTOR(__name, __base_type) \
_LIBCUDACXX_SPECIALIZE_GET(__name##1, __base_type) \
_LIBCUDACXX_SPECIALIZE_GET(__name##2, __base_type) \
_LIBCUDACXX_SPECIALIZE_GET(__name##3, __base_type) \
_LIBCUDACXX_SPECIALIZE_GET(__name##4, __base_type)

_LIBCUDACXX_BEGIN_NAMESPACE_STD

_LIBCUDACXX_SPECIALIZE_TUPLE_INTERFACE_VECTOR(char, signed char)
_LIBCUDACXX_SPECIALIZE_TUPLE_INTERFACE_VECTOR(uchar, unsigned char)
_LIBCUDACXX_SPECIALIZE_TUPLE_INTERFACE_VECTOR(short, short)
_LIBCUDACXX_SPECIALIZE_TUPLE_INTERFACE_VECTOR(ushort, unsigned short)
_LIBCUDACXX_SPECIALIZE_TUPLE_INTERFACE_VECTOR(int, int)
_LIBCUDACXX_SPECIALIZE_TUPLE_INTERFACE_VECTOR(uint, unsigned int)
_LIBCUDACXX_SPECIALIZE_TUPLE_INTERFACE_VECTOR(long, long)
_LIBCUDACXX_SPECIALIZE_TUPLE_INTERFACE_VECTOR(ulong, unsigned long)
_LIBCUDACXX_SPECIALIZE_TUPLE_INTERFACE_VECTOR(longlong, long long)
_LIBCUDACXX_SPECIALIZE_TUPLE_INTERFACE_VECTOR(ulonglong, unsigned long long)
_LIBCUDACXX_SPECIALIZE_TUPLE_INTERFACE_VECTOR(float, float)
_LIBCUDACXX_SPECIALIZE_TUPLE_INTERFACE_VECTOR(double, double)

template <size_t _Ip>
struct __get_element;

template <>
struct __get_element<0>
{
template <class _Vec, class _BaseType>
static _LIBCUDACXX_INLINE_VISIBILITY _LIBCUDACXX_CONSTEXPR_AFTER_CXX11 _BaseType& get(_Vec& __val) noexcept
{
return __val.x;
}

template <class _Vec, class _BaseType>
static _LIBCUDACXX_INLINE_VISIBILITY _LIBCUDACXX_CONSTEXPR_AFTER_CXX11 const _BaseType& get(const _Vec& __val) noexcept
{
return __val.x;
}

template <class _Vec, class _BaseType>
static _LIBCUDACXX_INLINE_VISIBILITY _LIBCUDACXX_CONSTEXPR_AFTER_CXX11 _BaseType&& get(_Vec&& __val) noexcept
{
return static_cast<_BaseType&&>(__val.x);
}

template <class _Vec, class _BaseType>
static _LIBCUDACXX_INLINE_VISIBILITY _LIBCUDACXX_CONSTEXPR_AFTER_CXX11 const _BaseType&&
get(const _Vec&& __val) noexcept
{
return static_cast<const _BaseType&&>(__val.x);
}
};

template <>
struct __get_element<1>
{
template <class _Vec, class _BaseType>
static _LIBCUDACXX_INLINE_VISIBILITY _LIBCUDACXX_CONSTEXPR_AFTER_CXX11 _BaseType& get(_Vec& __val) noexcept
{
return __val.y;
}

template <class _Vec, class _BaseType>
static _LIBCUDACXX_INLINE_VISIBILITY _LIBCUDACXX_CONSTEXPR_AFTER_CXX11 const _BaseType& get(const _Vec& __val) noexcept
{
return __val.y;
}

template <class _Vec, class _BaseType>
static _LIBCUDACXX_INLINE_VISIBILITY _LIBCUDACXX_CONSTEXPR_AFTER_CXX11 _BaseType&& get(_Vec&& __val) noexcept
{
return static_cast<_BaseType&&>(__val.y);
}

template <class _Vec, class _BaseType>
static _LIBCUDACXX_INLINE_VISIBILITY _LIBCUDACXX_CONSTEXPR_AFTER_CXX11 const _BaseType&&
get(const _Vec&& __val) noexcept
{
return static_cast<const _BaseType&&>(__val.y);
}
};
template <>
struct __get_element<2>
{
template <class _Vec, class _BaseType>
static _LIBCUDACXX_INLINE_VISIBILITY _LIBCUDACXX_CONSTEXPR_AFTER_CXX11 _BaseType& get(_Vec& __val) noexcept
{
return __val.z;
}

template <class _Vec, class _BaseType>
static _LIBCUDACXX_INLINE_VISIBILITY _LIBCUDACXX_CONSTEXPR_AFTER_CXX11 const _BaseType& get(const _Vec& __val) noexcept
{
return __val.z;
}

template <class _Vec, class _BaseType>
static _LIBCUDACXX_INLINE_VISIBILITY _LIBCUDACXX_CONSTEXPR_AFTER_CXX11 _BaseType&& get(_Vec&& __val) noexcept
{
return static_cast<_BaseType&&>(__val.z);
}

template <class _Vec, class _BaseType>
static _LIBCUDACXX_INLINE_VISIBILITY _LIBCUDACXX_CONSTEXPR_AFTER_CXX11 const _BaseType&&
get(const _Vec&& __val) noexcept
{
return static_cast<const _BaseType&&>(__val.z);
}
};

template <>
struct __get_element<3>
{
template <class _Vec, class _BaseType>
static _LIBCUDACXX_INLINE_VISIBILITY _LIBCUDACXX_CONSTEXPR_AFTER_CXX11 _BaseType& get(_Vec& __val) noexcept
{
return __val.w;
}

template <class _Vec, class _BaseType>
static _LIBCUDACXX_INLINE_VISIBILITY _LIBCUDACXX_CONSTEXPR_AFTER_CXX11 const _BaseType& get(const _Vec& __val) noexcept
{
return __val.w;
}

template <class _Vec, class _BaseType>
static _LIBCUDACXX_INLINE_VISIBILITY _LIBCUDACXX_CONSTEXPR_AFTER_CXX11 _BaseType&& get(_Vec&& __val) noexcept
{
return static_cast<_BaseType&&>(__val.w);
}

template <class _Vec, class _BaseType>
static _LIBCUDACXX_INLINE_VISIBILITY _LIBCUDACXX_CONSTEXPR_AFTER_CXX11 const _BaseType&&
get(const _Vec&& __val) noexcept
{
return static_cast<const _BaseType&&>(__val.w);
}
};

_LIBCUDACXX_SPECIALIZE_GET_VECTOR(char, signed char)
_LIBCUDACXX_SPECIALIZE_GET_VECTOR(uchar, unsigned char)
_LIBCUDACXX_SPECIALIZE_GET_VECTOR(short, short)
_LIBCUDACXX_SPECIALIZE_GET_VECTOR(ushort, unsigned short)
_LIBCUDACXX_SPECIALIZE_GET_VECTOR(int, int)
_LIBCUDACXX_SPECIALIZE_GET_VECTOR(uint, unsigned int)
_LIBCUDACXX_SPECIALIZE_GET_VECTOR(long, long)
_LIBCUDACXX_SPECIALIZE_GET_VECTOR(ulong, unsigned long)
_LIBCUDACXX_SPECIALIZE_GET_VECTOR(longlong, long long)
_LIBCUDACXX_SPECIALIZE_GET_VECTOR(ulonglong, unsigned long long)
_LIBCUDACXX_SPECIALIZE_GET_VECTOR(float, float)
_LIBCUDACXX_SPECIALIZE_GET_VECTOR(double, double)

_LIBCUDACXX_END_NAMESPACE_STD

// Those need to be defined in the global namespace because we need ADL to find them
_LIBCUDACXX_SPECIALIZE_GET_VECTOR(char, signed char)
_LIBCUDACXX_SPECIALIZE_GET_VECTOR(uchar, unsigned char)
_LIBCUDACXX_SPECIALIZE_GET_VECTOR(short, short)
_LIBCUDACXX_SPECIALIZE_GET_VECTOR(ushort, unsigned short)
_LIBCUDACXX_SPECIALIZE_GET_VECTOR(int, int)
_LIBCUDACXX_SPECIALIZE_GET_VECTOR(uint, unsigned int)
_LIBCUDACXX_SPECIALIZE_GET_VECTOR(long, long)
_LIBCUDACXX_SPECIALIZE_GET_VECTOR(ulong, unsigned long)
_LIBCUDACXX_SPECIALIZE_GET_VECTOR(longlong, long long)
_LIBCUDACXX_SPECIALIZE_GET_VECTOR(ulonglong, unsigned long long)
_LIBCUDACXX_SPECIALIZE_GET_VECTOR(float, float)
_LIBCUDACXX_SPECIALIZE_GET_VECTOR(double, double)

// This is a workaround for the fact that structured bindings require that the specializations of
// `tuple_size` and `tuple_element` reside in namespace std (https://eel.is/c++draft/dcl.struct.bind#4).
// See https://github.com/NVIDIA/libcudacxx/issues/316 for a short discussion
# if _CCCL_STD_VER >= 2017
namespace std
{

_LIBCUDACXX_SPECIALIZE_TUPLE_INTERFACE_VECTOR(char, signed char)
_LIBCUDACXX_SPECIALIZE_TUPLE_INTERFACE_VECTOR(uchar, unsigned char)
_LIBCUDACXX_SPECIALIZE_TUPLE_INTERFACE_VECTOR(short, short)
_LIBCUDACXX_SPECIALIZE_TUPLE_INTERFACE_VECTOR(ushort, unsigned short)
_LIBCUDACXX_SPECIALIZE_TUPLE_INTERFACE_VECTOR(int, int)
_LIBCUDACXX_SPECIALIZE_TUPLE_INTERFACE_VECTOR(uint, unsigned int)
_LIBCUDACXX_SPECIALIZE_TUPLE_INTERFACE_VECTOR(long, long)
_LIBCUDACXX_SPECIALIZE_TUPLE_INTERFACE_VECTOR(ulong, unsigned long)
_LIBCUDACXX_SPECIALIZE_TUPLE_INTERFACE_VECTOR(longlong, long long)
_LIBCUDACXX_SPECIALIZE_TUPLE_INTERFACE_VECTOR(ulonglong, unsigned long long)
_LIBCUDACXX_SPECIALIZE_TUPLE_INTERFACE_VECTOR(float, float)
_LIBCUDACXX_SPECIALIZE_TUPLE_INTERFACE_VECTOR(double, double)

} // namespace std

# endif // _CCCL_STD_VER >= 2017

# undef _LIBCUDACXX_SPECIALIZE_TUPLE_INTERFACE
# undef _LIBCUDACXX_SPECIALIZE_TUPLE_INTERFACE_VECTOR
# undef _LIBCUDACXX_SPECIALIZE_GET
# undef _LIBCUDACXX_SPECIALIZE_GET_VECTOR

_CCCL_DIAG_POP

# endif // _CCCL_CUDA_COMPILER
#endif // __cuda_std__

#endif // _LIBCUDACXX___TUPLE_VECTOR_TYPES_H
1 change: 1 addition & 0 deletions libcudacxx/include/cuda/std/detail/libcxx/include/tuple
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ template <class... Types>
#include "__tuple_dir/tuple_like.h"
#include "__tuple_dir/tuple_size.h"
#include "__tuple_dir/tuple_types.h"
#include "__tuple_dir/vector_types.h"
#include "__type_traits/maybe_const.h"
#include "__utility/forward.h"
#include "__utility/integer_sequence.h"
Expand Down
Loading

0 comments on commit b7ae902

Please sign in to comment.