Skip to content

Commit 439c4f5

Browse files
afrozenatorcopybara-github
authored andcommitted
[TRAX] Extract out decoding timing from decoding test.
PiperOrigin-RevId: 333374580
1 parent 245741e commit 439c4f5

File tree

3 files changed

+94
-56
lines changed

3 files changed

+94
-56
lines changed

oss_scripts/oss_tests.sh

+1
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ pytest --disable-warnings \
6262
--ignore=trax/supervised/trainer_lib_test.py \
6363
--ignore=trax/supervised/training_test.py \
6464
--ignore=trax/supervised/decoding_test.py \
65+
--ignore=trax/supervised/decoding_timing_test.py \
6566
trax/supervised
6667
set_status
6768

trax/supervised/decoding_test.py

-56
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818

1919
import functools
2020
import os
21-
import time
2221

2322
from jax import test_util # pylint: disable=unused-import
2423
from jax.config import config
@@ -202,61 +201,6 @@ def test_autoregressive_sample_reformer2_lsh(self):
202201
self.assertEqual(s.shape[0], 1)
203202
self.assertEqual(s.shape[1], 10)
204203

205-
def test_autoregressive_sample_reformer2_timing(self):
206-
max_len = 16
207-
208-
def _self_attention_fn():
209-
return functools.partial(
210-
layers.SelfAttention,
211-
predict_drop_len=2 * max_len,
212-
predict_mem_len=2 * max_len)
213-
214-
def _causal_attention_fn():
215-
return functools.partial(
216-
layers.CausalAttention,
217-
max_inference_length=2 * max_len)
218-
219-
pred_model = models.Reformer2(
220-
mode='predict',
221-
d_model=4*1024,
222-
d_ff=32*1024,
223-
dropout=0.05,
224-
max_len=max_len,
225-
n_heads=16,
226-
n_encoder_layers=3,
227-
n_decoder_layers=3,
228-
encoder_attention_type=_self_attention_fn(),
229-
encoder_decoder_attention_type=_causal_attention_fn(),
230-
input_vocab_size=32,
231-
ff_sparsity=128,
232-
axial_pos_shape=None,
233-
)
234-
235-
shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32)
236-
shape1l = shapes.ShapeDtype((1, max_len), dtype=np.int32)
237-
pred_model.init(input_signature=(shape1l, shape11))
238-
inputs = np.arange(16, dtype=np.int32).reshape(1, 16)
239-
240-
# This is decoding.autoregressive_sample but simplified and with timing.
241-
result, counter, start_time, total_time = [], 0, time.time(), 0.0
242-
for sample in decoding.autoregressive_sample_stream(
243-
pred_model, inputs, temperature=0.0): # accelerate=False):
244-
elapsed_time = time.time() - start_time
245-
start_time = time.time()
246-
if counter > 3:
247-
total_time += elapsed_time
248-
result.append(sample[:, None])
249-
counter += 1
250-
if counter >= 14:
251-
break
252-
253-
print('\n\n\nTotal time (10 tokens): %.4fs\n\n\n' % total_time)
254-
self.assertLess(total_time, 10.0) # If it's > 10s, it's some bug.
255-
# Check resulting shapes.
256-
s = np.concatenate(result, axis=1)
257-
self.assertEqual(s.shape[0], 1)
258-
self.assertEqual(s.shape[1], 14)
259-
260204
def test_autoregressive_sample_reformer2_copy_self_attn_quality(self):
261205
max_len = 32
262206

+93
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# coding=utf-8
2+
# Copyright 2020 The Trax Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
# Lint as: python3
17+
"""Timing tests for decoding."""
18+
19+
import functools
20+
import time
21+
22+
from jax import test_util # pylint: disable=unused-import
23+
from jax.config import config
24+
import numpy as np
25+
from tensorflow.compat.v2 import test
26+
27+
from trax import layers
28+
from trax import models
29+
from trax import shapes
30+
from trax.supervised import decoding
31+
32+
33+
class DecodingTimingTest(test.TestCase):
34+
35+
def test_autoregressive_sample_reformer2_timing(self):
36+
max_len = 16
37+
38+
def _self_attention_fn():
39+
return functools.partial(
40+
layers.SelfAttention,
41+
predict_drop_len=2 * max_len,
42+
predict_mem_len=2 * max_len)
43+
44+
def _causal_attention_fn():
45+
return functools.partial(
46+
layers.CausalAttention,
47+
max_inference_length=2 * max_len)
48+
49+
pred_model = models.Reformer2(
50+
mode='predict',
51+
d_model=4*1024,
52+
d_ff=32*1024,
53+
dropout=0.05,
54+
max_len=max_len,
55+
n_heads=16,
56+
n_encoder_layers=3,
57+
n_decoder_layers=3,
58+
encoder_attention_type=_self_attention_fn(),
59+
encoder_decoder_attention_type=_causal_attention_fn(),
60+
input_vocab_size=32,
61+
ff_sparsity=128,
62+
axial_pos_shape=None,
63+
)
64+
65+
shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32)
66+
shape1l = shapes.ShapeDtype((1, max_len), dtype=np.int32)
67+
pred_model.init(input_signature=(shape1l, shape11))
68+
inputs = np.arange(16, dtype=np.int32).reshape(1, 16)
69+
70+
# This is decoding.autoregressive_sample but simplified and with timing.
71+
result, counter, start_time, total_time = [], 0, time.time(), 0.0
72+
for sample in decoding.autoregressive_sample_stream(
73+
pred_model, inputs, temperature=0.0): # accelerate=False):
74+
elapsed_time = time.time() - start_time
75+
start_time = time.time()
76+
if counter > 3:
77+
total_time += elapsed_time
78+
result.append(sample[:, None])
79+
counter += 1
80+
if counter >= 14:
81+
break
82+
83+
print('\n\n\nTotal time (10 tokens): %.4fs\n\n\n' % total_time)
84+
self.assertLess(total_time, 20.0) # If it's > 20s, it's some bug.
85+
# Check resulting shapes.
86+
s = np.concatenate(result, axis=1)
87+
self.assertEqual(s.shape[0], 1)
88+
self.assertEqual(s.shape[1], 14)
89+
90+
91+
if __name__ == '__main__':
92+
config.config_with_absl()
93+
test.main()

0 commit comments

Comments
 (0)