From 8e1b6458ef9ef4facd6c86003a6845ef64e5346e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Hernandez?= Date: Thu, 23 May 2019 14:52:06 +0200 Subject: [PATCH 1/3] WIP LAMB Optimizer --- onmt/opts.py | 10 +++- onmt/utils/optimizers.py | 106 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 115 insertions(+), 1 deletion(-) diff --git a/onmt/opts.py b/onmt/opts.py index f763e4f254..8d644cbca8 100644 --- a/onmt/opts.py +++ b/onmt/opts.py @@ -429,7 +429,7 @@ def train_opts(parser): nargs="*", default=None, help='Criteria to use for early stopping.') group.add('--optim', '-optim', default='sgd', - choices=['sgd', 'adagrad', 'adadelta', 'adam', + choices=['sgd', 'adagrad', 'adadelta', 'adam', 'lamb', 'sparseadam', 'adafactor', 'fusedadam'], help="Optimization method.") group.add('--adagrad_accumulator_init', '-adagrad_accumulator_init', @@ -466,6 +466,14 @@ def train_opts(parser): 'suggested a value of 0.98 for beta2, this parameter may ' 'not work well for normal models / default ' 'baselines.') + group.add('--lamb_beta1', '-lamb_beta1', type=float, default=0.9, + help="The beta1 parameter used by Lamb.") + group.add('--lamb_beta2', '-lamb_beta2', type=float, default=0.999, + help="The beta2 parameter used by Lamb.") + group.add('--lamb_eps', '-lamb_eps', type=float, default=1e-8, + help="The epsilon parameter used by Lamb.") + group.add('--lamb_wd', '-lamb_wd', type=float, default=0.0, + help="The weight decay parameter used by Lamb." ) group.add('--label_smoothing', '-label_smoothing', type=float, default=0.0, help="Label smoothing value epsilon. " "Probabilities of all non-true labels " diff --git a/onmt/utils/optimizers.py b/onmt/utils/optimizers.py index 36039c33ee..63ecadbd61 100644 --- a/onmt/utils/optimizers.py +++ b/onmt/utils/optimizers.py @@ -6,6 +6,7 @@ import functools from copy import copy from math import sqrt +import math from onmt.utils.misc import fn_args @@ -82,6 +83,13 @@ def build_torch_optimizer(model, opt): params, lr=opt.learning_rate, betas=betas) + elif opt.optim == 'lamb': + optimizer = Lamb( + params, + lr=opt.learning_rate, + betas=(opt.lamb_beta1, opt.lamb_beta2), + eps=opt.lamb_eps, + weight_decay=opt.lamb_wd) else: raise ValueError('Invalid optimizer type: ' + opt.optim) @@ -517,3 +525,101 @@ def step(self, closure=None): p.data.add_(-group['weight_decay'] * lr_t, p.data) return loss + + +# Code below is an implementation of https://arxiv.org/pdf/1904.00962.pdf +# inspired but modified from https://github.com/cybertronai/pytorch-lamb + +class Lamb(torch.optim.Optimizer): + """Implements Lamb algorithm. + It has been proposed in `Reducing BERT Pre-Training Time from 3 Days to 76 Minutes`_. + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + adam (bool, optional): always use trust ratio = 1, which turns this into + Adam. Useful for comparison purposes. + .. _Reducing BERT Pre-Training Time from 3 Days to 76 Minutes: + https://arxiv.org/abs/1904.00962 + """ + + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, + weight_decay=0, adam=False): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + defaults = dict(lr=lr, betas=betas, eps=eps, + weight_decay=weight_decay) + self.adam = adam + super(Lamb, self).__init__(params, defaults) + + def step(self, closure=None): + """Performs a single optimization step. + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data + if grad.is_sparse: + raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.') + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p.data) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p.data) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + beta1, beta2 = group['betas'] + + state['step'] += 1 + + if group['weight_decay'] != 0: + grad.add_(group['weight_decay'], p.data) + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(1 - beta1, grad) + exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) + denom = exp_avg_sq.sqrt().add_(group['eps']) + + bias_correction1 = 1 - beta1 ** state['step'] + bias_correction2 = 1 - beta2 ** state['step'] + # Apply bias to lr to avoid broadcast. + step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 + + adam_step = exp_avg / denom + # L2 norm uses sum, but here since we're dividing, use mean to avoid overflow. + r1 = p.data.pow(2).mean().sqrt() + r2 = adam_step.pow(2).mean().sqrt() + r = 1 if r1 == 0 or r2 == 0 else min(r1/r2, 10) + state['r1'] = r1 + state['r2'] = r2 + state['r'] = r + if self.adam: + r = 1 + + p.data.add_(-step_size * r, adam_step) + + return loss \ No newline at end of file From c2ee86bf7acb995925053f0764a1c5ead7bed4ec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Hernandez?= Date: Wed, 5 Jun 2019 16:31:44 +0200 Subject: [PATCH 2/3] add a few comments --- onmt/utils/optimizers.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/onmt/utils/optimizers.py b/onmt/utils/optimizers.py index 63ecadbd61..06495ca9e9 100644 --- a/onmt/utils/optimizers.py +++ b/onmt/utils/optimizers.py @@ -527,11 +527,10 @@ def step(self, closure=None): return loss -# Code below is an implementation of https://arxiv.org/pdf/1904.00962.pdf -# inspired but modified from https://github.com/cybertronai/pytorch-lamb - class Lamb(torch.optim.Optimizer): """Implements Lamb algorithm. + Based on https://github.com/cybertronai/pytorch-lamb + which is itself based on `torch.optimizers.Adam`. It has been proposed in `Reducing BERT Pre-Training Time from 3 Days to 76 Minutes`_. Arguments: params (iterable): iterable of parameters to optimize or dicts defining @@ -591,6 +590,7 @@ def step(self, closure=None): # Exponential moving average of squared gradient values state['exp_avg_sq'] = torch.zeros_like(p.data) + # in the paper, exp_avg is m_t and exp_avg_sq is v_t exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] beta1, beta2 = group['betas'] @@ -599,9 +599,11 @@ def step(self, closure=None): if group['weight_decay'] != 0: grad.add_(group['weight_decay'], p.data) - # Decay the first and second moment running average coefficient + # m = beta1 * m + (1 - beta1) * grad exp_avg.mul_(beta1).add_(1 - beta1, grad) + # v = beta2 * m + (1 - beta2) * grad**2 exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) + denom = exp_avg_sq.sqrt().add_(group['eps']) bias_correction1 = 1 - beta1 ** state['step'] @@ -622,4 +624,4 @@ def step(self, closure=None): p.data.add_(-step_size * r, adam_step) - return loss \ No newline at end of file + return loss From 9dd286f9b221e047325f48fa45c6c59dffe8fb8d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Hernandez?= Date: Wed, 5 Jun 2019 16:40:06 +0200 Subject: [PATCH 3/3] fix flake --- onmt/opts.py | 2 +- onmt/utils/optimizers.py | 34 +++++++++++++++++++++------------- 2 files changed, 22 insertions(+), 14 deletions(-) diff --git a/onmt/opts.py b/onmt/opts.py index 8d644cbca8..b73527a3ec 100644 --- a/onmt/opts.py +++ b/onmt/opts.py @@ -473,7 +473,7 @@ def train_opts(parser): group.add('--lamb_eps', '-lamb_eps', type=float, default=1e-8, help="The epsilon parameter used by Lamb.") group.add('--lamb_wd', '-lamb_wd', type=float, default=0.0, - help="The weight decay parameter used by Lamb." ) + help="The weight decay parameter used by Lamb.") group.add('--label_smoothing', '-label_smoothing', type=float, default=0.0, help="Label smoothing value epsilon. " "Probabilities of all non-true labels " diff --git a/onmt/utils/optimizers.py b/onmt/utils/optimizers.py index 06495ca9e9..603f6212de 100644 --- a/onmt/utils/optimizers.py +++ b/onmt/utils/optimizers.py @@ -531,18 +531,20 @@ class Lamb(torch.optim.Optimizer): """Implements Lamb algorithm. Based on https://github.com/cybertronai/pytorch-lamb which is itself based on `torch.optimizers.Adam`. - It has been proposed in `Reducing BERT Pre-Training Time from 3 Days to 76 Minutes`_. + It has been proposed in `Reducing BERT Pre-Training Time + from 3 Days to 76 Minutes`_. Arguments: - params (iterable): iterable of parameters to optimize or dicts defining - parameter groups + params (iterable): iterable of parameters to optimize or + dicts defining parameter groups lr (float, optional): learning rate (default: 1e-3) - betas (Tuple[float, float], optional): coefficients used for computing - running averages of gradient and its square (default: (0.9, 0.999)) + betas (Tuple[float, float], optional): coefficients used + for computing running averages of gradient and + its square (default: (0.9, 0.999)) eps (float, optional): term added to the denominator to improve numerical stability (default: 1e-8) weight_decay (float, optional): weight decay (L2 penalty) (default: 0) - adam (bool, optional): always use trust ratio = 1, which turns this into - Adam. Useful for comparison purposes. + adam (bool, optional): always use trust ratio = 1, + which turns this into Adam. Useful for comparison purposes. .. _Reducing BERT Pre-Training Time from 3 Days to 76 Minutes: https://arxiv.org/abs/1904.00962 """ @@ -554,9 +556,11 @@ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if not 0.0 <= betas[0] < 1.0: - raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + raise ValueError("Invalid beta parameter at index 0: {}". + format(betas[0])) if not 0.0 <= betas[1] < 1.0: - raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + raise ValueError("Invalid beta parameter at index 1: {}". + format(betas[1])) defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) self.adam = adam @@ -578,7 +582,9 @@ def step(self, closure=None): continue grad = p.grad.data if grad.is_sparse: - raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.') + raise RuntimeError( + "Lamb does not support sparse gradients," + "consider SparseAdam instead.") state = self.state[p] @@ -609,13 +615,15 @@ def step(self, closure=None): bias_correction1 = 1 - beta1 ** state['step'] bias_correction2 = 1 - beta2 ** state['step'] # Apply bias to lr to avoid broadcast. - step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 + step_size = group['lr'] * \ + math.sqrt(bias_correction2) / bias_correction1 adam_step = exp_avg / denom - # L2 norm uses sum, but here since we're dividing, use mean to avoid overflow. + # L2 norm uses sum, but here since we're dividing, + # use mean to avoid overflow. r1 = p.data.pow(2).mean().sqrt() r2 = adam_step.pow(2).mean().sqrt() - r = 1 if r1 == 0 or r2 == 0 else min(r1/r2, 10) + r = 1 if r1 == 0 or r2 == 0 else min(r1/r2, 10) state['r1'] = r1 state['r2'] = r2 state['r'] = r