From 176148c9e2a0ff520dd2f108a9182075874a442b Mon Sep 17 00:00:00 2001 From: T2T Team Date: Thu, 3 Oct 2019 11:42:15 -0700 Subject: [PATCH] Fix attention rng mismatch between forward and reverse direction PiperOrigin-RevId: 272707157 --- tensor2tensor/trax/layers/attention.py | 3 +- tensor2tensor/trax/layers/reversible.py | 8 +-- .../trax/models/research/reformer.py | 8 ++- .../trax/models/research/reformer_test.py | 60 +++++++++++++++++++ 4 files changed, 72 insertions(+), 7 deletions(-) diff --git a/tensor2tensor/trax/layers/attention.py b/tensor2tensor/trax/layers/attention.py index a276a8b7c..32f2b6640 100644 --- a/tensor2tensor/trax/layers/attention.py +++ b/tensor2tensor/trax/layers/attention.py @@ -344,7 +344,8 @@ def new_params_and_state(self, input_shape, input_dtype, rng): class BaseCausalAttention(base.Layer): """Base class for variants of causal self-attention.""" - def __init__(self): + def __init__(self, mode='train'): + del mode super(BaseCausalAttention, self).__init__(n_inputs=3) def forward(self, inputs, params=(), state=(), rng=None, **kwargs): diff --git a/tensor2tensor/trax/layers/reversible.py b/tensor2tensor/trax/layers/reversible.py index 5b52aeaee..29f845244 100644 --- a/tensor2tensor/trax/layers/reversible.py +++ b/tensor2tensor/trax/layers/reversible.py @@ -101,8 +101,8 @@ def reverse(self, output, params=(), state=(), **kwargs): rngs = backend.random.split(rng, self._n_layers) layer_val = output - for layer, p, s, rng in reversed(zip(self.sublayers, - params, state, rngs)): + for layer, p, s, rng in reversed(list(zip(self.sublayers, + params, state, rngs))): layer_val = layer.reverse(layer_val, p, s, rng=rng, **kwargs) return layer_val @@ -116,8 +116,8 @@ def reverse_and_grad(self, output, ct, params=(), state=(), **kwargs): layer_val = output layer_ct = ct params_ct = [] - for layer, p, s, rng in reversed(zip(self.sublayers, - params, state, rngs)): + for layer, p, s, rng in reversed(list(zip(self.sublayers, + params, state, rngs))): layer_val, layer_ct = layer.reverse_and_grad( layer_val, layer_ct, p, s, rng=rng, **kwargs) layer_ct, p_ct = layer_ct diff --git a/tensor2tensor/trax/models/research/reformer.py b/tensor2tensor/trax/models/research/reformer.py index 56b8a5aa5..913343e67 100644 --- a/tensor2tensor/trax/models/research/reformer.py +++ b/tensor2tensor/trax/models/research/reformer.py @@ -254,14 +254,18 @@ def __init__(self, attention): super(ApplyAttentionWrapper, self).__init__(attention, [], []) self.attention = attention - def forward_and_backward(self, inputs, ct, **kwargs): + def forward_and_backward(self, inputs, ct, rng=None, **kwargs): # Simultaneous forward pass and backprop through the attention mechanism. qkv = inputs[:3] passthrough = inputs[3:] out_ct = ct[0] passthrough_ct = ct[1:] + if rng is not None: + # Adjust RNG to match the forward pass. + rng = backend.random.split(rng, self._n_layers)[0] - out, qkv_ct = self.attention.forward_and_backward(qkv, out_ct, **kwargs) + out, qkv_ct = self.attention.forward_and_backward( + qkv, out_ct, rng=rng, **kwargs) return (out,) + passthrough, qkv_ct + passthrough_ct diff --git a/tensor2tensor/trax/models/research/reformer_test.py b/tensor2tensor/trax/models/research/reformer_test.py index 799939748..0a8bcdad8 100644 --- a/tensor2tensor/trax/models/research/reformer_test.py +++ b/tensor2tensor/trax/models/research/reformer_test.py @@ -21,10 +21,43 @@ from absl.testing import absltest from absl.testing import parameterized +import jax +import numpy as onp + +from tensor2tensor.trax import backend from tensor2tensor.trax import layers as tl +from tensor2tensor.trax.backend import numpy as np from tensor2tensor.trax.models.research import reformer +class PoisonOnRNGMismatchAttention(tl.BaseCausalAttention): + """Fills gradients with NaNs if reverse rng does not match forward rng.""" + + # pylint: disable=protected-access + def forward_and_backward(self, inputs, ct, rng=None, **kwargs): + assert backend.get_name() == 'jax', ( + 'JAX backend is required to use forward_and_backward.') + + if ct is not None and tl.Layer._STASH_OUT is not None: + recovered_rng = tl.Layer._STASH_OUT.pop(self) + is_same = (rng[0] == recovered_rng[0]) & (rng[1] == recovered_rng[1]) + is_same = is_same.astype(np.float32) + # Divides by zero if rngs are not the same, which results in NaNs. + inputs = (inputs[0] / is_same, inputs[1] / is_same, inputs[2] / is_same) + + def _do_forward(x): # pylint: disable=invalid-name + res, _ = self.forward(x, rng=rng, **kwargs) + return res + output, vjpfun = jax.vjp(_do_forward, inputs) + return output, vjpfun(ct)[0] + + def forward(self, inputs, params=(), state=(), rng=None, **kwargs): + if tl.Layer._STASH_IN is not None: + tl.Layer._STASH_IN[self] = rng + return inputs[2], state + # pylint: enable=protected-access + + class ReformerTest(parameterized.TestCase): def test_reformer_lm_forward_shape(self): @@ -39,6 +72,33 @@ def test_reformer_lm_forward_shape(self): model, tuple(input_shape), integer_inputs=True) self.assertEqual(((1, 8, 16), (1, 8, 16)), final_shape) + def test_reformer_rng_consistency(self): + with backend.use_backend('jax'): + vocab_size = 16 + batch_size = 1 + input_shape = ((batch_size, 8), (batch_size, 8)) + model = reformer.ReformerLM( + vocab_size, d_model=32, d_ff=64, + d_attention_key=16, d_attention_value=16, n_layers=1, n_heads=2, + max_len=16, n_chunks=2, n_attention_chunks=1, mode='train', + attention_type=PoisonOnRNGMismatchAttention) + + rng = backend.random.get_prng(0) + params, state = model.initialize_once( + input_shape, (np.int32, np.int32), rng) + + def dummy_loss_fn(params): + inputs = (np.zeros(input_shape[0], dtype=np.int32),) * 2 + output = model(inputs, params=params, state=state, rng=rng) + dummy_loss = backend.numpy.sum(output[0]) + return dummy_loss + + grad_fn = backend.grad(dummy_loss_fn) + grads = grad_fn(params) + # PoisonOnRNGMismatchAttention uses NaNs to signal an rng mismatch. + for grad in jax.tree_util.tree_leaves(grads): + assert onp.all(onp.isfinite(grad)) + if __name__ == '__main__': absltest.main()