|
| 1 | +# coding=utf-8 |
| 2 | +# Copyright 2020 The Trax Authors. |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | + |
| 16 | +# Lint as: python3 |
| 17 | +"""Timing tests for decoding.""" |
| 18 | + |
| 19 | +import functools |
| 20 | +import time |
| 21 | + |
| 22 | +from jax import test_util # pylint: disable=unused-import |
| 23 | +from jax.config import config |
| 24 | +import numpy as np |
| 25 | +from tensorflow.compat.v2 import test |
| 26 | + |
| 27 | +from trax import layers |
| 28 | +from trax import models |
| 29 | +from trax import shapes |
| 30 | +from trax.supervised import decoding |
| 31 | + |
| 32 | + |
| 33 | +class DecodingTimingTest(test.TestCase): |
| 34 | + |
| 35 | + def test_autoregressive_sample_reformer2_timing(self): |
| 36 | + max_len = 16 |
| 37 | + |
| 38 | + def _self_attention_fn(): |
| 39 | + return functools.partial( |
| 40 | + layers.SelfAttention, |
| 41 | + predict_drop_len=2 * max_len, |
| 42 | + predict_mem_len=2 * max_len) |
| 43 | + |
| 44 | + def _causal_attention_fn(): |
| 45 | + return functools.partial( |
| 46 | + layers.CausalAttention, |
| 47 | + max_inference_length=2 * max_len) |
| 48 | + |
| 49 | + pred_model = models.Reformer2( |
| 50 | + mode='predict', |
| 51 | + d_model=4*1024, |
| 52 | + d_ff=32*1024, |
| 53 | + dropout=0.05, |
| 54 | + max_len=max_len, |
| 55 | + n_heads=16, |
| 56 | + n_encoder_layers=3, |
| 57 | + n_decoder_layers=3, |
| 58 | + encoder_attention_type=_self_attention_fn(), |
| 59 | + encoder_decoder_attention_type=_causal_attention_fn(), |
| 60 | + input_vocab_size=32, |
| 61 | + ff_sparsity=128, |
| 62 | + axial_pos_shape=None, |
| 63 | + ) |
| 64 | + |
| 65 | + shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32) |
| 66 | + shape1l = shapes.ShapeDtype((1, max_len), dtype=np.int32) |
| 67 | + pred_model.init(input_signature=(shape1l, shape11)) |
| 68 | + inputs = np.arange(16, dtype=np.int32).reshape(1, 16) |
| 69 | + |
| 70 | + # This is decoding.autoregressive_sample but simplified and with timing. |
| 71 | + result, counter, start_time, total_time = [], 0, time.time(), 0.0 |
| 72 | + for sample in decoding.autoregressive_sample_stream( |
| 73 | + pred_model, inputs, temperature=0.0): # accelerate=False): |
| 74 | + elapsed_time = time.time() - start_time |
| 75 | + start_time = time.time() |
| 76 | + if counter > 3: |
| 77 | + total_time += elapsed_time |
| 78 | + result.append(sample[:, None]) |
| 79 | + counter += 1 |
| 80 | + if counter >= 14: |
| 81 | + break |
| 82 | + |
| 83 | + print('\n\n\nTotal time (10 tokens): %.4fs\n\n\n' % total_time) |
| 84 | + self.assertLess(total_time, 20.0) # If it's > 20s, it's some bug. |
| 85 | + # Check resulting shapes. |
| 86 | + s = np.concatenate(result, axis=1) |
| 87 | + self.assertEqual(s.shape[0], 1) |
| 88 | + self.assertEqual(s.shape[1], 14) |
| 89 | + |
| 90 | + |
| 91 | +if __name__ == '__main__': |
| 92 | + config.config_with_absl() |
| 93 | + test.main() |
0 commit comments