Skip to content

Commit

Permalink
Add cyclical learning rate schedulers (#644)
Browse files Browse the repository at this point in the history
* Add cyclical learning rate schedulers
  • Loading branch information
Raphael Meudec authored and seanpmorgan committed Nov 12, 2019
1 parent 776b751 commit 0a225c7
Show file tree
Hide file tree
Showing 5 changed files with 479 additions and 0 deletions.
14 changes: 14 additions & 0 deletions tensorflow_addons/optimizers/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ py_library(
srcs = [
"__init__.py",
"conditional_gradient.py",
"cyclical_learning_rate.py",
"lamb.py",
"lazy_adam.py",
"lookahead.py",
Expand Down Expand Up @@ -110,3 +111,16 @@ py_test(
":optimizers",
],
)

py_test(
name = "cyclical_learning_rate_test",
size = "small",
srcs = [
"cyclical_learning_rate_test.py",
],
main = "cyclical_learning_rate_test.py",
srcs_version = "PY2AND3",
deps = [
":optimizers",
],
)
2 changes: 2 additions & 0 deletions tensorflow_addons/optimizers/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
| Submodule | Maintainers | Contact Info |
|:---------- |:------------- |:--------------|
| conditional_gradient | Pengyu Kan, Vishnu Lokhande | [email protected], [email protected] |
| cyclical_learning_rate | Raphael Meudec | [email protected] |
| lamb | Jing Li, Junjie Ke | [email protected], [email protected] |
| lazy_adam | Saishruthi Swaminathan | [email protected] |
| lookahead | Zhao Hanguang | [email protected] |
Expand All @@ -16,6 +17,7 @@
| Submodule | Optimizer | Reference |
|:--------- |:---------- |:---------|
| conditional_gradient | ConditionalGradient | https://arxiv.org/pdf/1803.06453.pdf |
| cyclical_learning_rate | Cyclical Learning Rate | https://arxiv.org/abs/1506.01186 |
| lamb | LAMB | https://arxiv.org/abs/1904.00962 |
| lazy_adam | LazyAdam | https://arxiv.org/abs/1412.6980 |
| lookahead | Lookahead | https://arxiv.org/abs/1907.08610v1 |
Expand Down
8 changes: 8 additions & 0 deletions tensorflow_addons/optimizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@
from __future__ import print_function

from tensorflow_addons.optimizers.conditional_gradient import ConditionalGradient
from tensorflow_addons.optimizers.cyclical_learning_rate import (
CyclicalLearningRate)
from tensorflow_addons.optimizers.cyclical_learning_rate import (
TriangularCyclicalLearningRate)
from tensorflow_addons.optimizers.cyclical_learning_rate import (
Triangular2CyclicalLearningRate)
from tensorflow_addons.optimizers.cyclical_learning_rate import (
ExponentialCyclicalLearningRate)
from tensorflow_addons.optimizers.lamb import LAMB
from tensorflow_addons.optimizers.lazy_adam import LazyAdam
from tensorflow_addons.optimizers.lookahead import Lookahead
Expand Down
293 changes: 293 additions & 0 deletions tensorflow_addons/optimizers/cyclical_learning_rate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,293 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Cyclical Learning Rate Schedule policies for TensorFlow."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf


@tf.keras.utils.register_keras_serializable(package='Addons')
class CyclicalLearningRate(tf.keras.optimizers.schedules.LearningRateSchedule):
"""A LearningRateSchedule that uses cyclical schedule."""

def __init__(
self,
initial_learning_rate,
maximal_learning_rate,
step_size,
scale_fn,
scale_mode="cycle",
name=None,
):
"""Applies cyclical schedule to the learning rate.
See Cyclical Learning Rates for Training Neural Networks. https://arxiv.org/abs/1506.01186
```python
lr_schedule = tf.keras.optimizers.schedules.CyclicalLearningRate(
initial_learning_rate=1e-4,
maximal_learning_rate=1e-2,
step_size=2000,
scale_fn=lambda x: 1.,
scale_mode="cycle",
name="MyCyclicScheduler")
model.compile(optimizer=tf.keras.optimizers.SGD(
learning_rate=lr_schedule),
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
model.fit(data, labels, epochs=5)
```
You can pass this schedule directly into a
`tf.keras.optimizers.Optimizer` as the learning rate.
Args:
initial_learning_rate: A scalar `float32` or `float64` `Tensor` or
a Python number. The initial learning rate.
maximal_learning_rate: A scalar `float32` or `float64` `Tensor` or
a Python number. The maximum learning rate.
step_size: A scalar `float32` or `float64` `Tensor` or a
Python number. Step size.
scale_fn: A function. Scheduling function applied in cycle
scale_mode: ['cycle', 'iterations']. Mode to apply during cyclic
schedule
name: (Optional) Name for the operation.
Returns:
Updated learning rate value.
"""
super(CyclicalLearningRate, self).__init__()
self.initial_learning_rate = initial_learning_rate
self.maximal_learning_rate = maximal_learning_rate
self.step_size = step_size
self.scale_fn = scale_fn
self.scale_mode = scale_mode
self.name = name

def __call__(self, step):
with tf.name_scope(self.name or "CyclicalLearningRate"):
initial_learning_rate = tf.convert_to_tensor(
self.initial_learning_rate, name="initial_learning_rate")
dtype = initial_learning_rate.dtype
maximal_learning_rate = tf.cast(self.maximal_learning_rate, dtype)
step_size = tf.cast(self.step_size, dtype)
cycle = tf.floor(1 + step / (2 * step_size))
x = tf.abs(step / step_size - 2 * cycle + 1)

mode_step = cycle if self.scale_mode == "cycle" else step

return initial_learning_rate + (
maximal_learning_rate - initial_learning_rate) * tf.maximum(
tf.cast(0, dtype), (1 - x)) * self.scale_fn(mode_step)

def get_config(self):
return {
"initial_learning_rate": self.initial_learning_rate,
"maximal_learning_rate": self.maximal_learning_rate,
"step_size": self.step_size,
"scale_mode": self.scale_mode,
}


@tf.keras.utils.register_keras_serializable(package='Addons')
class TriangularCyclicalLearningRate(CyclicalLearningRate):
def __init__(
self,
initial_learning_rate,
maximal_learning_rate,
step_size,
scale_mode="cycle",
name="TriangularCyclicalLearningRate",
):
"""Applies triangular cyclical schedule to the learning rate.
See Cyclical Learning Rates for Training Neural Networks. https://arxiv.org/abs/1506.01186
```python
from tf.keras.optimizers import schedules
lr_schedule = schedules.TriangularCyclicalLearningRate(
initial_learning_rate=1e-4,
maximal_learning_rate=1e-2,
step_size=2000,
scale_mode="cycle",
name="MyCyclicScheduler")
model.compile(optimizer=tf.keras.optimizers.SGD(
learning_rate=lr_schedule),
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
model.fit(data, labels, epochs=5)
```
You can pass this schedule directly into a
`tf.keras.optimizers.Optimizer` as the learning rate.
Args:
initial_learning_rate: A scalar `float32` or `float64` `Tensor` or
a Python number. The initial learning rate.
maximal_learning_rate: A scalar `float32` or `float64` `Tensor` or
a Python number. The maximum learning rate.
step_size: A scalar `float32` or `float64` `Tensor` or a
Python number. Step size.
scale_fn: A function. Scheduling function applied in cycle
scale_mode: ['cycle', 'iterations']. Mode to apply during cyclic
schedule
name: (Optional) Name for the operation.
Returns:
Updated learning rate value.
"""
super(TriangularCyclicalLearningRate, self).__init__(
initial_learning_rate=initial_learning_rate,
maximal_learning_rate=maximal_learning_rate,
step_size=step_size,
scale_fn=lambda x: 1.,
scale_mode=scale_mode,
name=name,
)


@tf.keras.utils.register_keras_serializable(package='Addons')
class Triangular2CyclicalLearningRate(CyclicalLearningRate):
def __init__(
self,
initial_learning_rate,
maximal_learning_rate,
step_size,
scale_mode="cycle",
name="Triangular2CyclicalLearningRate",
):
"""Applies triangular2 cyclical schedule to the learning rate.
See Cyclical Learning Rates for Training Neural Networks. https://arxiv.org/abs/1506.01186
```python
from tf.keras.optimizers import schedules
lr_schedule = schedules.Triangular2CyclicalLearningRate(
initial_learning_rate=1e-4,
maximal_learning_rate=1e-2,
step_size=2000,
scale_mode="cycle",
name="MyCyclicScheduler")
model.compile(optimizer=tf.keras.optimizers.SGD(
learning_rate=lr_schedule),
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
model.fit(data, labels, epochs=5)
```
You can pass this schedule directly into a
`tf.keras.optimizers.Optimizer` as the learning rate.
Args:
initial_learning_rate: A scalar `float32` or `float64` `Tensor` or
a Python number. The initial learning rate.
maximal_learning_rate: A scalar `float32` or `float64` `Tensor` or
a Python number. The maximum learning rate.
step_size: A scalar `float32` or `float64` `Tensor` or a
Python number. Step size.
scale_fn: A function. Scheduling function applied in cycle
scale_mode: ['cycle', 'iterations']. Mode to apply during cyclic
schedule
name: (Optional) Name for the operation.
Returns:
Updated learning rate value.
"""
super(Triangular2CyclicalLearningRate, self).__init__(
initial_learning_rate=initial_learning_rate,
maximal_learning_rate=maximal_learning_rate,
step_size=step_size,
scale_fn=lambda x: 1 / (2.**(x - 1)),
scale_mode=scale_mode,
name=name,
)


@tf.keras.utils.register_keras_serializable(package='Addons')
class ExponentialCyclicalLearningRate(CyclicalLearningRate):
def __init__(
self,
initial_learning_rate,
maximal_learning_rate,
step_size,
scale_mode="iterations",
gamma=1.,
name="ExponentialCyclicalLearningRate",
):
"""Applies exponential cyclical schedule to the learning rate.
See Cyclical Learning Rates for Training Neural Networks. https://arxiv.org/abs/1506.01186
```python
from tf.keras.optimizers import schedules
lr_schedule = ExponentialCyclicalLearningRate(
initial_learning_rate=1e-4,
maximal_learning_rate=1e-2,
step_size=2000,
scale_mode="cycle",
gamma=0.96,
name="MyCyclicScheduler")
model.compile(optimizer=tf.keras.optimizers.SGD(
learning_rate=lr_schedule),
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
model.fit(data, labels, epochs=5)
```
You can pass this schedule directly into a
`tf.keras.optimizers.Optimizer` as the learning rate.
Args:
initial_learning_rate: A scalar `float32` or `float64` `Tensor` or
a Python number. The initial learning rate.
maximal_learning_rate: A scalar `float32` or `float64` `Tensor` or
a Python number. The maximum learning rate.
step_size: A scalar `float32` or `float64` `Tensor` or a
Python number. Step size.
scale_fn: A function. Scheduling function applied in cycle
scale_mode: ['cycle', 'iterations']. Mode to apply during cyclic
schedule
gamma: A scalar `float32` or `float64` `Tensor` or a
Python number. Gamma value.
name: (Optional) Name for the operation.
Returns:
Updated learning rate value.
"""
super(ExponentialCyclicalLearningRate, self).__init__(
initial_learning_rate=initial_learning_rate,
maximal_learning_rate=maximal_learning_rate,
step_size=step_size,
scale_fn=lambda x: gamma**x,
scale_mode=scale_mode,
name=name,
)
Loading

0 comments on commit 0a225c7

Please sign in to comment.