Skip to content

davisyoshida/tf2-gradient-checkpointing

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

6 Commits
 
 
 
 
 
 

Repository files navigation

TensorFlow 2 Gradient Checkpointing

This is a simple decorator to enable gradient checkpointing (e.g. Chen et al. (2016)) in TF2. It isn't very polished, but it's been letting me train bigger GPT-2 models on smaller hardware, so I thought I'd share it.

Basic Usage

Use the checkpointable decorator to allow a function (or callable object such as a Keras Layer) to use gradient checkpointing. If checkpointing is desired, call the decorated function with the _checkpoint keyword argument set to True.

The example below shows a model with 40000 "layers", but checkpointing allows just 400 to be in memory at any point. On a GTX 1070 Ti, this code will result in an OOM error when the _checkpoint argument is set to False.

import tensorflow as tf

from checkpointing import checkpointable

@checkpointable
def f(x, y, some_str, some_bool, z=None):
    for _ in range(200):
        x += y * z
    return x

initial = tf.ones(100000, dtype=tf.float32)
y = tf.ones(100000, dtype=tf.float32) + 1e-7
z = tf.ones(100000, dtype=tf.float32) + 1e-7
with tf.GradientTape() as g:
    g.watch(initial)
    x = initial
    for _ in range(200):
        x = f(x, y, 'a', True, z=z, _checkpoint=True)
    loss = tf.reduce_sum(x)
print(g.gradient(loss, x))

Arguments which are not float32 tensors (or nested list/tuple structures of such tensors) are allowed, but ignored for the purposes of gradient computation.

Variables

If the decorated function uses variables which are not arguments, pass a list of them via the _watch_vars keyword argument as shown below.

layer = SomeKerasLayer()
wrapped_layer = checkpointable(layer)

with tf.GradientTape() as g:
    g.watch(layer.trainable_variables)
    output = wrapped_layer(*args, **kwargs, _checkpoint=True, _watch_vars=layer.trainable_variables)
print(g.gradient(output, layer.trainable_variables))

Warning: Dropout

Because gradient checkpoint relies on re-running the forward pass, stochastic layers such as a dropout will give different results for each pass. There is a hacky workaround available, which you can enable by passing _force_seed=True to the decorated function. This will use python's random library to get a random number, and set that as TensorFlow's random seed before each forward pass. If you have a better idea for addressing this issue, please do let me know.

About

Simple gradient checkpointing for eager mode execution

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages