Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] LAMB optimizer #1460

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
WIP LAMB Optimizer
francoishernandez committed May 23, 2019
commit 8e1b6458ef9ef4facd6c86003a6845ef64e5346e
10 changes: 9 additions & 1 deletion onmt/opts.py
Original file line number Diff line number Diff line change
@@ -429,7 +429,7 @@ def train_opts(parser):
nargs="*", default=None,
help='Criteria to use for early stopping.')
group.add('--optim', '-optim', default='sgd',
choices=['sgd', 'adagrad', 'adadelta', 'adam',
choices=['sgd', 'adagrad', 'adadelta', 'adam', 'lamb',
'sparseadam', 'adafactor', 'fusedadam'],
help="Optimization method.")
group.add('--adagrad_accumulator_init', '-adagrad_accumulator_init',
@@ -466,6 +466,14 @@ def train_opts(parser):
'suggested a value of 0.98 for beta2, this parameter may '
'not work well for normal models / default '
'baselines.')
group.add('--lamb_beta1', '-lamb_beta1', type=float, default=0.9,
help="The beta1 parameter used by Lamb.")
group.add('--lamb_beta2', '-lamb_beta2', type=float, default=0.999,
help="The beta2 parameter used by Lamb.")
group.add('--lamb_eps', '-lamb_eps', type=float, default=1e-8,
help="The epsilon parameter used by Lamb.")
group.add('--lamb_wd', '-lamb_wd', type=float, default=0.0,
help="The weight decay parameter used by Lamb." )
group.add('--label_smoothing', '-label_smoothing', type=float, default=0.0,
help="Label smoothing value epsilon. "
"Probabilities of all non-true labels "
106 changes: 106 additions & 0 deletions onmt/utils/optimizers.py
Original file line number Diff line number Diff line change
@@ -6,6 +6,7 @@
import functools
from copy import copy
from math import sqrt
import math

from onmt.utils.misc import fn_args

@@ -82,6 +83,13 @@ def build_torch_optimizer(model, opt):
params,
lr=opt.learning_rate,
betas=betas)
elif opt.optim == 'lamb':
optimizer = Lamb(
params,
lr=opt.learning_rate,
betas=(opt.lamb_beta1, opt.lamb_beta2),
eps=opt.lamb_eps,
weight_decay=opt.lamb_wd)
else:
raise ValueError('Invalid optimizer type: ' + opt.optim)

@@ -517,3 +525,101 @@ def step(self, closure=None):
p.data.add_(-group['weight_decay'] * lr_t, p.data)

return loss


# Code below is an implementation of https://arxiv.org/pdf/1904.00962.pdf
# inspired but modified from https://github.com/cybertronai/pytorch-lamb

class Lamb(torch.optim.Optimizer):
"""Implements Lamb algorithm.
It has been proposed in `Reducing BERT Pre-Training Time from 3 Days to 76 Minutes`_.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
adam (bool, optional): always use trust ratio = 1, which turns this into
Adam. Useful for comparison purposes.
.. _Reducing BERT Pre-Training Time from 3 Days to 76 Minutes:
https://arxiv.org/abs/1904.00962
"""

def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
weight_decay=0, adam=False):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
defaults = dict(lr=lr, betas=betas, eps=eps,
weight_decay=weight_decay)
self.adam = adam
super(Lamb, self).__init__(params, defaults)

def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()

for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.')

state = self.state[p]

# State initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data)

exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
beta1, beta2 = group['betas']

state['step'] += 1

if group['weight_decay'] != 0:
grad.add_(group['weight_decay'], p.data)

# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(1 - beta1, grad)
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
denom = exp_avg_sq.sqrt().add_(group['eps'])

bias_correction1 = 1 - beta1 ** state['step']
bias_correction2 = 1 - beta2 ** state['step']
# Apply bias to lr to avoid broadcast.
step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1

adam_step = exp_avg / denom
# L2 norm uses sum, but here since we're dividing, use mean to avoid overflow.
r1 = p.data.pow(2).mean().sqrt()
r2 = adam_step.pow(2).mean().sqrt()
r = 1 if r1 == 0 or r2 == 0 else min(r1/r2, 10)
state['r1'] = r1
state['r2'] = r2
state['r'] = r
if self.adam:
r = 1

p.data.add_(-step_size * r, adam_step)

return loss