Skip to content

Commit

Permalink
mdspan fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
miscco committed Jan 22, 2025
1 parent 1680ef8 commit e221c2c
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 28 deletions.
44 changes: 26 additions & 18 deletions libcudacxx/include/cuda/std/__mdspan/mdspan.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
#include <cuda/std/__type_traits/remove_pointer.h>
#include <cuda/std/__type_traits/remove_reference.h>
#include <cuda/std/__utility/as_const.h>
#include <cuda/std/__utility/declval.h>
#include <cuda/std/__utility/integer_sequence.h>
#include <cuda/std/__utility/move.h>
#include <cuda/std/array>
Expand Down Expand Up @@ -138,7 +139,8 @@ class mdspan
constexpr mdspan(mdspan&&) = default;

_CCCL_TEMPLATE(class... _OtherIndexTypes)
_CCCL_REQUIRES(((sizeof...(_OtherIndexTypes) == rank()) || (sizeof...(_OtherIndexTypes) == rank_dynamic()))
_CCCL_REQUIRES(((sizeof...(_OtherIndexTypes) == extents_type::rank())
|| (sizeof...(_OtherIndexTypes) == extents_type::rank_dynamic()))
_CCCL_AND _CCCL_FOLD_AND(_CCCL_TRAIT(is_convertible, _OtherIndexTypes, index_type))
_CCCL_AND _CCCL_FOLD_AND(_CCCL_TRAIT(is_nothrow_constructible, index_type, _OtherIndexTypes))
_CCCL_AND _CCCL_TRAIT(is_constructible, mapping_type, extents_type)
Expand All @@ -157,15 +159,15 @@ class mdspan
&& _CCCL_TRAIT(is_default_constructible, accessor_type);

_CCCL_TEMPLATE(class _OtherIndexType, size_t _Size)
_CCCL_REQUIRES((_Size == rank_dynamic()) _CCCL_AND __is_constructible_from<_OtherIndexType>)
_CCCL_REQUIRES((_Size == extents_type::rank_dynamic()) _CCCL_AND __is_constructible_from<_OtherIndexType>)
_LIBCUDACXX_HIDE_FROM_ABI constexpr mdspan(data_handle_type __p, const array<_OtherIndexType, _Size>& __exts)
: __ptr_(_CUDA_VSTD::move(__p))
, __map_(extents_type{__exts})
, __acc_{}
{}

_CCCL_TEMPLATE(class _OtherIndexType, size_t _Size)
_CCCL_REQUIRES((_Size == rank()) _CCCL_AND(_Size != rank_dynamic())
_CCCL_REQUIRES((_Size == extents_type::rank()) _CCCL_AND(_Size != extents_type::rank_dynamic())
_CCCL_AND __is_constructible_from<_OtherIndexType>)
_LIBCUDACXX_HIDE_FROM_ABI explicit constexpr mdspan(data_handle_type __p, const array<_OtherIndexType, _Size>& __exts)
: __ptr_(_CUDA_VSTD::move(__p))
Expand All @@ -174,15 +176,15 @@ class mdspan
{}

_CCCL_TEMPLATE(class _OtherIndexType, size_t _Size)
_CCCL_REQUIRES((_Size == rank_dynamic()) _CCCL_AND __is_constructible_from<_OtherIndexType>)
_CCCL_REQUIRES((_Size == extents_type::rank_dynamic()) _CCCL_AND __is_constructible_from<_OtherIndexType>)
_LIBCUDACXX_HIDE_FROM_ABI constexpr mdspan(data_handle_type __p, span<_OtherIndexType, _Size> __exts)
: __ptr_(_CUDA_VSTD::move(__p))
, __map_(extents_type{__exts})
, __acc_{}
{}

_CCCL_TEMPLATE(class _OtherIndexType, size_t _Size)
_CCCL_REQUIRES((_Size == rank()) _CCCL_AND(_Size != rank_dynamic())
_CCCL_REQUIRES((_Size == extents_type::rank()) _CCCL_AND(_Size != extents_type::rank_dynamic())
_CCCL_AND __is_constructible_from<_OtherIndexType>)
_LIBCUDACXX_HIDE_FROM_ABI explicit constexpr mdspan(data_handle_type __p, span<_OtherIndexType, _Size> __exts)
: __ptr_(_CUDA_VSTD::move(__p))
Expand Down Expand Up @@ -224,7 +226,7 @@ class mdspan
&& _CCCL_TRAIT(is_convertible, const _OtherAccessor&, accessor_type);

_CCCL_TEMPLATE(class _OtherElementType, class _OtherExtents, class _OtherLayoutPolicy, class _OtherAccessor)
_CCCL_REQUIRES((rank() > 0) //
_CCCL_REQUIRES((extents_type::rank() > 0) //
_CCCL_AND __is_convertible_from<_OtherExtents, _OtherLayoutPolicy, _OtherAccessor> //
_CCCL_AND __is_implicit_convertible_from<_OtherExtents, _OtherLayoutPolicy, _OtherAccessor>)
_LIBCUDACXX_HIDE_FROM_ABI constexpr mdspan(
Expand All @@ -243,7 +245,7 @@ class mdspan
// its extents() function returns a const reference to extents_type.
// The only way this can be triggered is if the mapping conversion constructor would for example
// always construct its extents() only from the dynamic extents, instead of from the other extents.
for (size_t __r = 0; __r < rank(); __r++)
for (size_t __r = 0; __r < extents_type::rank(); __r++)
{
// Not catching this could lead to out of bounds errors later
// e.g. mdspan<int, dextents<char,1>, non_checking_layout> m =
Expand All @@ -255,7 +257,7 @@ class mdspan
}

_CCCL_TEMPLATE(class _OtherElementType, class _OtherExtents, class _OtherLayoutPolicy, class _OtherAccessor)
_CCCL_REQUIRES((rank() == 0) //
_CCCL_REQUIRES((extents_type::rank() == 0) //
_CCCL_AND __is_convertible_from<_OtherExtents, _OtherLayoutPolicy, _OtherAccessor> //
_CCCL_AND __is_implicit_convertible_from<_OtherExtents, _OtherLayoutPolicy, _OtherAccessor>)
_LIBCUDACXX_HIDE_FROM_ABI constexpr mdspan(
Expand All @@ -270,7 +272,7 @@ class mdspan
"mdspan: incompatible extents for mdspan construction");
}
_CCCL_TEMPLATE(class _OtherElementType, class _OtherExtents, class _OtherLayoutPolicy, class _OtherAccessor)
_CCCL_REQUIRES((rank() > 0) //
_CCCL_REQUIRES((extents_type::rank() > 0) //
_CCCL_AND __is_convertible_from<_OtherExtents, _OtherLayoutPolicy, _OtherAccessor> //
_CCCL_AND(!__is_implicit_convertible_from<_OtherExtents, _OtherLayoutPolicy, _OtherAccessor>))
_LIBCUDACXX_HIDE_FROM_ABI explicit constexpr mdspan(
Expand All @@ -289,7 +291,7 @@ class mdspan
// its extents() function returns a const reference to extents_type.
// The only way this can be triggered is if the mapping conversion constructor would for example
// always construct its extents() only from the dynamic extents, instead of from the other extents.
for (size_t __r = 0; __r < rank(); __r++)
for (size_t __r = 0; __r < extents_type::rank(); __r++)
{
// Not catching this could lead to out of bounds errors later
// e.g. mdspan<int, dextents<char,1>, non_checking_layout> m =
Expand All @@ -301,7 +303,7 @@ class mdspan
}

_CCCL_TEMPLATE(class _OtherElementType, class _OtherExtents, class _OtherLayoutPolicy, class _OtherAccessor)
_CCCL_REQUIRES((rank() == 0) //
_CCCL_REQUIRES((extents_type::rank() == 0) //
_CCCL_AND __is_convertible_from<_OtherExtents, _OtherLayoutPolicy, _OtherAccessor> //
_CCCL_AND(!__is_implicit_convertible_from<_OtherExtents, _OtherLayoutPolicy, _OtherAccessor>))
_LIBCUDACXX_HIDE_FROM_ABI explicit constexpr mdspan(
Expand All @@ -324,7 +326,7 @@ class mdspan

# if defined(_LIBCUDACXX_HAS_MULTIARG_OPERATOR_BRACKETS)
_CCCL_TEMPLATE(class... _OtherIndexTypes)
_CCCL_REQUIRES((sizeof...(_OtherIndexTypes) == rank())
_CCCL_REQUIRES((sizeof...(_OtherIndexTypes) == extents_type::rank())
_CCCL_AND _CCCL_FOLD_AND(_CCCL_TRAIT(is_convertible, _OtherIndexTypes, index_type))
_CCCL_AND _CCCL_FOLD_AND(_CCCL_TRAIT(is_nothrow_constructible, index_type, _OtherIndexTypes)))
_LIBCUDACXX_HIDE_FROM_ABI constexpr reference operator[](_OtherIndexTypes... __indices) const
Expand All @@ -337,7 +339,7 @@ class mdspan
}
# else
_CCCL_TEMPLATE(class _OtherIndexType)
_CCCL_REQUIRES((rank() == 1) _CCCL_AND _CCCL_TRAIT(is_convertible, _OtherIndexType, index_type)
_CCCL_REQUIRES((extents_type::rank() == 1) _CCCL_AND _CCCL_TRAIT(is_convertible, _OtherIndexType, index_type)
_CCCL_AND _CCCL_TRAIT(is_nothrow_constructible, index_type, _OtherIndexType))
_LIBCUDACXX_HIDE_FROM_ABI constexpr reference operator[](_OtherIndexType __index) const
{
Expand All @@ -362,15 +364,16 @@ class mdspan
_CCCL_TEMPLATE(class _OtherIndexType)
_CCCL_REQUIRES(_CCCL_TRAIT(is_convertible, const _OtherIndexType&, index_type)
_CCCL_AND _CCCL_TRAIT(is_nothrow_constructible, index_type, const _OtherIndexType&))
_LIBCUDACXX_HIDE_FROM_ABI constexpr reference operator[](const array<_OtherIndexType, rank()>& __indices) const
_LIBCUDACXX_HIDE_FROM_ABI constexpr reference
operator[](const array<_OtherIndexType, extents_type::rank()>& __indices) const
{
return __acc_.access(__ptr_, __op_bracket(__indices, make_index_sequence<rank()>()));
}

_CCCL_TEMPLATE(class _OtherIndexType)
_CCCL_REQUIRES(_CCCL_TRAIT(is_convertible, const _OtherIndexType&, index_type)
_CCCL_AND _CCCL_TRAIT(is_nothrow_constructible, index_type, const _OtherIndexType&))
_LIBCUDACXX_HIDE_FROM_ABI constexpr reference operator[](span<_OtherIndexType, rank()> __indices) const
_LIBCUDACXX_HIDE_FROM_ABI constexpr reference operator[](span<_OtherIndexType, extents_type::rank()> __indices) const
{
return __acc_.access(__ptr_, __op_bracket(__indices, make_index_sequence<rank()>()));
}
Expand Down Expand Up @@ -446,28 +449,33 @@ class mdspan
return __acc_;
};

_LIBCUDACXX_HIDE_FROM_ABI static constexpr bool is_always_unique()
_LIBCUDACXX_HIDE_FROM_ABI static constexpr bool is_always_unique() noexcept(noexcept(mapping_type::is_always_unique()))
{
return mapping_type::is_always_unique();
};
_LIBCUDACXX_HIDE_FROM_ABI static constexpr bool is_always_exhaustive()
_LIBCUDACXX_HIDE_FROM_ABI static constexpr bool
is_always_exhaustive() noexcept(noexcept(mapping_type::is_always_exhaustive()))
{
return mapping_type::is_always_exhaustive();
};
_LIBCUDACXX_HIDE_FROM_ABI static constexpr bool is_always_strided()
_LIBCUDACXX_HIDE_FROM_ABI static constexpr bool
is_always_strided() noexcept(noexcept(mapping_type::is_always_strided()))
{
return mapping_type::is_always_strided();
};

_LIBCUDACXX_HIDE_FROM_ABI constexpr bool is_unique() const
noexcept(noexcept(_CUDA_VSTD::declval<const mapping_type&>().is_unique()))
{
return __map_.is_unique();
};
_LIBCUDACXX_HIDE_FROM_ABI constexpr bool is_exhaustive() const
noexcept(noexcept(_CUDA_VSTD::declval<const mapping_type&>().is_exhaustive()))
{
return __map_.is_exhaustive();
};
_LIBCUDACXX_HIDE_FROM_ABI constexpr bool is_strided() const
noexcept(noexcept(_CUDA_VSTD::declval<const mapping_type&>().is_strided()))
{
return __map_.is_strided();
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@
#include "test_macros.h"

// GCC warns about comma operator changing its meaning inside [] in C++23
#if defined(TEST_COMPILER_GCC)
#if _CCCL_COMPILER(GCC, >=, 10)
# pragma GCC diagnostic push
# pragma GCC diagnostic ignored "-Wcomma-subscript"
#endif // TEST_COMPILER_GCC
#endif // _CCCL_COMPILER(GCC, >=, 10)

template <class MDS>
__host__ __device__ constexpr auto& access(MDS mds, int64_t i0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,12 +171,12 @@ __host__ __device__ constexpr void test_mdspan_types(const H& handle, const M& m
ASSERT_SAME_TYPE(decltype(m.is_unique()), bool);
ASSERT_SAME_TYPE(decltype(m.is_exhaustive()), bool);
ASSERT_SAME_TYPE(decltype(m.is_strided()), bool);
assert(!noexcept(MDS::is_always_unique()));
assert(!noexcept(MDS::is_always_exhaustive()));
assert(!noexcept(MDS::is_always_strided()));
assert(!noexcept(m.is_unique()));
assert(!noexcept(m.is_exhaustive()));
assert(!noexcept(m.is_strided()));
assert(noexcept(MDS::is_always_unique() == noexcept(M::is_always_unique())));
assert(noexcept(MDS::is_always_exhaustive()) == noexcept(M::is_always_exhaustive()));
assert(noexcept(MDS::is_always_strided()) == noexcept(M::is_always_strided()));
assert(noexcept(m.is_unique()) == noexcept(m.is_always_unique()));
assert(noexcept(m.is_exhaustive()) == noexcept(m.is_exhaustive()));
assert(noexcept(m.is_strided()) == noexcept(m.is_strided()));
assert(MDS::is_always_unique() == M::is_always_unique());
assert(MDS::is_always_exhaustive() == M::is_always_exhaustive());
assert(MDS::is_always_strided() == M::is_always_strided());
Expand Down Expand Up @@ -249,11 +249,11 @@ __host__ __device__ TEST_CONSTEXPR_CXX20 bool test_evil()
int main(int, char**)
{
test();
static_assert(test(), "");
// static_assert(test(), "");

test_evil();
#if TEST_STD_VER >= 2020
static_assert(test(), "");
static_assert(test_evil(), "");
#endif // TEST_STD_VER >= 2020

return 0;
Expand Down

0 comments on commit e221c2c

Please sign in to comment.