Skip to content

Commit

Permalink
FEAT: Use Baseiter optimizations and arbitrary order where possible
Browse files Browse the repository at this point in the history
  • Loading branch information
bluss committed Apr 12, 2021
1 parent 33506cf commit e5224a5
Show file tree
Hide file tree
Showing 10 changed files with 80 additions and 81 deletions.
35 changes: 4 additions & 31 deletions src/dimension/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -728,36 +728,6 @@ where
}
}

/// Move the axis which has the smallest absolute stride and a length
/// greater than one to be the last axis.
pub fn move_min_stride_axis_to_last<D>(dim: &mut D, strides: &mut D)
where
D: Dimension,
{
debug_assert_eq!(dim.ndim(), strides.ndim());
match dim.ndim() {
0 | 1 => {}
2 => {
if dim[1] <= 1
|| dim[0] > 1 && (strides[0] as isize).abs() < (strides[1] as isize).abs()
{
dim.slice_mut().swap(0, 1);
strides.slice_mut().swap(0, 1);
}
}
n => {
if let Some(min_stride_axis) = (0..n)
.filter(|&ax| dim[ax] > 1)
.min_by_key(|&ax| (strides[ax] as isize).abs())
{
let last = n - 1;
dim.slice_mut().swap(last, min_stride_axis);
strides.slice_mut().swap(last, min_stride_axis);
}
}
}
}

/// Remove axes with length one, except never removing the last axis.
pub(crate) fn squeeze<D>(dim: &mut D, strides: &mut D)
where
Expand Down Expand Up @@ -801,14 +771,17 @@ pub(crate) fn sort_axes_to_standard<D>(dim: &mut D, strides: &mut D)
where
D: Dimension,
{
debug_assert!(dim.ndim() > 1);
if dim.ndim() <= 1 {
return;
}
debug_assert_eq!(dim.ndim(), strides.ndim());
// bubble sort axes
let mut changed = true;
while changed {
changed = false;
for i in 0..dim.ndim() - 1 {
// make sure higher stride axes sort before.
debug_assert!(strides.get_stride(Axis(i)) >= 0);
if strides.get_stride(Axis(i)).abs() < strides.get_stride(Axis(i + 1)).abs() {
changed = true;
dim.slice_mut().swap(i, i + 1);
Expand Down
14 changes: 5 additions & 9 deletions src/impl_methods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use crate::argument_traits::AssignElem;
use crate::dimension;
use crate::dimension::IntoDimension;
use crate::dimension::{
abs_index, axes_of, do_slice, merge_axes, move_min_stride_axis_to_last,
abs_index, axes_of, do_slice, merge_axes,
offset_from_ptr_to_memory, size_of_shape_checked, stride_offset, Axes,
};
use crate::dimension::broadcast::co_broadcast;
Expand Down Expand Up @@ -316,7 +316,7 @@ where
where
S: Data,
{
IndexedIter::new(self.view().into_elements_base())
IndexedIter::new(self.view().into_elements_base_keep_dims())
}

/// Return an iterator of indexes and mutable references to the elements of the array.
Expand All @@ -329,7 +329,7 @@ where
where
S: DataMut,
{
IndexedIterMut::new(self.view_mut().into_elements_base())
IndexedIterMut::new(self.view_mut().into_elements_base_keep_dims())
}

/// Return a sliced view of the array.
Expand Down Expand Up @@ -2175,9 +2175,7 @@ where
if let Some(slc) = self.as_slice_memory_order() {
slc.iter().fold(init, f)
} else {
let mut v = self.view();
move_min_stride_axis_to_last(&mut v.dim, &mut v.strides);
v.into_elements_base().fold(init, f)
self.view().into_elements_base_any_order().fold(init, f)
}
}

Expand Down Expand Up @@ -2295,9 +2293,7 @@ where
match self.try_as_slice_memory_order_mut() {
Ok(slc) => slc.iter_mut().for_each(f),
Err(arr) => {
let mut v = arr.view_mut();
move_min_stride_axis_to_last(&mut v.dim, &mut v.strides);
v.into_elements_base().for_each(f);
arr.view_mut().into_elements_base_any_order().for_each(f);
}
}
}
Expand Down
45 changes: 34 additions & 11 deletions src/impl_views/conversions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ use alloc::slice;

use crate::imp_prelude::*;

use crate::{Baseiter, ElementsBase, ElementsBaseMut, Iter, IterMut};

use crate::iter::{self, AxisIter, AxisIterMut};
use crate::iter::{self, Iter, IterMut, AxisIter, AxisIterMut};
use crate::iterators::base::{Baseiter, ElementsBase, ElementsBaseMut, OrderOption, PreserveOrder,
ArbitraryOrder, NoOptimization};
use crate::math_cell::MathCell;
use crate::IndexLonger;

Expand Down Expand Up @@ -140,14 +140,25 @@ impl<'a, A, D> ArrayView<'a, A, D>
where
D: Dimension,
{
/// Create a base iter fromt the view with the given order option
#[inline]
pub(crate) fn into_base_iter<F: OrderOption>(self) -> Baseiter<A, D> {
unsafe { Baseiter::new_with_order::<F>(self.ptr.as_ptr(), self.dim, self.strides) }
}

#[inline]
pub(crate) fn into_elements_base_keep_dims(self) -> ElementsBase<'a, A, D> {
ElementsBase::new::<NoOptimization>(self)
}

#[inline]
pub(crate) fn into_base_iter(self) -> Baseiter<A, D> {
unsafe { Baseiter::new(self.ptr.as_ptr(), self.dim, self.strides) }
pub(crate) fn into_elements_base_preserve_order(self) -> ElementsBase<'a, A, D> {
ElementsBase::new::<PreserveOrder>(self)
}

#[inline]
pub(crate) fn into_elements_base(self) -> ElementsBase<'a, A, D> {
ElementsBase::new(self)
pub(crate) fn into_elements_base_any_order(self) -> ElementsBase<'a, A, D> {
ElementsBase::new::<ArbitraryOrder>(self)
}

pub(crate) fn into_iter_(self) -> Iter<'a, A, D> {
Expand Down Expand Up @@ -179,16 +190,28 @@ where
unsafe { RawArrayViewMut::new(self.ptr, self.dim, self.strides) }
}

/// Create a base iter fromt the view with the given order option
#[inline]
pub(crate) fn into_base_iter(self) -> Baseiter<A, D> {
unsafe { Baseiter::new(self.ptr.as_ptr(), self.dim, self.strides) }
pub(crate) fn into_base_iter<F: OrderOption>(self) -> Baseiter<A, D> {
unsafe { Baseiter::new_with_order::<F>(self.ptr.as_ptr(), self.dim, self.strides) }
}

#[inline]
pub(crate) fn into_elements_base(self) -> ElementsBaseMut<'a, A, D> {
ElementsBaseMut::new(self)
pub(crate) fn into_elements_base_keep_dims(self) -> ElementsBaseMut<'a, A, D> {
ElementsBaseMut::new::<NoOptimization>(self)
}

#[inline]
pub(crate) fn into_elements_base_preserve_order(self) -> ElementsBaseMut<'a, A, D> {
ElementsBaseMut::new::<PreserveOrder>(self)
}

#[inline]
pub(crate) fn into_elements_base_any_order(self) -> ElementsBaseMut<'a, A, D> {
ElementsBaseMut::new::<ArbitraryOrder>(self)
}


/// Return the array’s data as a slice, if it is contiguous and in standard order.
/// Otherwise return self in the Err branch of the result.
pub(crate) fn try_into_slice(self) -> Result<&'a mut [A], Self> {
Expand Down
22 changes: 6 additions & 16 deletions src/iterators/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,16 +57,6 @@ pub(crate) struct Baseiter<A, D> {
index: Option<D>,
}

impl<A, D: Dimension> Baseiter<A, D> {
/// Creating a Baseiter is unsafe because shape and stride parameters need
/// to be correct to avoid performing an unsafe pointer offset while
/// iterating.
#[inline]
pub unsafe fn new(ptr: *mut A, dim: D, strides: D) -> Baseiter<A, D> {
Self::new_with_order::<NoOptimization>(ptr, dim, strides)
}
}

impl<A, D: Dimension> Baseiter<A, D> {
/// Creating a Baseiter is unsafe because shape and stride parameters need
/// to be correct to avoid performing an unsafe pointer offset while
Expand Down Expand Up @@ -246,9 +236,9 @@ clone_bounds!(
);

impl<'a, A, D: Dimension> ElementsBase<'a, A, D> {
pub fn new(v: ArrayView<'a, A, D>) -> Self {
pub fn new<F: OrderOption>(v: ArrayView<'a, A, D>) -> Self {
ElementsBase {
inner: v.into_base_iter(),
inner: v.into_base_iter::<F>(),
life: PhantomData,
}
}
Expand Down Expand Up @@ -332,7 +322,7 @@ where
inner: if let Some(slc) = self_.to_slice() {
ElementsRepr::Slice(slc.iter())
} else {
ElementsRepr::Counted(self_.into_elements_base())
ElementsRepr::Counted(self_.into_elements_base_preserve_order())
},
}
}
Expand All @@ -346,7 +336,7 @@ where
IterMut {
inner: match self_.try_into_slice() {
Ok(x) => ElementsRepr::Slice(x.iter_mut()),
Err(self_) => ElementsRepr::Counted(self_.into_elements_base()),
Err(self_) => ElementsRepr::Counted(self_.into_elements_base_preserve_order()),
},
}
}
Expand Down Expand Up @@ -391,9 +381,9 @@ pub(crate) struct ElementsBaseMut<'a, A, D> {
}

impl<'a, A, D: Dimension> ElementsBaseMut<'a, A, D> {
pub fn new(v: ArrayViewMut<'a, A, D>) -> Self {
pub fn new<F: OrderOption>(v: ArrayViewMut<'a, A, D>) -> Self {
ElementsBaseMut {
inner: v.into_base_iter(),
inner: v.into_base_iter::<F>(),
life: PhantomData,
}
}
Expand Down
4 changes: 2 additions & 2 deletions src/iterators/chunks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ where
type IntoIter = ExactChunksIter<'a, A, D>;
fn into_iter(self) -> Self::IntoIter {
ExactChunksIter {
iter: self.base.into_elements_base(),
iter: self.base.into_elements_base_any_order(),
chunk: self.chunk,
inner_strides: self.inner_strides,
}
Expand Down Expand Up @@ -169,7 +169,7 @@ where
type IntoIter = ExactChunksIterMut<'a, A, D>;
fn into_iter(self) -> Self::IntoIter {
ExactChunksIterMut {
iter: self.base.into_elements_base(),
iter: self.base.into_elements_base_any_order(),
chunk: self.chunk,
inner_strides: self.inner_strides,
}
Expand Down
5 changes: 3 additions & 2 deletions src/iterators/lanes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::marker::PhantomData;
use crate::imp_prelude::*;
use crate::{Layout, NdProducer};
use crate::iterators::Baseiter;
use crate::iterators::base::NoOptimization;

impl_ndproducer! {
['a, A, D: Dimension]
Expand Down Expand Up @@ -83,7 +84,7 @@ where
type IntoIter = LanesIter<'a, A, D>;
fn into_iter(self) -> Self::IntoIter {
LanesIter {
iter: self.base.into_base_iter(),
iter: self.base.into_base_iter::<NoOptimization>(),
inner_len: self.inner_len,
inner_stride: self.inner_stride,
life: PhantomData,
Expand Down Expand Up @@ -134,7 +135,7 @@ where
type IntoIter = LanesIterMut<'a, A, D>;
fn into_iter(self) -> Self::IntoIter {
LanesIterMut {
iter: self.base.into_base_iter(),
iter: self.base.into_base_iter::<NoOptimization>(),
inner_len: self.inner_len,
inner_stride: self.inner_stride,
life: PhantomData,
Expand Down
2 changes: 1 addition & 1 deletion src/iterators/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
mod macros;

mod axis;
mod base;
pub(crate) mod base;
mod chunks;
pub mod iter;
mod lanes;
Expand Down
2 changes: 1 addition & 1 deletion src/iterators/windows.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ where
type IntoIter = WindowsIter<'a, A, D>;
fn into_iter(self) -> Self::IntoIter {
WindowsIter {
iter: self.base.into_elements_base(),
iter: self.base.into_elements_base_any_order(),
window: self.window,
strides: self.strides,
}
Expand Down
3 changes: 1 addition & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,7 @@ pub use crate::slice::{
MultiSliceArg, NewAxis, Slice, SliceArg, SliceInfo, SliceInfoElem, SliceNextDim,
};

use crate::iterators::Baseiter;
use crate::iterators::{ElementsBase, ElementsBaseMut, Iter, IterMut};
use crate::iterators::{ElementsBase, ElementsBaseMut};

pub use crate::arraytraits::AsArray;
#[cfg(feature = "std")]
Expand Down
29 changes: 23 additions & 6 deletions tests/windows.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
clippy::many_single_char_names
)]

use std::collections::HashSet;
use std::hash::Hash;

use ndarray::prelude::*;
use ndarray::Zip;

Expand Down Expand Up @@ -117,6 +120,20 @@ fn test_window_zip() {
}
}

fn set<T>(iter: impl IntoIterator<Item = T>) -> HashSet<T>
where
T: Eq + Hash
{
iter.into_iter().collect()
}

/// Assert equal sets (same collection but order doesn't matter)
macro_rules! assert_set_eq {
($a:expr, $b:expr) => {
assert_eq!(set($a), set($b))
}
}

#[test]
fn test_window_neg_stride() {
let array = Array::from_iter(1..10).into_shape((3, 3)).unwrap();
Expand All @@ -131,24 +148,24 @@ fn test_window_neg_stride() {
answer.invert_axis(Axis(1));
answer.map_inplace(|a| a.invert_axis(Axis(1)));

itertools::assert_equal(
assert_set_eq!(
array.slice(s![.., ..;-1]).windows((2, 2)),
answer.iter()
answer.iter().map(Array::view)
);

answer.invert_axis(Axis(0));
answer.map_inplace(|a| a.invert_axis(Axis(0)));

itertools::assert_equal(
assert_set_eq!(
array.slice(s![..;-1, ..;-1]).windows((2, 2)),
answer.iter()
answer.iter().map(Array::view)
);

answer.invert_axis(Axis(1));
answer.map_inplace(|a| a.invert_axis(Axis(1)));

itertools::assert_equal(
assert_set_eq!(
array.slice(s![..;-1, ..]).windows((2, 2)),
answer.iter()
answer.iter().map(Array::view)
);
}

0 comments on commit e5224a5

Please sign in to comment.