Skip to content

Commit

Permalink
more fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
miscco committed Jan 21, 2025
1 parent 231da58 commit 968018b
Show file tree
Hide file tree
Showing 8 changed files with 124 additions and 148 deletions.
38 changes: 18 additions & 20 deletions libcudacxx/include/cuda/std/__mdspan/extents.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ struct __static_partial_sums
// The position of a dynamic value is indicated through a tag value.
template <class _TDynamic, class _TStatic, _TStatic _DynTag, _TStatic... _Values>
struct __maybe_static_array
: private __possibly_empty_array<_TDynamic, _CCCL_FOLD_PLUS(size_t(0), static_cast<size_t>(_Values == _DynTag))>
{
static_assert(is_convertible<_TStatic, _TDynamic>::value,
"__maybe_static_array: _TStatic must be convertible to _TDynamic");
Expand All @@ -168,9 +169,6 @@ struct __maybe_static_array
using _StaticValues = __static_array<_TStatic, _Values...>;
using _DynamicValues = __possibly_empty_array<_TDynamic, __size_dynamic_>;

// Dynamic values member
_CCCL_NO_UNIQUE_ADDRESS _DynamicValues __dyn_vals_;

// static mapping of indices to the position in the dynamic values array
using _DynamicIdxMap = __static_partial_sums<static_cast<size_t>(_Values == _DynTag)...>;

Expand All @@ -182,20 +180,18 @@ struct __maybe_static_array

public:
_LIBCUDACXX_HIDE_FROM_ABI constexpr __maybe_static_array() noexcept
: __dyn_vals_{__zeros(make_index_sequence<__size_dynamic_>())}
: _DynamicValues{__zeros(make_index_sequence<__size_dynamic_>())}
{}

template <class _Tp, size_t _Size>
_LIBCUDACXX_HIDE_FROM_ABI constexpr __maybe_static_array(span<_Tp, _Size> __vals) noexcept
# if _CCCL_STD_VER <= 2017 // NVCC complains that this constructor would not be constexpr without it
: __dyn_vals_{}
# endif // _CCCL_STD_VER <= 2017
: _DynamicValues{}
{
if constexpr (_Size == __size_dynamic_)
{
for (size_t __i = 0; __i != _Size; __i++)
{
__dyn_vals_[__i] = static_cast<_TDynamic>(__vals[__i]);
(*static_cast<_DynamicValues*>(this))[__i] = static_cast<_TDynamic>(__vals[__i]);
}
}
else
Expand All @@ -205,7 +201,7 @@ struct __maybe_static_array
_TStatic __static_val = _StaticValues::__get(__i);
if (__static_val == _DynTag)
{
__dyn_vals_[_DynamicIdxMap::__get(__i)] = static_cast<_TDynamic>(__vals[__i]);
(*static_cast<_DynamicValues*>(this))[_DynamicIdxMap::__get(__i)] = static_cast<_TDynamic>(__vals[__i]);
}
else
{
Expand All @@ -224,13 +220,14 @@ struct __maybe_static_array
_CCCL_TEMPLATE(class... _DynVals)
_CCCL_REQUIRES((sizeof...(_DynVals) == __size_dynamic_) && (!__all<__is_std_span<_DynVals>...>::value))
_LIBCUDACXX_HIDE_FROM_ABI constexpr __maybe_static_array(_DynVals... __vals) noexcept
: __dyn_vals_{static_cast<_TDynamic>(__vals)...}
: _DynamicValues{static_cast<_TDynamic>(__vals)...}
{}

// constructors from all values -- here rank will be greater than 0
_CCCL_TEMPLATE(class... _DynVals)
_CCCL_REQUIRES((sizeof...(_DynVals) != __size_dynamic_) && (!__all<__is_std_span<_DynVals>...>::value))
_LIBCUDACXX_HIDE_FROM_ABI constexpr __maybe_static_array(_DynVals... __vals)
: _DynamicValues{}
{
static_assert((sizeof...(_DynVals) == __size_), "Invalid number of values.");
_TDynamic __values[__size_] = {static_cast<_TDynamic>(__vals)...};
Expand All @@ -239,7 +236,7 @@ struct __maybe_static_array
_TStatic __static_val = _StaticValues::__get(__i);
if (__static_val == _DynTag)
{
__dyn_vals_[_DynamicIdxMap::__get(__i)] = __values[__i];
(*static_cast<_DynamicValues*>(this))[_DynamicIdxMap::__get(__i)] = __values[__i];
}
else
{
Expand Down Expand Up @@ -270,7 +267,9 @@ struct __maybe_static_array
_CCCL_ASSERT(__i < __size_, "extents access: index must be less than rank");
}
_TStatic __static_val = _StaticValues::__get(__i);
return __static_val == _DynTag ? __dyn_vals_[_DynamicIdxMap::__get(__i)] : static_cast<_TDynamic>(__static_val);
return __static_val == _DynTag
? (*static_cast<const _DynamicValues*>(this))[_DynamicIdxMap::__get(__i)]
: static_cast<_TDynamic>(__static_val);
}
_LIBCUDACXX_HIDE_FROM_ABI constexpr _TDynamic operator[](size_t __i) const
{
Expand Down Expand Up @@ -389,7 +388,7 @@ struct __extent_delegate_tag
// Used by mdspan, mdarray and layout mappings.
// See ISO C++ standard [mdspan.extents]
template <class _IndexType, size_t... _Extents>
class extents
class extents : private __mdspan_detail::__maybe_static_array<_IndexType, size_t, dynamic_extent, _Extents...>
{
public:
// typedefs for integral types used
Expand All @@ -410,7 +409,6 @@ class extents

// internal storage type using __maybe_static_array
using _Values = __mdspan_detail::__maybe_static_array<_IndexType, size_t, dynamic_extent, _Extents...>;
_CCCL_NO_UNIQUE_ADDRESS _Values __vals_;

public:
// [mdspan.extents.obs], observers of multidimensional index space
Expand All @@ -425,7 +423,7 @@ class extents

_LIBCUDACXX_HIDE_FROM_ABI constexpr index_type extent(rank_type __r) const noexcept
{
return __vals_.__value(__r);
return this->__value(__r);
}
_LIBCUDACXX_HIDE_FROM_ABI static constexpr size_t static_extent(rank_type __r) noexcept
{
Expand All @@ -446,7 +444,7 @@ class extents
_CCCL_REQUIRES((sizeof...(_OtherIndexTypes) == __rank_ || sizeof...(_OtherIndexTypes) == __rank_dynamic_)
_CCCL_AND __all_convertible_to_index_type<_OtherIndexTypes...>)
_LIBCUDACXX_HIDE_FROM_ABI constexpr explicit extents(_OtherIndexTypes... __dynvals) noexcept
: __vals_(static_cast<index_type>(__dynvals)...)
: _Values(static_cast<index_type>(__dynvals)...)
{
// Not catching this could lead to out of bounds errors later
// e.g. mdspan m(ptr, dextents<char, 1>(200u)); leads to an extent of -56 on m
Expand Down Expand Up @@ -475,7 +473,7 @@ class extents
_CCCL_TEMPLATE(class _OtherIndexType, size_t _Size)
_CCCL_REQUIRES((_Size == __rank_dynamic_) _CCCL_AND __is_convertible_to_index_type<_OtherIndexType>)
_LIBCUDACXX_HIDE_FROM_ABI constexpr extents(span<_OtherIndexType, _Size> __exts) noexcept
: __vals_(__exts)
: _Values(__exts)
{
// Not catching this could lead to out of bounds errors later
// e.g. array a{200u}; mdspan<int, dextents<char,1>> m(ptr, extents(span<unsigned,1>(a))); leads to an extent of -56
Expand All @@ -488,7 +486,7 @@ class extents
_CCCL_REQUIRES((_Size != __rank_dynamic_) _CCCL_AND(_Size == __rank_)
_CCCL_AND __is_convertible_to_index_type<_OtherIndexType>)
_LIBCUDACXX_HIDE_FROM_ABI explicit constexpr extents(span<_OtherIndexType, _Size> __exts) noexcept
: __vals_(__exts)
: _Values(__exts)
{
// Not catching this could lead to out of bounds errors later
// e.g. array a{200u}; mdspan<int, dextents<char,1>> m(ptr, extents(span<unsigned,1>(a))); leads to an extent of -56
Expand Down Expand Up @@ -546,7 +544,7 @@ class extents
_CCCL_REQUIRES((rank() > 0) _CCCL_AND __potentially_narrowing<_OtherIndexType>)
_LIBCUDACXX_HIDE_FROM_ABI constexpr extents(__extent_delegate_tag,
const extents<_OtherIndexType, _OtherExtents...>& __other) noexcept
: __vals_(__construct_vals_from_extents(integral_constant<size_t, 0>(), integral_constant<size_t, 0>(), __other))
: _Values(__construct_vals_from_extents(integral_constant<size_t, 0>(), integral_constant<size_t, 0>(), __other))
{
for (size_t __r = 0; __r < rank(); __r++)
{
Expand All @@ -568,7 +566,7 @@ class extents
_CCCL_REQUIRES((rank() > 0) _CCCL_AND(!__potentially_narrowing<_OtherIndexType>))
_LIBCUDACXX_HIDE_FROM_ABI constexpr extents(__extent_delegate_tag,
const extents<_OtherIndexType, _OtherExtents...>& __other) noexcept
: __vals_(__construct_vals_from_extents(integral_constant<size_t, 0>(), integral_constant<size_t, 0>(), __other))
: _Values(__construct_vals_from_extents(integral_constant<size_t, 0>(), integral_constant<size_t, 0>(), __other))
{
for (size_t __r = 0; __r < rank(); __r++)
{
Expand Down
33 changes: 12 additions & 21 deletions libcudacxx/include/cuda/std/__mdspan/layout_left.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,27 +65,22 @@ class layout_left::mapping
return x && ((*res / x) != y);
}

template <size_t _Rank = _Extents::rank(), enable_if_t<_Rank != 0, int> = 0>
_LIBCUDACXX_HIDE_FROM_ABI static constexpr bool __required_span_size_is_representable(const extents_type& __ext)
{
index_type __prod = __ext.extent(0);
for (rank_type __r = 1; __r < extents_type::rank(); __r++)
if constexpr (extents_type::rank() != 0)
{
bool __overflowed = __mul_overflow(__prod, __ext.extent(__r), &__prod);
if (__overflowed)
index_type __prod = __ext.extent(0);
for (rank_type __r = 1; __r < extents_type::rank(); __r++)
{
return false;
if (__mul_overflow(__prod, __ext.extent(__r), &__prod))
{
return false;
}
}
}
return true;
}

template <size_t _Rank = extents_type::rank(), enable_if_t<_Rank == 0, int> = 0>
_LIBCUDACXX_HIDE_FROM_ABI static constexpr bool __required_span_size_is_representable(const extents_type& __ext)
{
return true;
}

static_assert((extents_type::rank_dynamic() > 0) || __required_span_size_is_representable(extents_type()),
"layout_left::mapping product of static extents must be representable as index_type.");

Expand Down Expand Up @@ -203,23 +198,19 @@ class layout_left::mapping
return __extents_;
}

template <size_t _Rank = _Extents::rank(), enable_if_t<_Rank != 0, int> = 0>
_LIBCUDACXX_HIDE_FROM_ABI constexpr index_type required_span_size() const noexcept
{
index_type __size = 1;
for (size_t __r = 0; __r != extents_type::rank(); __r++)
if constexpr (extents_type::rank() != 0)
{
__size *= __extents_.extent(__r);
for (size_t __r = 0; __r != extents_type::rank(); __r++)
{
__size *= __extents_.extent(__r);
}
}
return __size;
}

template <size_t _Rank = _Extents::rank(), enable_if_t<_Rank == 0, int> = 0>
_LIBCUDACXX_HIDE_FROM_ABI constexpr index_type required_span_size() const noexcept
{
return 1;
}

template <size_t... _Pos>
_LIBCUDACXX_HIDE_FROM_ABI constexpr index_type
__op_index(const array<index_type, _Extents::rank()>& __idx_a, index_sequence<_Pos...>) const noexcept
Expand Down
19 changes: 7 additions & 12 deletions libcudacxx/include/cuda/std/__mdspan/layout_right.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,27 +66,22 @@ class layout_right::mapping
return __x && ((*__res / __x) != __y);
}

template <size_t _Rank = _Extents::rank(), enable_if_t<_Rank != 0, int> = 0>
_LIBCUDACXX_HIDE_FROM_ABI static constexpr bool __required_span_size_is_representable(const extents_type& __ext)
{
index_type __prod = __ext.extent(0);
for (rank_type __r = 1; __r < extents_type::rank(); __r++)
if constexpr (extents_type::rank() != 0)
{
bool __overflowed = __mul_overflow(__prod, __ext.extent(__r), &__prod);
if (__overflowed)
index_type __prod = __ext.extent(0);
for (rank_type __r = 1; __r < extents_type::rank(); __r++)
{
return false;
if (__mul_overflow(__prod, __ext.extent(__r), &__prod))
{
return false;
}
}
}
return true;
}

template <size_t _Rank = _Extents::rank(), enable_if_t<_Rank == 0, int> = 0>
_LIBCUDACXX_HIDE_FROM_ABI static constexpr bool __required_span_size_is_representable(const extents_type& __ext)
{
return true;
}

static_assert((extents_type::rank_dynamic() > 0) || __required_span_size_is_representable(extents_type()),
"layout_right::mapping product of static extents must be representable as index_type.");

Expand Down
Loading

0 comments on commit 968018b

Please sign in to comment.