diff --git a/crates/burn-core/src/optim/adagrad.rs b/crates/burn-core/src/optim/adagrad.rs index e64c75db6e..2df43619a6 100644 --- a/crates/burn-core/src/optim/adagrad.rs +++ b/crates/burn-core/src/optim/adagrad.rs @@ -27,9 +27,9 @@ pub struct AdaGradConfig { /// AdaGrad optimizer #[derive(Clone)] -pub struct AdaGrad { +pub struct AdaGrad { lr_decay: LrDecay, - weight_decay: Option>, + weight_decay: Option, } /// AdaGrad state. @@ -38,7 +38,7 @@ pub struct AdaGradState { lr_decay: LrDecayState, } -impl SimpleOptimizer for AdaGrad { +impl SimpleOptimizer for AdaGrad { type State = AdaGradState; fn step( @@ -79,7 +79,7 @@ impl AdaGradConfig { /// Returns an optimizer that can be used to optimize a module. pub fn init>( &self, - ) -> OptimizerAdaptor, M, B> { + ) -> OptimizerAdaptor { let optim = AdaGrad { lr_decay: LrDecay { lr_decay: self.lr_decay, @@ -157,7 +157,7 @@ mod tests { use crate::optim::{GradientsParams, Optimizer}; use crate::record::{BinFileRecorder, FullPrecisionSettings, Recorder}; use crate::tensor::{Distribution, Tensor, TensorData}; - use crate::{nn, nn::Linear, TestAutodiffBackend, TestBackend}; + use crate::{nn, nn::Linear, TestAutodiffBackend}; const LEARNING_RATE: LearningRate = 0.01; @@ -274,8 +274,7 @@ mod tests { } fn create_adagrad( - ) -> OptimizerAdaptor, Linear, TestAutodiffBackend> - { + ) -> OptimizerAdaptor, TestAutodiffBackend> { let config = AdaGradConfig::new(); AdaGrad { lr_decay: LrDecay { diff --git a/crates/burn-core/src/optim/adam.rs b/crates/burn-core/src/optim/adam.rs index d7b4fe59b4..c2a5e3f679 100644 --- a/crates/burn-core/src/optim/adam.rs +++ b/crates/burn-core/src/optim/adam.rs @@ -32,18 +32,19 @@ pub struct AdamConfig { /// Adam optimizer as described in the paper [Adam: A Method for Stochastic Optimization](https://arxiv.org/pdf/1412.6980.pdf). #[derive(Clone)] -pub struct Adam { +pub struct Adam { momentum: AdaptiveMomentum, - weight_decay: Option>, + weight_decay: Option, } /// Adam state. #[derive(Record, Clone, new)] pub struct AdamState { - momentum: AdaptiveMomentumState, + /// The current adaptive momentum. + pub momentum: AdaptiveMomentumState, } -impl SimpleOptimizer for Adam { +impl SimpleOptimizer for Adam { type State = AdamState; fn step( @@ -83,9 +84,7 @@ impl AdamConfig { /// # Returns /// /// Returns an optimizer that can be used to optimize a module. - pub fn init>( - &self, - ) -> OptimizerAdaptor, M, B> { + pub fn init>(&self) -> OptimizerAdaptor { let optim = Adam { momentum: AdaptiveMomentum { beta_1: self.beta_1, @@ -106,9 +105,12 @@ impl AdamConfig { /// Adaptive momentum state. #[derive(Record, new, Clone)] pub struct AdaptiveMomentumState { - time: usize, - moment_1: Tensor, - moment_2: Tensor, + /// The number of iterations aggregated. + pub time: usize, + /// The first order momentum. + pub moment_1: Tensor, + /// The second order momentum. + pub moment_2: Tensor, } #[derive(Clone)] @@ -190,7 +192,7 @@ mod tests { use crate::optim::{GradientsParams, Optimizer}; use crate::record::{BinFileRecorder, FullPrecisionSettings, Recorder}; use crate::tensor::{Distribution, Tensor, TensorData}; - use crate::{nn, TestAutodiffBackend, TestBackend}; + use crate::{nn, TestAutodiffBackend}; const LEARNING_RATE: LearningRate = 0.01; @@ -350,8 +352,7 @@ mod tests { .load_record(record) } - fn create_adam( - ) -> OptimizerAdaptor, nn::Linear, TestAutodiffBackend> + fn create_adam() -> OptimizerAdaptor, TestAutodiffBackend> { let config = AdamConfig::new(); Adam { diff --git a/crates/burn-core/src/optim/adamw.rs b/crates/burn-core/src/optim/adamw.rs index 666ec18f13..32f4898dc6 100644 --- a/crates/burn-core/src/optim/adamw.rs +++ b/crates/burn-core/src/optim/adamw.rs @@ -1,14 +1,12 @@ +use super::{AdaptiveMomentumState, SimpleOptimizer}; +use crate::config::Config; +use crate::optim::adaptor::OptimizerAdaptor; +use crate::tensor::{backend::AutodiffBackend, Tensor}; use crate::{ self as burn, grad_clipping::GradientClippingConfig, module::AutodiffModule, record::Record, LearningRate, }; -use std::marker::PhantomData; - -use super::SimpleOptimizer; -use crate::config::Config; -use crate::optim::adaptor::OptimizerAdaptor; -use crate::tensor::{backend::AutodiffBackend, Tensor}; -use burn_tensor::{backend::Backend, ops::Device, ElementConversion}; +use burn_tensor::{backend::Backend, ops::Device}; /// AdamW configuration. #[derive(Config)] @@ -31,19 +29,19 @@ pub struct AdamWConfig { /// AdamW optimizer as described in the paper [Decoupled Weight Decay Regularization, Loshchilov and Hutter, 2019](https://arxiv.org/abs/1711.05101). #[derive(Clone)] -pub struct AdamW { +pub struct AdamW { momentum: AdaptiveMomentumW, weight_decay: f32, - _phantom: PhantomData, } /// AdamW state. #[derive(Record, Clone, new)] pub struct AdamWState { - momentum: AdaptiveMomentumWState, + /// Th current adaptive momentum state. + pub momentum: AdaptiveMomentumState, } -impl SimpleOptimizer for AdamW { +impl SimpleOptimizer for AdamW { type State = AdamWState; /// A single optimization step for any tensor that represents the parameters of a model. @@ -81,9 +79,7 @@ impl AdamWConfig { /// # Returns /// /// Returns an optimizer that can be used to optimize a module. - pub fn init>( - &self, - ) -> OptimizerAdaptor, M, B> { + pub fn init>(&self) -> OptimizerAdaptor { let optim = AdamW { momentum: AdaptiveMomentumW { beta_1: self.beta_1, @@ -91,7 +87,6 @@ impl AdamWConfig { epsilon: self.epsilon, }, weight_decay: self.weight_decay, - _phantom: Default::default(), }; let mut optim = OptimizerAdaptor::from(optim); @@ -102,14 +97,6 @@ impl AdamWConfig { } } -/// Adaptive momentum state. -#[derive(Record, new, Clone)] -pub struct AdaptiveMomentumWState { - time: usize, - moment_1: Tensor, - moment_2: Tensor, -} - #[derive(Clone)] struct AdaptiveMomentumW { beta_1: f32, @@ -121,8 +108,8 @@ impl AdaptiveMomentumW { pub fn transform( &self, grad: Tensor, - state: Option>, - ) -> (Tensor, AdaptiveMomentumWState) { + state: Option>, + ) -> (Tensor, AdaptiveMomentumState) { let state = if let Some(mut state) = state { // Update first moment estimate. let factor = 1.0 - self.beta_1; @@ -151,10 +138,10 @@ impl AdaptiveMomentumW { let factor = 1.0 - self.beta_2; let moment_2 = grad.powf_scalar(2.0).mul_scalar(factor); - AdaptiveMomentumWState::new(1, moment_1, moment_2) + AdaptiveMomentumState::new(1, moment_1, moment_2) }; - let time: i32 = (state.time as i32).elem(); + let time: i32 = state.time as i32; // Compute bias-corrected first and second moment estimates. let moment_1_corrected = state @@ -173,28 +160,11 @@ impl AdaptiveMomentumW { ( update_delta, - AdaptiveMomentumWState::new(state.time, state.moment_1, state.moment_2), + AdaptiveMomentumState::new(state.time, state.moment_1, state.moment_2), ) } } -impl AdaptiveMomentumWState { - /// Move state to device. - /// - /// # Arguments - /// - /// * `device` - Device to move state to. - /// - /// # Returns - /// - /// Returns state moved to device. - pub fn to_device(mut self, device: &B::Device) -> Self { - self.moment_1 = self.moment_1.to_device(device); - self.moment_2 = self.moment_2.to_device(device); - self - } -} - #[cfg(test)] mod tests { use super::*; @@ -202,7 +172,7 @@ mod tests { use crate::optim::{GradientsParams, Optimizer}; use crate::record::{BinFileRecorder, FullPrecisionSettings, Recorder}; use crate::tensor::{Distribution, Tensor, TensorData}; - use crate::{nn, TestAutodiffBackend, TestBackend}; + use crate::{nn, TestAutodiffBackend}; use tempfile::TempDir; const LEARNING_RATE: LearningRate = 0.01; @@ -366,8 +336,7 @@ mod tests { } fn create_adamw( - ) -> OptimizerAdaptor, nn::Linear, TestAutodiffBackend> - { + ) -> OptimizerAdaptor, TestAutodiffBackend> { let config = AdamWConfig::new(); AdamW { momentum: AdaptiveMomentumW { @@ -376,7 +345,6 @@ mod tests { epsilon: config.epsilon, }, weight_decay: config.weight_decay, - _phantom: Default::default(), } .into() } diff --git a/crates/burn-core/src/optim/decay.rs b/crates/burn-core/src/optim/decay.rs index 9979f57c29..27f258c0d8 100644 --- a/crates/burn-core/src/optim/decay.rs +++ b/crates/burn-core/src/optim/decay.rs @@ -4,13 +4,13 @@ use crate as burn; use crate::record::Record; use crate::config::Config; -use crate::tensor::{ElementConversion, Tensor}; +use crate::tensor::Tensor; /// Configuration to create [weight decay](WeightDecay). #[derive(Config)] pub struct WeightDecayConfig { /// L2 penalty. - pub penalty: f64, + pub penalty: f32, } /// State of [weight decay](WeightDecay). @@ -21,15 +21,15 @@ pub struct WeightDecayState { /// Weight decay implementation that transforms gradients. #[derive(Clone)] -pub struct WeightDecay { - penalty: B::FloatElem, +pub struct WeightDecay { + penalty: f32, } -impl WeightDecay { +impl WeightDecay { /// Creates a new [weight decay](WeightDecay) from a [config](WeightDecayConfig). pub fn new(config: &WeightDecayConfig) -> Self { Self { - penalty: config.penalty.elem(), + penalty: config.penalty, } } @@ -43,7 +43,7 @@ impl WeightDecay { /// # Returns /// /// * `grad` - Transformed gradient. - pub fn transform( + pub fn transform( &self, grad: Tensor, tensor: Tensor, diff --git a/crates/burn-core/src/optim/rmsprop.rs b/crates/burn-core/src/optim/rmsprop.rs index 63ccbe542d..bff9859df5 100644 --- a/crates/burn-core/src/optim/rmsprop.rs +++ b/crates/burn-core/src/optim/rmsprop.rs @@ -41,7 +41,7 @@ impl RmsPropConfig { /// Returns an optimizer that can be used to optimize a module. pub fn init>( &self, - ) -> OptimizerAdaptor, M, B> { + ) -> OptimizerAdaptor { let weight_decay = self.weight_decay.as_ref().map(WeightDecay::new); let mut optim = OptimizerAdaptor::from(RmsProp { @@ -65,16 +65,16 @@ impl RmsPropConfig { /// Optimizer that implements stochastic gradient descent with momentum. /// The optimizer can be configured with [RmsPropConfig](RmsPropConfig). #[derive(Clone)] -pub struct RmsProp { +pub struct RmsProp { alpha: f32, // epsilon: f32, centered: bool, // momentum: Option>, momentum: RmsPropMomentum, - weight_decay: Option>, + weight_decay: Option, } -impl SimpleOptimizer for RmsProp { +impl SimpleOptimizer for RmsProp { type State = RmsPropState; fn step( @@ -136,15 +136,19 @@ impl SimpleOptimizer for RmsProp { /// State of [RmsProp](RmsProp) #[derive(Record, Clone, new)] pub struct RmsPropState { - square_avg: SquareAvgState, - centered: CenteredState, - momentum: Option>, + /// Current squared average state. + pub square_avg: SquareAvgState, + /// Current centered state + pub centered: CenteredState, + /// Current gradient momentum, if any. + pub momentum: Option>, } /// [SquareAvgState](SquareAvgState) is to store and pass optimizer step params. #[derive(Record, Clone, new)] pub struct SquareAvgState { - square_avg: Tensor, + /// Current squared average. + pub square_avg: Tensor, } impl SquareAvgState { @@ -183,8 +187,10 @@ impl SquareAvgState { /// [CenteredState](CenteredState) is to store and pass optimizer step params. #[derive(Record, Clone, new)] pub struct CenteredState { - grad_avg: Option>, - avg: Tensor, + /// The averaged gradient to calculate the centered gradient, if available. + pub grad_avg: Option>, + /// The current average value. + pub avg: Tensor, } impl CenteredState { @@ -316,7 +322,7 @@ mod tests { use crate::optim::{GradientsParams, Optimizer}; use crate::record::{BinFileRecorder, FullPrecisionSettings, Recorder}; use crate::tensor::{Distribution, Tensor, TensorData}; - use crate::{nn, TestAutodiffBackend, TestBackend}; + use crate::{nn, TestAutodiffBackend}; use tempfile::TempDir; const LEARNING_RATE: LearningRate = 0.01; @@ -530,8 +536,7 @@ mod tests { } fn create_rmsprop( - ) -> OptimizerAdaptor, nn::Linear, TestAutodiffBackend> - { + ) -> OptimizerAdaptor, TestAutodiffBackend> { RmsPropConfig { alpha: 0.99, epsilon: 1e-9, diff --git a/crates/burn-core/src/optim/sgd.rs b/crates/burn-core/src/optim/sgd.rs index 325325cbfa..9b5d66c4fa 100644 --- a/crates/burn-core/src/optim/sgd.rs +++ b/crates/burn-core/src/optim/sgd.rs @@ -28,13 +28,14 @@ pub struct SgdConfig { #[derive(Clone)] pub struct Sgd { momentum: Option>, - weight_decay: Option>, + weight_decay: Option, } /// State of [Sgd](Sgd). #[derive(Record, Clone, new)] pub struct SgdState { - momentum: Option>, + /// The current state of the momentum (if any). + pub momentum: Option>, } impl SgdConfig {