Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement arithmetic ops on more combinations of types #744

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 143 additions & 0 deletions src/impl_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,74 @@ where
}
}

/// Perform elementwise
#[doc=$doc]
/// between references `self` and `rhs`,
/// and return the result as a new `Array`.
///
/// If their shapes disagree, `rhs` is broadcast to the shape of `self`.
///
/// **Panics** if broadcasting isn’t possible.
impl<'a, 'b, A, B, S2, D, E> $trt<&'b ArrayBase<S2, E>> for ArrayView<'a, A, D>
where
A: Clone + $trt<B, Output=A>,
B: Clone,
S2: Data<Elem=B>,
D: Dimension,
E: Dimension,
{
type Output = Array<A, D>;
fn $mth(self, rhs: &'b ArrayBase<S2, E>) -> Array<A, D> {
// FIXME: Can we co-broadcast arrays here? And how?
self.to_owned().$mth(rhs)
}
}

/// Perform elementwise
#[doc=$doc]
/// between references `self` and `rhs`,
/// and return the result as a new `Array`.
///
/// If their shapes disagree, `rhs` is broadcast to the shape of `self`.
///
/// **Panics** if broadcasting isn’t possible.
impl<'a, 'b, A, B, S, D, E> $trt<ArrayView<'b, B, E>> for &'a ArrayBase<S, D>
where
A: Clone + $trt<B, Output=A>,
B: Clone,
S: Data<Elem=A>,
D: Dimension,
E: Dimension,
{
type Output = Array<A, D>;
fn $mth(self, rhs: ArrayView<'b, B, E>) -> Array<A, D> {
// FIXME: Can we co-broadcast arrays here? And how?
self.to_owned().$mth(rhs)
}
}

/// Perform elementwise
#[doc=$doc]
/// between references `self` and `rhs`,
/// and return the result as a new `Array`.
///
/// If their shapes disagree, `rhs` is broadcast to the shape of `self`.
///
/// **Panics** if broadcasting isn’t possible.
impl<'a, 'b, A, B, D, E> $trt<ArrayView<'b, B, E>> for ArrayView<'a, A, D>
where
A: Clone + $trt<B, Output=A>,
B: Clone,
D: Dimension,
E: Dimension,
{
type Output = Array<A, D>;
fn $mth(self, rhs: ArrayView<'b, B, E>) -> Array<A, D> {
// FIXME: Can we co-broadcast arrays here? And how?
self.to_owned().$mth(rhs)
}
}

/// Perform elementwise
#[doc=$doc]
/// between `self` and the scalar `x`,
Expand Down Expand Up @@ -163,6 +231,21 @@ impl<'a, A, S, D, B> $trt<B> for &'a ArrayBase<S, D>
self.to_owned().$mth(x)
}
}

/// Perform elementwise
#[doc=$doc]
/// between the reference `self` and the scalar `x`,
/// and return the result as a new `Array`.
impl<'a, A, D, B> $trt<B> for ArrayView<'a, A, D>
where A: Clone + $trt<B, Output=A>,
D: Dimension,
B: ScalarOperand,
{
type Output = Array<A, D>;
fn $mth(self, x: B) -> Array<A, D> {
self.to_owned().$mth(x)
}
}
);
);

Expand Down Expand Up @@ -218,6 +301,23 @@ impl<'a, S, D> $trt<&'a ArrayBase<S, D>> for $scalar
})
}
}

/// Perform elementwise
/// between the scalar `self` and array `rhs`,
/// and return the result as a new `Array`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not important, but just something I saw. Allegedly this doc is not visible in rustdoc, but that comment may be outdated now(!). then we want that #[doc=$doc] back to make a complete sentence.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would especially discourage these impl blocks. We have to spend a lot of impl blocks on the arithmetic ops with scalars on the left hand side. Part of me want to just remove everything with left hand side scalars, because it can never(?) be properly general because of how we need to express it as an impl .. for $scalar (for a specific single type).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've removed this impl block in the most recent commit because it's not necessary.

Part of me want to just remove everything with left hand side scalars, because it can never(?) be properly general because of how we need to express it as an impl .. for $scalar (for a specific single type).

I understand the objection, and AFAIK it will never be possible to write these impls for generic scalars even after re_rebalance_coherence. However, IMO supporting left-hand-side scalars is important, even if it's just f32/f64. I use these impls all the time in my code. They are especially important for non-commutative ops, where it's not straightforward to just move the scalar to the other side.

impl<'a, D> $trt<ArrayView<'a, $scalar, D>> for $scalar
where
D: Dimension,
{
type Output = Array<$scalar, D>;
fn $mth(self, rhs: ArrayView<'a, $scalar, D>) -> Array<$scalar, D> {
if_commutative!($commutative {
rhs.$mth(self)
} or {
self.$mth(rhs.to_owned())
})
}
}
);
}

Expand Down Expand Up @@ -320,6 +420,19 @@ mod arithmetic_ops {
}
}

impl<'a, A, D> Neg for ArrayView<'a, A, D>
where
for<'b> &'b A: Neg<Output = A>,
D: Dimension,
{
type Output = Array<A, D>;
/// Perform an elementwise negation of reference `self` and return the
/// result as a new `Array`.
fn neg(self) -> Array<A, D> {
self.map(Neg::neg)
}
}

impl<A, S, D> Not for ArrayBase<S, D>
where
A: Clone + Not<Output = A>,
Expand Down Expand Up @@ -349,6 +462,19 @@ mod arithmetic_ops {
self.map(Not::not)
}
}

impl<'a, A, D> Not for ArrayView<'a, A, D>
where
for<'b> &'b A: Not<Output = A>,
D: Dimension,
{
type Output = Array<A, D>;
/// Perform an elementwise unary not of reference `self` and return the
/// result as a new `Array`.
fn not(self) -> Array<A, D> {
self.map(Not::not)
}
}
}

mod assign_ops {
Expand All @@ -359,6 +485,23 @@ mod assign_ops {
($trt:ident, $method:ident, $doc:expr) => {
use std::ops::$trt;

#[doc=$doc]
/// If their shapes disagree, `rhs` is broadcast to the shape of `self`.
///
/// **Panics** if broadcasting isn’t possible.
impl<A, S, S2, D, E> $trt<ArrayBase<S2, E>> for ArrayBase<S, D>
where
A: Clone + $trt<A>,
S: DataMut<Elem = A>,
S2: Data<Elem = A>,
D: Dimension,
E: Dimension,
{
fn $method(&mut self, rhs: ArrayBase<S2, E>) {
self.$method(&rhs)
}
}

#[doc=$doc]
/// If their shapes disagree, `rhs` is broadcast to the shape of `self`.
///
Expand Down
31 changes: 16 additions & 15 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -607,18 +607,18 @@ pub type Ixs = isize;
///
/// ### Binary Operators with Two Arrays
///
/// Let `A` be an array or view of any kind. Let `B` be an array
/// Let `A` be an array or view of any kind. Let `O` be an array
/// with owned storage (either `Array` or `ArcArray`).
/// Let `C` be an array with mutable data (either `Array`, `ArcArray`
/// or `ArrayViewMut`).
/// Let `M` be an array with mutable data (either `Array`, `ArcArray`
/// or `ArrayViewMut`). Let `V` be an `ArrayView`.
/// The following combinations of operands
/// are supported for an arbitrary binary operator denoted by `@` (it can be
/// `+`, `-`, `*`, `/` and so on).
///
/// - `&A @ &A` which produces a new `Array`
/// - `B @ A` which consumes `B`, updates it with the result, and returns it
/// - `B @ &A` which consumes `B`, updates it with the result, and returns it
/// - `C @= &A` which performs an arithmetic operation in place
/// - `&A @ &A`, `&A @ V`, `V @ &A`, or `V @ V` which produce a new `Array`
/// - `O @ A` which consumes `O`, updates it with the result, and returns it
/// - `O @ &A` which consumes `O`, updates it with the result, and returns it
/// - `M @= &A` or `M @= A` which performs an arithmetic operation in place on `M`
///
/// Note that the element type needs to implement the operator trait and the
/// `Clone` trait.
Expand Down Expand Up @@ -646,18 +646,19 @@ pub type Ixs = isize;
/// are supported (scalar can be on either the left or right side, but
/// `ScalarOperand` docs has the detailed condtions).
///
/// - `&A @ K` or `K @ &A` which produces a new `Array`
/// - `B @ K` or `K @ B` which consumes `B`, updates it with the result and returns it
/// - `C @= K` which performs an arithmetic operation in place
/// - `&A @ K`, `V @ K`, `K @ &A`, or `K @ V` which produces a new `Array`
/// - `O @ K` or `K @ O` which consumes `O`, updates it with the result and returns it
/// - `M @= K` which performs an arithmetic operation in place
///
/// ### Unary Operators
///
/// Let `A` be an array or view of any kind. Let `B` be an array with owned
/// storage (either `Array` or `ArcArray`). The following operands are supported
/// for an arbitrary unary operator denoted by `@` (it can be `-` or `!`).
/// Let `A` be an array or view of any kind. Let `O` be an array with owned
/// storage (either `Array` or `ArcArray`). Let `V` be an `ArrayView`. The
/// following operands are supported for an arbitrary unary operator denoted by
/// `@` (it can be `-` or `!`).
///
/// - `@&A` which produces a new `Array`
/// - `@B` which consumes `B`, updates it with the result, and returns it
/// - `@&A` or `@V` which produces a new `Array`
/// - `@O` which consumes `O`, updates it with the result, and returns it
Copy link
Member

@bluss bluss Oct 17, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for updating this, even if it's a draft PR or whatever, it makes it easy to discuss changes 🙂

///
/// ## Broadcasting
///
Expand Down