Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add cyclical learning rate schedulers #644

Merged
merged 10 commits into from
Nov 12, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions tensorflow_addons/optimizers/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ py_library(
srcs = [
"__init__.py",
"conditional_gradient.py",
"cyclical_learning_rate.py",
"lamb.py",
"lazy_adam.py",
"lookahead.py",
Expand Down Expand Up @@ -110,3 +111,16 @@ py_test(
":optimizers",
],
)

py_test(
name = "cyclical_learning_rate_test",
size = "small",
srcs = [
"cyclical_learning_rate_test.py",
],
main = "cyclical_learning_rate_test.py",
srcs_version = "PY2AND3",
deps = [
":optimizers",
],
)
2 changes: 2 additions & 0 deletions tensorflow_addons/optimizers/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
| Submodule | Maintainers | Contact Info |
|:---------- |:------------- |:--------------|
| conditional_gradient | Pengyu Kan, Vishnu Lokhande | [email protected], [email protected] |
| cyclical_learning_rate | Raphael Meudec | [email protected] |
| lamb | Jing Li, Junjie Ke | [email protected], [email protected] |
| lazy_adam | Saishruthi Swaminathan | [email protected] |
| lookahead | Zhao Hanguang | [email protected] |
Expand All @@ -16,6 +17,7 @@
| Submodule | Optimizer | Reference |
|:--------- |:---------- |:---------|
| conditional_gradient | ConditionalGradient | https://arxiv.org/pdf/1803.06453.pdf |
| cyclical_learning_rate | Cyclical Learning Rate | https://arxiv.org/abs/1506.01186 |
| lamb | LAMB | https://arxiv.org/abs/1904.00962 |
| lazy_adam | LazyAdam | https://arxiv.org/abs/1412.6980 |
| lookahead | Lookahead | https://arxiv.org/abs/1907.08610v1 |
Expand Down
8 changes: 8 additions & 0 deletions tensorflow_addons/optimizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@
from __future__ import print_function

from tensorflow_addons.optimizers.conditional_gradient import ConditionalGradient
from tensorflow_addons.optimizers.cyclical_learning_rate import (
CyclicalLearningRate)
from tensorflow_addons.optimizers.cyclical_learning_rate import (
TriangularCyclicalLearningRate)
from tensorflow_addons.optimizers.cyclical_learning_rate import (
Triangular2CyclicalLearningRate)
from tensorflow_addons.optimizers.cyclical_learning_rate import (
ExponentialCyclicalLearningRate)
from tensorflow_addons.optimizers.lamb import LAMB
from tensorflow_addons.optimizers.lazy_adam import LazyAdam
from tensorflow_addons.optimizers.lookahead import Lookahead
Expand Down
293 changes: 293 additions & 0 deletions tensorflow_addons/optimizers/cyclical_learning_rate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,293 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Cyclical Learning Rate Schedule policies for TensorFlow."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf


@tf.keras.utils.register_keras_serializable(package='Addons')
class CyclicalLearningRate(tf.keras.optimizers.schedules.LearningRateSchedule):
"""A LearningRateSchedule that uses cyclical schedule."""

def __init__(
self,
initial_learning_rate,
maximal_learning_rate,
step_size,
scale_fn,
scale_mode="cycle",
name=None,
):
"""Applies cyclical schedule to the learning rate.

See Cyclical Learning Rates for Training Neural Networks. https://arxiv.org/abs/1506.01186


```python
lr_schedule = tf.keras.optimizers.schedules.CyclicalLearningRate(
initial_learning_rate=1e-4,
maximal_learning_rate=1e-2,
step_size=2000,
scale_fn=lambda x: 1.,
scale_mode="cycle",
name="MyCyclicScheduler")

model.compile(optimizer=tf.keras.optimizers.SGD(
learning_rate=lr_schedule),
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])

model.fit(data, labels, epochs=5)
```

You can pass this schedule directly into a
`tf.keras.optimizers.Optimizer` as the learning rate.

Args:
initial_learning_rate: A scalar `float32` or `float64` `Tensor` or
a Python number. The initial learning rate.
maximal_learning_rate: A scalar `float32` or `float64` `Tensor` or
a Python number. The maximum learning rate.
step_size: A scalar `float32` or `float64` `Tensor` or a
Python number. Step size.
scale_fn: A function. Scheduling function applied in cycle
scale_mode: ['cycle', 'iterations']. Mode to apply during cyclic
schedule
name: (Optional) Name for the operation.

Returns:
Updated learning rate value.
"""
super(CyclicalLearningRate, self).__init__()
self.initial_learning_rate = initial_learning_rate
self.maximal_learning_rate = maximal_learning_rate
self.step_size = step_size
self.scale_fn = scale_fn
self.scale_mode = scale_mode
self.name = name

def __call__(self, step):
with tf.name_scope(self.name or "CyclicalLearningRate"):
initial_learning_rate = tf.convert_to_tensor(
self.initial_learning_rate, name="initial_learning_rate")
dtype = initial_learning_rate.dtype
maximal_learning_rate = tf.cast(self.maximal_learning_rate, dtype)
step_size = tf.cast(self.step_size, dtype)
cycle = tf.floor(1 + step / (2 * step_size))
x = tf.abs(step / step_size - 2 * cycle + 1)

mode_step = cycle if self.scale_mode == "cycle" else step

return initial_learning_rate + (
maximal_learning_rate - initial_learning_rate) * tf.maximum(
tf.cast(0, dtype), (1 - x)) * self.scale_fn(mode_step)

def get_config(self):
return {
"initial_learning_rate": self.initial_learning_rate,
"maximal_learning_rate": self.maximal_learning_rate,
"step_size": self.step_size,
"scale_mode": self.scale_mode,
}


@tf.keras.utils.register_keras_serializable(package='Addons')
class TriangularCyclicalLearningRate(CyclicalLearningRate):
def __init__(
self,
initial_learning_rate,
maximal_learning_rate,
step_size,
scale_mode="cycle",
name="TriangularCyclicalLearningRate",
):
"""Applies triangular cyclical schedule to the learning rate.

See Cyclical Learning Rates for Training Neural Networks. https://arxiv.org/abs/1506.01186


```python
from tf.keras.optimizers import schedules

lr_schedule = schedules.TriangularCyclicalLearningRate(
initial_learning_rate=1e-4,
maximal_learning_rate=1e-2,
step_size=2000,
scale_mode="cycle",
name="MyCyclicScheduler")

model.compile(optimizer=tf.keras.optimizers.SGD(
learning_rate=lr_schedule),
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])

model.fit(data, labels, epochs=5)
```

You can pass this schedule directly into a
`tf.keras.optimizers.Optimizer` as the learning rate.

Args:
initial_learning_rate: A scalar `float32` or `float64` `Tensor` or
a Python number. The initial learning rate.
maximal_learning_rate: A scalar `float32` or `float64` `Tensor` or
a Python number. The maximum learning rate.
step_size: A scalar `float32` or `float64` `Tensor` or a
Python number. Step size.
scale_fn: A function. Scheduling function applied in cycle
scale_mode: ['cycle', 'iterations']. Mode to apply during cyclic
schedule
name: (Optional) Name for the operation.

Returns:
Updated learning rate value.
"""
super(TriangularCyclicalLearningRate, self).__init__(
initial_learning_rate=initial_learning_rate,
maximal_learning_rate=maximal_learning_rate,
step_size=step_size,
scale_fn=lambda x: 1.,
scale_mode=scale_mode,
name=name,
)


@tf.keras.utils.register_keras_serializable(package='Addons')
class Triangular2CyclicalLearningRate(CyclicalLearningRate):
def __init__(
self,
initial_learning_rate,
maximal_learning_rate,
step_size,
scale_mode="cycle",
name="Triangular2CyclicalLearningRate",
):
"""Applies triangular2 cyclical schedule to the learning rate.

See Cyclical Learning Rates for Training Neural Networks. https://arxiv.org/abs/1506.01186


```python
from tf.keras.optimizers import schedules

lr_schedule = schedules.Triangular2CyclicalLearningRate(
initial_learning_rate=1e-4,
maximal_learning_rate=1e-2,
step_size=2000,
scale_mode="cycle",
name="MyCyclicScheduler")

model.compile(optimizer=tf.keras.optimizers.SGD(
learning_rate=lr_schedule),
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])

model.fit(data, labels, epochs=5)
```

You can pass this schedule directly into a
`tf.keras.optimizers.Optimizer` as the learning rate.

Args:
initial_learning_rate: A scalar `float32` or `float64` `Tensor` or
a Python number. The initial learning rate.
maximal_learning_rate: A scalar `float32` or `float64` `Tensor` or
a Python number. The maximum learning rate.
step_size: A scalar `float32` or `float64` `Tensor` or a
Python number. Step size.
scale_fn: A function. Scheduling function applied in cycle
scale_mode: ['cycle', 'iterations']. Mode to apply during cyclic
schedule
name: (Optional) Name for the operation.

Returns:
Updated learning rate value.
"""
super(Triangular2CyclicalLearningRate, self).__init__(
initial_learning_rate=initial_learning_rate,
maximal_learning_rate=maximal_learning_rate,
step_size=step_size,
scale_fn=lambda x: 1 / (2.**(x - 1)),
scale_mode=scale_mode,
name=name,
)


@tf.keras.utils.register_keras_serializable(package='Addons')
class ExponentialCyclicalLearningRate(CyclicalLearningRate):
def __init__(
self,
initial_learning_rate,
maximal_learning_rate,
step_size,
scale_mode="iterations",
gamma=1.,
name="ExponentialCyclicalLearningRate",
):
"""Applies exponential cyclical schedule to the learning rate.

See Cyclical Learning Rates for Training Neural Networks. https://arxiv.org/abs/1506.01186


```python
from tf.keras.optimizers import schedules

lr_schedule = ExponentialCyclicalLearningRate(
initial_learning_rate=1e-4,
maximal_learning_rate=1e-2,
step_size=2000,
scale_mode="cycle",
gamma=0.96,
name="MyCyclicScheduler")

model.compile(optimizer=tf.keras.optimizers.SGD(
learning_rate=lr_schedule),
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])

model.fit(data, labels, epochs=5)
```

You can pass this schedule directly into a
`tf.keras.optimizers.Optimizer` as the learning rate.

Args:
initial_learning_rate: A scalar `float32` or `float64` `Tensor` or
a Python number. The initial learning rate.
maximal_learning_rate: A scalar `float32` or `float64` `Tensor` or
a Python number. The maximum learning rate.
step_size: A scalar `float32` or `float64` `Tensor` or a
Python number. Step size.
scale_fn: A function. Scheduling function applied in cycle
scale_mode: ['cycle', 'iterations']. Mode to apply during cyclic
schedule
gamma: A scalar `float32` or `float64` `Tensor` or a
Python number. Gamma value.
name: (Optional) Name for the operation.

Returns:
Updated learning rate value.
"""
super(ExponentialCyclicalLearningRate, self).__init__(
initial_learning_rate=initial_learning_rate,
maximal_learning_rate=maximal_learning_rate,
step_size=step_size,
scale_fn=lambda x: gamma**x,
scale_mode=scale_mode,
name=name,
)
Loading