Skip to content

Commit

Permalink
Merge pull request #669 from jturner314/fix-axis-iter
Browse files Browse the repository at this point in the history
Fix axis iterators
  • Loading branch information
jturner314 authored Aug 20, 2019
2 parents ce80d38 + 1443df8 commit bc795b8
Show file tree
Hide file tree
Showing 3 changed files with 264 additions and 44 deletions.
4 changes: 2 additions & 2 deletions src/impl_methods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1005,7 +1005,7 @@ where
/// The last view may have less elements if `size` does not divide
/// the axis' dimension.
///
/// **Panics** if `axis` is out of bounds.
/// **Panics** if `axis` is out of bounds or if `size` is zero.
///
/// ```
/// use ndarray::Array;
Expand Down Expand Up @@ -1036,7 +1036,7 @@ where
///
/// Iterator element is `ArrayViewMut<A, D>`
///
/// **Panics** if `axis` is out of bounds.
/// **Panics** if `axis` is out of bounds or if `size` is zero.
pub fn axis_chunks_iter_mut(&mut self, axis: Axis, size: usize) -> AxisChunksIterMut<'_, A, D>
where
S: DataMut,
Expand Down
131 changes: 90 additions & 41 deletions src/iterators/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -738,11 +738,19 @@ where

#[derive(Debug)]
pub struct AxisIterCore<A, D> {
/// Index along the axis of the value of `.next()`, relative to the start
/// of the axis.
index: Ix,
len: Ix,
/// (Exclusive) upper bound on `index`. Initially, this is equal to the
/// length of the axis.
end: Ix,
/// Stride along the axis (offset between consecutive pointers).
stride: Ixs,
/// Shape of the iterator's items.
inner_dim: D,
/// Strides of the iterator's items.
inner_strides: D,
/// Pointer corresponding to `index == 0`.
ptr: *mut A,
}

Expand All @@ -751,7 +759,7 @@ clone_bounds!(
AxisIterCore[A, D] {
@copy {
index,
len,
end,
stride,
ptr,
}
Expand All @@ -767,54 +775,53 @@ impl<A, D: Dimension> AxisIterCore<A, D> {
Di: RemoveAxis<Smaller = D>,
S: Data<Elem = A>,
{
let shape = v.shape()[axis.index()];
let stride = v.strides()[axis.index()];
AxisIterCore {
index: 0,
len: shape,
stride,
end: v.len_of(axis),
stride: v.stride_of(axis),
inner_dim: v.dim.remove_axis(axis),
inner_strides: v.strides.remove_axis(axis),
ptr: v.ptr,
}
}

#[inline]
unsafe fn offset(&self, index: usize) -> *mut A {
debug_assert!(
index <= self.len,
"index={}, len={}, stride={}",
index < self.end,
"index={}, end={}, stride={}",
index,
self.len,
self.end,
self.stride
);
self.ptr.offset(index as isize * self.stride)
}

/// Split the iterator at index, yielding two disjoint iterators.
/// Splits the iterator at `index`, yielding two disjoint iterators.
///
/// **Panics** if `index` is strictly greater than the iterator's length
/// `index` is relative to the current state of the iterator (which is not
/// necessarily the start of the axis).
///
/// **Panics** if `index` is strictly greater than the iterator's remaining
/// length.
fn split_at(self, index: usize) -> (Self, Self) {
assert!(index <= self.len);
let right_ptr = if index != self.len {
unsafe { self.offset(index) }
} else {
self.ptr
};
assert!(index <= self.len());
let mid = self.index + index;
let left = AxisIterCore {
index: 0,
len: index,
index: self.index,
end: mid,
stride: self.stride,
inner_dim: self.inner_dim.clone(),
inner_strides: self.inner_strides.clone(),
ptr: self.ptr,
};
let right = AxisIterCore {
index: 0,
len: self.len - index,
index: mid,
end: self.end,
stride: self.stride,
inner_dim: self.inner_dim,
inner_strides: self.inner_strides,
ptr: right_ptr,
ptr: self.ptr,
};
(left, right)
}
Expand All @@ -827,7 +834,7 @@ where
type Item = *mut A;

fn next(&mut self) -> Option<Self::Item> {
if self.index >= self.len {
if self.index >= self.end {
None
} else {
let ptr = unsafe { self.offset(self.index) };
Expand All @@ -837,7 +844,7 @@ where
}

fn size_hint(&self) -> (usize, Option<usize>) {
let len = self.len - self.index;
let len = self.len();
(len, Some(len))
}
}
Expand All @@ -847,16 +854,25 @@ where
D: Dimension,
{
fn next_back(&mut self) -> Option<Self::Item> {
if self.index >= self.len {
if self.index >= self.end {
None
} else {
self.len -= 1;
let ptr = unsafe { self.offset(self.len) };
let ptr = unsafe { self.offset(self.end - 1) };
self.end -= 1;
Some(ptr)
}
}
}

impl<A, D> ExactSizeIterator for AxisIterCore<A, D>
where
D: Dimension,
{
fn len(&self) -> usize {
self.end - self.index
}
}

/// An iterator that traverses over an axis and
/// and yields each subview.
///
Expand Down Expand Up @@ -899,9 +915,13 @@ impl<'a, A, D: Dimension> AxisIter<'a, A, D> {
}
}

/// Split the iterator at index, yielding two disjoint iterators.
/// Splits the iterator at `index`, yielding two disjoint iterators.
///
/// **Panics** if `index` is strictly greater than the iterator's length
/// `index` is relative to the current state of the iterator (which is not
/// necessarily the start of the axis).
///
/// **Panics** if `index` is strictly greater than the iterator's remaining
/// length.
pub fn split_at(self, index: usize) -> (Self, Self) {
let (left, right) = self.iter.split_at(index);
(
Expand Down Expand Up @@ -946,7 +966,7 @@ where
D: Dimension,
{
fn len(&self) -> usize {
self.size_hint().0
self.iter.len()
}
}

Expand Down Expand Up @@ -981,9 +1001,13 @@ impl<'a, A, D: Dimension> AxisIterMut<'a, A, D> {
}
}

/// Split the iterator at index, yielding two disjoint iterators.
/// Splits the iterator at `index`, yielding two disjoint iterators.
///
/// **Panics** if `index` is strictly greater than the iterator's length
/// `index` is relative to the current state of the iterator (which is not
/// necessarily the start of the axis).
///
/// **Panics** if `index` is strictly greater than the iterator's remaining
/// length.
pub fn split_at(self, index: usize) -> (Self, Self) {
let (left, right) = self.iter.split_at(index);
(
Expand Down Expand Up @@ -1028,7 +1052,7 @@ where
D: Dimension,
{
fn len(&self) -> usize {
self.size_hint().0
self.iter.len()
}
}

Expand All @@ -1048,7 +1072,16 @@ impl<'a, A, D: Dimension> NdProducer for AxisIter<'a, A, D> {
}
#[doc(hidden)]
fn as_ptr(&self) -> Self::Ptr {
self.iter.ptr
if self.len() > 0 {
// `self.iter.index` is guaranteed to be in-bounds if any of the
// iterator remains (i.e. if `self.len() > 0`).
unsafe { self.iter.offset(self.iter.index) }
} else {
// In this case, `self.iter.index` may be past the end, so we must
// not call `.offset()`. It's okay to return a dangling pointer
// because it will never be used in the length 0 case.
std::ptr::NonNull::dangling().as_ptr()
}
}

fn contiguous_stride(&self) -> isize {
Expand All @@ -1065,7 +1098,7 @@ impl<'a, A, D: Dimension> NdProducer for AxisIter<'a, A, D> {
}
#[doc(hidden)]
unsafe fn uget_ptr(&self, i: &Self::Dim) -> Self::Ptr {
self.iter.ptr.offset(self.iter.stride * i[0] as isize)
self.iter.offset(self.iter.index + i[0])
}

#[doc(hidden)]
Expand Down Expand Up @@ -1096,7 +1129,16 @@ impl<'a, A, D: Dimension> NdProducer for AxisIterMut<'a, A, D> {
}
#[doc(hidden)]
fn as_ptr(&self) -> Self::Ptr {
self.iter.ptr
if self.len() > 0 {
// `self.iter.index` is guaranteed to be in-bounds if any of the
// iterator remains (i.e. if `self.len() > 0`).
unsafe { self.iter.offset(self.iter.index) }
} else {
// In this case, `self.iter.index` may be past the end, so we must
// not call `.offset()`. It's okay to return a dangling pointer
// because it will never be used in the length 0 case.
std::ptr::NonNull::dangling().as_ptr()
}
}

fn contiguous_stride(&self) -> isize {
Expand All @@ -1113,7 +1155,7 @@ impl<'a, A, D: Dimension> NdProducer for AxisIterMut<'a, A, D> {
}
#[doc(hidden)]
unsafe fn uget_ptr(&self, i: &Self::Dim) -> Self::Ptr {
self.iter.ptr.offset(self.iter.stride * i[0] as isize)
self.iter.offset(self.iter.index + i[0])
}

#[doc(hidden)]
Expand Down Expand Up @@ -1164,21 +1206,28 @@ clone_bounds!(
///
/// Returns an axis iterator with the correct stride to move between chunks,
/// the number of chunks, and the shape of the last chunk.
///
/// **Panics** if `size == 0`.
fn chunk_iter_parts<A, D: Dimension>(
v: ArrayView<'_, A, D>,
axis: Axis,
size: usize,
) -> (AxisIterCore<A, D>, usize, D) {
assert_ne!(size, 0, "Chunk size must be nonzero.");
let axis_len = v.len_of(axis);
let size = if size > axis_len { axis_len } else { size };
let n_whole_chunks = axis_len / size;
let chunk_remainder = axis_len % size;
let iter_len = if chunk_remainder == 0 {
n_whole_chunks
} else {
n_whole_chunks + 1
};
let stride = v.stride_of(axis) * size as isize;
let stride = if n_whole_chunks == 0 {
// This case avoids potential overflow when `size > axis_len`.
0
} else {
v.stride_of(axis) * size as isize
};

let axis = axis.index();
let mut inner_dim = v.dim.clone();
Expand All @@ -1193,7 +1242,7 @@ fn chunk_iter_parts<A, D: Dimension>(

let iter = AxisIterCore {
index: 0,
len: iter_len,
end: iter_len,
stride,
inner_dim,
inner_strides: v.strides,
Expand Down Expand Up @@ -1270,7 +1319,7 @@ macro_rules! chunk_iter_impl {
D: Dimension,
{
fn next_back(&mut self) -> Option<Self::Item> {
let is_uneven = self.iter.len > self.n_whole_chunks;
let is_uneven = self.iter.end > self.n_whole_chunks;
let res = self.iter.next_back();
self.get_subview(res, is_uneven)
}
Expand Down
Loading

0 comments on commit bc795b8

Please sign in to comment.