Skip to content

Commit

Permalink
Make optimizer state public (#2561)
Browse files Browse the repository at this point in the history
* Make optimizer state public

* Don't use local cube

* Change weight decay config
  • Loading branch information
ArthurBrussee authored Nov 28, 2024
1 parent 4258502 commit 6f494e5
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 91 deletions.
13 changes: 6 additions & 7 deletions crates/burn-core/src/optim/adagrad.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ pub struct AdaGradConfig {

/// AdaGrad optimizer
#[derive(Clone)]
pub struct AdaGrad<B: Backend> {
pub struct AdaGrad {
lr_decay: LrDecay,
weight_decay: Option<WeightDecay<B>>,
weight_decay: Option<WeightDecay>,
}

/// AdaGrad state.
Expand All @@ -38,7 +38,7 @@ pub struct AdaGradState<B: Backend, const D: usize> {
lr_decay: LrDecayState<B, D>,
}

impl<B: Backend> SimpleOptimizer<B> for AdaGrad<B> {
impl<B: Backend> SimpleOptimizer<B> for AdaGrad {
type State<const D: usize> = AdaGradState<B, D>;

fn step<const D: usize>(
Expand Down Expand Up @@ -79,7 +79,7 @@ impl AdaGradConfig {
/// Returns an optimizer that can be used to optimize a module.
pub fn init<B: AutodiffBackend, M: AutodiffModule<B>>(
&self,
) -> OptimizerAdaptor<AdaGrad<B::InnerBackend>, M, B> {
) -> OptimizerAdaptor<AdaGrad, M, B> {
let optim = AdaGrad {
lr_decay: LrDecay {
lr_decay: self.lr_decay,
Expand Down Expand Up @@ -157,7 +157,7 @@ mod tests {
use crate::optim::{GradientsParams, Optimizer};
use crate::record::{BinFileRecorder, FullPrecisionSettings, Recorder};
use crate::tensor::{Distribution, Tensor, TensorData};
use crate::{nn, nn::Linear, TestAutodiffBackend, TestBackend};
use crate::{nn, nn::Linear, TestAutodiffBackend};

const LEARNING_RATE: LearningRate = 0.01;

Expand Down Expand Up @@ -274,8 +274,7 @@ mod tests {
}

fn create_adagrad(
) -> OptimizerAdaptor<AdaGrad<TestBackend>, Linear<TestAutodiffBackend>, TestAutodiffBackend>
{
) -> OptimizerAdaptor<AdaGrad, Linear<TestAutodiffBackend>, TestAutodiffBackend> {
let config = AdaGradConfig::new();
AdaGrad {
lr_decay: LrDecay {
Expand Down
27 changes: 14 additions & 13 deletions crates/burn-core/src/optim/adam.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,19 @@ pub struct AdamConfig {

/// Adam optimizer as described in the paper [Adam: A Method for Stochastic Optimization](https://arxiv.org/pdf/1412.6980.pdf).
#[derive(Clone)]
pub struct Adam<B: Backend> {
pub struct Adam {
momentum: AdaptiveMomentum,
weight_decay: Option<WeightDecay<B>>,
weight_decay: Option<WeightDecay>,
}

/// Adam state.
#[derive(Record, Clone, new)]
pub struct AdamState<B: Backend, const D: usize> {
momentum: AdaptiveMomentumState<B, D>,
/// The current adaptive momentum.
pub momentum: AdaptiveMomentumState<B, D>,
}

impl<B: Backend> SimpleOptimizer<B> for Adam<B> {
impl<B: Backend> SimpleOptimizer<B> for Adam {
type State<const D: usize> = AdamState<B, D>;

fn step<const D: usize>(
Expand Down Expand Up @@ -83,9 +84,7 @@ impl AdamConfig {
/// # Returns
///
/// Returns an optimizer that can be used to optimize a module.
pub fn init<B: AutodiffBackend, M: AutodiffModule<B>>(
&self,
) -> OptimizerAdaptor<Adam<B::InnerBackend>, M, B> {
pub fn init<B: AutodiffBackend, M: AutodiffModule<B>>(&self) -> OptimizerAdaptor<Adam, M, B> {
let optim = Adam {
momentum: AdaptiveMomentum {
beta_1: self.beta_1,
Expand All @@ -106,9 +105,12 @@ impl AdamConfig {
/// Adaptive momentum state.
#[derive(Record, new, Clone)]
pub struct AdaptiveMomentumState<B: Backend, const D: usize> {
time: usize,
moment_1: Tensor<B, D>,
moment_2: Tensor<B, D>,
/// The number of iterations aggregated.
pub time: usize,
/// The first order momentum.
pub moment_1: Tensor<B, D>,
/// The second order momentum.
pub moment_2: Tensor<B, D>,
}

#[derive(Clone)]
Expand Down Expand Up @@ -190,7 +192,7 @@ mod tests {
use crate::optim::{GradientsParams, Optimizer};
use crate::record::{BinFileRecorder, FullPrecisionSettings, Recorder};
use crate::tensor::{Distribution, Tensor, TensorData};
use crate::{nn, TestAutodiffBackend, TestBackend};
use crate::{nn, TestAutodiffBackend};

const LEARNING_RATE: LearningRate = 0.01;

Expand Down Expand Up @@ -350,8 +352,7 @@ mod tests {
.load_record(record)
}

fn create_adam(
) -> OptimizerAdaptor<Adam<TestBackend>, nn::Linear<TestAutodiffBackend>, TestAutodiffBackend>
fn create_adam() -> OptimizerAdaptor<Adam, nn::Linear<TestAutodiffBackend>, TestAutodiffBackend>
{
let config = AdamConfig::new();
Adam {
Expand Down
66 changes: 17 additions & 49 deletions crates/burn-core/src/optim/adamw.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
use super::{AdaptiveMomentumState, SimpleOptimizer};
use crate::config::Config;
use crate::optim::adaptor::OptimizerAdaptor;
use crate::tensor::{backend::AutodiffBackend, Tensor};
use crate::{
self as burn, grad_clipping::GradientClippingConfig, module::AutodiffModule, record::Record,
LearningRate,
};
use std::marker::PhantomData;

use super::SimpleOptimizer;
use crate::config::Config;
use crate::optim::adaptor::OptimizerAdaptor;
use crate::tensor::{backend::AutodiffBackend, Tensor};
use burn_tensor::{backend::Backend, ops::Device, ElementConversion};
use burn_tensor::{backend::Backend, ops::Device};

/// AdamW configuration.
#[derive(Config)]
Expand All @@ -31,19 +29,19 @@ pub struct AdamWConfig {

/// AdamW optimizer as described in the paper [Decoupled Weight Decay Regularization, Loshchilov and Hutter, 2019](https://arxiv.org/abs/1711.05101).
#[derive(Clone)]
pub struct AdamW<B: Backend> {
pub struct AdamW {
momentum: AdaptiveMomentumW,
weight_decay: f32,
_phantom: PhantomData<B>,
}

/// AdamW state.
#[derive(Record, Clone, new)]
pub struct AdamWState<B: Backend, const D: usize> {
momentum: AdaptiveMomentumWState<B, D>,
/// Th current adaptive momentum state.
pub momentum: AdaptiveMomentumState<B, D>,
}

impl<B: Backend> SimpleOptimizer<B> for AdamW<B> {
impl<B: Backend> SimpleOptimizer<B> for AdamW {
type State<const D: usize> = AdamWState<B, D>;

/// A single optimization step for any tensor that represents the parameters of a model.
Expand Down Expand Up @@ -81,17 +79,14 @@ impl AdamWConfig {
/// # Returns
///
/// Returns an optimizer that can be used to optimize a module.
pub fn init<B: AutodiffBackend, M: AutodiffModule<B>>(
&self,
) -> OptimizerAdaptor<AdamW<B::InnerBackend>, M, B> {
pub fn init<B: AutodiffBackend, M: AutodiffModule<B>>(&self) -> OptimizerAdaptor<AdamW, M, B> {
let optim = AdamW {
momentum: AdaptiveMomentumW {
beta_1: self.beta_1,
beta_2: self.beta_2,
epsilon: self.epsilon,
},
weight_decay: self.weight_decay,
_phantom: Default::default(),
};

let mut optim = OptimizerAdaptor::from(optim);
Expand All @@ -102,14 +97,6 @@ impl AdamWConfig {
}
}

/// Adaptive momentum state.
#[derive(Record, new, Clone)]
pub struct AdaptiveMomentumWState<B: Backend, const D: usize> {
time: usize,
moment_1: Tensor<B, D>,
moment_2: Tensor<B, D>,
}

#[derive(Clone)]
struct AdaptiveMomentumW {
beta_1: f32,
Expand All @@ -121,8 +108,8 @@ impl AdaptiveMomentumW {
pub fn transform<B: Backend, const D: usize>(
&self,
grad: Tensor<B, D>,
state: Option<AdaptiveMomentumWState<B, D>>,
) -> (Tensor<B, D>, AdaptiveMomentumWState<B, D>) {
state: Option<AdaptiveMomentumState<B, D>>,
) -> (Tensor<B, D>, AdaptiveMomentumState<B, D>) {
let state = if let Some(mut state) = state {
// Update first moment estimate.
let factor = 1.0 - self.beta_1;
Expand Down Expand Up @@ -151,10 +138,10 @@ impl AdaptiveMomentumW {
let factor = 1.0 - self.beta_2;
let moment_2 = grad.powf_scalar(2.0).mul_scalar(factor);

AdaptiveMomentumWState::new(1, moment_1, moment_2)
AdaptiveMomentumState::new(1, moment_1, moment_2)
};

let time: i32 = (state.time as i32).elem();
let time: i32 = state.time as i32;

// Compute bias-corrected first and second moment estimates.
let moment_1_corrected = state
Expand All @@ -173,36 +160,19 @@ impl AdaptiveMomentumW {

(
update_delta,
AdaptiveMomentumWState::new(state.time, state.moment_1, state.moment_2),
AdaptiveMomentumState::new(state.time, state.moment_1, state.moment_2),
)
}
}

impl<B: Backend, const D: usize> AdaptiveMomentumWState<B, D> {
/// Move state to device.
///
/// # Arguments
///
/// * `device` - Device to move state to.
///
/// # Returns
///
/// Returns state moved to device.
pub fn to_device(mut self, device: &B::Device) -> Self {
self.moment_1 = self.moment_1.to_device(device);
self.moment_2 = self.moment_2.to_device(device);
self
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::module::{Module, Param};
use crate::optim::{GradientsParams, Optimizer};
use crate::record::{BinFileRecorder, FullPrecisionSettings, Recorder};
use crate::tensor::{Distribution, Tensor, TensorData};
use crate::{nn, TestAutodiffBackend, TestBackend};
use crate::{nn, TestAutodiffBackend};
use tempfile::TempDir;

const LEARNING_RATE: LearningRate = 0.01;
Expand Down Expand Up @@ -366,8 +336,7 @@ mod tests {
}

fn create_adamw(
) -> OptimizerAdaptor<AdamW<TestBackend>, nn::Linear<TestAutodiffBackend>, TestAutodiffBackend>
{
) -> OptimizerAdaptor<AdamW, nn::Linear<TestAutodiffBackend>, TestAutodiffBackend> {
let config = AdamWConfig::new();
AdamW {
momentum: AdaptiveMomentumW {
Expand All @@ -376,7 +345,6 @@ mod tests {
epsilon: config.epsilon,
},
weight_decay: config.weight_decay,
_phantom: Default::default(),
}
.into()
}
Expand Down
14 changes: 7 additions & 7 deletions crates/burn-core/src/optim/decay.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@ use crate as burn;
use crate::record::Record;

use crate::config::Config;
use crate::tensor::{ElementConversion, Tensor};
use crate::tensor::Tensor;

/// Configuration to create [weight decay](WeightDecay).
#[derive(Config)]
pub struct WeightDecayConfig {
/// L2 penalty.
pub penalty: f64,
pub penalty: f32,
}

/// State of [weight decay](WeightDecay).
Expand All @@ -21,15 +21,15 @@ pub struct WeightDecayState<B: Backend, const D: usize> {

/// Weight decay implementation that transforms gradients.
#[derive(Clone)]
pub struct WeightDecay<B: Backend> {
penalty: B::FloatElem,
pub struct WeightDecay {
penalty: f32,
}

impl<B: Backend> WeightDecay<B> {
impl WeightDecay {
/// Creates a new [weight decay](WeightDecay) from a [config](WeightDecayConfig).
pub fn new(config: &WeightDecayConfig) -> Self {
Self {
penalty: config.penalty.elem(),
penalty: config.penalty,
}
}

Expand All @@ -43,7 +43,7 @@ impl<B: Backend> WeightDecay<B> {
/// # Returns
///
/// * `grad` - Transformed gradient.
pub fn transform<const D: usize>(
pub fn transform<B: Backend, const D: usize>(
&self,
grad: Tensor<B, D>,
tensor: Tensor<B, D>,
Expand Down
Loading

0 comments on commit 6f494e5

Please sign in to comment.