From 992e3a2cf7eaf4399e1c40235f9036dc318fc9d5 Mon Sep 17 00:00:00 2001 From: Jim Turner Date: Mon, 22 Jul 2019 19:40:42 -0400 Subject: [PATCH 1/5] Add more tests for AxisIter/Mut --- tests/iterators.rs | 147 ++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 146 insertions(+), 1 deletion(-) diff --git a/tests/iterators.rs b/tests/iterators.rs index c9bb4289f..bb3381bb2 100644 --- a/tests/iterators.rs +++ b/tests/iterators.rs @@ -9,7 +9,7 @@ extern crate ndarray; use ndarray::prelude::*; use ndarray::Ix; -use ndarray::{arr2, arr3, aview1, indices, s, Axis, Data, Dimension, Slice}; +use ndarray::{arr2, arr3, aview1, indices, s, Axis, Data, Dimension, Slice, Zip}; use itertools::assert_equal; use itertools::{enumerate, rev}; @@ -262,6 +262,68 @@ fn axis_iter() { ); } +#[test] +fn axis_iter_split_at() { + let a = Array::from_iter(0..5); + let iter = a.axis_iter(Axis(0)); + let all: Vec<_> = iter.clone().collect(); + for mid in 0..=all.len() { + let (left, right) = iter.clone().split_at(mid); + assert_eq!(&all[..mid], &left.collect::>()[..]); + assert_eq!(&all[mid..], &right.collect::>()[..]); + } +} + +#[test] +fn axis_iter_split_at_partially_consumed() { + let a = Array::from_iter(0..5); + let mut iter = a.axis_iter(Axis(0)); + while iter.next().is_some() { + let remaining: Vec<_> = iter.clone().collect(); + for mid in 0..=remaining.len() { + let (left, right) = iter.clone().split_at(mid); + assert_eq!(&remaining[..mid], &left.collect::>()[..]); + assert_eq!(&remaining[mid..], &right.collect::>()[..]); + } + } +} + +#[test] +fn axis_iter_zip() { + let a = Array::from_iter(0..5); + let iter = a.axis_iter(Axis(0)); + let mut b = Array::zeros(5); + Zip::from(&mut b).and(iter).apply(|b, a| *b = a[()]); + assert_eq!(a, b); +} + +#[test] +fn axis_iter_zip_partially_consumed() { + let a = Array::from_iter(0..5); + let mut iter = a.axis_iter(Axis(0)); + let mut consumed = 0; + while iter.next().is_some() { + consumed += 1; + let mut b = Array::zeros(a.len() - consumed); + Zip::from(&mut b).and(iter.clone()).apply(|b, a| *b = a[()]); + assert_eq!(a.slice(s![consumed..]), b); + } +} + +#[test] +fn axis_iter_zip_partially_consumed_discontiguous() { + let a = Array::from_iter(0..5); + let mut iter = a.axis_iter(Axis(0)); + let mut consumed = 0; + while iter.next().is_some() { + consumed += 1; + let mut b = Array::zeros((a.len() - consumed) * 2); + b.slice_collapse(s![..;2]); + Zip::from(&mut b).and(iter.clone()).apply(|b, a| *b = a[()]); + assert_eq!(a.slice(s![consumed..]), b); + } +} + #[test] fn outer_iter_corner_cases() { let a2 = ArcArray::::zeros((0, 3)); @@ -366,6 +428,89 @@ fn axis_chunks_iter() { assert_equal(it, vec![a.view()]); } +#[test] +fn axis_iter_mut_split_at() { + let mut a = Array::from_iter(0..5); + let mut a_clone = a.clone(); + let all: Vec<_> = a_clone.axis_iter_mut(Axis(0)).collect(); + for mid in 0..=all.len() { + let (left, right) = a.axis_iter_mut(Axis(0)).split_at(mid); + assert_eq!(&all[..mid], &left.collect::>()[..]); + assert_eq!(&all[mid..], &right.collect::>()[..]); + } +} + +#[test] +fn axis_iter_mut_split_at_partially_consumed() { + let mut a = Array::from_iter(0..5); + for consumed in 1..=a.len() { + for mid in 0..=(a.len() - consumed) { + let mut a_clone = a.clone(); + let remaining: Vec<_> = { + let mut iter = a_clone.axis_iter_mut(Axis(0)); + for _ in 0..consumed { + iter.next(); + } + iter.collect() + }; + let (left, right) = { + let mut iter = a.axis_iter_mut(Axis(0)); + for _ in 0..consumed { + iter.next(); + } + iter.split_at(mid) + }; + assert_eq!(&remaining[..mid], &left.collect::>()[..]); + assert_eq!(&remaining[mid..], &right.collect::>()[..]); + } + } +} + +#[test] +fn axis_iter_mut_zip() { + let orig = Array::from_iter(0..5); + let mut cloned = orig.clone(); + let iter = cloned.axis_iter_mut(Axis(0)); + let mut b = Array::zeros(5); + Zip::from(&mut b).and(iter).apply(|b, mut a| { + a[()] += 1; + *b = a[()]; + }); + assert_eq!(cloned, b); + assert_eq!(cloned, orig + 1); +} + +#[test] +fn axis_iter_mut_zip_partially_consumed() { + let mut a = Array::from_iter(0..5); + for consumed in 1..=a.len() { + let remaining = a.len() - consumed; + let mut iter = a.axis_iter_mut(Axis(0)); + for _ in 0..consumed { + iter.next(); + } + let mut b = Array::zeros(remaining); + Zip::from(&mut b).and(iter).apply(|b, a| *b = a[()]); + assert_eq!(a.slice(s![consumed..]), b); + } +} + +#[test] +fn axis_iter_mut_zip_partially_consumed_discontiguous() { + let mut a = Array::from_iter(0..5); + for consumed in 1..=a.len() { + let remaining = a.len() - consumed; + let mut iter = a.axis_iter_mut(Axis(0)); + for _ in 0..consumed { + iter.next(); + } + let mut b = Array::zeros(remaining * 2); + b.slice_collapse(s![..;2]); + Zip::from(&mut b).and(iter).apply(|b, a| *b = a[()]); + assert_eq!(a.slice(s![consumed..]), b); + } +} + #[test] fn axis_chunks_iter_corner_cases() { // examples provided by @bluss in PR #65 From 60a17d7a785b0d8083bec517d888bb477a4c4d22 Mon Sep 17 00:00:00 2001 From: Jim Turner Date: Mon, 22 Jul 2019 19:41:00 -0400 Subject: [PATCH 2/5] Fix partially-consumed AxisIter/Mut This fixes the behavior of `.split_at()` and the `NdProducer` implementations for partially-consumed `AxisIter` or `AxisIterMut` instances. --- src/iterators/mod.rs | 120 +++++++++++++++++++++++++++++-------------- 1 file changed, 81 insertions(+), 39 deletions(-) diff --git a/src/iterators/mod.rs b/src/iterators/mod.rs index 9feb429e6..638ed08e7 100644 --- a/src/iterators/mod.rs +++ b/src/iterators/mod.rs @@ -738,11 +738,19 @@ where #[derive(Debug)] pub struct AxisIterCore { + /// Index along the axis of the value of `.next()`, relative to the start + /// of the axis. index: Ix, - len: Ix, + /// (Exclusive) upper bound on `index`. Initially, this is equal to the + /// length of the axis. + end: Ix, + /// Stride along the axis (offset between consecutive pointers). stride: Ixs, + /// Shape of the iterator's items. inner_dim: D, + /// Strides of the iterator's items. inner_strides: D, + /// Pointer corresponding to `index == 0`. ptr: *mut A, } @@ -751,7 +759,7 @@ clone_bounds!( AxisIterCore[A, D] { @copy { index, - len, + end, stride, ptr, } @@ -767,54 +775,53 @@ impl AxisIterCore { Di: RemoveAxis, S: Data, { - let shape = v.shape()[axis.index()]; - let stride = v.strides()[axis.index()]; AxisIterCore { index: 0, - len: shape, - stride, + end: v.len_of(axis), + stride: v.stride_of(axis), inner_dim: v.dim.remove_axis(axis), inner_strides: v.strides.remove_axis(axis), ptr: v.ptr, } } + #[inline] unsafe fn offset(&self, index: usize) -> *mut A { debug_assert!( - index <= self.len, - "index={}, len={}, stride={}", + index <= self.end, + "index={}, end={}, stride={}", index, - self.len, + self.end, self.stride ); self.ptr.offset(index as isize * self.stride) } - /// Split the iterator at index, yielding two disjoint iterators. + /// Splits the iterator at `index`, yielding two disjoint iterators. + /// + /// `index` is relative to the current state of the iterator (which is not + /// necessarily the start of the axis). /// - /// **Panics** if `index` is strictly greater than the iterator's length + /// **Panics** if `index` is strictly greater than the iterator's remaining + /// length. fn split_at(self, index: usize) -> (Self, Self) { - assert!(index <= self.len); - let right_ptr = if index != self.len { - unsafe { self.offset(index) } - } else { - self.ptr - }; + assert!(index <= self.len()); + let mid = self.index + index; let left = AxisIterCore { - index: 0, - len: index, + index: self.index, + end: mid, stride: self.stride, inner_dim: self.inner_dim.clone(), inner_strides: self.inner_strides.clone(), ptr: self.ptr, }; let right = AxisIterCore { - index: 0, - len: self.len - index, + index: mid, + end: self.end, stride: self.stride, inner_dim: self.inner_dim, inner_strides: self.inner_strides, - ptr: right_ptr, + ptr: self.ptr, }; (left, right) } @@ -827,7 +834,7 @@ where type Item = *mut A; fn next(&mut self) -> Option { - if self.index >= self.len { + if self.index >= self.end { None } else { let ptr = unsafe { self.offset(self.index) }; @@ -837,7 +844,7 @@ where } fn size_hint(&self) -> (usize, Option) { - let len = self.len - self.index; + let len = self.len(); (len, Some(len)) } } @@ -847,16 +854,25 @@ where D: Dimension, { fn next_back(&mut self) -> Option { - if self.index >= self.len { + if self.index >= self.end { None } else { - self.len -= 1; - let ptr = unsafe { self.offset(self.len) }; + self.end -= 1; + let ptr = unsafe { self.offset(self.end) }; Some(ptr) } } } +impl ExactSizeIterator for AxisIterCore +where + D: Dimension, +{ + fn len(&self) -> usize { + self.end - self.index + } +} + /// An iterator that traverses over an axis and /// and yields each subview. /// @@ -899,9 +915,13 @@ impl<'a, A, D: Dimension> AxisIter<'a, A, D> { } } - /// Split the iterator at index, yielding two disjoint iterators. + /// Splits the iterator at `index`, yielding two disjoint iterators. + /// + /// `index` is relative to the current state of the iterator (which is not + /// necessarily the start of the axis). /// - /// **Panics** if `index` is strictly greater than the iterator's length + /// **Panics** if `index` is strictly greater than the iterator's remaining + /// length. pub fn split_at(self, index: usize) -> (Self, Self) { let (left, right) = self.iter.split_at(index); ( @@ -946,7 +966,7 @@ where D: Dimension, { fn len(&self) -> usize { - self.size_hint().0 + self.iter.len() } } @@ -981,9 +1001,13 @@ impl<'a, A, D: Dimension> AxisIterMut<'a, A, D> { } } - /// Split the iterator at index, yielding two disjoint iterators. + /// Splits the iterator at `index`, yielding two disjoint iterators. /// - /// **Panics** if `index` is strictly greater than the iterator's length + /// `index` is relative to the current state of the iterator (which is not + /// necessarily the start of the axis). + /// + /// **Panics** if `index` is strictly greater than the iterator's remaining + /// length. pub fn split_at(self, index: usize) -> (Self, Self) { let (left, right) = self.iter.split_at(index); ( @@ -1028,7 +1052,7 @@ where D: Dimension, { fn len(&self) -> usize { - self.size_hint().0 + self.iter.len() } } @@ -1048,7 +1072,16 @@ impl<'a, A, D: Dimension> NdProducer for AxisIter<'a, A, D> { } #[doc(hidden)] fn as_ptr(&self) -> Self::Ptr { - self.iter.ptr + if self.len() > 0 { + // `self.iter.index` is guaranteed to be in-bounds if any of the + // iterator remains (i.e. if `self.len() > 0`). + unsafe { self.iter.offset(self.iter.index) } + } else { + // In this case, `self.iter.index` may be past the end, so we must + // not call `.offset()`. It's okay to return a dangling pointer + // because it will never be used in the length 0 case. + std::ptr::NonNull::dangling().as_ptr() + } } fn contiguous_stride(&self) -> isize { @@ -1065,7 +1098,7 @@ impl<'a, A, D: Dimension> NdProducer for AxisIter<'a, A, D> { } #[doc(hidden)] unsafe fn uget_ptr(&self, i: &Self::Dim) -> Self::Ptr { - self.iter.ptr.offset(self.iter.stride * i[0] as isize) + self.iter.offset(self.iter.index + i[0]) } #[doc(hidden)] @@ -1096,7 +1129,16 @@ impl<'a, A, D: Dimension> NdProducer for AxisIterMut<'a, A, D> { } #[doc(hidden)] fn as_ptr(&self) -> Self::Ptr { - self.iter.ptr + if self.len() > 0 { + // `self.iter.index` is guaranteed to be in-bounds if any of the + // iterator remains (i.e. if `self.len() > 0`). + unsafe { self.iter.offset(self.iter.index) } + } else { + // In this case, `self.iter.index` may be past the end, so we must + // not call `.offset()`. It's okay to return a dangling pointer + // because it will never be used in the length 0 case. + std::ptr::NonNull::dangling().as_ptr() + } } fn contiguous_stride(&self) -> isize { @@ -1113,7 +1155,7 @@ impl<'a, A, D: Dimension> NdProducer for AxisIterMut<'a, A, D> { } #[doc(hidden)] unsafe fn uget_ptr(&self, i: &Self::Dim) -> Self::Ptr { - self.iter.ptr.offset(self.iter.stride * i[0] as isize) + self.iter.offset(self.iter.index + i[0]) } #[doc(hidden)] @@ -1193,7 +1235,7 @@ fn chunk_iter_parts( let iter = AxisIterCore { index: 0, - len: iter_len, + end: iter_len, stride, inner_dim, inner_strides: v.strides, @@ -1270,7 +1312,7 @@ macro_rules! chunk_iter_impl { D: Dimension, { fn next_back(&mut self) -> Option { - let is_uneven = self.iter.len > self.n_whole_chunks; + let is_uneven = self.iter.end > self.n_whole_chunks; let res = self.iter.next_back(); self.get_subview(res, is_uneven) } From d0c4c55ed94c0b68742d4067738bf196afc22fae Mon Sep 17 00:00:00 2001 From: Jim Turner Date: Mon, 22 Jul 2019 19:41:06 -0400 Subject: [PATCH 3/5] Make IterAxis/Mut debug check more strict --- src/iterators/mod.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/iterators/mod.rs b/src/iterators/mod.rs index 638ed08e7..e3bfc7ec8 100644 --- a/src/iterators/mod.rs +++ b/src/iterators/mod.rs @@ -788,7 +788,7 @@ impl AxisIterCore { #[inline] unsafe fn offset(&self, index: usize) -> *mut A { debug_assert!( - index <= self.end, + index < self.end, "index={}, end={}, stride={}", index, self.end, @@ -857,8 +857,8 @@ where if self.index >= self.end { None } else { + let ptr = unsafe { self.offset(self.end - 1) }; self.end -= 1; - let ptr = unsafe { self.offset(self.end) }; Some(ptr) } } From a0130ad619a61fa397d2059c22f7cafa714b49bd Mon Sep 17 00:00:00 2001 From: Jim Turner Date: Mon, 22 Jul 2019 19:41:16 -0400 Subject: [PATCH 4/5] Add tests for zero handling in AxisChunksIter/Mut --- tests/iterators.rs | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/tests/iterators.rs b/tests/iterators.rs index bb3381bb2..ac5ea1229 100644 --- a/tests/iterators.rs +++ b/tests/iterators.rs @@ -574,6 +574,19 @@ fn axis_chunks_iter_zero_stride() { } } +#[should_panic] +#[test] +fn axis_chunks_iter_zero_chunk_size() { + let a = Array::from_iter(0..5); + a.axis_chunks_iter(Axis(0), 0); +} + +#[test] +fn axis_chunks_iter_zero_axis_len() { + let a = Array::from_iter(0..0); + assert!(a.axis_chunks_iter(Axis(0), 5).next().is_none()); +} + #[test] fn axis_chunks_iter_mut() { let a = ArcArray::from_iter(0..24); @@ -585,6 +598,19 @@ fn axis_chunks_iter_mut() { assert_eq!(col0, arr3(&[[[42, 1], [2, 3]], [[12, 13], [14, 15]]])); } +#[should_panic] +#[test] +fn axis_chunks_iter_mut_zero_chunk_size() { + let mut a = Array::from_iter(0..5); + a.axis_chunks_iter_mut(Axis(0), 0); +} + +#[test] +fn axis_chunks_iter_mut_zero_axis_len() { + let mut a = Array::from_iter(0..0); + assert!(a.axis_chunks_iter_mut(Axis(0), 5).next().is_none()); +} + #[test] fn outer_iter_size_hint() { // Check that the size hint is correctly computed From 4f857a319ff279474d51399dc9c3a0c7b8739728 Mon Sep 17 00:00:00 2001 From: Jim Turner Date: Mon, 22 Jul 2019 19:41:21 -0400 Subject: [PATCH 5/5] Fix zero handling in AxisChunksIter/Mut Now, chunk size of zero and axis length of zero are handled correctly. --- src/impl_methods.rs | 4 ++-- src/iterators/mod.rs | 11 +++++++++-- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/src/impl_methods.rs b/src/impl_methods.rs index 5f89005dc..e119e764d 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -1005,7 +1005,7 @@ where /// The last view may have less elements if `size` does not divide /// the axis' dimension. /// - /// **Panics** if `axis` is out of bounds. + /// **Panics** if `axis` is out of bounds or if `size` is zero. /// /// ``` /// use ndarray::Array; @@ -1036,7 +1036,7 @@ where /// /// Iterator element is `ArrayViewMut` /// - /// **Panics** if `axis` is out of bounds. + /// **Panics** if `axis` is out of bounds or if `size` is zero. pub fn axis_chunks_iter_mut(&mut self, axis: Axis, size: usize) -> AxisChunksIterMut where S: DataMut, diff --git a/src/iterators/mod.rs b/src/iterators/mod.rs index e3bfc7ec8..3f14bc570 100644 --- a/src/iterators/mod.rs +++ b/src/iterators/mod.rs @@ -1206,13 +1206,15 @@ clone_bounds!( /// /// Returns an axis iterator with the correct stride to move between chunks, /// the number of chunks, and the shape of the last chunk. +/// +/// **Panics** if `size == 0`. fn chunk_iter_parts( v: ArrayView, axis: Axis, size: usize, ) -> (AxisIterCore, usize, D) { + assert_ne!(size, 0, "Chunk size must be nonzero."); let axis_len = v.len_of(axis); - let size = if size > axis_len { axis_len } else { size }; let n_whole_chunks = axis_len / size; let chunk_remainder = axis_len % size; let iter_len = if chunk_remainder == 0 { @@ -1220,7 +1222,12 @@ fn chunk_iter_parts( } else { n_whole_chunks + 1 }; - let stride = v.stride_of(axis) * size as isize; + let stride = if n_whole_chunks == 0 { + // This case avoids potential overflow when `size > axis_len`. + 0 + } else { + v.stride_of(axis) * size as isize + }; let axis = axis.index(); let mut inner_dim = v.dim.clone();