diff --git a/src/impl_methods.rs b/src/impl_methods.rs
index 08eed3038..b9bb2bace 100644
--- a/src/impl_methods.rs
+++ b/src/impl_methods.rs
@@ -1005,7 +1005,7 @@ where
/// The last view may have less elements if `size` does not divide
/// the axis' dimension.
///
- /// **Panics** if `axis` is out of bounds.
+ /// **Panics** if `axis` is out of bounds or if `size` is zero.
///
/// ```
/// use ndarray::Array;
@@ -1036,7 +1036,7 @@ where
///
/// Iterator element is `ArrayViewMut`
///
- /// **Panics** if `axis` is out of bounds.
+ /// **Panics** if `axis` is out of bounds or if `size` is zero.
pub fn axis_chunks_iter_mut(&mut self, axis: Axis, size: usize) -> AxisChunksIterMut<'_, A, D>
where
S: DataMut,
diff --git a/src/iterators/mod.rs b/src/iterators/mod.rs
index b98d944f8..0f15c4b50 100644
--- a/src/iterators/mod.rs
+++ b/src/iterators/mod.rs
@@ -738,11 +738,19 @@ where
#[derive(Debug)]
pub struct AxisIterCore {
+ /// Index along the axis of the value of `.next()`, relative to the start
+ /// of the axis.
index: Ix,
- len: Ix,
+ /// (Exclusive) upper bound on `index`. Initially, this is equal to the
+ /// length of the axis.
+ end: Ix,
+ /// Stride along the axis (offset between consecutive pointers).
stride: Ixs,
+ /// Shape of the iterator's items.
inner_dim: D,
+ /// Strides of the iterator's items.
inner_strides: D,
+ /// Pointer corresponding to `index == 0`.
ptr: *mut A,
}
@@ -751,7 +759,7 @@ clone_bounds!(
AxisIterCore[A, D] {
@copy {
index,
- len,
+ end,
stride,
ptr,
}
@@ -767,54 +775,53 @@ impl AxisIterCore {
Di: RemoveAxis,
S: Data,
{
- let shape = v.shape()[axis.index()];
- let stride = v.strides()[axis.index()];
AxisIterCore {
index: 0,
- len: shape,
- stride,
+ end: v.len_of(axis),
+ stride: v.stride_of(axis),
inner_dim: v.dim.remove_axis(axis),
inner_strides: v.strides.remove_axis(axis),
ptr: v.ptr,
}
}
+ #[inline]
unsafe fn offset(&self, index: usize) -> *mut A {
debug_assert!(
- index <= self.len,
- "index={}, len={}, stride={}",
+ index < self.end,
+ "index={}, end={}, stride={}",
index,
- self.len,
+ self.end,
self.stride
);
self.ptr.offset(index as isize * self.stride)
}
- /// Split the iterator at index, yielding two disjoint iterators.
+ /// Splits the iterator at `index`, yielding two disjoint iterators.
///
- /// **Panics** if `index` is strictly greater than the iterator's length
+ /// `index` is relative to the current state of the iterator (which is not
+ /// necessarily the start of the axis).
+ ///
+ /// **Panics** if `index` is strictly greater than the iterator's remaining
+ /// length.
fn split_at(self, index: usize) -> (Self, Self) {
- assert!(index <= self.len);
- let right_ptr = if index != self.len {
- unsafe { self.offset(index) }
- } else {
- self.ptr
- };
+ assert!(index <= self.len());
+ let mid = self.index + index;
let left = AxisIterCore {
- index: 0,
- len: index,
+ index: self.index,
+ end: mid,
stride: self.stride,
inner_dim: self.inner_dim.clone(),
inner_strides: self.inner_strides.clone(),
ptr: self.ptr,
};
let right = AxisIterCore {
- index: 0,
- len: self.len - index,
+ index: mid,
+ end: self.end,
stride: self.stride,
inner_dim: self.inner_dim,
inner_strides: self.inner_strides,
- ptr: right_ptr,
+ ptr: self.ptr,
};
(left, right)
}
@@ -827,7 +834,7 @@ where
type Item = *mut A;
fn next(&mut self) -> Option {
- if self.index >= self.len {
+ if self.index >= self.end {
None
} else {
let ptr = unsafe { self.offset(self.index) };
@@ -837,7 +844,7 @@ where
}
fn size_hint(&self) -> (usize, Option) {
- let len = self.len - self.index;
+ let len = self.len();
(len, Some(len))
}
}
@@ -847,16 +854,25 @@ where
D: Dimension,
{
fn next_back(&mut self) -> Option {
- if self.index >= self.len {
+ if self.index >= self.end {
None
} else {
- self.len -= 1;
- let ptr = unsafe { self.offset(self.len) };
+ let ptr = unsafe { self.offset(self.end - 1) };
+ self.end -= 1;
Some(ptr)
}
}
}
+impl ExactSizeIterator for AxisIterCore
+where
+ D: Dimension,
+{
+ fn len(&self) -> usize {
+ self.end - self.index
+ }
+}
+
/// An iterator that traverses over an axis and
/// and yields each subview.
///
@@ -899,9 +915,13 @@ impl<'a, A, D: Dimension> AxisIter<'a, A, D> {
}
}
- /// Split the iterator at index, yielding two disjoint iterators.
+ /// Splits the iterator at `index`, yielding two disjoint iterators.
///
- /// **Panics** if `index` is strictly greater than the iterator's length
+ /// `index` is relative to the current state of the iterator (which is not
+ /// necessarily the start of the axis).
+ ///
+ /// **Panics** if `index` is strictly greater than the iterator's remaining
+ /// length.
pub fn split_at(self, index: usize) -> (Self, Self) {
let (left, right) = self.iter.split_at(index);
(
@@ -946,7 +966,7 @@ where
D: Dimension,
{
fn len(&self) -> usize {
- self.size_hint().0
+ self.iter.len()
}
}
@@ -981,9 +1001,13 @@ impl<'a, A, D: Dimension> AxisIterMut<'a, A, D> {
}
}
- /// Split the iterator at index, yielding two disjoint iterators.
+ /// Splits the iterator at `index`, yielding two disjoint iterators.
///
- /// **Panics** if `index` is strictly greater than the iterator's length
+ /// `index` is relative to the current state of the iterator (which is not
+ /// necessarily the start of the axis).
+ ///
+ /// **Panics** if `index` is strictly greater than the iterator's remaining
+ /// length.
pub fn split_at(self, index: usize) -> (Self, Self) {
let (left, right) = self.iter.split_at(index);
(
@@ -1028,7 +1052,7 @@ where
D: Dimension,
{
fn len(&self) -> usize {
- self.size_hint().0
+ self.iter.len()
}
}
@@ -1048,7 +1072,16 @@ impl<'a, A, D: Dimension> NdProducer for AxisIter<'a, A, D> {
}
#[doc(hidden)]
fn as_ptr(&self) -> Self::Ptr {
- self.iter.ptr
+ if self.len() > 0 {
+ // `self.iter.index` is guaranteed to be in-bounds if any of the
+ // iterator remains (i.e. if `self.len() > 0`).
+ unsafe { self.iter.offset(self.iter.index) }
+ } else {
+ // In this case, `self.iter.index` may be past the end, so we must
+ // not call `.offset()`. It's okay to return a dangling pointer
+ // because it will never be used in the length 0 case.
+ std::ptr::NonNull::dangling().as_ptr()
+ }
}
fn contiguous_stride(&self) -> isize {
@@ -1065,7 +1098,7 @@ impl<'a, A, D: Dimension> NdProducer for AxisIter<'a, A, D> {
}
#[doc(hidden)]
unsafe fn uget_ptr(&self, i: &Self::Dim) -> Self::Ptr {
- self.iter.ptr.offset(self.iter.stride * i[0] as isize)
+ self.iter.offset(self.iter.index + i[0])
}
#[doc(hidden)]
@@ -1096,7 +1129,16 @@ impl<'a, A, D: Dimension> NdProducer for AxisIterMut<'a, A, D> {
}
#[doc(hidden)]
fn as_ptr(&self) -> Self::Ptr {
- self.iter.ptr
+ if self.len() > 0 {
+ // `self.iter.index` is guaranteed to be in-bounds if any of the
+ // iterator remains (i.e. if `self.len() > 0`).
+ unsafe { self.iter.offset(self.iter.index) }
+ } else {
+ // In this case, `self.iter.index` may be past the end, so we must
+ // not call `.offset()`. It's okay to return a dangling pointer
+ // because it will never be used in the length 0 case.
+ std::ptr::NonNull::dangling().as_ptr()
+ }
}
fn contiguous_stride(&self) -> isize {
@@ -1113,7 +1155,7 @@ impl<'a, A, D: Dimension> NdProducer for AxisIterMut<'a, A, D> {
}
#[doc(hidden)]
unsafe fn uget_ptr(&self, i: &Self::Dim) -> Self::Ptr {
- self.iter.ptr.offset(self.iter.stride * i[0] as isize)
+ self.iter.offset(self.iter.index + i[0])
}
#[doc(hidden)]
@@ -1164,13 +1206,15 @@ clone_bounds!(
///
/// Returns an axis iterator with the correct stride to move between chunks,
/// the number of chunks, and the shape of the last chunk.
+///
+/// **Panics** if `size == 0`.
fn chunk_iter_parts(
v: ArrayView<'_, A, D>,
axis: Axis,
size: usize,
) -> (AxisIterCore, usize, D) {
+ assert_ne!(size, 0, "Chunk size must be nonzero.");
let axis_len = v.len_of(axis);
- let size = if size > axis_len { axis_len } else { size };
let n_whole_chunks = axis_len / size;
let chunk_remainder = axis_len % size;
let iter_len = if chunk_remainder == 0 {
@@ -1178,7 +1222,12 @@ fn chunk_iter_parts(
} else {
n_whole_chunks + 1
};
- let stride = v.stride_of(axis) * size as isize;
+ let stride = if n_whole_chunks == 0 {
+ // This case avoids potential overflow when `size > axis_len`.
+ 0
+ } else {
+ v.stride_of(axis) * size as isize
+ };
let axis = axis.index();
let mut inner_dim = v.dim.clone();
@@ -1193,7 +1242,7 @@ fn chunk_iter_parts(
let iter = AxisIterCore {
index: 0,
- len: iter_len,
+ end: iter_len,
stride,
inner_dim,
inner_strides: v.strides,
@@ -1270,7 +1319,7 @@ macro_rules! chunk_iter_impl {
D: Dimension,
{
fn next_back(&mut self) -> Option {
- let is_uneven = self.iter.len > self.n_whole_chunks;
+ let is_uneven = self.iter.end > self.n_whole_chunks;
let res = self.iter.next_back();
self.get_subview(res, is_uneven)
}
diff --git a/tests/iterators.rs b/tests/iterators.rs
index a1c19c39a..6408b2f8d 100644
--- a/tests/iterators.rs
+++ b/tests/iterators.rs
@@ -7,7 +7,7 @@
use ndarray::prelude::*;
use ndarray::Ix;
-use ndarray::{arr2, arr3, aview1, indices, s, Axis, Data, Dimension, Slice};
+use ndarray::{arr2, arr3, aview1, indices, s, Axis, Data, Dimension, Slice, Zip};
use itertools::assert_equal;
use itertools::{enumerate, rev};
@@ -260,6 +260,68 @@ fn axis_iter() {
);
}
+#[test]
+fn axis_iter_split_at() {
+ let a = Array::from_iter(0..5);
+ let iter = a.axis_iter(Axis(0));
+ let all: Vec<_> = iter.clone().collect();
+ for mid in 0..=all.len() {
+ let (left, right) = iter.clone().split_at(mid);
+ assert_eq!(&all[..mid], &left.collect::>()[..]);
+ assert_eq!(&all[mid..], &right.collect::>()[..]);
+ }
+}
+
+#[test]
+fn axis_iter_split_at_partially_consumed() {
+ let a = Array::from_iter(0..5);
+ let mut iter = a.axis_iter(Axis(0));
+ while iter.next().is_some() {
+ let remaining: Vec<_> = iter.clone().collect();
+ for mid in 0..=remaining.len() {
+ let (left, right) = iter.clone().split_at(mid);
+ assert_eq!(&remaining[..mid], &left.collect::>()[..]);
+ assert_eq!(&remaining[mid..], &right.collect::>()[..]);
+ }
+ }
+}
+
+#[test]
+fn axis_iter_zip() {
+ let a = Array::from_iter(0..5);
+ let iter = a.axis_iter(Axis(0));
+ let mut b = Array::zeros(5);
+ Zip::from(&mut b).and(iter).apply(|b, a| *b = a[()]);
+ assert_eq!(a, b);
+}
+
+#[test]
+fn axis_iter_zip_partially_consumed() {
+ let a = Array::from_iter(0..5);
+ let mut iter = a.axis_iter(Axis(0));
+ let mut consumed = 0;
+ while iter.next().is_some() {
+ consumed += 1;
+ let mut b = Array::zeros(a.len() - consumed);
+ Zip::from(&mut b).and(iter.clone()).apply(|b, a| *b = a[()]);
+ assert_eq!(a.slice(s![consumed..]), b);
+ }
+}
+
+#[test]
+fn axis_iter_zip_partially_consumed_discontiguous() {
+ let a = Array::from_iter(0..5);
+ let mut iter = a.axis_iter(Axis(0));
+ let mut consumed = 0;
+ while iter.next().is_some() {
+ consumed += 1;
+ let mut b = Array::zeros((a.len() - consumed) * 2);
+ b.slice_collapse(s![..;2]);
+ Zip::from(&mut b).and(iter.clone()).apply(|b, a| *b = a[()]);
+ assert_eq!(a.slice(s![consumed..]), b);
+ }
+}
+
#[test]
fn outer_iter_corner_cases() {
let a2 = ArcArray::::zeros((0, 3));
@@ -364,6 +426,89 @@ fn axis_chunks_iter() {
assert_equal(it, vec![a.view()]);
}
+#[test]
+fn axis_iter_mut_split_at() {
+ let mut a = Array::from_iter(0..5);
+ let mut a_clone = a.clone();
+ let all: Vec<_> = a_clone.axis_iter_mut(Axis(0)).collect();
+ for mid in 0..=all.len() {
+ let (left, right) = a.axis_iter_mut(Axis(0)).split_at(mid);
+ assert_eq!(&all[..mid], &left.collect::>()[..]);
+ assert_eq!(&all[mid..], &right.collect::>()[..]);
+ }
+}
+
+#[test]
+fn axis_iter_mut_split_at_partially_consumed() {
+ let mut a = Array::from_iter(0..5);
+ for consumed in 1..=a.len() {
+ for mid in 0..=(a.len() - consumed) {
+ let mut a_clone = a.clone();
+ let remaining: Vec<_> = {
+ let mut iter = a_clone.axis_iter_mut(Axis(0));
+ for _ in 0..consumed {
+ iter.next();
+ }
+ iter.collect()
+ };
+ let (left, right) = {
+ let mut iter = a.axis_iter_mut(Axis(0));
+ for _ in 0..consumed {
+ iter.next();
+ }
+ iter.split_at(mid)
+ };
+ assert_eq!(&remaining[..mid], &left.collect::>()[..]);
+ assert_eq!(&remaining[mid..], &right.collect::>()[..]);
+ }
+ }
+}
+
+#[test]
+fn axis_iter_mut_zip() {
+ let orig = Array::from_iter(0..5);
+ let mut cloned = orig.clone();
+ let iter = cloned.axis_iter_mut(Axis(0));
+ let mut b = Array::zeros(5);
+ Zip::from(&mut b).and(iter).apply(|b, mut a| {
+ a[()] += 1;
+ *b = a[()];
+ });
+ assert_eq!(cloned, b);
+ assert_eq!(cloned, orig + 1);
+}
+
+#[test]
+fn axis_iter_mut_zip_partially_consumed() {
+ let mut a = Array::from_iter(0..5);
+ for consumed in 1..=a.len() {
+ let remaining = a.len() - consumed;
+ let mut iter = a.axis_iter_mut(Axis(0));
+ for _ in 0..consumed {
+ iter.next();
+ }
+ let mut b = Array::zeros(remaining);
+ Zip::from(&mut b).and(iter).apply(|b, a| *b = a[()]);
+ assert_eq!(a.slice(s![consumed..]), b);
+ }
+}
+
+#[test]
+fn axis_iter_mut_zip_partially_consumed_discontiguous() {
+ let mut a = Array::from_iter(0..5);
+ for consumed in 1..=a.len() {
+ let remaining = a.len() - consumed;
+ let mut iter = a.axis_iter_mut(Axis(0));
+ for _ in 0..consumed {
+ iter.next();
+ }
+ let mut b = Array::zeros(remaining * 2);
+ b.slice_collapse(s![..;2]);
+ Zip::from(&mut b).and(iter).apply(|b, a| *b = a[()]);
+ assert_eq!(a.slice(s![consumed..]), b);
+ }
+}
+
#[test]
fn axis_chunks_iter_corner_cases() {
// examples provided by @bluss in PR #65
@@ -427,6 +572,19 @@ fn axis_chunks_iter_zero_stride() {
}
}
+#[should_panic]
+#[test]
+fn axis_chunks_iter_zero_chunk_size() {
+ let a = Array::from_iter(0..5);
+ a.axis_chunks_iter(Axis(0), 0);
+}
+
+#[test]
+fn axis_chunks_iter_zero_axis_len() {
+ let a = Array::from_iter(0..0);
+ assert!(a.axis_chunks_iter(Axis(0), 5).next().is_none());
+}
+
#[test]
fn axis_chunks_iter_mut() {
let a = ArcArray::from_iter(0..24);
@@ -438,6 +596,19 @@ fn axis_chunks_iter_mut() {
assert_eq!(col0, arr3(&[[[42, 1], [2, 3]], [[12, 13], [14, 15]]]));
}
+#[should_panic]
+#[test]
+fn axis_chunks_iter_mut_zero_chunk_size() {
+ let mut a = Array::from_iter(0..5);
+ a.axis_chunks_iter_mut(Axis(0), 0);
+}
+
+#[test]
+fn axis_chunks_iter_mut_zero_axis_len() {
+ let mut a = Array::from_iter(0..0);
+ assert!(a.axis_chunks_iter_mut(Axis(0), 5).next().is_none());
+}
+
#[test]
fn outer_iter_size_hint() {
// Check that the size hint is correctly computed