diff --git a/src/impl_methods.rs b/src/impl_methods.rs index 08eed3038..b9bb2bace 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -1005,7 +1005,7 @@ where /// The last view may have less elements if `size` does not divide /// the axis' dimension. /// - /// **Panics** if `axis` is out of bounds. + /// **Panics** if `axis` is out of bounds or if `size` is zero. /// /// ``` /// use ndarray::Array; @@ -1036,7 +1036,7 @@ where /// /// Iterator element is `ArrayViewMut` /// - /// **Panics** if `axis` is out of bounds. + /// **Panics** if `axis` is out of bounds or if `size` is zero. pub fn axis_chunks_iter_mut(&mut self, axis: Axis, size: usize) -> AxisChunksIterMut<'_, A, D> where S: DataMut, diff --git a/src/iterators/mod.rs b/src/iterators/mod.rs index b98d944f8..0f15c4b50 100644 --- a/src/iterators/mod.rs +++ b/src/iterators/mod.rs @@ -738,11 +738,19 @@ where #[derive(Debug)] pub struct AxisIterCore { + /// Index along the axis of the value of `.next()`, relative to the start + /// of the axis. index: Ix, - len: Ix, + /// (Exclusive) upper bound on `index`. Initially, this is equal to the + /// length of the axis. + end: Ix, + /// Stride along the axis (offset between consecutive pointers). stride: Ixs, + /// Shape of the iterator's items. inner_dim: D, + /// Strides of the iterator's items. inner_strides: D, + /// Pointer corresponding to `index == 0`. ptr: *mut A, } @@ -751,7 +759,7 @@ clone_bounds!( AxisIterCore[A, D] { @copy { index, - len, + end, stride, ptr, } @@ -767,54 +775,53 @@ impl AxisIterCore { Di: RemoveAxis, S: Data, { - let shape = v.shape()[axis.index()]; - let stride = v.strides()[axis.index()]; AxisIterCore { index: 0, - len: shape, - stride, + end: v.len_of(axis), + stride: v.stride_of(axis), inner_dim: v.dim.remove_axis(axis), inner_strides: v.strides.remove_axis(axis), ptr: v.ptr, } } + #[inline] unsafe fn offset(&self, index: usize) -> *mut A { debug_assert!( - index <= self.len, - "index={}, len={}, stride={}", + index < self.end, + "index={}, end={}, stride={}", index, - self.len, + self.end, self.stride ); self.ptr.offset(index as isize * self.stride) } - /// Split the iterator at index, yielding two disjoint iterators. + /// Splits the iterator at `index`, yielding two disjoint iterators. /// - /// **Panics** if `index` is strictly greater than the iterator's length + /// `index` is relative to the current state of the iterator (which is not + /// necessarily the start of the axis). + /// + /// **Panics** if `index` is strictly greater than the iterator's remaining + /// length. fn split_at(self, index: usize) -> (Self, Self) { - assert!(index <= self.len); - let right_ptr = if index != self.len { - unsafe { self.offset(index) } - } else { - self.ptr - }; + assert!(index <= self.len()); + let mid = self.index + index; let left = AxisIterCore { - index: 0, - len: index, + index: self.index, + end: mid, stride: self.stride, inner_dim: self.inner_dim.clone(), inner_strides: self.inner_strides.clone(), ptr: self.ptr, }; let right = AxisIterCore { - index: 0, - len: self.len - index, + index: mid, + end: self.end, stride: self.stride, inner_dim: self.inner_dim, inner_strides: self.inner_strides, - ptr: right_ptr, + ptr: self.ptr, }; (left, right) } @@ -827,7 +834,7 @@ where type Item = *mut A; fn next(&mut self) -> Option { - if self.index >= self.len { + if self.index >= self.end { None } else { let ptr = unsafe { self.offset(self.index) }; @@ -837,7 +844,7 @@ where } fn size_hint(&self) -> (usize, Option) { - let len = self.len - self.index; + let len = self.len(); (len, Some(len)) } } @@ -847,16 +854,25 @@ where D: Dimension, { fn next_back(&mut self) -> Option { - if self.index >= self.len { + if self.index >= self.end { None } else { - self.len -= 1; - let ptr = unsafe { self.offset(self.len) }; + let ptr = unsafe { self.offset(self.end - 1) }; + self.end -= 1; Some(ptr) } } } +impl ExactSizeIterator for AxisIterCore +where + D: Dimension, +{ + fn len(&self) -> usize { + self.end - self.index + } +} + /// An iterator that traverses over an axis and /// and yields each subview. /// @@ -899,9 +915,13 @@ impl<'a, A, D: Dimension> AxisIter<'a, A, D> { } } - /// Split the iterator at index, yielding two disjoint iterators. + /// Splits the iterator at `index`, yielding two disjoint iterators. /// - /// **Panics** if `index` is strictly greater than the iterator's length + /// `index` is relative to the current state of the iterator (which is not + /// necessarily the start of the axis). + /// + /// **Panics** if `index` is strictly greater than the iterator's remaining + /// length. pub fn split_at(self, index: usize) -> (Self, Self) { let (left, right) = self.iter.split_at(index); ( @@ -946,7 +966,7 @@ where D: Dimension, { fn len(&self) -> usize { - self.size_hint().0 + self.iter.len() } } @@ -981,9 +1001,13 @@ impl<'a, A, D: Dimension> AxisIterMut<'a, A, D> { } } - /// Split the iterator at index, yielding two disjoint iterators. + /// Splits the iterator at `index`, yielding two disjoint iterators. /// - /// **Panics** if `index` is strictly greater than the iterator's length + /// `index` is relative to the current state of the iterator (which is not + /// necessarily the start of the axis). + /// + /// **Panics** if `index` is strictly greater than the iterator's remaining + /// length. pub fn split_at(self, index: usize) -> (Self, Self) { let (left, right) = self.iter.split_at(index); ( @@ -1028,7 +1052,7 @@ where D: Dimension, { fn len(&self) -> usize { - self.size_hint().0 + self.iter.len() } } @@ -1048,7 +1072,16 @@ impl<'a, A, D: Dimension> NdProducer for AxisIter<'a, A, D> { } #[doc(hidden)] fn as_ptr(&self) -> Self::Ptr { - self.iter.ptr + if self.len() > 0 { + // `self.iter.index` is guaranteed to be in-bounds if any of the + // iterator remains (i.e. if `self.len() > 0`). + unsafe { self.iter.offset(self.iter.index) } + } else { + // In this case, `self.iter.index` may be past the end, so we must + // not call `.offset()`. It's okay to return a dangling pointer + // because it will never be used in the length 0 case. + std::ptr::NonNull::dangling().as_ptr() + } } fn contiguous_stride(&self) -> isize { @@ -1065,7 +1098,7 @@ impl<'a, A, D: Dimension> NdProducer for AxisIter<'a, A, D> { } #[doc(hidden)] unsafe fn uget_ptr(&self, i: &Self::Dim) -> Self::Ptr { - self.iter.ptr.offset(self.iter.stride * i[0] as isize) + self.iter.offset(self.iter.index + i[0]) } #[doc(hidden)] @@ -1096,7 +1129,16 @@ impl<'a, A, D: Dimension> NdProducer for AxisIterMut<'a, A, D> { } #[doc(hidden)] fn as_ptr(&self) -> Self::Ptr { - self.iter.ptr + if self.len() > 0 { + // `self.iter.index` is guaranteed to be in-bounds if any of the + // iterator remains (i.e. if `self.len() > 0`). + unsafe { self.iter.offset(self.iter.index) } + } else { + // In this case, `self.iter.index` may be past the end, so we must + // not call `.offset()`. It's okay to return a dangling pointer + // because it will never be used in the length 0 case. + std::ptr::NonNull::dangling().as_ptr() + } } fn contiguous_stride(&self) -> isize { @@ -1113,7 +1155,7 @@ impl<'a, A, D: Dimension> NdProducer for AxisIterMut<'a, A, D> { } #[doc(hidden)] unsafe fn uget_ptr(&self, i: &Self::Dim) -> Self::Ptr { - self.iter.ptr.offset(self.iter.stride * i[0] as isize) + self.iter.offset(self.iter.index + i[0]) } #[doc(hidden)] @@ -1164,13 +1206,15 @@ clone_bounds!( /// /// Returns an axis iterator with the correct stride to move between chunks, /// the number of chunks, and the shape of the last chunk. +/// +/// **Panics** if `size == 0`. fn chunk_iter_parts( v: ArrayView<'_, A, D>, axis: Axis, size: usize, ) -> (AxisIterCore, usize, D) { + assert_ne!(size, 0, "Chunk size must be nonzero."); let axis_len = v.len_of(axis); - let size = if size > axis_len { axis_len } else { size }; let n_whole_chunks = axis_len / size; let chunk_remainder = axis_len % size; let iter_len = if chunk_remainder == 0 { @@ -1178,7 +1222,12 @@ fn chunk_iter_parts( } else { n_whole_chunks + 1 }; - let stride = v.stride_of(axis) * size as isize; + let stride = if n_whole_chunks == 0 { + // This case avoids potential overflow when `size > axis_len`. + 0 + } else { + v.stride_of(axis) * size as isize + }; let axis = axis.index(); let mut inner_dim = v.dim.clone(); @@ -1193,7 +1242,7 @@ fn chunk_iter_parts( let iter = AxisIterCore { index: 0, - len: iter_len, + end: iter_len, stride, inner_dim, inner_strides: v.strides, @@ -1270,7 +1319,7 @@ macro_rules! chunk_iter_impl { D: Dimension, { fn next_back(&mut self) -> Option { - let is_uneven = self.iter.len > self.n_whole_chunks; + let is_uneven = self.iter.end > self.n_whole_chunks; let res = self.iter.next_back(); self.get_subview(res, is_uneven) } diff --git a/tests/iterators.rs b/tests/iterators.rs index a1c19c39a..6408b2f8d 100644 --- a/tests/iterators.rs +++ b/tests/iterators.rs @@ -7,7 +7,7 @@ use ndarray::prelude::*; use ndarray::Ix; -use ndarray::{arr2, arr3, aview1, indices, s, Axis, Data, Dimension, Slice}; +use ndarray::{arr2, arr3, aview1, indices, s, Axis, Data, Dimension, Slice, Zip}; use itertools::assert_equal; use itertools::{enumerate, rev}; @@ -260,6 +260,68 @@ fn axis_iter() { ); } +#[test] +fn axis_iter_split_at() { + let a = Array::from_iter(0..5); + let iter = a.axis_iter(Axis(0)); + let all: Vec<_> = iter.clone().collect(); + for mid in 0..=all.len() { + let (left, right) = iter.clone().split_at(mid); + assert_eq!(&all[..mid], &left.collect::>()[..]); + assert_eq!(&all[mid..], &right.collect::>()[..]); + } +} + +#[test] +fn axis_iter_split_at_partially_consumed() { + let a = Array::from_iter(0..5); + let mut iter = a.axis_iter(Axis(0)); + while iter.next().is_some() { + let remaining: Vec<_> = iter.clone().collect(); + for mid in 0..=remaining.len() { + let (left, right) = iter.clone().split_at(mid); + assert_eq!(&remaining[..mid], &left.collect::>()[..]); + assert_eq!(&remaining[mid..], &right.collect::>()[..]); + } + } +} + +#[test] +fn axis_iter_zip() { + let a = Array::from_iter(0..5); + let iter = a.axis_iter(Axis(0)); + let mut b = Array::zeros(5); + Zip::from(&mut b).and(iter).apply(|b, a| *b = a[()]); + assert_eq!(a, b); +} + +#[test] +fn axis_iter_zip_partially_consumed() { + let a = Array::from_iter(0..5); + let mut iter = a.axis_iter(Axis(0)); + let mut consumed = 0; + while iter.next().is_some() { + consumed += 1; + let mut b = Array::zeros(a.len() - consumed); + Zip::from(&mut b).and(iter.clone()).apply(|b, a| *b = a[()]); + assert_eq!(a.slice(s![consumed..]), b); + } +} + +#[test] +fn axis_iter_zip_partially_consumed_discontiguous() { + let a = Array::from_iter(0..5); + let mut iter = a.axis_iter(Axis(0)); + let mut consumed = 0; + while iter.next().is_some() { + consumed += 1; + let mut b = Array::zeros((a.len() - consumed) * 2); + b.slice_collapse(s![..;2]); + Zip::from(&mut b).and(iter.clone()).apply(|b, a| *b = a[()]); + assert_eq!(a.slice(s![consumed..]), b); + } +} + #[test] fn outer_iter_corner_cases() { let a2 = ArcArray::::zeros((0, 3)); @@ -364,6 +426,89 @@ fn axis_chunks_iter() { assert_equal(it, vec![a.view()]); } +#[test] +fn axis_iter_mut_split_at() { + let mut a = Array::from_iter(0..5); + let mut a_clone = a.clone(); + let all: Vec<_> = a_clone.axis_iter_mut(Axis(0)).collect(); + for mid in 0..=all.len() { + let (left, right) = a.axis_iter_mut(Axis(0)).split_at(mid); + assert_eq!(&all[..mid], &left.collect::>()[..]); + assert_eq!(&all[mid..], &right.collect::>()[..]); + } +} + +#[test] +fn axis_iter_mut_split_at_partially_consumed() { + let mut a = Array::from_iter(0..5); + for consumed in 1..=a.len() { + for mid in 0..=(a.len() - consumed) { + let mut a_clone = a.clone(); + let remaining: Vec<_> = { + let mut iter = a_clone.axis_iter_mut(Axis(0)); + for _ in 0..consumed { + iter.next(); + } + iter.collect() + }; + let (left, right) = { + let mut iter = a.axis_iter_mut(Axis(0)); + for _ in 0..consumed { + iter.next(); + } + iter.split_at(mid) + }; + assert_eq!(&remaining[..mid], &left.collect::>()[..]); + assert_eq!(&remaining[mid..], &right.collect::>()[..]); + } + } +} + +#[test] +fn axis_iter_mut_zip() { + let orig = Array::from_iter(0..5); + let mut cloned = orig.clone(); + let iter = cloned.axis_iter_mut(Axis(0)); + let mut b = Array::zeros(5); + Zip::from(&mut b).and(iter).apply(|b, mut a| { + a[()] += 1; + *b = a[()]; + }); + assert_eq!(cloned, b); + assert_eq!(cloned, orig + 1); +} + +#[test] +fn axis_iter_mut_zip_partially_consumed() { + let mut a = Array::from_iter(0..5); + for consumed in 1..=a.len() { + let remaining = a.len() - consumed; + let mut iter = a.axis_iter_mut(Axis(0)); + for _ in 0..consumed { + iter.next(); + } + let mut b = Array::zeros(remaining); + Zip::from(&mut b).and(iter).apply(|b, a| *b = a[()]); + assert_eq!(a.slice(s![consumed..]), b); + } +} + +#[test] +fn axis_iter_mut_zip_partially_consumed_discontiguous() { + let mut a = Array::from_iter(0..5); + for consumed in 1..=a.len() { + let remaining = a.len() - consumed; + let mut iter = a.axis_iter_mut(Axis(0)); + for _ in 0..consumed { + iter.next(); + } + let mut b = Array::zeros(remaining * 2); + b.slice_collapse(s![..;2]); + Zip::from(&mut b).and(iter).apply(|b, a| *b = a[()]); + assert_eq!(a.slice(s![consumed..]), b); + } +} + #[test] fn axis_chunks_iter_corner_cases() { // examples provided by @bluss in PR #65 @@ -427,6 +572,19 @@ fn axis_chunks_iter_zero_stride() { } } +#[should_panic] +#[test] +fn axis_chunks_iter_zero_chunk_size() { + let a = Array::from_iter(0..5); + a.axis_chunks_iter(Axis(0), 0); +} + +#[test] +fn axis_chunks_iter_zero_axis_len() { + let a = Array::from_iter(0..0); + assert!(a.axis_chunks_iter(Axis(0), 5).next().is_none()); +} + #[test] fn axis_chunks_iter_mut() { let a = ArcArray::from_iter(0..24); @@ -438,6 +596,19 @@ fn axis_chunks_iter_mut() { assert_eq!(col0, arr3(&[[[42, 1], [2, 3]], [[12, 13], [14, 15]]])); } +#[should_panic] +#[test] +fn axis_chunks_iter_mut_zero_chunk_size() { + let mut a = Array::from_iter(0..5); + a.axis_chunks_iter_mut(Axis(0), 0); +} + +#[test] +fn axis_chunks_iter_mut_zero_axis_len() { + let mut a = Array::from_iter(0..0); + assert!(a.axis_chunks_iter_mut(Axis(0), 5).next().is_none()); +} + #[test] fn outer_iter_size_hint() { // Check that the size hint is correctly computed