Skip to content

Commit

Permalink
Feat/adam optimizer (#140)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard authored Dec 30, 2022
1 parent 248039d commit eea5a26
Show file tree
Hide file tree
Showing 17 changed files with 405 additions and 28 deletions.
26 changes: 26 additions & 0 deletions burn-autodiff/src/ops/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1062,6 +1062,32 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
unary_ops_wrapper(tensor.node.clone(), output, ops)
}

fn sqrt<const D: usize>(
tensor: &<ADBackendDecorator<B> as Backend>::TensorPrimitive<D>,
) -> <ADBackendDecorator<B> as Backend>::TensorPrimitive<D> {
#[derive(new, Debug)]
struct Backward<B: Backend, const D: usize> {
_b: B,
}

impl<B: Backend, const D: usize> UnaryOps<B::TensorPrimitive<D>, B::TensorPrimitive<D>>
for Backward<B, D>
{
fn partial(
&self,
state: &UnaryOpsNodeState<B::TensorPrimitive<D>, B::TensorPrimitive<D>>,
) -> B::TensorPrimitive<D> {
let value = B::div_scalar(&B::powf(&state.input.value(), -0.5), &2.to_elem());
B::mul(&state.output.grad(), &value)
}
}

let output = B::sqrt(tensor.tensor_ref());
let ops = Backward::<B, D>::new(B::default());

unary_ops_wrapper(tensor.node.clone(), output, ops)
}

fn erf<const D: usize>(
tensor: &<ADBackendDecorator<B> as Backend>::TensorPrimitive<D>,
) -> <ADBackendDecorator<B> as Backend>::TensorPrimitive<D> {
Expand Down
2 changes: 2 additions & 0 deletions burn-autodiff/src/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ mod pow;
mod relu;
mod reshape;
mod softmax;
mod sqrt;
mod sub;
mod transpose;

Expand All @@ -43,6 +44,7 @@ macro_rules! testgen_all {
burn_autodiff::testgen_ad_mul!();
burn_autodiff::testgen_ad_neg!();
burn_autodiff::testgen_ad_powf!();
burn_autodiff::testgen_ad_sqrt!();
burn_autodiff::testgen_ad_relu!();
burn_autodiff::testgen_ad_reshape!();
burn_autodiff::testgen_ad_softmax!();
Expand Down
28 changes: 28 additions & 0 deletions burn-autodiff/src/tests/sqrt.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#[burn_tensor_testgen::testgen(ad_sqrt)]
mod tests {
use super::*;
use burn_tensor::Data;

#[test]
fn should_diff_sqrt() {
let data_1 = Data::<f32, 2>::from([[0.0, 1.0], [3.0, 4.0]]);
let data_2 = Data::<f32, 2>::from([[6.0, 7.0], [9.0, 10.0]]);

let tensor_1 = TestADTensor::from_data(data_1);
let tensor_2 = TestADTensor::from_data(data_2);

let tensor_3 = tensor_1.matmul(&tensor_2.sqrt());
let tensor_4 = tensor_3.matmul(&tensor_2);
let grads = tensor_4.backward();

let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();

grad_1
.to_data()
.assert_approx_eq(&Data::from([[82.1126, 99.0832], [82.1126, 99.0832]]), 3);
grad_2
.to_data()
.assert_approx_eq(&Data::from([[30.3093, 33.1204], [34.5819, 38.7694]]), 3);
}
}
8 changes: 8 additions & 0 deletions burn-ndarray/src/element.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ pub(crate) trait ExpElement {
fn exp_elem(self) -> Self;
fn log_elem(self) -> Self;
fn pow_elem(self, value: f32) -> Self;
fn sqrt_elem(self) -> Self;
}

macro_rules! impl_exp_elem {
Expand All @@ -23,6 +24,9 @@ macro_rules! impl_exp_elem {
fn pow_elem(self, value: f32) -> Self {
$elem::powf(self, value.into())
}
fn sqrt_elem(self) -> Self {
$elem::sqrt(self)
}
}
};
($elem:ident, $tmp:ident) => {
Expand All @@ -39,6 +43,10 @@ macro_rules! impl_exp_elem {
let tmp = $tmp::powf(self as $tmp, value as $tmp);
tmp as $elem
}
fn sqrt_elem(self) -> Self {
let tmp = $tmp::sqrt(self as $tmp);
tmp as $elem
}
}
};
}
Expand Down
7 changes: 7 additions & 0 deletions burn-ndarray/src/ops/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,13 @@ impl<E: NdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E> {
NdArrayTensor { array, shape }
}

fn sqrt<const D: usize>(tensor: &NdArrayTensor<E, D>) -> NdArrayTensor<E, D> {
let array = tensor.array.mapv(|a| a.sqrt_elem()).into_shared();
let shape = tensor.shape;

NdArrayTensor { array, shape }
}

fn erf<const D: usize>(tensor: &NdArrayTensor<E, D>) -> NdArrayTensor<E, D> {
let array = tensor
.array
Expand Down
4 changes: 4 additions & 0 deletions burn-tch/src/ops/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,10 @@ impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
to_tensor(tensor.tensor.pow_tensor_scalar(value as f64))
}

fn sqrt<const D: usize>(tensor: &TchTensor<E, D>) -> TchTensor<E, D> {
to_tensor(tensor.tensor.sqrt())
}

fn erf<const D: usize>(tensor: &TchTensor<E, D>) -> TchTensor<E, D> {
to_tensor(tensor.tensor.erf())
}
Expand Down
15 changes: 15 additions & 0 deletions burn-tensor/src/tensor/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,16 @@ where
}
}

impl<B> Tensor<B, 1>
where
B: Backend,
{
/// Returns the first value of the tensor.
pub fn single_value(&self) -> B::Elem {
self.to_data().value[0]
}
}

impl<const D: usize, B> Tensor<B, D>
where
B: Backend,
Expand Down Expand Up @@ -90,6 +100,11 @@ where
Self::new(B::powf(&self.value, value))
}

/// Applies element wise root square operation.
pub fn sqrt(&self) -> Self {
Self::new(B::sqrt(&self.value))
}

/// Returns the shape of the current tensor.
pub fn shape(&self) -> &Shape<D> {
B::shape(&self.value)
Expand Down
1 change: 1 addition & 0 deletions burn-tensor/src/tensor/ops/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ pub trait TensorOps<B: Backend> {
fn exp<const D: usize>(tensor: &B::TensorPrimitive<D>) -> B::TensorPrimitive<D>;
fn log<const D: usize>(tensor: &B::TensorPrimitive<D>) -> B::TensorPrimitive<D>;
fn powf<const D: usize>(tensor: &B::TensorPrimitive<D>, value: f32) -> B::TensorPrimitive<D>;
fn sqrt<const D: usize>(tensor: &B::TensorPrimitive<D>) -> B::TensorPrimitive<D>;
fn erf<const D: usize>(tensor: &B::TensorPrimitive<D>) -> B::TensorPrimitive<D>;
fn cat<const D: usize>(tensors: &[B::TensorPrimitive<D>], dim: usize) -> B::TensorPrimitive<D>;
fn relu<const D: usize>(tensor: &B::TensorPrimitive<D>) -> B::TensorPrimitive<D>;
Expand Down
4 changes: 1 addition & 3 deletions burn/src/nn/layer_norm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,7 @@ impl<B: Backend> LayerNorm<B> {
pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
let (var, mean) = input.var_mean_bias(D - 1);

let input_normalized = input
.sub(&mean)
.div(&var.powf(0.5).add_scalar(self.epsilon));
let input_normalized = input.sub(&mean).div(&var.sqrt().add_scalar(self.epsilon));

input_normalized
.mul(&self.gamma.unsqueeze())
Expand Down
Loading

0 comments on commit eea5a26

Please sign in to comment.