From a852994a3bf336fb90f45950fc0a6b71260e111c Mon Sep 17 00:00:00 2001 From: Etienne Pot Date: Mon, 11 Sep 2017 13:18:05 -0700 Subject: [PATCH 01/39] Attention moe can mix attention layer types PiperOrigin-RevId: 168274573 --- tensor2tensor/models/attention_lm_moe.py | 66 +++++++++++++++++++----- 1 file changed, 52 insertions(+), 14 deletions(-) diff --git a/tensor2tensor/models/attention_lm_moe.py b/tensor2tensor/models/attention_lm_moe.py index adbb871b5..3afe77fc0 100644 --- a/tensor2tensor/models/attention_lm_moe.py +++ b/tensor2tensor/models/attention_lm_moe.py @@ -60,6 +60,13 @@ def get_choices(): ] +LAYER_SYMBOLS = { + "h": AttentionType.MULTIHEAD, # multi-Head + "e": AttentionType.LOCAL_EXPERTS, # Experts + "m": AttentionType.MEMORY_EFFICIENT, # Memory +} + + @registry.register_model class AttentionLmMoe(t2t_model.T2TModel): """Attention net. See file docstring.""" @@ -133,11 +140,20 @@ def print_shape(x, suffix, debug=False): assert hparams.batch_size >= hparams.max_length - for layer in xrange(hparams.num_hidden_layers): + num_hidden_layers = ( + len(hparams.attention_layers) or hparams.num_hidden_layers) + for layer in xrange(num_hidden_layers): with tf.variable_scope("layer_%d" % layer): + + # Use the layer type defined in attention_layers + if hparams.attention_layers: + attention_type = LAYER_SYMBOLS[hparams.attention_layers[layer]] + else: + attention_type = hparams.attention_type + with tf.variable_scope( - "attention_{}".format(hparams.attention_type)): - if hparams.attention_type == AttentionType.MULTIHEAD: + "attention_{}".format(attention_type)): + if attention_type == AttentionType.MULTIHEAD: y = dp( common_attention.multihead_attention, preprocess(x), @@ -151,7 +167,7 @@ def print_shape(x, suffix, debug=False): attention_type=("local_mask_right" if hparams.attention_local else "dot_product"), name="decoder_self_attention") - elif hparams.attention_type == AttentionType.MEMORY_EFFICIENT: + elif attention_type == AttentionType.MEMORY_EFFICIENT: assert hparams.layer_preprocess_sequence == "n" y = dp( common_attention.multihead_self_attention_memory_efficient, @@ -159,7 +175,7 @@ def print_shape(x, suffix, debug=False): decoder_self_attention_bias, hparams.num_heads, name="decoder_self_attention") - elif hparams.attention_type == AttentionType.LOCAL_EXPERTS: + elif attention_type == AttentionType.LOCAL_EXPERTS: y, loss = dp( common_attention.local_expert_attention, preprocess(x), @@ -350,6 +366,10 @@ def attention_lm_moe_base(): hparams.add_hparam("pos", "timing") # timing, none hparams.add_hparam("moe_layers", "2") # comma separated list of layer numbers # moe params. local attention moe. + # If attention_layers is set, the num_hidden_layers parameter will be ignored + # and each caracter of the string will correspond to one attention + # layer type + hparams.add_hparam("attention_layers", "") hparams.add_hparam("attention_type", AttentionType.MULTIHEAD) hparams.add_hparam("attention_local", int(False)) hparams.add_hparam("attention_moe_k", 2) @@ -370,14 +390,24 @@ def attention_lm_moe_base(): @registry.register_hparams -def attention_lm_moe_base_ae(): - """Base model with attention expert.""" +def attention_lm_moe_base_long_seq(): + """Hyper parameters specifics for long sequence generation.""" hparams = attention_lm_moe_base() - hparams.attention_type = AttentionType.LOCAL_EXPERTS - hparams.use_sepconv = int(True) + hparams.max_length = 0 # max_length == batch_size hparams.eval_drop_long_sequences = int(True) hparams.min_length_bucket = 256 # Avoid cyclic problems for big batches + hparams.use_sepconv = int(True) + + return hparams + + +@registry.register_hparams +def attention_lm_moe_base_ae(): + """Base model with attention expert.""" + hparams = attention_lm_moe_base_long_seq() + hparams.attention_type = AttentionType.LOCAL_EXPERTS + hparams.learning_rate = 0.05 hparams.learning_rate_warmup_steps = 10000 # According to noam, ("n", "da") seems better for harder-to-learn models @@ -389,12 +419,20 @@ def attention_lm_moe_base_ae(): @registry.register_hparams def attention_lm_moe_base_local(): """Base model with attention expert.""" - hparams = attention_lm_moe_base() + hparams = attention_lm_moe_base_long_seq() hparams.attention_local = int(True) - hparams.use_sepconv = int(True) - hparams.max_length = 0 # max_length == batch_size - hparams.eval_drop_long_sequences = int(True) - hparams.min_length_bucket = 256 # Avoid cyclic problems for big batches + return hparams + + +@registry.register_hparams +def attention_lm_moe_base_hybrid(): + """Base model with attention expert.""" + hparams = attention_lm_moe_base_long_seq() + hparams.attention_layers = "hehe" # Alternate local/expert + hparams.attention_local = int(True) + + # hparams.layer_preprocess_sequence = "n" + # hparams.layer_postprocess_sequence = "da" return hparams From 017f83a91da0a4f481834fa0eac44d446774acfe Mon Sep 17 00:00:00 2001 From: Ashish Vaswani Date: Mon, 11 Sep 2017 14:16:09 -0700 Subject: [PATCH 02/39] Bug fixes in masked_local_attention_2d and local_attention_2d. We needed to scatter the representations after attention back into the right positions. Added test2dGatherAndScatter, which tests for invertibility of 2d gather and scatter functions. PiperOrigin-RevId: 168283655 --- tensor2tensor/layers/common_attention.py | 49 ++++++++++++------- tensor2tensor/layers/common_attention_test.py | 20 ++++++++ 2 files changed, 50 insertions(+), 19 deletions(-) diff --git a/tensor2tensor/layers/common_attention.py b/tensor2tensor/layers/common_attention.py index 6f7c9fa23..3b89ef1bc 100644 --- a/tensor2tensor/layers/common_attention.py +++ b/tensor2tensor/layers/common_attention.py @@ -650,16 +650,13 @@ def local_attention_2d(q, """ with tf.variable_scope( name, default_name="local_self_attention_2d", values=[q, k, v]): + q_shape = q.get_shape().as_list() v_shape = tf.shape(v) - depth_v = tf.shape(v)[4] - batch_size = tf.shape(q)[0] - num_heads = tf.shape(q)[1] - original_length = tf.shape(q)[2] * tf.shape(q)[3] q = pad_to_multiple_2d(q, query_shape) k = pad_to_multiple_2d(k, query_shape) v = pad_to_multiple_2d(v, query_shape) - + padded_q_shape = tf.shape(q) # Setting up k and v values paddings = [[0, 0], [0, 0], [memory_flange[0], memory_flange[1]], [memory_flange[0], memory_flange[1]], [0, 0]] @@ -684,12 +681,13 @@ def local_attention_2d(q, attention = tf.nn.softmax(logits + attention_bias) output = tf.matmul(attention, v_new) - - output = tf.reshape(output, [batch_size, num_heads, -1, depth_v]) - # Remove the padding if introduced - output = tf.slice(output, [0, 0, 0, 0], [-1, -1, original_length, -1]) - # [batch, heads, h, w, depth_v] - return tf.reshape(output, v_shape) + # putting the representations back in the right place + output = scatter_blocks_2d(output, q_indices, padded_q_shape) + # Remove the padding if introduced + output = tf.slice(output, [0, 0, 0, 0, 0], + [-1, -1, v_shape[2], v_shape[3], -1]) + output.set_shape(q_shape) + return output def pad_to_multiple_2d(x, block_shape): @@ -726,6 +724,19 @@ def gather_blocks_2d(x, indices): return tf.transpose(x_new, [2, 3, 0, 1, 4]) +def scatter_blocks_2d(x, indices, shape): + """scatters blocks from x into shape with indices.""" + x_shape = tf.shape(x) + # [length, batch, heads, dim] + x_t = tf.transpose(tf.reshape(x, [x_shape[0], x_shape[1], -1, x_shape[-1]]), + [2, 0, 1, 3]) + x_t_shape = tf.shape(x_t) + indices = tf.reshape(indices, [-1, 1]) + scattered_x = tf.scatter_nd(indices, x_t, x_t_shape) + scattered_x = tf.transpose(scattered_x, [1, 2, 0, 3]) + return tf.reshape(scattered_x, shape) + + def gather_indices_2d(x, block_shape, block_stride): """Getting gather indices.""" # making an identity matrix kernel @@ -769,11 +780,8 @@ def masked_local_attention_2d(q, """ with tf.variable_scope( name, default_name="local_masked_self_attention_2d", values=[q, k, v]): + q_shape = q.get_shape().as_list() v_shape = tf.shape(v) - depth_v = tf.shape(v)[4] - batch_size = tf.shape(q)[0] - num_heads = tf.shape(q)[1] - original_length = tf.shape(q)[2] * tf.shape(q)[3] def make_mask(query_shape, memory_flange): """creates a mask. @@ -808,6 +816,7 @@ def make_mask(query_shape, memory_flange): # 0. is visible location, 1.0 is masked. return 1. - final_mask q = pad_to_multiple_2d(q, query_shape) + padded_q_shape = tf.shape(q) k = pad_to_multiple_2d(k, query_shape) v = pad_to_multiple_2d(v, query_shape) # Setting up k and v values. Padding top, left, and right @@ -838,11 +847,13 @@ def make_mask(query_shape, memory_flange): tf.to_float(tf.logical_or(attention_mask, padding_mask)) *-1e9) attention = tf.nn.softmax(logits + attention_bias) output = tf.matmul(attention, v_new) - output = tf.reshape(output, [batch_size, num_heads, -1, depth_v]) + # putting the representations back in the right place + output = scatter_blocks_2d(output, q_indices, padded_q_shape) # Remove the padding if introduced - output = tf.slice(output, [0, 0, 0, 0], [-1, -1, original_length, -1]) - # [batch, heads, h, w, depth_v] - return tf.reshape(output, v_shape) + output = tf.slice(output, [0, 0, 0, 0, 0], + [-1, -1, v_shape[2], v_shape[3], -1]) + output.set_shape(q_shape) + return output def compute_qkv(query_antecedent, memory_antecedent, total_key_depth, diff --git a/tensor2tensor/layers/common_attention_test.py b/tensor2tensor/layers/common_attention_test.py index d8f6f2b39..644b27a98 100644 --- a/tensor2tensor/layers/common_attention_test.py +++ b/tensor2tensor/layers/common_attention_test.py @@ -162,5 +162,25 @@ def testMultiheadSelfAttentionMemoryEfficient(self): self.assertAllClose(dnorm_bias, dnorm_bias_f) self.assertAllClose(dx, dx_f) + def test2dGatherAndScatter(self): + """2d gather and scatter invertibility test.""" + batch_size = 2 + num_heads = 2 + height = 4 + width = 6 + depth = 8 + query_shape = (2, 3) + x = np.random.rand(batch_size, num_heads, height, width, depth) + with self.test_session() as session: + x_indices = common_attention.gather_indices_2d( + x, query_shape, query_shape) + gathered_x = common_attention.gather_blocks_2d(x, x_indices) + x_shape = tf.constant([batch_size, num_heads, height, width, depth]) + scattered_x = common_attention.scatter_blocks_2d( + gathered_x, x_indices, x_shape) + session.run(tf.global_variables_initializer()) + res = session.run(scattered_x) + self.assertAllClose(x, res) + if __name__ == "__main__": tf.test.main() From b5db405a4dc09bb72241dceebac9806cefa348e4 Mon Sep 17 00:00:00 2001 From: Etienne Pot Date: Mon, 11 Sep 2017 15:06:11 -0700 Subject: [PATCH 03/39] Fix the pad_remover for attention expert when hybrid attention layers. Now only applied for the attention expert. PiperOrigin-RevId: 168292028 --- tensor2tensor/models/attention_lm_moe.py | 23 ++++++----------------- 1 file changed, 6 insertions(+), 17 deletions(-) diff --git a/tensor2tensor/models/attention_lm_moe.py b/tensor2tensor/models/attention_lm_moe.py index 3afe77fc0..abdd68c8b 100644 --- a/tensor2tensor/models/attention_lm_moe.py +++ b/tensor2tensor/models/attention_lm_moe.py @@ -105,8 +105,7 @@ def _diet_expert(x): expert_fn = expert_utils.ffn_expert_fn( hparams.hidden_size, moe_hidden_sizes, hparams.hidden_size) - if (hparams.attention_type == AttentionType.LOCAL_EXPERTS - and not hparams.use_inputs): + if not hparams.use_inputs: # As preprocess and postprocess are called with batch of size one (all # batches concatenated), we just make sure that batch_norm is not use ( # should not either way) @@ -135,8 +134,6 @@ def print_shape(x, suffix, debug=False): batch_coordinate = dp_remove_pad(batch_coordinate) x = dp(print_shape, x, "in") - x = dp_remove_pad(x) - x = dp(print_shape, x, "in_flat") assert hparams.batch_size >= hparams.max_length @@ -176,9 +173,11 @@ def print_shape(x, suffix, debug=False): hparams.num_heads, name="decoder_self_attention") elif attention_type == AttentionType.LOCAL_EXPERTS: + x_in = preprocess(x) + x_in = dp_remove_pad(x_in) y, loss = dp( common_attention.local_expert_attention, - preprocess(x), + x_in, k=hparams.attention_moe_k, loss_coef=hparams.attention_load_balance, attention_num_experts=hparams.attention_num_experts, @@ -188,6 +187,7 @@ def print_shape(x, suffix, debug=False): split_batch=bool(hparams.attention_split_batch), attention_kq_size=hparams.attention_kq_size, attention_v_size=hparams.attention_v_size) + y = dp_restore_pad(y) # TODO(avaswani, epot, noam): Do we need to divide by num shards ? extra_loss += tf.add_n(loss) / dp.n else: @@ -214,15 +214,8 @@ def print_shape(x, suffix, debug=False): x, hparams.filter_size) else: - x_in = preprocess(x) additional_conv_params = dict() if hparams.use_sepconv: - # Restore padding so sequences don't attend to each others - # restore_pad will apply a reshape like x_ref, to restore the - # original shape. Here this works because the last dimension is - # constant between the output of attention and the original input - # but it shouldn't necessarily be the case. - x_in = dp_restore_pad(x_in) additional_conv_params = dict( padding="LEFT", # Parameters copied from the transformer model @@ -231,19 +224,15 @@ def print_shape(x, suffix, debug=False): ) y = dp( common_layers.conv_hidden_relu, - x_in, + preprocess(x), hparams.filter_size, hparams.hidden_size, dropout=hparams.relu_dropout, **additional_conv_params ) - if hparams.use_sepconv: - y = dp_remove_pad(y) x = postprocess(x, y) x = preprocess(x) - x = dp_restore_pad(x) - decoder_output = dp(tf.expand_dims, x, 2) return decoder_output, extra_loss From 15682d535244ca33983d8933df9725d459e02d4f Mon Sep 17 00:00:00 2001 From: Noam Shazeer Date: Mon, 11 Sep 2017 16:23:33 -0700 Subject: [PATCH 04/39] Added a new model "aligned" for aligned sequence problems without autoregression/masking. PiperOrigin-RevId: 168302680 --- tensor2tensor/models/__init__.py | 1 + tensor2tensor/models/aligned.py | 256 +++++++++++++++++++++++++++++++ 2 files changed, 257 insertions(+) create mode 100644 tensor2tensor/models/aligned.py diff --git a/tensor2tensor/models/__init__.py b/tensor2tensor/models/__init__.py index acebef809..f5fafe706 100644 --- a/tensor2tensor/models/__init__.py +++ b/tensor2tensor/models/__init__.py @@ -23,6 +23,7 @@ # pylint: disable=unused-import from tensor2tensor.layers import modalities +from tensor2tensor.models import aligned from tensor2tensor.models import attention_lm from tensor2tensor.models import attention_lm_moe from tensor2tensor.models import bluenet diff --git a/tensor2tensor/models/aligned.py b/tensor2tensor/models/aligned.py new file mode 100644 index 000000000..9cadc0cae --- /dev/null +++ b/tensor2tensor/models/aligned.py @@ -0,0 +1,256 @@ +# coding=utf-8 +# Copyright 2017 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Single stack of transformations with no masking. + +Produces output aligned with inputs. + +Configurable using hyperparameters to use some combination of convolutions, +attention, mixtures of experts, etc. + +A good problem for this model is languagemodel_wiki_scramble1k50 . +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# Dependency imports + +from tensor2tensor.layers import common_attention +from tensor2tensor.layers import common_hparams +from tensor2tensor.layers import common_layers +from tensor2tensor.utils import diet +from tensor2tensor.utils import expert_utils +from tensor2tensor.utils import registry +from tensor2tensor.utils import t2t_model + +import tensorflow as tf + + +ModeKeys = tf.estimator.ModeKeys # pylint: disable=invalid-name + + +def _should_preprocess(layer_type): + return layer_type not in ["timing", "pos_emb"] + + +def _should_postprocess(layer_type): + return layer_type not in ["timing", "pos_emb"] + + +@registry.register_model +class Aligned(t2t_model.T2TModel): + """Attention net. See file docstring.""" + + def model_fn_body_sharded(self, sharded_features): + # Remove dropout if not training + hparams = self._hparams + dp = self._data_parallelism + x = dp(tf.squeeze, sharded_features["inputs"], 2) + def preprocess(x): + return dp(common_layers.layer_preprocess, x, hparams) + def postprocess(x, y): + return dp(common_layers.layer_postprocess, x, y, hparams) + x = dp(tf.nn.dropout, x, 1.0 - hparams.layer_prepostprocess_dropout) + extra_loss = 0.0 + ffn_hidden_sizes = [int(s) for s in hparams.ffn_hidden_sizes.split(",")] + moe_hidden_sizes = [int(s) for s in hparams.moe_hidden_sizes.split(",")] + if hparams.diet_experts: + hsize, = moe_hidden_sizes + + def _diet_expert(x): + return diet.diet_expert(x, hsize, diet.diet_adam_optimizer_params()) + + expert_fn = _diet_expert + else: + expert_fn = expert_utils.ffn_expert_fn( + hparams.hidden_size, moe_hidden_sizes, hparams.hidden_size) + + batch_coordinate = dp(get_batch_coordinate, x) + + assert hparams.batch_size >= hparams.max_length + + layers = hparams.layers.strip(",").split(",") + for layer_num, layer_type in enumerate(layers): + with tf.variable_scope("%s_%d" % (layer_type, layer_num)): + if _should_preprocess(layer_type): + x = preprocess(x) + if layer_type == "timing": + y = dp(common_attention.add_timing_signal_nd, x) + elif layer_type == "pos_emb": + y = dp(common_attention.add_positional_embedding_nd, + x, hparams.max_length, name="pos_emb") + elif layer_type == "att": + # multihead attention + y = dp( + common_attention.multihead_attention, + x, + None, + None, # bias + hparams.attention_key_channels or hparams.hidden_size, + hparams.attention_value_channels or hparams.hidden_size, + hparams.hidden_size, + hparams.num_heads, + hparams.attention_dropout, + attention_type=("local_unmasked" if hparams.attention_local + else "dot_product"), + name="decoder_self_attention") + elif layer_type == "local_expert_attention": + y, loss = dp( + common_attention.local_expert_attention, + x, + k=hparams.attention_moe_k, + loss_coef=hparams.attention_load_balance, + attention_num_experts=hparams.attention_num_experts, + train=hparams.mode == ModeKeys.TRAIN, + batch_coordinate=batch_coordinate, + mask_right=False, + split_batch=bool(hparams.attention_split_batch), + attention_kq_size=hparams.attention_kq_size, + attention_v_size=hparams.attention_v_size) + # TODO(avaswani, epot, noam): Do we need to divide by num shards ? + extra_loss += tf.add_n(loss) / dp.n + elif layer_type == "moe": + y, loss = expert_utils.distributed_moe( + dp, + self._ps_devices, + x, + hparams.mode == ModeKeys.TRAIN, + input_size=hparams.hidden_size, + expert_fn=expert_fn, + num_experts=hparams.moe_num_experts, + k=hparams.moe_k, + loss_coef=hparams.moe_loss_coef) + extra_loss += loss + elif layer_type == "ffn": + y = dp( + expert_utils.ffn_expert_fn( + hparams.hidden_size, + ffn_hidden_sizes, + hparams.hidden_size), + dp(expert_utils.flatten_all_but_last, x)) + y = dp(expert_utils.reshape_like, y, x) + elif layer_type == "conv": + y = dp( + common_layers.conv1d, + x, + hparams.hidden_size, + hparams.kernel_height, + activation=tf.nn.relu, + padding="SAME", + ) + else: + assert False, "unknown sublayer %s" % layer_type + if _should_postprocess(layer_type): + x = postprocess(x, y) + else: + x = y + x = preprocess(x) + + decoder_output = dp(tf.expand_dims, x, 2) + return decoder_output, extra_loss + + +def get_batch_coordinate(x): + """Return a flat int32 tensor of shape [1, batch_size*length, 1].""" + # Compute the batch coordinate before flattening all batches + batch_coordinate = tf.expand_dims( + common_attention.coordinate_tensor(tf.shape(x)[:-1], axis=0), axis=-1) + return batch_coordinate + + +@registry.register_hparams +def aligned_base(): + """Set of hyperparameters. + + Returns: + a hparams object + """ + hparams = common_hparams.basic_params1() + hparams.hidden_size = 512 + hparams.batch_size = 5000 + hparams.max_length = 1024 + hparams.dropout = 0.0 + hparams.layer_prepostprocess_dropout = 0.0 + hparams.label_smoothing = 0.0 + hparams.clip_grad_norm = 0. # i.e. no gradient clipping + hparams.optimizer_adam_epsilon = 1e-9 + hparams.learning_rate_decay_scheme = "noam" + hparams.learning_rate = 0.1 + hparams.learning_rate_warmup_steps = 2000 + hparams.initializer_gain = 1.0 + hparams.initializer = "uniform_unit_scaling" + hparams.weight_decay = 0.0 + hparams.optimizer_adam_beta1 = 0.9 + hparams.optimizer_adam_beta2 = 0.98 + hparams.shared_embedding_and_softmax_weights = int(False) + hparams.add_hparam("ffn_hidden_sizes", "2048") # Add new ones like this. + hparams.moe_num_experts = 32 + hparams.layer_preprocess_sequence = "n" + hparams.layer_postprocess_sequence = "da" + hparams.add_hparam("layers", "timing," + "att,ffn," * 4) + + # attention-related flags + hparams.add_hparam("num_heads", 8) + hparams.add_hparam("attention_key_channels", 0) + hparams.add_hparam("attention_value_channels", 0) + # All hyperparameters ending in "dropout" are automatically set to 0.0 + # when not in training mode. + hparams.add_hparam("attention_dropout", 0.0) + hparams.add_hparam("pos", "timing") # timing, none + # moe params. local attention moe. + hparams.add_hparam("attention_local", int(False)) + hparams.add_hparam("attention_moe_k", 2) + hparams.add_hparam("attention_num_experts", 16) + hparams.add_hparam("attention_split_batch", int(False)) + # Key, query and value dimensions for the attention + hparams.add_hparam("attention_kq_size", 128) + hparams.add_hparam("attention_v_size", 256) + # Loss coef for load balancing + hparams.add_hparam("attention_load_balance", 2e-2) + hparams.add_hparam("diet_experts", int(False)) + hparams.add_hparam("memory_efficient_ffn", int(False)) + # if True, we learn a non-autoregressive model from "inputs" to "targets". + # if False, we learn an autoregressive model to generate "targets" + return hparams + + +@registry.register_hparams +def aligned_with_conv(): + hparams = aligned_base() + hparams.layers = "timing," + "conv,att,ffn," * 4 + return hparams + + +@registry.register_hparams +def aligned_local(): + hparams = aligned_base() + hparams.attention_local = int(True) + return hparams + + +@registry.register_hparams +def aligned_pos_emb(): + hparams = aligned_base() + hparams.layers = "pos_emb," + "att,ffn," * 4 + return hparams + + +@registry.register_hparams +def aligned_moe(): + hparams = aligned_base() + hparams.layers = "timing," + "att,moe," * 4 + return hparams From a7c70874e0e545dd6d890c1122e4005e2b65ccf3 Mon Sep 17 00:00:00 2001 From: Ashish Vaswani Date: Tue, 12 Sep 2017 17:12:25 -0700 Subject: [PATCH 05/39] Added tests for 2-d local attention. Refactoring to use dot_product_attention in local_1d and 2d attention functions. Adding a flag for image summaries in dot_product_attention because we need to figure out the best way to get image summaries in 2d functions. PiperOrigin-RevId: 168472353 --- tensor2tensor/layers/common_attention.py | 101 +++++++++--------- tensor2tensor/layers/common_attention_test.py | 65 ++++++++++- 2 files changed, 117 insertions(+), 49 deletions(-) diff --git a/tensor2tensor/layers/common_attention.py b/tensor2tensor/layers/common_attention.py index 3b89ef1bc..fdba48b01 100644 --- a/tensor2tensor/layers/common_attention.py +++ b/tensor2tensor/layers/common_attention.py @@ -416,7 +416,8 @@ def dot_product_attention(q, bias, dropout_rate=0.0, image_shapes=None, - name=None): + name=None, + make_image_summary=True): """dot-product attention. Args: @@ -428,6 +429,7 @@ def dot_product_attention(q, image_shapes: optional tuple of integer scalars. see comments for attention_image_summary() name: an optional string + make_image_summary: True if you want an image summary. Returns: A Tensor. @@ -443,7 +445,8 @@ def dot_product_attention(q, weights = tf.nn.dropout(weights, 1.0 - dropout_rate) if (not tf.get_variable_scope().reuse and # Summaries don't work well within tf.while_loop() - "/while/" not in tf.contrib.framework.get_name_scope()): + "/while/" not in tf.contrib.framework.get_name_scope() and + make_image_summary): attention_image_summary(weights, image_shapes) return tf.matmul(weights, v) @@ -616,11 +619,9 @@ def pad_l_and_r(x, pad_length): v_new = tf.gather(v_t, gather_indices) v_new = tf.transpose(v_new, [2, 3, 0, 1, 4]) - logits = tf.matmul(q, k_new, transpose_b=True) - - attention = tf.nn.softmax(logits + attention_bias) - output = tf.matmul(attention, v_new) - + output = dot_product_attention( + q, k_new, v_new, attention_bias, dropout_rate=0., name="local_1d", + make_image_summary=False) output = tf.reshape(output, [batch_size, num_heads, -1, depth_v]) # Remove the padding if introduced output = tf.slice(output, [0, 0, 0, 0], [-1, -1, original_length, -1]) @@ -677,10 +678,9 @@ def local_attention_2d(q, attention_bias = tf.expand_dims( tf.to_float(embedding_to_padding(k_new)) * -1e9, axis=-2) - logits = tf.matmul(q_new, k_new, transpose_b=True) - - attention = tf.nn.softmax(logits + attention_bias) - output = tf.matmul(attention, v_new) + output = dot_product_attention(q_new, k_new, v_new, attention_bias, + dropout_rate=0., name="local_2d", + make_image_summary=False) # putting the representations back in the right place output = scatter_blocks_2d(output, q_indices, padded_q_shape) # Remove the padding if introduced @@ -756,6 +756,42 @@ def gather_indices_2d(x, block_shape, block_stride): return tf.cast(indices, tf.int32) +def make_2d_block_raster_mask(query_shape, memory_flange): + """creates a mask for 2d block raster scany. + + The query mask can look to the left, top left, top, and top right, but + not to the right. Inside the query, we have the standard raster scan + masking. + Args: + query_shape: A tuple of ints (query_height, query_width) + memory_flange: A tuple of ints + (memory_flange_height, memory_flange_width) + + Returns: + A tensor of shape query_size, memory_size + """ + # mask inside the query block + query_triangle = tf.matrix_band_part( + tf.ones([np.prod(query_shape), np.prod(query_shape)]), -1, 0) + split_query_masks = tf.split(query_triangle, query_shape[0], axis=1) + # adding mask for left and right + mask_pieces = [ + tf.concat( + [tf.ones([np.prod(query_shape), memory_flange[1]]), + split_query_masks[i], + tf.zeros([np.prod(query_shape), memory_flange[1]]) + ], axis=1) for i in range(query_shape[0])] + # adding mask for top + final_mask = tf.concat( + [tf.ones( + [np.prod(query_shape), + (query_shape[1]+2*memory_flange[1])*memory_flange[0]]), + tf.concat(mask_pieces, axis=1) + ], axis=1) + # 0. is visible location, 1.0 is masked. + return 1. - final_mask + + def masked_local_attention_2d(q, k, v, @@ -782,39 +818,7 @@ def masked_local_attention_2d(q, name, default_name="local_masked_self_attention_2d", values=[q, k, v]): q_shape = q.get_shape().as_list() v_shape = tf.shape(v) - def make_mask(query_shape, memory_flange): - """creates a mask. - - The query mask can look to the left, top left, top, and top right, but - not the right. Inside the query, we have the standard raster scan - masking. - Args: - query_shape: A tuple of ints (query_height, query_width) - memory_flange: A tuple of ints - (memory_flange_height, memory_flange_width) - - Returns: - A tensor of shape query_size, memory_size - """ - - query_triangle = tf.matrix_band_part( - tf.ones([np.prod(query_shape), np.prod(query_shape)]), -1, 0) - split_query_masks = tf.split(query_triangle, query_shape[0], axis=1) - mask_pieces = [ - tf.concat( - [tf.ones([np.prod(query_shape), memory_flange[1]]), - split_query_masks[i], - tf.zeros([np.prod(query_shape), memory_flange[1]]) - ], axis=1) for i in range(query_shape[0])] - - final_mask = tf.concat( - [tf.ones( - [np.prod(query_shape), - (query_shape[1]+2*memory_flange[1])*memory_flange[0]]), - tf.concat(mask_pieces, axis=1) - ], axis=1) - # 0. is visible location, 1.0 is masked. - return 1. - final_mask + q = pad_to_multiple_2d(q, query_shape) padded_q_shape = tf.shape(q) k = pad_to_multiple_2d(k, query_shape) @@ -833,20 +837,21 @@ def make_mask(query_shape, memory_flange): k_and_v_indices = gather_indices_2d(k, memory_shape, query_shape) k_new = gather_blocks_2d(k, k_and_v_indices) v_new = gather_blocks_2d(v, k_and_v_indices) - logits = tf.matmul(q_new, k_new, transpose_b=True) # Combining the mask for padding and visible region attention_mask_shape = [np.prod(query_shape), (query_shape[0]+memory_flange[0])* (query_shape[1]+2*memory_flange[1])] - attention_mask = tf.cast(make_mask(query_shape, memory_flange), tf.bool) + attention_mask = tf.cast( + make_2d_block_raster_mask(query_shape, memory_flange), tf.bool) # reshaping attention mask to have same dims as logits attention_mask = tf.reshape(attention_mask, [1, 1, 1]+attention_mask_shape) padding_mask = tf.expand_dims( tf.cast(embedding_to_padding(k_new), tf.bool), axis=-2) attention_bias = ( tf.to_float(tf.logical_or(attention_mask, padding_mask)) *-1e9) - attention = tf.nn.softmax(logits + attention_bias) - output = tf.matmul(attention, v_new) + output = dot_product_attention(q_new, k_new, v_new, attention_bias, + dropout_rate=0., name="masked_local_2d", + make_image_summary=False) # putting the representations back in the right place output = scatter_blocks_2d(output, q_indices, padded_q_shape) # Remove the padding if introduced diff --git a/tensor2tensor/layers/common_attention_test.py b/tensor2tensor/layers/common_attention_test.py index 644b27a98..7823936fa 100644 --- a/tensor2tensor/layers/common_attention_test.py +++ b/tensor2tensor/layers/common_attention_test.py @@ -162,7 +162,7 @@ def testMultiheadSelfAttentionMemoryEfficient(self): self.assertAllClose(dnorm_bias, dnorm_bias_f) self.assertAllClose(dx, dx_f) - def test2dGatherAndScatter(self): + def test2dGatherAndScatterInvertibility(self): """2d gather and scatter invertibility test.""" batch_size = 2 num_heads = 2 @@ -182,5 +182,68 @@ def test2dGatherAndScatter(self): res = session.run(scattered_x) self.assertAllClose(x, res) + def test2dBlockRasterScanMask(self): + """Testing the 2d block raster scan mask.""" + query_shape = (2, 3) + memory_flange = (2, 1) + with self.test_session() as session: + mask = common_attention.make_2d_block_raster_mask( + query_shape, memory_flange) + res = session.run(mask) + correct_mask = np.array( + [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, + 1.0, 0.0, 1.0, 1.0, 1.0, 1.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, + 1.0, 0.0, 1.0, 1.0, 1.0, 1.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 1.0, 0.0, 1.0, 1.0, 1.0, 1.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 1.0, 0.0, 0.0, 1.0, 1.0, 1.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 1.0, 0.0, 0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 1.0, 0.0, 0.0, 0.0, 0.0, 1.0]]) + self.assertAllClose(correct_mask, res) + + def test2dGather(self): + """Testing 2d index gather and block gather functions.""" + batch_size = 2 + num_heads = 2 + height = 4 + width = 6 + depth = 8 + query_shape = (2, 3) + x = np.random.rand(batch_size, num_heads, height, width, depth) + y = np.reshape(x, (batch_size, num_heads, -1, depth)) + correct_indices = [[0, 1, 2, 6, 7, 8], + [3, 4, 5, 9, 10, 11], + [12, 13, 14, 18, 19, 20], + [15, 16, 17, 21, 22, 23]] + correct_gathered_x = [[[y[0, 0, correct_indices[0]], + y[0, 0, correct_indices[1]], + y[0, 0, correct_indices[2]], + y[0, 0, correct_indices[3]]], + [y[0, 1, correct_indices[0]], + y[0, 1, correct_indices[1]], + y[0, 1, correct_indices[2]], + y[0, 1, correct_indices[3]]]], + [[y[1, 0, correct_indices[0]], + y[1, 0, correct_indices[1]], + y[1, 0, correct_indices[2]], + y[1, 0, correct_indices[3]]], + [y[1, 1, correct_indices[0]], + y[1, 1, correct_indices[1]], + y[1, 1, correct_indices[2]], + y[1, 1, correct_indices[3]]]]] + + with self.test_session() as session: + x_indices = common_attention.gather_indices_2d( + x, query_shape, query_shape) + gathered_x = common_attention.gather_blocks_2d(x, x_indices) + x_indices, gathered_x = session.run([x_indices, gathered_x]) + self.assertAllEqual(correct_indices, x_indices) + self.assertAllClose(correct_gathered_x, gathered_x) + + if __name__ == "__main__": tf.test.main() From 4f0737502cbb3d0ce8bdf311f088f539c7101a59 Mon Sep 17 00:00:00 2001 From: Niki Parmar Date: Wed, 13 Sep 2017 11:36:40 -0700 Subject: [PATCH 06/39] use the right value for shape PiperOrigin-RevId: 168570496 --- tensor2tensor/layers/common_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensor2tensor/layers/common_attention.py b/tensor2tensor/layers/common_attention.py index fdba48b01..1da33479b 100644 --- a/tensor2tensor/layers/common_attention.py +++ b/tensor2tensor/layers/common_attention.py @@ -482,7 +482,7 @@ def masked_local_attention_1d( # If (length < 2 * block_length), then we use only one block. block_length = tf.where(tf.less(length, block_length * 2), length, block_length) - depth_k = tf.shape(q)[3] + depth_k = tf.shape(k)[3] depth_v = tf.shape(v)[3] original_length = length padding_size = tf.mod(-length, block_length) From 802b95fcb5e23562d46e2efec6e1b7769c9d674e Mon Sep 17 00:00:00 2001 From: T2T Team Date: Wed, 13 Sep 2017 12:33:42 -0700 Subject: [PATCH 07/39] Separate out encoding a decoding steps. PiperOrigin-RevId: 168579149 --- tensor2tensor/models/transformer.py | 109 +++++++++++++++++++++++----- 1 file changed, 90 insertions(+), 19 deletions(-) diff --git a/tensor2tensor/models/transformer.py b/tensor2tensor/models/transformer.py index a2e76dd13..4ee6746a1 100644 --- a/tensor2tensor/models/transformer.py +++ b/tensor2tensor/models/transformer.py @@ -41,34 +41,105 @@ class Transformer(t2t_model.T2TModel): """Attention net. See file docstring.""" + def encode(self, inputs, target_space, hparams): + """Encode transformer inputs. + + Args: + inputs: Transformer inputs [batch_size, input_length, hidden_dim] + target_space: scalar, target space ID. + hparams: hyperparmeters for model. + + Returns: + Tuple of: + encoder_output: Encoder representation. + [batch_size, input_length, hidden_dim] + encoder_decoder_attention_bias: Bias and mask weights for + encodre-decoder attention. [batch_size, input_length] + """ + inputs = common_layers.flatten4d3d(inputs) + + encoder_input, self_attention_bias, encoder_decoder_attention_bias = ( + transformer_prepare_encoder(inputs, target_space, hparams)) + + encoder_input = tf.nn.dropout( + encoder_input, 1.0 - hparams.layer_prepostprocess_dropout) + + encoder_output = transformer_encoder( + encoder_input, + self_attention_bias, + hparams) + + return encoder_output, encoder_decoder_attention_bias + + def decode( + self, + decoder_input, + encoder_output, + encoder_decoder_attention_bias, + decoder_self_attention_bias, + hparams): + """Decode Transformer outputs from encoder representation. + + Args: + decoder_input: inputs to bottom of the model. + [batch_size, decoder_length, hidden_dim] + encoder_output: Encoder representation. + [batch_size, input_length, hidden_dim] + encoder_decoder_attention_bias: Bias and mask weights for + encoder-decoder attention. [batch_size, input_length] + decoder_self_attention_bias: Bias and mask weights for decoder + self-attention. [batch_size, decoder_length] + hparams: hyperparmeters for model. + + Returns: + Final decoder representaiton. [batch_size, decoder_length, hidden_dim] + """ + decoder_input = tf.nn.dropout(decoder_input, + 1.0 - hparams.layer_prepostprocess_dropout) + + decoder_output = transformer_decoder( + decoder_input, + encoder_output, + decoder_self_attention_bias, + encoder_decoder_attention_bias, + hparams) + + # Expand since t2t expects 4d tensors. + return tf.expand_dims(decoder_output, axis=2) + def model_fn_body(self, features): + """Transformet main model_fn. + + Args: + features: Map of features to the model. Should contain the following: + "inputs": Transformer inputs [batch_size, input_length, hidden_dim] + "tragets": Target decoder outputs. + [batch_size, decoder_length, hidden_dim] + "target_space_id" + + Returns: + Final decoder representaiton. [batch_size, decoder_length, hidden_dim] + """ hparams = self._hparams - targets = features["targets"] + inputs = features["inputs"] + target_space = features["target_space_id"] + encoder_output, encoder_decoder_attention_bias = self.encode( + inputs, target_space, hparams) - inputs = common_layers.flatten4d3d(inputs) + targets = features["targets"] targets = common_layers.flatten4d3d(targets) - (encoder_input, encoder_self_attention_bias, - encoder_decoder_attention_bias) = transformer_prepare_encoder( - inputs, target_space, hparams) - (decoder_input, decoder_self_attention_bias) = transformer_prepare_decoder( + decoder_input, decoder_self_attention_bias = transformer_prepare_decoder( targets, hparams) - encoder_input = tf.nn.dropout(encoder_input, - 1.0 - hparams.layer_prepostprocess_dropout) - decoder_input = tf.nn.dropout(decoder_input, - 1.0 - hparams.layer_prepostprocess_dropout) - encoder_output = transformer_encoder(encoder_input, - encoder_self_attention_bias, hparams) - - decoder_output = transformer_decoder( - decoder_input, encoder_output, decoder_self_attention_bias, - encoder_decoder_attention_bias, hparams) - decoder_output = tf.expand_dims(decoder_output, 2) - - return decoder_output + return self.decode( + decoder_input, + encoder_output, + encoder_decoder_attention_bias, + decoder_self_attention_bias, + hparams) @registry.register_model From 466ce80f09cbdec86fca93e74a1dd5c286713f06 Mon Sep 17 00:00:00 2001 From: T2T Team Date: Wed, 13 Sep 2017 12:34:40 -0700 Subject: [PATCH 08/39] Split out timing signal function. PiperOrigin-RevId: 168579271 --- tensor2tensor/layers/common_attention.py | 47 +++++++++++++++++++++--- 1 file changed, 41 insertions(+), 6 deletions(-) diff --git a/tensor2tensor/layers/common_attention.py b/tensor2tensor/layers/common_attention.py index 1da33479b..840131c6a 100644 --- a/tensor2tensor/layers/common_attention.py +++ b/tensor2tensor/layers/common_attention.py @@ -37,8 +37,9 @@ _expert_count = 0 -def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): - """Adds a bunch of sinusoids of different frequencies to a Tensor. +def get_timing_signal_1d( + length, channels, min_timescale=1.0, max_timescale=1.0e4): + """Gets a bunch of sinusoids of different frequencies. Each channel of the input Tensor is incremented by a sinusoid of a different frequency and phase. @@ -58,15 +59,15 @@ def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): the channels dimension. Args: - x: a Tensor with shape [batch, length, channels] + length: scalar, length of timing signal sequence. + channels: scalar, size of timing embeddings to create. The number of + different timescales is equal to channels / 2. min_timescale: a float max_timescale: a float Returns: - a Tensor the same shape as x. + a Tensor of timing signals [1, length, channels] """ - length = tf.shape(x)[1] - channels = tf.shape(x)[2] position = tf.to_float(tf.range(length)) num_timescales = channels // 2 log_timescale_increment = ( @@ -78,6 +79,40 @@ def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): signal = tf.concat([tf.sin(scaled_time), tf.cos(scaled_time)], axis=1) signal = tf.pad(signal, [[0, 0], [0, tf.mod(channels, 2)]]) signal = tf.reshape(signal, [1, length, channels]) + return signal + + +def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): + """Adds a bunch of sinusoids of different frequencies to a Tensor. + + Each channel of the input Tensor is incremented by a sinusoid of a different + frequency and phase. + + This allows attention to learn to use absolute and relative positions. + Timing signals should be added to some precursors of both the query and the + memory inputs to attention. + + The use of relative position is possible because sin(x+y) and cos(x+y) can be + experessed in terms of y, sin(x) and cos(x). + + In particular, we use a geometric sequence of timescales starting with + min_timescale and ending with max_timescale. The number of different + timescales is equal to channels / 2. For each timescale, we + generate the two sinusoidal signals sin(timestep/timescale) and + cos(timestep/timescale). All of these sinusoids are concatenated in + the channels dimension. + + Args: + x: a Tensor with shape [batch, length, channels] + min_timescale: a float + max_timescale: a float + + Returns: + a Tensor the same shape as x. + """ + length = tf.shape(x)[1] + channels = tf.shape(x)[2] + signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) return x + signal From 79ba4a8b98752cf0d5cbed6718f2ef6cdcfd1374 Mon Sep 17 00:00:00 2001 From: T2T Team Date: Wed, 13 Sep 2017 14:53:40 -0700 Subject: [PATCH 09/39] Adding has_inputs property to Problem. PiperOrigin-RevId: 168599850 --- tensor2tensor/data_generators/problem.py | 15 ++++++++------- tensor2tensor/utils/input_fn_builder.py | 16 ++++++++++------ 2 files changed, 18 insertions(+), 13 deletions(-) diff --git a/tensor2tensor/data_generators/problem.py b/tensor2tensor/data_generators/problem.py index 4aa4862ef..cb8b47aee 100644 --- a/tensor2tensor/data_generators/problem.py +++ b/tensor2tensor/data_generators/problem.py @@ -17,20 +17,15 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function - import collections import os import random - # Dependency imports - import six - from tensor2tensor.data_generators import generator_utils from tensor2tensor.data_generators import text_encoder from tensor2tensor.utils import metrics from tensor2tensor.utils import registry - import tensorflow as tf @@ -385,6 +380,10 @@ def _preprocess(example): return dataset + @property + def has_inputs(self): + return "inputs" in self.get_feature_encoders() + @property def feature_info(self): """Retrieve dict. @@ -404,7 +403,8 @@ def feature_info(self): input_mods = hp.input_modality target_mod = hp.target_modality vocabs = hp.vocabulary - in_id = hp.input_space_id + if self.has_inputs: + in_id = hp.input_space_id out_id = hp.target_space_id features = collections.defaultdict(FeatureInfo) @@ -422,7 +422,8 @@ def feature_info(self): for name, encoder in six.iteritems(vocabs): features[name].encoder = encoder - features["inputs"].space_id = in_id + if self.has_inputs: + features["inputs"].space_id = in_id features["targets"].space_id = out_id self._feature_info = features diff --git a/tensor2tensor/utils/input_fn_builder.py b/tensor2tensor/utils/input_fn_builder.py index cfa782e8d..5a63a8bd1 100644 --- a/tensor2tensor/utils/input_fn_builder.py +++ b/tensor2tensor/utils/input_fn_builder.py @@ -127,16 +127,18 @@ def input_fn(): feature_map["problem_choice"] = problem_choice # Set shapes so the ranks are clear. - feature_map["inputs"].set_shape([None, None, None, None]) + if problem_instance.has_inputs: + feature_map["inputs"].set_shape([None, None, None, None]) + feature_map["input_space_id"].set_shape([]) feature_map["targets"].set_shape([None, None, None, None]) feature_map["problem_choice"].set_shape([]) - feature_map["input_space_id"].set_shape([]) feature_map["target_space_id"].set_shape([]) if mode == tf.estimator.ModeKeys.PREDICT: feature_map["infer_targets"] = feature_map["targets"] # Forced shape obfuscation is necessary for inference. - feature_map["inputs"]._shape = tf.TensorShape([None, None, None, None]) # pylint: disable=protected-access + if problem_instance.has_inputs: + feature_map["inputs"]._shape = tf.TensorShape([None, None, None, None]) # pylint: disable=protected-access feature_map["targets"]._shape = tf.TensorShape([None, None, None, None]) # pylint: disable=protected-access # This is because of a bug in the Estimator that short-circuits prediction @@ -238,11 +240,13 @@ def features_for_problem(problem_instance, feature_map["targets"] = feature_map["inputs"] # Ensure inputs and targets are proper rank. - while len(feature_map["inputs"].get_shape()) != 4: - feature_map["inputs"] = tf.expand_dims(feature_map["inputs"], axis=-1) + if problem_instance.has_inputs: + while len(feature_map["inputs"].get_shape()) != 4: + feature_map["inputs"] = tf.expand_dims(feature_map["inputs"], axis=-1) while len(feature_map["targets"].get_shape()) != 4: feature_map["targets"] = tf.expand_dims(feature_map["targets"], axis=-1) - feature_map["input_space_id"] = tf.constant(p_hparams.input_space_id) + if problem_instance.has_inputs: + feature_map["input_space_id"] = tf.constant(p_hparams.input_space_id) feature_map["target_space_id"] = tf.constant(p_hparams.target_space_id) return feature_map From 7035ffe8a1711476d137a5c6e0af85c70c718df7 Mon Sep 17 00:00:00 2001 From: T2T Team Date: Wed, 13 Sep 2017 15:04:34 -0700 Subject: [PATCH 10/39] Allowing explicit timing positions to be used, by adding function add_timing_signal_1d_given_position in common_attention.py that takes timing positions (as a tensor of shape [batch, length]). PiperOrigin-RevId: 168601518 --- tensor2tensor/layers/common_attention.py | 27 ++++++++++++++++++++++++ tensor2tensor/utils/t2t_model.py | 3 +++ 2 files changed, 30 insertions(+) diff --git a/tensor2tensor/layers/common_attention.py b/tensor2tensor/layers/common_attention.py index 840131c6a..daefb56c5 100644 --- a/tensor2tensor/layers/common_attention.py +++ b/tensor2tensor/layers/common_attention.py @@ -116,6 +116,33 @@ def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): return x + signal +def add_timing_signal_1d_given_position(x, position, min_timescale=1.0, + max_timescale=1.0e4): + """Adds sinusoids of diff frequencies to a Tensor, with timing position given. + + Args: + x: a Tensor with shape [batch, length, channels] + position: a Tensor with shape [batch, length] + min_timescale: a float + max_timescale: a float + + Returns: + a Tensor the same shape as x. + """ + channels = tf.shape(x)[2] + num_timescales = channels // 2 + log_timescale_increment = ( + math.log(float(max_timescale) / float(min_timescale)) / + (tf.to_float(num_timescales) - 1)) + inv_timescales = min_timescale * tf.exp( + tf.to_float(tf.range(num_timescales)) * -log_timescale_increment) + scaled_time = (tf.expand_dims(tf.to_float(position), 2) * + tf.expand_dims(tf.expand_dims(inv_timescales, 0), 0)) + signal = tf.concat([tf.sin(scaled_time), tf.cos(scaled_time)], axis=2) + signal = tf.pad(signal, [[0, 0], [0, 0], [0, tf.mod(channels, 2)]]) + return x + signal + + def add_timing_signal_nd(x, min_timescale=1.0, max_timescale=1.0e4): """Adds a bunch of sinusoids of different frequencies to a Tensor. diff --git a/tensor2tensor/utils/t2t_model.py b/tensor2tensor/utils/t2t_model.py index 32627f7e3..916de50b7 100644 --- a/tensor2tensor/utils/t2t_model.py +++ b/tensor2tensor/utils/t2t_model.py @@ -464,6 +464,9 @@ def model_fn(self, features, skip=False, last_position_only=False): transformed_features["targets"] = target_modality.targets_bottom_sharded( sharded_features["targets"], dp) + # Allows later access to pre-embedding raw targets. + transformed_features["raw_targets"] = sharded_features["targets"] + # Construct the model body. with tf.variable_scope("body", reuse=self._problem_idx > 0): if skip: From 213859956b00db43925b54a7ff938fa034885959 Mon Sep 17 00:00:00 2001 From: Etienne Pot Date: Thu, 14 Sep 2017 15:40:13 -0700 Subject: [PATCH 11/39] fix off-by-one num_samples bug in decode_from_dataset PiperOrigin-RevId: 168757167 --- tensor2tensor/utils/decoding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensor2tensor/utils/decoding.py b/tensor2tensor/utils/decoding.py index d84fd740b..a27ff72df 100644 --- a/tensor2tensor/utils/decoding.py +++ b/tensor2tensor/utils/decoding.py @@ -176,7 +176,7 @@ def decode_from_dataset(estimator, target_file.write(str(decoded_target) + "\n") if (decode_hp.num_samples >= 0 and - num_predictions >= decode_hp.num_samples): + (num_predictions + 1) >= decode_hp.num_samples): break if decode_to_file: From 32375386d8b2abd8a8d619f482e92454e2afdca8 Mon Sep 17 00:00:00 2001 From: Etienne Pot Date: Thu, 14 Sep 2017 16:36:54 -0700 Subject: [PATCH 12/39] Use decode_hparams.batch_size when decoding from dataset PiperOrigin-RevId: 168765336 --- tensor2tensor/utils/decoding.py | 10 ++++++++-- tensor2tensor/utils/input_fn_builder.py | 23 ++++++++++++++++------- 2 files changed, 24 insertions(+), 9 deletions(-) diff --git a/tensor2tensor/utils/decoding.py b/tensor2tensor/utils/decoding.py index a27ff72df..fc5f22c1a 100644 --- a/tensor2tensor/utils/decoding.py +++ b/tensor2tensor/utils/decoding.py @@ -47,7 +47,7 @@ def decode_hparams(overrides=""): save_images=False, problem_idx=0, extra_length=50, - batch_size=32, + batch_size=0, beam_size=4, alpha=0.6, return_beams=False, @@ -113,7 +113,8 @@ def decode_from_dataset(estimator, hparams=hparams, data_file_patterns=infer_problems_data, num_datashards=devices.data_parallelism().n, - fixed_problem=problem_idx) + fixed_problem=problem_idx, + batch_size=decode_hp.batch_size) # Get the predictions as an iterable predictions = estimator.predict(infer_input_fn) @@ -188,6 +189,11 @@ def decode_from_dataset(estimator, def decode_from_file(estimator, filename, decode_hp, decode_to_file=None): """Compute predictions on entries in filename and write them out.""" + if not decode_hp.batch_size: + decode_hp.batch_size = 32 + tf.logging.info( + "decode_hp.batch_size not specified; default=%d" % decode_hp.batch_size) + hparams = estimator.params problem_id = decode_hp.problem_idx inputs_vocab = hparams.problems[problem_id].vocabulary["inputs"] diff --git a/tensor2tensor/utils/input_fn_builder.py b/tensor2tensor/utils/input_fn_builder.py index 5a63a8bd1..c9dde1a14 100644 --- a/tensor2tensor/utils/input_fn_builder.py +++ b/tensor2tensor/utils/input_fn_builder.py @@ -34,7 +34,8 @@ def build_input_fn(mode, num_datashards=None, fixed_problem=None, worker_replicas=None, - worker_id=None): + worker_id=None, + batch_size=None): """Provides input to the graph, either from disk or via a placeholder. This function produces an input function that will feed data into @@ -61,6 +62,7 @@ def build_input_fn(mode, setting with hparams.problem_choice == distributed. worker_id: int, id of this worker replica. Used in multiproblem setting with hparams.problem_choice == distributed. + batch_size: int, if provided, will use a fixed batch size. Returns: A function that returns a dictionary of features and the target labels. @@ -98,6 +100,7 @@ def input_fn(): problem_filepatterns, num_datashards, mode, + batch_size=batch_size, name="problem_%d" % problem_idx) problem_batches.append(feature_map) @@ -211,19 +214,25 @@ def features_for_problem(problem_instance, data_filepatterns, num_datashards, mode, + batch_size=None, name="problem_inputs"): """Feature map for Problem.""" with tf.name_scope(name): with tf.device("/cpu:0"): # Input reading on CPU capacity = (p_hparams.max_expected_batch_size_per_shard * num_datashards) + batching_scheme = data_reader.hparams_to_batching_scheme( + hparams, + shard_multiplier=num_datashards, + drop_long_sequences=(mode == tf.estimator.ModeKeys.TRAIN or + hparams.eval_drop_long_sequences), + length_multiplier=(p_hparams.batch_size_multiplier)) + if batch_size: + # If batch_size is fixed, use a single input bucket + batching_scheme["batch_sizes"] = [batch_size] + batching_scheme["boundaries"] = [] feature_map = data_reader.input_pipeline( problem_instance, data_filepatterns, capacity, mode, hparams, - data_reader.hparams_to_batching_scheme( - hparams, - shard_multiplier=num_datashards, - drop_long_sequences=(mode == tf.estimator.ModeKeys.TRAIN or - hparams.eval_drop_long_sequences), - length_multiplier=(p_hparams.batch_size_multiplier))) + batching_scheme) # Reverse inputs and targets features if the problem was reversed. if problem_instance is not None: From e6e4263680be61d7d495c9e27e1617f453a683d7 Mon Sep 17 00:00:00 2001 From: Noam Shazeer Date: Thu, 14 Sep 2017 19:16:27 -0700 Subject: [PATCH 13/39] Add wiki_scramble_128 dataset. PiperOrigin-RevId: 168782469 --- tensor2tensor/data_generators/wiki.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tensor2tensor/data_generators/wiki.py b/tensor2tensor/data_generators/wiki.py index 6f6c97686..396d120c7 100644 --- a/tensor2tensor/data_generators/wiki.py +++ b/tensor2tensor/data_generators/wiki.py @@ -223,6 +223,19 @@ def generator(self, data_dir, tmp_dir, _): yield {"inputs": inputs, "targets": targets} +@registry.register_problem +class LanguagemodelWikiScramble128(LanguagemodelWikiScramble): + """Sequence length 128, 50% scrambed.""" + + @property + def sequence_length(self): + return 128 + + @property + def scramble_fraction(self): + return 0.5 + + @registry.register_problem class LanguagemodelWikiScramble1k50(LanguagemodelWikiScramble): """Sequence length 1024, 50% scrambed.""" From 6cb0bc8bea37cdbc31cc6260d9727c577bd3a278 Mon Sep 17 00:00:00 2001 From: T2T Team Date: Fri, 15 Sep 2017 12:00:26 -0700 Subject: [PATCH 14/39] Add ability to average the last N checkpoints, without needing to specify individual checkpoints. PiperOrigin-RevId: 168867603 --- tensor2tensor/utils/avg_checkpoints.py | 38 ++++++++++++++++++++------ 1 file changed, 29 insertions(+), 9 deletions(-) diff --git a/tensor2tensor/utils/avg_checkpoints.py b/tensor2tensor/utils/avg_checkpoints.py index 77acd4353..4d1c56eda 100644 --- a/tensor2tensor/utils/avg_checkpoints.py +++ b/tensor2tensor/utils/avg_checkpoints.py @@ -18,6 +18,8 @@ from __future__ import division from __future__ import print_function +import os + # Dependency imports import numpy as np @@ -30,6 +32,9 @@ flags.DEFINE_string("checkpoints", "", "Comma-separated list of checkpoints to average.") +flags.DEFINE_integer("num_last_checkpoints", 0, + "Averages the last N saved checkpoints." + " If the checkpoints flag is set, this is ignored.") flags.DEFINE_string("prefix", "", "Prefix (e.g., directory) to append to each checkpoint.") flags.DEFINE_string("output_path", "/tmp/averaged.ckpt", @@ -42,17 +47,32 @@ def checkpoint_exists(path): def main(_): - # Get the checkpoints list from flags and run some basic checks. - checkpoints = [c.strip() for c in FLAGS.checkpoints.split(",")] - checkpoints = [c for c in checkpoints if c] - if not checkpoints: - raise ValueError("No checkpoints provided for averaging.") - if FLAGS.prefix: - checkpoints = [FLAGS.prefix + c for c in checkpoints] + if FLAGS.checkpoints: + # Get the checkpoints list from flags and run some basic checks. + checkpoints = [c.strip() for c in FLAGS.checkpoints.split(",")] + checkpoints = [c for c in checkpoints if c] + if not checkpoints: + raise ValueError("No checkpoints provided for averaging.") + if FLAGS.prefix: + checkpoints = [FLAGS.prefix + c for c in checkpoints] + else: + assert FLAGS.num_last_checkpoints >= 1, "Must average at least one model" + assert FLAGS.prefix, ("Prefix must be provided when averaging last" + " N checkpoints") + checkpoint_state = tf.train.get_checkpoint_state( + os.path.dirname(FLAGS.prefix)) + # Checkpoints are ordered from oldest to newest. + checkpoints = checkpoint_state.all_model_checkpoint_paths[ + -FLAGS.num_last_checkpoints:] + checkpoints = [c for c in checkpoints if checkpoint_exists(c)] if not checkpoints: - raise ValueError( - "None of the provided checkpoints exist. %s" % FLAGS.checkpoints) + if FLAGS.checkpoints: + raise ValueError( + "None of the provided checkpoints exist. %s" % FLAGS.checkpoints) + else: + raise ValueError("Could not find checkpoints at %s" % + os.path.dirname(FLAGS.prefix)) # Read variables from all checkpoints and average them. tf.logging.info("Reading variables and averaging checkpoints:") From be19196ded9f907098bf4747a138d632e0a9736b Mon Sep 17 00:00:00 2001 From: Noam Shazeer Date: Fri, 15 Sep 2017 14:30:54 -0700 Subject: [PATCH 15/39] Working on a model for cnn_dailymail summarization task. Make greedy inference and beam search work in prepend mode. After this change, inference in prepend mode requires batch size 1, since padding is not properly ignored. PiperOrigin-RevId: 168889211 --- tensor2tensor/data_generators/inspect.py | 4 ++-- tensor2tensor/data_generators/problem.py | 12 +++++++----- tensor2tensor/models/transformer.py | 8 ++++++++ tensor2tensor/utils/data_reader.py | 3 +-- tensor2tensor/utils/t2t_model.py | 22 +++++++++++++++++++--- 5 files changed, 37 insertions(+), 12 deletions(-) diff --git a/tensor2tensor/data_generators/inspect.py b/tensor2tensor/data_generators/inspect.py index 848b74a2d..c84f00606 100644 --- a/tensor2tensor/data_generators/inspect.py +++ b/tensor2tensor/data_generators/inspect.py @@ -67,9 +67,9 @@ def main(_): inputs = [int(i) for i in x.features.feature["inputs"].int64_list.value] targets = [int(i) for i in x.features.feature["targets"].int64_list.value] if FLAGS.print_inputs: - print(encoder.decode(inputs) if encoder else inputs) + print("INPUTS:\n" + encoder.decode(inputs) if encoder else inputs) if FLAGS.print_targets: - print(encoder.decode(targets) if encoder else targets) + print("TARGETS:\n" + encoder.decode(targets) if encoder else targets) total_input_tokens += len(inputs) total_target_tokens += len(targets) total_sequences += 1 diff --git a/tensor2tensor/data_generators/problem.py b/tensor2tensor/data_generators/problem.py index cb8b47aee..a006d5627 100644 --- a/tensor2tensor/data_generators/problem.py +++ b/tensor2tensor/data_generators/problem.py @@ -102,15 +102,18 @@ def default_model_hparams(): data_dir=None) -def preprocess_examples_common(examples, hparams): +def preprocess_examples_common(examples, hparams, mode): """Preprocessing steps common to all models.""" if hparams.max_input_seq_length > 0: examples["inputs"] = examples["inputs"][:hparams.max_input_seq_length] if hparams.max_target_seq_length > 0: examples["targets"] = examples["targets"][:hparams.max_target_seq_length] if hparams.prepend_mode != "none": - examples["targets"] = tf.concat( - [examples["inputs"], [0], examples["targets"]], 0) + if mode == tf.estimator.ModeKeys.PREDICT: + examples["partial_targets"] = tf.concat([examples["inputs"], [0]], 0) + else: + examples["targets"] = tf.concat( + [examples["inputs"], [0], examples["targets"]], 0) return examples @@ -196,8 +199,7 @@ def example_reading_spec(self): return (data_fields, data_items_to_decoders) def preprocess_examples(self, examples, mode, hparams): - del mode - return preprocess_examples_common(examples, hparams) + return preprocess_examples_common(examples, hparams, mode) def eval_metrics(self): return [ diff --git a/tensor2tensor/models/transformer.py b/tensor2tensor/models/transformer.py index 4ee6746a1..7d52824fa 100644 --- a/tensor2tensor/models/transformer.py +++ b/tensor2tensor/models/transformer.py @@ -725,6 +725,14 @@ def transformer_parameter_attention_b(): return hparams +@registry.register_hparams +def transformer_prepend(): + hparams = transformer_base() + hparams.prepend_mode = "prepend_inputs_masked_attention" + hparams.max_length = 0 + return hparams + + @registry.register_ranged_hparams("transformer_base") def transformer_base_range(rhp): """Small range of hyperparameters.""" diff --git a/tensor2tensor/utils/data_reader.py b/tensor2tensor/utils/data_reader.py index 834e631ac..d94e85e39 100644 --- a/tensor2tensor/utils/data_reader.py +++ b/tensor2tensor/utils/data_reader.py @@ -288,7 +288,7 @@ def input_pipeline(problem, data_file_pattern, capacity, mode, hparams, def _preprocess(example, problem, data_file_pattern, hparams, mode): """Preprocessing for example.""" if problem is None: - example = preprocess_examples_common(example, hparams) + example = preprocess_examples_common(example, hparams, mode) example = preprocessing(example, data_file_pattern) else: example = problem.preprocess_examples(example, mode, hparams) @@ -384,7 +384,6 @@ def padded_batch(dataset, batch_size, padded_shapes=None): def _bucket_boundaries(max_length, min_length=8, length_bucket_step=1.1): """A default set of length-bucket boundaries.""" - assert min_length <= max_length assert length_bucket_step > 1.0 x = min_length boundaries = [] diff --git a/tensor2tensor/utils/t2t_model.py b/tensor2tensor/utils/t2t_model.py index 916de50b7..812e5aee3 100644 --- a/tensor2tensor/utils/t2t_model.py +++ b/tensor2tensor/utils/t2t_model.py @@ -228,10 +228,19 @@ def _beam_decode(self, features, decode_length, beam_size, top_beams, samples: an integer `Tensor`. Top samples from the beam search """ + batch_size = tf.shape(features["inputs"])[0] + batch_size = tf.Print(batch_size, [batch_size], "beam_decode batch_size=") + def symbols_to_logits_fn(ids): """Go from ids to logits.""" ids = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3) ids = tf.pad(ids[:, 1:], [[0, 0], [0, 1], [0, 0], [0, 0]]) + if "partial_targets" in features: + pt = features["partial_targets"] + pt_length = tf.shape(pt)[1] + pt = tf.tile(pt, [1, beam_size]) + pt = tf.reshape(pt, [batch_size * beam_size, pt_length, 1, 1]) + ids = tf.concat([pt, ids], axis=1) features["targets"] = ids self._coverage = None @@ -247,7 +256,6 @@ def symbols_to_logits_fn(ids): logits = logits[:, current_output_position, :, :] return tf.squeeze(logits, axis=[1, 2]) - batch_size = tf.shape(features["inputs"])[0] initial_ids = tf.zeros([batch_size], dtype=tf.int32) inputs_old = features["inputs"] @@ -263,7 +271,9 @@ def symbols_to_logits_fn(ids): target_modality = self._hparams.problems[self._problem_idx].target_modality vocab_size = target_modality.top_dimensionality # Setting decode length to input length + decode_length - decode_length = tf.shape(features["inputs"])[1] + tf.constant(decode_length) + decode_length = tf.constant(decode_length) + if "partial_targets" not in features: + decode_length += tf.shape(features["inputs"])[1] ids, scores = beam_search.beam_search(symbols_to_logits_fn, initial_ids, beam_size, decode_length, vocab_size, alpha) @@ -333,7 +343,9 @@ def infer_step(recent_output, recent_logits, unused_loss): # Create an initial output tensor. This will be passed # to the infer_step, which adds one timestep at every iteration. if "partial_targets" in features: - initial_output = tf.convert_to_tensor(features["partial_targets"]) + initial_output = tf.to_int64(tf.expand_dims( + tf.expand_dims(features["partial_targets"], 2), 3)) + batch_size = tf.shape(initial_output)[0] else: batch_size = tf.shape(features["inputs"])[0] initial_output = tf.zeros((batch_size, 0, 1, 1), dtype=tf.int64) @@ -366,6 +378,10 @@ def infer_step(recent_output, recent_logits, unused_loss): if inputs_old is not None: # Restore to not confuse Estimator. features["inputs"] = inputs_old losses = {"training": loss} + if "partial_targets" in features: + partial_target_length = tf.shape(features["partial_targets"])[1] + result = tf.slice( + result, [0, partial_target_length, 0, 0], [-1, -1, -1, -1]) return result, logits, losses def sample(self, features, last_position_only=False): From 6970dea82e2605e21054e3eccb4888a7bda9535e Mon Sep 17 00:00:00 2001 From: Manoj Kumar Date: Mon, 18 Sep 2017 11:14:48 -0700 Subject: [PATCH 16/39] Change ptb data generator to encode end of sentences with tags during PiperOrigin-RevId: 169116643 --- tensor2tensor/data_generators/ptb.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tensor2tensor/data_generators/ptb.py b/tensor2tensor/data_generators/ptb.py index 893c2b77c..31bc83c0a 100644 --- a/tensor2tensor/data_generators/ptb.py +++ b/tensor2tensor/data_generators/ptb.py @@ -42,9 +42,9 @@ def _read_words(filename): """Reads words from a file.""" with tf.gfile.GFile(filename, "r") as f: if sys.version_info[0] >= 3: - return f.read().replace("\n", " ").split() + return f.read().replace("\n", " %s " % EOS).split() else: - return f.read().decode("utf-8").replace("\n", " ").split() + return f.read().decode("utf-8").replace("\n", " %s " % EOS).split() def _build_vocab(filename, vocab_path, vocab_size): @@ -151,7 +151,7 @@ def generator(self, data_dir, tmp_dir, train): def _generator(self, filename, encoder): with tf.gfile.GFile(filename, "r") as f: for line in f: - line = " ".join(line.replace("\n", EOS).split()) + line = " ".join(line.replace("\n", " %s " % EOS).split()) tok = encoder.encode(line) if tok: yield {"inputs": [0], "targets": tok} From 3aa13683e30b15633b0c358f9d888f702cff0c3f Mon Sep 17 00:00:00 2001 From: T2T Team Date: Mon, 18 Sep 2017 13:19:02 -0700 Subject: [PATCH 17/39] Rename ambiguous function names. PiperOrigin-RevId: 169135518 --- tensor2tensor/layers/common_layers.py | 6 +++--- tensor2tensor/layers/common_layers_test.py | 2 +- tensor2tensor/models/attention_lm.py | 2 +- tensor2tensor/models/attention_lm_moe.py | 2 +- tensor2tensor/models/bytenet.py | 2 +- tensor2tensor/models/lstm.py | 4 ++-- tensor2tensor/models/multimodel.py | 2 +- tensor2tensor/models/slicenet.py | 2 +- tensor2tensor/models/transformer.py | 2 +- tensor2tensor/models/transformer_vae.py | 2 +- 10 files changed, 13 insertions(+), 13 deletions(-) diff --git a/tensor2tensor/layers/common_layers.py b/tensor2tensor/layers/common_layers.py index bd9ff896d..6554e0d31 100644 --- a/tensor2tensor/layers/common_layers.py +++ b/tensor2tensor/layers/common_layers.py @@ -209,7 +209,7 @@ def embedding(x, vocab_size, dense_size, name=None, reuse=None, multiplier=1.0): return tf.reshape(emb_x, [shape[0], shape[1], shape[2], static_shape[4]]) -def shift_left(x, pad_value=None): +def shift_right(x, pad_value=None): """Shift the second dimension of x right by one.""" if pad_value is None: shifted_targets = tf.pad(x, [[0, 0], [1, 0], [0, 0], [0, 0]])[:, :-1, :, :] @@ -218,7 +218,7 @@ def shift_left(x, pad_value=None): return shifted_targets -def shift_left_3d(x, pad_value=None): +def shift_right_3d(x, pad_value=None): """Shift the second dimension of x right by one.""" if pad_value is None: shifted_targets = tf.pad(x, [[0, 0], [1, 0], [0, 0]])[:, :-1, :] @@ -815,7 +815,7 @@ def decompress_seqcnn(x, # Flatten x and embedded targets. Flat targets are factor* larger on axis=1. flat_x = tf.reshape(x, [-1, 1, 1, hidden_size]) flat_targets = tf.reshape(targets_emb, [-1, factor, 1, hidden_size]) - shifted_targets = shift_left(flat_targets) + shifted_targets = shift_right(flat_targets) # Run a SeqCNN large-batch to produce factor outputs out of every target. flat_x += tf.zeros_like(shifted_targets) # Broadcast on axis=1. flat_outputs = conv_block( diff --git a/tensor2tensor/layers/common_layers_test.py b/tensor2tensor/layers/common_layers_test.py index d11f8ce2c..ee07c48d3 100644 --- a/tensor2tensor/layers/common_layers_test.py +++ b/tensor2tensor/layers/common_layers_test.py @@ -281,7 +281,7 @@ def testShiftLeft(self): expected = np.zeros((5, 7, 1, 11)) expected[:, 1, :] = np.ones_like(expected[:, 1, :]) with self.test_session() as session: - a = common_layers.shift_left(tf.constant(x1, dtype=tf.float32)) + a = common_layers.shift_right(tf.constant(x1, dtype=tf.float32)) actual = session.run(a) self.assertAllEqual(actual, expected) diff --git a/tensor2tensor/models/attention_lm.py b/tensor2tensor/models/attention_lm.py index 3302f45be..696057233 100644 --- a/tensor2tensor/models/attention_lm.py +++ b/tensor2tensor/models/attention_lm.py @@ -79,7 +79,7 @@ def attention_lm_prepare_decoder(targets, hparams): else: decoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle(tf.shape(targets)[1])) - decoder_input = common_layers.shift_left_3d(targets) + decoder_input = common_layers.shift_right_3d(targets) if hparams.pos == "timing": decoder_input = common_attention.add_timing_signal_1d(decoder_input) return (decoder_input, decoder_self_attention_bias) diff --git a/tensor2tensor/models/attention_lm_moe.py b/tensor2tensor/models/attention_lm_moe.py index abdd68c8b..42a9fbabf 100644 --- a/tensor2tensor/models/attention_lm_moe.py +++ b/tensor2tensor/models/attention_lm_moe.py @@ -262,7 +262,7 @@ def attention_lm_moe_prepare_decoder(targets, hparams): common_attention.attention_bias_lower_triangle(tf.shape(targets)[1])) # TODO(epot): The padding remover should take into account that the input is # shifted. - decoder_input = common_layers.shift_left_3d(targets) + decoder_input = common_layers.shift_right_3d(targets) if hparams.pos == "timing": decoder_input = common_attention.add_timing_signal_1d(decoder_input) return (decoder_input, decoder_self_attention_bias, pad_remover) diff --git a/tensor2tensor/models/bytenet.py b/tensor2tensor/models/bytenet.py index e4537ef3f..5af0c4435 100644 --- a/tensor2tensor/models/bytenet.py +++ b/tensor2tensor/models/bytenet.py @@ -66,7 +66,7 @@ def bytenet_internal(inputs, targets, hparams): final_encoder = residual_dilated_conv(inputs, hparams.num_block_repeat, "SAME", "encoder", hparams) - shifted_targets = common_layers.shift_left(targets) + shifted_targets = common_layers.shift_right(targets) kernel = (hparams.kernel_height, hparams.kernel_width) decoder_start = common_layers.conv_block( tf.concat([final_encoder, shifted_targets], axis=3), diff --git a/tensor2tensor/models/lstm.py b/tensor2tensor/models/lstm.py index d1c3101b4..20475a5a9 100644 --- a/tensor2tensor/models/lstm.py +++ b/tensor2tensor/models/lstm.py @@ -221,7 +221,7 @@ def lstm_seq2seq_internal(inputs, targets, hparams, train): _, final_encoder_state = lstm( tf.reverse(inputs, axis=[1]), hparams, train, "encoder") # LSTM decoder. - shifted_targets = common_layers.shift_left(targets) + shifted_targets = common_layers.shift_right(targets) decoder_outputs, _ = lstm( common_layers.flatten4d3d(shifted_targets), hparams, @@ -240,7 +240,7 @@ def lstm_seq2seq_internal_attention(inputs, targets, hparams, train): encoder_outputs, final_encoder_state = lstm( tf.reverse(inputs, axis=[1]), hparams, train, "encoder") # LSTM decoder with attention - shifted_targets = common_layers.shift_left(targets) + shifted_targets = common_layers.shift_right(targets) decoder_outputs, _ = lstm_attention_decoder( common_layers.flatten4d3d(shifted_targets), hparams, train, "decoder", final_encoder_state, encoder_outputs) diff --git a/tensor2tensor/models/multimodel.py b/tensor2tensor/models/multimodel.py index 5df8fcd3c..a4c82d942 100644 --- a/tensor2tensor/models/multimodel.py +++ b/tensor2tensor/models/multimodel.py @@ -99,7 +99,7 @@ def prepare_decoder(targets, target_space_emb): common_attention.attention_bias_lower_triangle(tf.shape(targets)[1])) target_space_emb = tf.reshape(target_space_emb, [1, 1, -1]) target_space_emb = tf.tile(target_space_emb, [tf.shape(targets)[0], 1, 1]) - decoder_input = common_layers.shift_left_3d( + decoder_input = common_layers.shift_right_3d( targets, pad_value=target_space_emb) decoder_input = common_attention.add_timing_signal_1d(decoder_input) return (decoder_input, decoder_self_attention_bias) diff --git a/tensor2tensor/models/slicenet.py b/tensor2tensor/models/slicenet.py index 6b07dc640..5377fd97e 100644 --- a/tensor2tensor/models/slicenet.py +++ b/tensor2tensor/models/slicenet.py @@ -198,7 +198,7 @@ def norm_fn(x, name): similarity_loss = 0.0 # Use attention from each target to look at input and retrieve. - targets_shifted = common_layers.shift_left( + targets_shifted = common_layers.shift_right( targets_flat, pad_value=target_space_emb) if hparams.attention_type == "none": targets_with_attention = tf.zeros_like(targets_shifted) diff --git a/tensor2tensor/models/transformer.py b/tensor2tensor/models/transformer.py index 7d52824fa..9e5fdacc6 100644 --- a/tensor2tensor/models/transformer.py +++ b/tensor2tensor/models/transformer.py @@ -238,7 +238,7 @@ def transformer_prepare_decoder(targets, hparams): if hparams.proximity_bias: decoder_self_attention_bias += common_attention.attention_bias_proximal( tf.shape(targets)[1]) - decoder_input = common_layers.shift_left_3d(targets) + decoder_input = common_layers.shift_right_3d(targets) if hparams.pos == "timing": decoder_input = common_attention.add_timing_signal_1d(decoder_input) return (decoder_input, decoder_self_attention_bias) diff --git a/tensor2tensor/models/transformer_vae.py b/tensor2tensor/models/transformer_vae.py index e3279495a..86950d6b7 100644 --- a/tensor2tensor/models/transformer_vae.py +++ b/tensor2tensor/models/transformer_vae.py @@ -187,7 +187,7 @@ def encode(x, x_space, hparams, name): def decode(cond_vec, cond_add, gold, c, ed, hparams): """Transformer decoder.""" drop_gold = tf.nn.dropout(gold, 1.0 - hparams.layer_prepostprocess_dropout) - decoder_input = common_layers.shift_left(drop_gold, pad_value=cond_vec) + decoder_input = common_layers.shift_right(drop_gold, pad_value=cond_vec) if cond_add is not None: decoder_input += cond_add decoder_input = tf.squeeze(decoder_input, axis=2) From 558fe96923d5d56622a8d596a5441069c85cc72e Mon Sep 17 00:00:00 2001 From: T2T Team Date: Mon, 18 Sep 2017 14:31:05 -0700 Subject: [PATCH 18/39] Move the final layer_preprocess in the encoder and decoder in to the variable scopes so that they don't share parameters. PiperOrigin-RevId: 169147528 --- tensor2tensor/models/transformer.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tensor2tensor/models/transformer.py b/tensor2tensor/models/transformer.py index 9e5fdacc6..855e0fa55 100644 --- a/tensor2tensor/models/transformer.py +++ b/tensor2tensor/models/transformer.py @@ -282,10 +282,10 @@ def transformer_encoder(encoder_input, y = transformer_ffn_layer( common_layers.layer_preprocess(x, hparams), hparams, pad_remover) x = common_layers.layer_postprocess(x, y, hparams) - # if normalization is done in layer_preprocess, then it shuold also be done - # on the output, since the output can grow very large, being the sum of - # a whole stack of unnormalized layer outputs. - return common_layers.layer_preprocess(x, hparams) + # if normalization is done in layer_preprocess, then it shuold also be done + # on the output, since the output can grow very large, being the sum of + # a whole stack of unnormalized layer outputs. + return common_layers.layer_preprocess(x, hparams) def transformer_decoder(decoder_input, @@ -336,10 +336,10 @@ def transformer_decoder(decoder_input, y = transformer_ffn_layer( common_layers.layer_preprocess(x, hparams), hparams) x = common_layers.layer_postprocess(x, y, hparams) - # if normalization is done in layer_preprocess, then it shuold also be done - # on the output, since the output can grow very large, being the sum of - # a whole stack of unnormalized layer outputs. - return common_layers.layer_preprocess(x, hparams) + # if normalization is done in layer_preprocess, then it shuold also be done + # on the output, since the output can grow very large, being the sum of + # a whole stack of unnormalized layer outputs. + return common_layers.layer_preprocess(x, hparams) def transformer_ffn_layer(x, hparams, pad_remover=None): From 1e712d3a0bd9c20d31d128f060339951f9e56d1a Mon Sep 17 00:00:00 2001 From: Noam Shazeer Date: Mon, 18 Sep 2017 16:07:01 -0700 Subject: [PATCH 19/39] More experiments with "aligned" model and wiki_scramble dataset. PiperOrigin-RevId: 169162566 --- tensor2tensor/data_generators/wiki.py | 6 + tensor2tensor/layers/common_attention.py | 30 +++- tensor2tensor/models/aligned.py | 220 +++++++++++++++++++++-- 3 files changed, 238 insertions(+), 18 deletions(-) diff --git a/tensor2tensor/data_generators/wiki.py b/tensor2tensor/data_generators/wiki.py index 396d120c7..30a16817b 100644 --- a/tensor2tensor/data_generators/wiki.py +++ b/tensor2tensor/data_generators/wiki.py @@ -31,6 +31,7 @@ from tensor2tensor.data_generators import generator_utils from tensor2tensor.data_generators import problem from tensor2tensor.data_generators import text_encoder +from tensor2tensor.utils import metrics from tensor2tensor.utils import registry import tensorflow as tf @@ -222,6 +223,11 @@ def generator(self, data_dir, tmp_dir, _): inputs = self.scramble(targets) yield {"inputs": inputs, "targets": targets} + def eval_metrics(self): + return [ + metrics.Metrics.ACC, metrics.Metrics.NEG_LOG_PERPLEXITY + ] + @registry.register_problem class LanguagemodelWikiScramble128(LanguagemodelWikiScramble): diff --git a/tensor2tensor/layers/common_attention.py b/tensor2tensor/layers/common_attention.py index daefb56c5..9b4235cc3 100644 --- a/tensor2tensor/layers/common_attention.py +++ b/tensor2tensor/layers/common_attention.py @@ -251,18 +251,42 @@ def embedding_to_padding(emb): return tf.to_float(tf.equal(emb_sum, 0.0)) +def attention_bias_local(length, max_backward, max_forward): + """Create an bias tensor to be added to attention logits. + + A position may attend to positions at most max_distance from it, + forward and backwards. + + This does not actually save any computation. + + Args: + length: an integer Scalar. + max_backward: an int64 Scalar - maximum distance backward to attend. + negative values indicate unlimited. + max_forward: an int64 Scalar - maximum distance forward to attend. + negative values indicate unlimited. + + Returns: + a `Tensor` with shape [1, 1, length, length]. + """ + band = tf.matrix_band_part( + tf.ones([length, length]), max_backward, max_forward) + ret = -1e9 * (1.0 - band) + return tf.reshape(ret, [1, 1, length, length]) + + def attention_bias_lower_triangle(length): """Create an bias tensor to be added to attention logits. + Allows a query to attend to all positions up to and including its own. + Args: length: a Scalar. Returns: a `Tensor` with shape [1, 1, length, length]. """ - lower_triangle = tf.matrix_band_part(tf.ones([length, length]), -1, 0) - ret = -1e9 * (1.0 - lower_triangle) - return tf.reshape(ret, [1, 1, length, length]) + return attention_bias_local(length, -1, 0) def attention_bias_ignore_padding(memory_padding): diff --git a/tensor2tensor/models/aligned.py b/tensor2tensor/models/aligned.py index 9cadc0cae..90100c842 100644 --- a/tensor2tensor/models/aligned.py +++ b/tensor2tensor/models/aligned.py @@ -44,7 +44,8 @@ def _should_preprocess(layer_type): - return layer_type not in ["timing", "pos_emb"] + return layer_type not in [ + "timing", "pos_emb", "att_memory_efficient"] def _should_postprocess(layer_type): @@ -81,8 +82,6 @@ def _diet_expert(x): batch_coordinate = dp(get_batch_coordinate, x) - assert hparams.batch_size >= hparams.max_length - layers = hparams.layers.strip(",").split(",") for layer_num, layer_type in enumerate(layers): with tf.variable_scope("%s_%d" % (layer_type, layer_num)): @@ -94,7 +93,25 @@ def _diet_expert(x): y = dp(common_attention.add_positional_embedding_nd, x, hparams.max_length, name="pos_emb") elif layer_type == "att": - # multihead attention + y = dp( + common_attention.multihead_attention, + x, + None, + None, # bias + hparams.attention_key_channels or hparams.hidden_size, + hparams.attention_value_channels or hparams.hidden_size, + hparams.hidden_size, + hparams.num_heads, + hparams.attention_dropout) + elif layer_type == "att_memory_efficient": + assert hparams.layer_preprocess_sequence == "n" + zero_bias = tf.zeros([1, 1, 1, 1]) + y = dp( + common_attention.multihead_self_attention_memory_efficient, + x, + zero_bias, + hparams.num_heads) + elif layer_type == "att_local": y = dp( common_attention.multihead_attention, x, @@ -105,10 +122,29 @@ def _diet_expert(x): hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, - attention_type=("local_unmasked" if hparams.attention_local - else "dot_product"), - name="decoder_self_attention") - elif layer_type == "local_expert_attention": + attention_type="local_unmasked", + block_length=hparams.local_attention_window, + block_width=hparams.local_attention_window) + elif layer_type == "att_pseudolocal": + # This is an inefficient implementation of local attention, for the + # purpose of testing model quality. + def _pseudolocal_bias(x): + return common_attention.attention_bias_local( + tf.shape(x)[1], + hparams.local_attention_window, + hparams.local_attention_window) + pseudolocal_bias = dp(_pseudolocal_bias, x) + y = dp( + common_attention.multihead_attention, + x, + None, + pseudolocal_bias, + hparams.attention_key_channels or hparams.hidden_size, + hparams.attention_value_channels or hparams.hidden_size, + hparams.hidden_size, + hparams.num_heads, + hparams.attention_dropout) + elif layer_type == "att_local_expert": y, loss = dp( common_attention.local_expert_attention, x, @@ -176,6 +212,10 @@ def get_batch_coordinate(x): def aligned_base(): """Set of hyperparameters. + languagemodel_wiki_scramble1k50, 1gpu, 7k steps (10min): log(ppl)_eval = 2.60 + 12.0 steps/sec on P100 + 8gpu (8x batch), 7k steps: log(ppl)_eval = 2.00 + Returns: a hparams object """ @@ -183,6 +223,7 @@ def aligned_base(): hparams.hidden_size = 512 hparams.batch_size = 5000 hparams.max_length = 1024 + hparams.min_length_bucket = 1024 hparams.dropout = 0.0 hparams.layer_prepostprocess_dropout = 0.0 hparams.label_smoothing = 0.0 @@ -196,12 +237,12 @@ def aligned_base(): hparams.weight_decay = 0.0 hparams.optimizer_adam_beta1 = 0.9 hparams.optimizer_adam_beta2 = 0.98 - hparams.shared_embedding_and_softmax_weights = int(False) + hparams.shared_embedding_and_softmax_weights = int(True) hparams.add_hparam("ffn_hidden_sizes", "2048") # Add new ones like this. hparams.moe_num_experts = 32 hparams.layer_preprocess_sequence = "n" hparams.layer_postprocess_sequence = "da" - hparams.add_hparam("layers", "timing," + "att,ffn," * 4) + hparams.add_hparam("layers", "timing," + "conv,att,ffn," * 2) # attention-related flags hparams.add_hparam("num_heads", 8) @@ -223,34 +264,183 @@ def aligned_base(): hparams.add_hparam("attention_load_balance", 2e-2) hparams.add_hparam("diet_experts", int(False)) hparams.add_hparam("memory_efficient_ffn", int(False)) + hparams.add_hparam("local_attention_window", 128) # if True, we learn a non-autoregressive model from "inputs" to "targets". # if False, we learn an autoregressive model to generate "targets" return hparams @registry.register_hparams -def aligned_with_conv(): +def aligned_memory_efficient(): + """Use multihead_self_attention_memory_efficient. + + languagemodel_wiki_scramble1k50, 1gpu, 7k steps: log(ppl)_eval = 2.59 + 8.7 steps/sec on P100 + 8gpu (8x batch), 7k steps: log(ppl)_eval = 2.02 + + Returns: + a hparams object + """ hparams = aligned_base() - hparams.layers = "timing," + "conv,att,ffn," * 4 + hparams.layers = "timing," + "conv,att_memory_efficient,ffn," * 2 + return hparams + + +@registry.register_hparams +def aligned_local_expert(): + """Use local_expert_attention. + + languagemodel_wiki_scramble1k50, 1gpu, 7k steps: log(ppl)_eval = 2.72 + 10.2 steps/sec on P100 + 8gpu (8x batch), 7k steps: log(ppl)_eval = 2.27 + + Returns: + a hparams object + """ + hparams = aligned_base() + hparams.layers = "timing," + "conv,att_local_expert,ffn," * 2 return hparams @registry.register_hparams def aligned_local(): + """Use local attention code. + + languagemodel_wiki_scramble1k50, 1gpu, 7k steps: log(ppl)_eval = 2.57 + 12.8 steps/sec on P100 + 8gpu (8x batch), 7k steps: log(ppl)_eval = 2.08 + + Returns: + a hparams object + """ + hparams = aligned_base() + hparams.layers = "timing," + "conv,att_local,ffn," * 2 + return hparams + + +@registry.register_hparams +def aligned_local_1k(): + """Use local attention code, attend to full sequence. + + languagemodel_wiki_scramble1k50, 1gpu, 7k steps: log(ppl)_eval = 2.57 + 7.5 steps/sec on P100 + 8gpu (8x batch), 7k steps: log(ppl)_eval = 2.00 + + Returns: + a hparams object + """ + hparams = aligned_local() + hparams.local_attention_window = 1024 + return hparams + + +@registry.register_hparams +def aligned_pseudolocal(): + """Use a bias to simulate local attention. attention radius 128. + + languagemodel_wiki_scramble1k50, 1gpu, 7k steps: log(ppl)_eval = 2.57 + 12.0 steps/sec on P100 + 8gpu (8x batch), 7k steps: log(ppl)_eval = 2.06 + + Returns: + a hparams object + """ + hparams = aligned_base() + hparams.layers = "timing," + "conv,att_pseudolocal,ffn," * 2 + return hparams + + +@registry.register_hparams +def aligned_pseudolocal_256(): + """Use a bias to simulate local attention. attentio radius 256. + + languagemodel_wiki_scramble1k50, 1gpu, 7k steps: log(ppl)_eval = 2.56 + 12.0 steps/sec on P100 + 8gpu (8x batch), 7k steps: log(ppl)_eval = 2.05 + + Returns: + a hparams object + """ + hparams = aligned_pseudolocal() + hparams.local_attention_window = 256 + return hparams + + +@registry.register_hparams +def aligned_no_timing(): + """No timing signal. + + languagemodel_wiki_scramble1k50, 1gpu, 7k steps: log(ppl)_eval = 2.75 + 12.3 steps/sec on P100 + 8gpu (8x batch), 7k steps: log(ppl)_eval = 2.39 + + Returns: + a hparams object + """ hparams = aligned_base() - hparams.attention_local = int(True) + hparams.layers = "conv,att,ffn," * 2 + return hparams + + +@registry.register_hparams +def aligned_no_att(): + """No attention at all. + + languagemodel_wiki_scramble1k50, 1gpu, 7k steps: log(ppl)_eval = 2.89 + 20.8 steps/sec on P100 + 8gpu (8x batch), 7k steps: log(ppl)_eval = 2.70 + + Returns: + a hparams object + """ + hparams = aligned_base() + hparams.layers = "conv,ffn," * 2 return hparams @registry.register_hparams def aligned_pos_emb(): + """positional embedding insead of timing signal. + + languagemodel_wiki_scramble1k50, 1gpu, 7k steps: log(ppl)_eval = 2.67 + 12.1 steps/sec on P100 + 8gpu (8x batch), 7k steps: log(ppl)_eval = 2.00 + + Returns: + a hparams object + """ hparams = aligned_base() - hparams.layers = "pos_emb," + "att,ffn," * 4 + hparams.layers = "pos_emb," + "conv,att,ffn," * 2 return hparams @registry.register_hparams def aligned_moe(): + """mixture of experts instead of ffn. + + languagemodel_wiki_scramble1k50, 1gpu, 7k steps: log(ppl)_eval = 2.62 + 6.7 steps/sec on P100 + 8gpu (8x batch), 7k steps: log(ppl)_eval = 1.94 + + Returns: + a hparams object + """ + hparams = aligned_base() + hparams.layers = "timing," + "conv,att,moe," * 2 + return hparams + + +@registry.register_hparams +def aligned_8k(): + """version for languagemodel_wiki_scramble8k50. + + languagemodel_wiki_scramble1k50, 1gpu, 7k steps: log(ppl)_eval = 2.93 + 1.5 steps/sec on P100 + + Returns: + a hparams object + """ hparams = aligned_base() - hparams.layers = "timing," + "att,moe," * 4 + hparams.max_length = 8192 + hparams.batch_size = 8192 return hparams From 1c7d365dd37a5873017b9529e9fa6fba9c1a6e50 Mon Sep 17 00:00:00 2001 From: T2T Team Date: Mon, 18 Sep 2017 16:34:12 -0700 Subject: [PATCH 20/39] Initial version of fast decoding for transformer models. PiperOrigin-RevId: 169166125 --- tensor2tensor/layers/common_attention.py | 33 ++++- tensor2tensor/models/transformer.py | 177 ++++++++++++++++++++++- tensor2tensor/models/transformer_test.py | 43 +++++- tensor2tensor/utils/t2t_model.py | 19 ++- 4 files changed, 259 insertions(+), 13 deletions(-) diff --git a/tensor2tensor/layers/common_attention.py b/tensor2tensor/layers/common_attention.py index 9b4235cc3..582f8e9b3 100644 --- a/tensor2tensor/layers/common_attention.py +++ b/tensor2tensor/layers/common_attention.py @@ -1064,6 +1064,7 @@ def multihead_attention(query_antecedent, kv_filter_width=1, q_padding="VALID", kv_padding="VALID", + cache=None, name=None): """Multihead scaled-dot-product attention with input/output transformations. @@ -1087,11 +1088,28 @@ def multihead_attention(query_antecedent, to be. q_padding: One of "VALID", "SAME" or "LEFT". Default is VALID: No padding. kv_padding: One of "VALID", "SAME" or "LEFT". Default is VALID: No padding. - + cache: dict, containing Tensors which are the results of previous + attentions, used for fast decoding. Expects the dict to contrain two + keys; 'k' and 'v', for the initial call the values for these keys should + be empty Tensors of the appropriate shape. + 'k' [batch_size, 0, key_channels] + 'v' [batch_size, 0, value_channels] name: an optional string + Caching: + WARNING: For decoder self-attention, i.e. when memory_antecedent == None, + the caching assumes that the bias contains future masking. + + The caching works by saving all the previous key and value values so that + you are able to send just the last query location to this attention + function. I.e. if the cache dict is provided it assumes the query is of the + shape [batch_size, 1, hiddem_dim] rather than the full memory. + Returns: - A Tensor. + The result of the attention transformation. The output shape is + [batch_size, length_q, hidden_dim] + unless the cache dict is provided in which case only the last memory + position is calculated and the output shape is [batch_size, 1, hidden_dim] Raises: ValueError: if the key depth or value depth are not divisible by the @@ -1111,6 +1129,17 @@ def multihead_attention(query_antecedent, total_value_depth, q_filter_width, kv_filter_width, q_padding, kv_padding) + if cache is not None: + if attention_type != "dot_product": + raise NotImplementedError( + "Caching is not guaranteed to work with attention types other than" + " dot_product.") + if bias is None: + raise ValueError("Bias required for caching. See function docstring " + "for details.") + k = cache["k"] = tf.concat([cache["k"], k], axis=1) + v = cache["v"] = tf.concat([cache["v"], v], axis=1) + q = split_heads(q, num_heads) k = split_heads(k, num_heads) v = split_heads(v, num_heads) diff --git a/tensor2tensor/models/transformer.py b/tensor2tensor/models/transformer.py index 855e0fa55..918fc8645 100644 --- a/tensor2tensor/models/transformer.py +++ b/tensor2tensor/models/transformer.py @@ -77,7 +77,8 @@ def decode( encoder_output, encoder_decoder_attention_bias, decoder_self_attention_bias, - hparams): + hparams, + cache=None): """Decode Transformer outputs from encoder representation. Args: @@ -90,6 +91,8 @@ def decode( decoder_self_attention_bias: Bias and mask weights for decoder self-attention. [batch_size, decoder_length] hparams: hyperparmeters for model. + cache: dict, containing tensors which are the results of previous + attentions, used for fast decoding. Returns: Final decoder representaiton. [batch_size, decoder_length, hidden_dim] @@ -102,7 +105,8 @@ def decode( encoder_output, decoder_self_attention_bias, encoder_decoder_attention_bias, - hparams) + hparams, + cache=cache) # Expand since t2t expects 4d tensors. return tf.expand_dims(decoder_output, axis=2) @@ -141,6 +145,152 @@ def model_fn_body(self, features): decoder_self_attention_bias, hparams) + # TODO(llion): Enable fast inference once it's been fully tested. + def x_greedy_infer( + self, features, decode_length, last_position_only=True): + """Fast version of greedy decoding. + + Args: + features: an map of string to `Tensor` + decode_length: an integer. How many additional timesteps to decode. + last_position_only: MUST be true for fast decoding! + + Returns: + samples: [batch_size, input_length + decode_length] + logits: Not returned + losses: Not returned + + Raises: + ValueError: If last_position_only if False + NotImplementedError: If there are multiple data shards. + """ + if not last_position_only: + raise ValueError("Fast decoding only deals with the last positions!") + if self._num_datashards != 1: + raise NotImplementedError("Fast decoding only supports a single shard.") + dp = self._data_parallelism + hparams = self._hparams + + inputs = features["inputs"] + batch_size = tf.shape(inputs)[0] + # TODO(llion): Support class modality + decode_length = tf.shape(inputs)[1] + decode_length + + # TODO(llion): Clean up this reshaping logic. + inputs = tf.expand_dims(inputs, axis=1) + if len(inputs.shape) < 5: + inputs = tf.expand_dims(inputs, axis=4) + s = tf.shape(inputs) + inputs = tf.reshape(inputs, [s[0] * s[1], s[2], s[3], s[4]]) + # _shard_features called to ensure that the variable names match + inputs = self._shard_features({"inputs": inputs})["inputs"] + input_modality = self._problem_hparams.input_modality["inputs"] + with tf.variable_scope(input_modality.name): + inputs = input_modality.bottom_sharded(inputs, dp) + with tf.variable_scope("body"): + encoder_output, encoder_decoder_attention_bias = dp( + self.encode, inputs, features["target_space_id"], hparams) + + if hparams.pos == "timing": + timing_signal = common_attention.get_timing_signal_1d( + decode_length + 1, hparams.hidden_size) + + target_modality = self._problem_hparams.target_modality + + def preprocess_targets(targets, i): + """Performs preprocessing steps on the targets to prepare for the decoder. + + This includes: + - Embedding the ids. + - Flattening to 3D tensor. + - Optionally adding timing signals. + + Args: + targets: inputs ids to the decoder. [batch_size, 1] + i: scalar, Step number of the decoding loop. + + Returns: + Processed targets [batch_size, 1, hidden_dim] + """ + # _shard_features called to ensure that the variable names match + targets = self._shard_features({"targets": targets})["targets"] + with tf.variable_scope(target_modality.name): + targets = target_modality.targets_bottom_sharded(targets, dp)[0] + targets = common_layers.flatten4d3d(targets) + + # TODO(llion): Explain! Is this even needed? + targets = tf.cond( + tf.equal(i, 0), + lambda: tf.zeros_like(targets), + lambda: targets) + + if hparams.pos == "timing": + targets += timing_signal[:, i:i+1] + return targets + + decoder_self_attention_bias = ( + common_attention.attention_bias_lower_triangle(decode_length)) + if hparams.proximity_bias: + decoder_self_attention_bias += common_attention.attention_bias_proximal( + decode_length) + + def symbols_to_logits_fn(ids, i, cache): + """Go from ids to logits for next symbol.""" + targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3) + targets = preprocess_targets(targets, i) + + bias = decoder_self_attention_bias[:, :, i:i+1, :i+1] + + with tf.variable_scope("body"): + body_outputs = self._data_parallelism( + self.decode, + targets, + encoder_output[0], + encoder_decoder_attention_bias[0], + bias, + hparams, + cache) + + with tf.variable_scope(target_modality.name): + logits = target_modality.top_sharded(body_outputs, None, dp)[0] + + return tf.squeeze(logits, axis=[1, 2, 3]) + + def inner_loop(i, next_id, decoded_ids, cache): + logits = symbols_to_logits_fn(next_id, i, cache) + next_id = tf.expand_dims(tf.argmax(logits, axis=-1), axis=1) + decoded_ids = tf.concat([decoded_ids, next_id], axis=1) + return i+1, next_id, decoded_ids, cache + + key_channels = hparams.attention_key_channels or hparams.hidden_size + value_channels = hparams.attention_value_channels or hparams.hidden_size + num_layers = hparams.num_decoder_layers or hparams.num_hidden_layers + + cache = { + "layer_%d" % layer: { + "k": tf.zeros([batch_size, 0, key_channels]), + "v": tf.zeros([batch_size, 0, value_channels]), + } for layer in range(num_layers) + } + decoded_ids = tf.zeros([batch_size, 0], dtype=tf.int64) + next_id = tf.zeros([batch_size, 1], dtype=tf.int64) + _, _, decoded_ids, _ = tf.while_loop( + # TODO(llion): Early stopping. + lambda i, *_: tf.less(i, decode_length), + inner_loop, + [tf.constant(0), next_id, decoded_ids, cache], + shape_invariants=[ + tf.TensorShape([]), + tf.TensorShape([None, None]), + tf.TensorShape([None, None]), + {"layer_%d" % layer: { + "k": tf.TensorShape([None, None, key_channels]), + "v": tf.TensorShape([None, None, value_channels]), + } for layer in range(num_layers)} + ]) + + return decoded_ids, None, None + @registry.register_model class TransformerEncoder(t2t_model.T2TModel): @@ -293,6 +443,7 @@ def transformer_decoder(decoder_input, decoder_self_attention_bias, encoder_decoder_attention_bias, hparams, + cache=None, name="decoder"): """A stack of transformer layers. @@ -304,6 +455,8 @@ def transformer_decoder(decoder_input, encoder_decoder_attention_bias: bias Tensor for encoder-decoder attention (see common_attention.attention_bias()) hparams: hyperparameters for model + cache: dict, containing tensors which are the results of previous + attentions, used for fast decoding. name: a string Returns: @@ -313,20 +466,28 @@ def transformer_decoder(decoder_input, with tf.variable_scope(name): for layer in xrange(hparams.num_decoder_layers or hparams.num_hidden_layers): - with tf.variable_scope("layer_%d" % layer): + layer_name = "layer_%d" % layer + layer_cache = cache[layer_name] if cache is not None else None + with tf.variable_scope(layer_name): with tf.variable_scope("self_attention"): y = common_attention.multihead_attention( - common_layers.layer_preprocess( - x, hparams), None, decoder_self_attention_bias, + common_layers.layer_preprocess(x, hparams), + None, + decoder_self_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, - hparams.hidden_size, hparams.num_heads, hparams.attention_dropout) + hparams.hidden_size, + hparams.num_heads, + hparams.attention_dropout, + cache=layer_cache) x = common_layers.layer_postprocess(x, y, hparams) if encoder_output is not None: with tf.variable_scope("encdec_attention"): + # TODO(llion): Add caching. y = common_attention.multihead_attention( - common_layers.layer_preprocess( - x, hparams), encoder_output, encoder_decoder_attention_bias, + common_layers.layer_preprocess(x, hparams), + encoder_output, + encoder_decoder_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, diff --git a/tensor2tensor/models/transformer_test.py b/tensor2tensor/models/transformer_test.py index 9e450a670..77e17a494 100644 --- a/tensor2tensor/models/transformer_test.py +++ b/tensor2tensor/models/transformer_test.py @@ -32,16 +32,22 @@ BATCH_SIZE = 3 INPUT_LENGTH = 5 TARGET_LENGTH = 7 -VOCAB_SIZE = 9 +VOCAB_SIZE = 10 class TransformerTest(tf.test.TestCase): - def getModel(self): + def getModel(self, mode=tf.estimator.ModeKeys.TRAIN): hparams = transformer.transformer_small() + hparams.hidden_size = 8 + hparams.filter_size = 32 + hparams.num_heads = 1 + hparams.layer_prepostprocess_dropout = 0.0 + p_hparams = problem_hparams.test_problem_hparams( hparams, VOCAB_SIZE, VOCAB_SIZE) hparams.problems = [p_hparams] + inputs = -1 + np.random.random_integers( VOCAB_SIZE, size=(BATCH_SIZE, INPUT_LENGTH, 1, 1)) targets = -1 + np.random.random_integers( @@ -64,6 +70,39 @@ def testTransformer(self): res = session.run(logits) self.assertEqual(res.shape, (BATCH_SIZE, TARGET_LENGTH, 1, 1, VOCAB_SIZE)) + def testGreedyVsFast(self): + model, features = self.getModel() + + decode_length = 2 + + out_logits, _ = model.model_fn(features) + out_logits = tf.squeeze(out_logits[0], axis=[2, 3]) + loss = tf.nn.sparse_softmax_cross_entropy_with_logits( + logits=tf.reshape(out_logits, [-1, VOCAB_SIZE]), + labels=tf.reshape(features["targets"], [-1])) + loss = tf.reduce_mean(loss) + apply_grad = tf.train.AdamOptimizer(0.001).minimize(loss) + + with self.test_session(): + tf.global_variables_initializer().run() + for _ in range(100): + apply_grad.run() + + model, _ = self.getModel(tf.estimator.ModeKeys.PREDICT) + + with tf.variable_scope(tf.get_variable_scope(), reuse=True): + greedy_result, _, _ = model._slow_greedy_infer( + features, decode_length, last_position_only=True) + greedy_result = tf.squeeze(greedy_result, axis=[2, 3]) + + fast_result, _, _ = model.x_greedy_infer(features, decode_length) + + with self.test_session(): + greedy_res = greedy_result.eval() + fast_res = fast_result.eval() + + self.assertEqual(fast_res.shape, (BATCH_SIZE, INPUT_LENGTH + decode_length)) + self.assertAllClose(greedy_res, fast_res) if __name__ == "__main__": tf.test.main() diff --git a/tensor2tensor/utils/t2t_model.py b/tensor2tensor/utils/t2t_model.py index 812e5aee3..6d38a5ba8 100644 --- a/tensor2tensor/utils/t2t_model.py +++ b/tensor2tensor/utils/t2t_model.py @@ -292,7 +292,24 @@ def symbols_to_logits_fn(ids): return {"outputs": ids[:, :top_beams, 1:], "scores": scores} return ids[:, :top_beams, 1:] - def _greedy_infer(self, features, decode_length, last_position_only): + def _greedy_infer(self, features, decode_length, last_position_only): + """A greedy inference method. + + Models should ideally implement a more efficient version of this function. + + Args: + features: an map of string to `Tensor` + decode_length: an integer. How many additional timesteps to decode. + last_position_only: a boolean, speed-up by computing last position only. + + Returns: + samples: an integer `Tensor`. + logits: `Tensor` of shape [batch_size, time, 1, 1, vocab_size]. + losses: a dictionary: {loss-name (string): floating point `Scalar`} + """ + return self._slow_greedy_infer(features, decode_length, last_position_only) + + def _slow_greedy_infer(self, features, decode_length, last_position_only): """A slow greedy inference method. Quadratic time in decode_length. From aa40c4b373ee76900497b1c93f122fc49317e776 Mon Sep 17 00:00:00 2001 From: T2T Team Date: Mon, 18 Sep 2017 20:18:50 -0700 Subject: [PATCH 21/39] Update experiment function signature to (run_config, hparams) PiperOrigin-RevId: 169187769 --- tensor2tensor/bin/t2t-trainer | 7 ++ tensor2tensor/utils/model_builder.py | 17 +---- tensor2tensor/utils/trainer_utils.py | 93 ++++++++++++----------- tensor2tensor/utils/trainer_utils_test.py | 23 ++++-- 4 files changed, 75 insertions(+), 65 deletions(-) diff --git a/tensor2tensor/bin/t2t-trainer b/tensor2tensor/bin/t2t-trainer index 7c7b48932..5defbb465 100644 --- a/tensor2tensor/bin/t2t-trainer +++ b/tensor2tensor/bin/t2t-trainer @@ -43,6 +43,7 @@ import tensorflow as tf flags = tf.flags FLAGS = flags.FLAGS +# See trainer_utils.py for additional command-line flags. flags.DEFINE_string("t2t_usr_dir", "", "Path to a Python module that will be imported. The " "__init__.py file should include the necessary imports. " @@ -53,6 +54,12 @@ flags.DEFINE_string("tmp_dir", "/tmp/t2t_datagen", "Temporary storage directory.") flags.DEFINE_bool("generate_data", False, "Generate data before training?") +flags.DEFINE_integer("eval_steps", 10, "Number of steps in evaluation.") +flags.DEFINE_string("output_dir", "", "Base output directory for run.") +flags.DEFINE_string("master", "", "Address of TensorFlow master.") +flags.DEFINE_string("schedule", "local_run", + "Method of tf.contrib.learn.Experiment to run.") + def main(_): tf.logging.set_verbosity(tf.logging.INFO) diff --git a/tensor2tensor/utils/model_builder.py b/tensor2tensor/utils/model_builder.py index 7c4172743..a0d362035 100644 --- a/tensor2tensor/utils/model_builder.py +++ b/tensor2tensor/utils/model_builder.py @@ -50,9 +50,7 @@ def model_fn(model, worker_id=0, worker_replicas=1, eval_run_autoregressive=False, - decode_hparams=None, - autotune=False, - objective=None): + decode_hparams=None): """Builds the model for all modes. * TRAIN: Constructs loss and train_op @@ -72,8 +70,6 @@ def model_fn(model, worker_replicas: int, number of workers. eval_run_autoregressive: bool, whether to run evaluation autoregressively. decode_hparams: HParams for decode settings. Used when mode == PREDICT. - autotune: bool, whether this model is being used for autotuning. - objective: str, the objective if autotune==True. Returns: tf.estimator.EstimatorSpec @@ -193,8 +189,6 @@ def nth_model(n): if mode == tf.estimator.ModeKeys.EVAL: eval_metrics_fns = metrics.create_evaluation_metrics( zip(problem_names, hparams.problem_instances), hparams) - _check_autotune_metrics( - eval_metrics_fns, autotune=autotune, objective=objective) eval_metrics = {} for metric_name, metric_fn in six.iteritems(eval_metrics_fns): @@ -391,15 +385,6 @@ def _exp_decay_after(step, rate, from_which_step): name="exponential_decay_step_cond") -def _check_autotune_metrics(metrics_dict, autotune=False, objective=None): - if not autotune: - return - - if objective not in metrics_dict: - raise ValueError("Tuning objective %s not among evaluation metrics %s" % - (objective, metrics_dict.keys())) - - def _log_variable_sizes(var_list, tag): """Log the sizes and shapes of variables, and the total size. diff --git a/tensor2tensor/utils/trainer_utils.py b/tensor2tensor/utils/trainer_utils.py index 8ed7fb678..f2bb62c1f 100644 --- a/tensor2tensor/utils/trainer_utils.py +++ b/tensor2tensor/utils/trainer_utils.py @@ -45,7 +45,6 @@ "If True, logs the contents of the registry and exits.") flags.DEFINE_bool("tfdbg", False, "If True, use the TF debugger CLI on train/eval.") -flags.DEFINE_string("output_dir", "", "Base output directory for run.") flags.DEFINE_string("model", "", "Which model to use.") flags.DEFINE_string("hparams_set", "", "Which parameters to use.") flags.DEFINE_string("hparams_range", "", "Parameters range.") @@ -61,7 +60,6 @@ flags.DEFINE_string("data_dir", "/tmp/data", "Directory with training data.") flags.DEFINE_integer("train_steps", 250000, "The number of steps to run training for.") -flags.DEFINE_integer("eval_steps", 10, "Number of steps in evaluation.") flags.DEFINE_bool("eval_run_autoregressive", False, "Run eval autoregressively where we condition on previous" "generated output instead of the actual target.") @@ -80,9 +78,6 @@ "Whether to log device placement.") # Distributed training flags -flags.DEFINE_string("master", "", "Address of TensorFlow master.") -flags.DEFINE_string("schedule", "local_run", - "Method of tf.contrib.learn.Experiment to run.") flags.DEFINE_integer("local_eval_frequency", 2000, "Run evaluation every this steps during local training.") flags.DEFINE_bool("locally_shard_to_cpu", False, @@ -91,7 +86,7 @@ flags.DEFINE_bool("daisy_chain_variables", True, "copy variables around in a daisy chain") flags.DEFINE_bool("sync", False, "Sync compute on PS.") -flags.DEFINE_string("worker_job", "/job:worker", "name of worker job") +flags.DEFINE_string("worker_job", "/job:localhost", "name of worker job") flags.DEFINE_integer("worker_gpu", 1, "How many GPUs to use.") flags.DEFINE_integer("worker_replicas", 1, "How many workers to use.") flags.DEFINE_integer("worker_id", 0, "Which worker task are we.") @@ -113,29 +108,26 @@ def make_experiment_fn(data_dir, model_name, train_steps, eval_steps): """Returns experiment_fn for learn_runner. Wraps create_experiment.""" - def experiment_fn(output_dir): + def experiment_fn(run_config, hparams): return create_experiment( - output_dir=output_dir, - data_dir=data_dir, + data_dir, model_name=model_name, train_steps=train_steps, - eval_steps=eval_steps) + eval_steps=eval_steps, + hparams=hparams, + run_config=run_config) return experiment_fn -def create_experiment(output_dir, data_dir, model_name, train_steps, - eval_steps): +def create_experiment(data_dir, model_name, train_steps, eval_steps, hparams, + run_config): """Create Experiment.""" - hparams = create_hparams( - FLAGS.hparams_set, FLAGS.problems, data_dir, passed_hparams=FLAGS.hparams) - if FLAGS.worker_id == 0 and FLAGS.schedule in ["local_run", "train"]: - save_metadata(output_dir, hparams) estimator, input_fns = create_experiment_components( - hparams=hparams, - output_dir=output_dir, data_dir=data_dir, - model_name=model_name) + model_name=model_name, + hparams=hparams, + run_config=run_config) train_monitors = [] eval_hooks = [] if FLAGS.tfdbg: @@ -153,9 +145,12 @@ def create_experiment(output_dir, data_dir, model_name, train_steps, eval_hooks=eval_hooks) -def create_experiment_components(hparams, output_dir, data_dir, model_name): +def create_experiment_components(data_dir, model_name, hparams, run_config): """Constructs and returns Estimator and train/eval input functions.""" - tf.logging.info("Creating experiment, storing model files in %s", output_dir) + tf.logging.info("Creating experiment, storing model files in %s", + run_config.model_dir) + + hparams = add_problem_hparams(hparams, FLAGS.problems) num_datashards = devices.data_parallelism().n train_input_fn = input_fn_builder.build_input_fn( @@ -176,11 +171,6 @@ def create_experiment_components(hparams, output_dir, data_dir, model_name): worker_replicas=FLAGS.worker_replicas, worker_id=FLAGS.worker_id) - autotune = False - objective = None - if hasattr(FLAGS, "autotune"): - autotune = FLAGS.autotune - objective = FLAGS.objective model_fn = model_builder.build_model_fn( model_name, problem_names=FLAGS.problems.split("-"), @@ -188,20 +178,13 @@ def create_experiment_components(hparams, output_dir, data_dir, model_name): worker_id=FLAGS.worker_id, worker_replicas=FLAGS.worker_replicas, eval_run_autoregressive=FLAGS.eval_run_autoregressive, - decode_hparams=decoding.decode_hparams(FLAGS.decode_hparams), - autotune=autotune, - objective=objective) + decode_hparams=decoding.decode_hparams(FLAGS.decode_hparams)) + estimator = tf.estimator.Estimator( model_fn=model_fn, - model_dir=output_dir, + model_dir=run_config.model_dir, params=hparams, - config=tf.contrib.learn.RunConfig( - master=FLAGS.master, - gpu_memory_fraction=FLAGS.worker_gpu_memory_fraction, - session_config=session_config(), - keep_checkpoint_max=FLAGS.keep_checkpoint_max, - keep_checkpoint_every_n_hours=FLAGS.keep_checkpoint_every_n_hours, - save_checkpoints_secs=FLAGS.save_checkpoints_secs)) + config=run_config) return estimator, { tf.estimator.ModeKeys.TRAIN: train_input_fn, @@ -279,7 +262,7 @@ def save_metadata(output_dir, hparams): f.write(hparams.to_json()) -def create_hparams(params_id, problems, data_dir, passed_hparams=None): +def create_hparams(params_id, data_dir, passed_hparams=None): """Returns hyperparameters, including any flag value overrides. If the hparams FLAG is set, then it will use any values specified in @@ -288,7 +271,6 @@ def create_hparams(params_id, problems, data_dir, passed_hparams=None): Args: params_id: which set of parameters to choose (must be in _PARAMS above). - problems: the string with problem names to get problem_hparams from. data_dir: the directory containing the training data. passed_hparams: command-line overrides for some hparams. @@ -301,7 +283,22 @@ def create_hparams(params_id, problems, data_dir, passed_hparams=None): if passed_hparams: hparams = hparams.parse(passed_hparams) - return add_problem_hparams(hparams, problems) + return hparams + + +def create_run_config(output_dir): + """Create a RunConfig object.""" + + run_config = tf.contrib.learn.RunConfig( + model_dir=output_dir, + master=FLAGS.master, + gpu_memory_fraction=FLAGS.worker_gpu_memory_fraction, + session_config=session_config(), + keep_checkpoint_max=FLAGS.keep_checkpoint_max, + keep_checkpoint_every_n_hours=FLAGS.keep_checkpoint_every_n_hours, + save_checkpoints_secs=FLAGS.save_checkpoints_secs) + + return run_config def run(data_dir, model, output_dir, train_steps, eval_steps, schedule): @@ -327,9 +324,17 @@ def run(data_dir, model, output_dir, train_steps, eval_steps, schedule): train_steps=train_steps, eval_steps=eval_steps) + # Create hparams and run_config + run_config = create_run_config(output_dir) + hparams = create_hparams( + FLAGS.hparams_set, data_dir, passed_hparams=FLAGS.hparams) + if FLAGS.worker_id == 0 and schedule in ["local_run", "train"]: + save_metadata(output_dir, hparams) + if schedule == "local_run": # Run the local demo. - exp = exp_fn(output_dir) + + exp = exp_fn(run_config, hparams) if exp.train_steps > 0 and exp.eval_steps > 0: tf.logging.info("Performing local training and evaluation.") exp.train_and_evaluate() @@ -341,8 +346,10 @@ def run(data_dir, model, output_dir, train_steps, eval_steps, schedule): exp.evaluate(delay_secs=0) else: # Perform distributed training/evaluation. - learn_runner.run( - experiment_fn=exp_fn, schedule=schedule, output_dir=output_dir) + learn_runner.run(experiment_fn=exp_fn, + schedule=schedule, + run_config=run_config, + hparams=hparams) def validate_flags(): diff --git a/tensor2tensor/utils/trainer_utils_test.py b/tensor2tensor/utils/trainer_utils_test.py index 6045dd2e0..1a971ac0c 100644 --- a/tensor2tensor/utils/trainer_utils_test.py +++ b/tensor2tensor/utils/trainer_utils_test.py @@ -33,8 +33,14 @@ import tensorflow as tf +flags = tf.flags FLAGS = tf.flags.FLAGS +flags.DEFINE_string("schedule", "local_run", "") +flags.DEFINE_integer("eval_steps", 10, "Number of steps in evaluation.") +flags.DEFINE_string("master", "", "Address of TensorFlow master.") +flags.DEFINE_string("output_dir", "", "Base output directory for run.") + @registry.register_problem class TinyAlgo(algorithmic.AlgorithmicIdentityBinary40): @@ -84,13 +90,17 @@ def testHParamsImported(self): def testSingleStep(self): model_name = "transformer" - FLAGS.hparams_set = "transformer_test" + FLAGS.worker_job = "/job:localhost" + data_dir = TrainerUtilsTest.data_dir + hparams = trainer_utils.create_hparams("transformer_test", data_dir) exp = trainer_utils.create_experiment( - output_dir=tf.test.get_temp_dir(), - data_dir=TrainerUtilsTest.data_dir, + data_dir=data_dir, model_name=model_name, train_steps=1, - eval_steps=1) + eval_steps=1, + hparams=hparams, + run_config=trainer_utils.create_run_config( + output_dir=tf.test.get_temp_dir())) exp.test() def testSingleEvalStepRawSession(self): @@ -100,12 +110,13 @@ def testSingleEvalStepRawSession(self): model_name = "transformer" FLAGS.hparams_set = "transformer_test" FLAGS.problems = "tiny_algo" + FLAGS.worker_job = "/job:localhost" data_dir = "/tmp" # Used only when a vocab file or such like is needed. # Create the problem object, hparams, placeholders, features dict. encoders = registry.problem(FLAGS.problems).feature_encoders(data_dir) - hparams = trainer_utils.create_hparams(FLAGS.hparams_set, FLAGS.problems, - data_dir) + hparams = trainer_utils.create_hparams(FLAGS.hparams_set, data_dir) + hparams = trainer_utils.add_problem_hparams(hparams, FLAGS.problems) inputs_ph = tf.placeholder(dtype=tf.int32) # Just length dimension. batch_inputs = tf.reshape(inputs_ph, [1, -1, 1, 1]) # Make it 4D. # In INFER mode targets can be None. From aec87db8df301d33b7a84722e72de19832963bd7 Mon Sep 17 00:00:00 2001 From: T2T Team Date: Tue, 19 Sep 2017 10:50:54 -0700 Subject: [PATCH 22/39] [tf.contrib.data] Standardize transformation functions for use with `Dataset.apply()`. PiperOrigin-RevId: 169264919 --- tensor2tensor/utils/data_reader.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensor2tensor/utils/data_reader.py b/tensor2tensor/utils/data_reader.py index d94e85e39..acf4ae026 100644 --- a/tensor2tensor/utils/data_reader.py +++ b/tensor2tensor/utils/data_reader.py @@ -367,8 +367,8 @@ def batching_fn(bucket_id, grouped_dataset): if hasattr(dataset, "apply"): # If the Dataset supports dynamic window size, use it. dataset = dataset.apply( - tf.contrib.data.group_by_window, - args=(example_to_bucket_id, batching_fn, None, window_size_fn)) + tf.contrib.data.group_by_window( + example_to_bucket_id, batching_fn, None, window_size_fn)) else: dataset = dataset.group_by_window(example_to_bucket_id, batching_fn, window_size) From 0b8573c7d89f55f6e9d8c3ca7d7d7640293e09e9 Mon Sep 17 00:00:00 2001 From: Ryan Sepassi Date: Tue, 19 Sep 2017 12:15:37 -0700 Subject: [PATCH 23/39] @recompute_grad decorator PiperOrigin-RevId: 169279745 --- tensor2tensor/layers/rev_block.py | 47 ++++++++++++++++++++++++-- tensor2tensor/layers/rev_block_test.py | 26 ++++++++++++++ 2 files changed, 71 insertions(+), 2 deletions(-) diff --git a/tensor2tensor/layers/rev_block.py b/tensor2tensor/layers/rev_block.py index 8502e0a8b..3dff92c5c 100644 --- a/tensor2tensor/layers/rev_block.py +++ b/tensor2tensor/layers/rev_block.py @@ -18,11 +18,15 @@ From [The Reversible Residual Network: Backpropagation Without Storing Activations](https://arxiv.org/abs/1707.04585). + +Also contains the @recompute_grad decorator, which recomputes the forward +function on the backwards pass. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import functools import re # Dependency imports @@ -286,8 +290,8 @@ def custom_grad_fn(inputs, variables, ys, grad_ys): # idxs. f_var_grads.reverse() g_var_grads.reverse() - for idxs, grads in list(zip(f_vars_idxs, f_var_grads)) + list(zip( - g_vars_idxs, g_var_grads)): + for idxs, grads in list(zip(f_vars_idxs, f_var_grads)) + list( + zip(g_vars_idxs, g_var_grads)): for i, grad in zip(idxs, grads): variable_grads[i] = grad @@ -316,3 +320,42 @@ def forward(x1, x2, *side_inputs): gate_outputs=is_training) return forward(x1, x2, *(f_side_input + g_side_input)) + + +def recompute_grad(fn): + """Decorator that recomputes the function on the backwards pass. + + Args: + fn: a function that takes Tensors (all as positional arguments) and returns + a tuple of Tensors. + + Returns: + A wrapped fn that is identical to fn when called, but its activations will + be discarded and recomputed on the backwards pass (i.e. on a call to + tf.gradients). + """ + + @functools.wraps(fn) + def wrapped(*args): + return _recompute_grad(fn, args) + + return wrapped + + +def _recompute_grad(fn, args): + """See recompute_grad.""" + + def grad_fn(inputs, variables, outputs, output_grads): + del outputs + # recompute outputs + outputs = fn(*inputs) + grads = tf.gradients(outputs, inputs + variables, output_grads) + grad_inputs = grads[:len(inputs)] + grad_vars = grads[len(inputs):] + return grad_inputs, grad_vars + + @common_layers.fn_with_custom_grad(grad_fn) + def fn_with_recompute(*args): + return fn(*args) + + return fn_with_recompute(*args) diff --git a/tensor2tensor/layers/rev_block_test.py b/tensor2tensor/layers/rev_block_test.py index 5aecc8ea3..3e5f7c932 100644 --- a/tensor2tensor/layers/rev_block_test.py +++ b/tensor2tensor/layers/rev_block_test.py @@ -137,5 +137,31 @@ def f(x): self._testRevBlock(x=x, f=f) +class RecomputeTest(tf.test.TestCase): + + def testRecompute(self): + + @rev_block.recompute_grad + def fn_recompute(x, y): + return x + y, x**y + + def fn(x, y): + return x + y, x**y + + x = tf.ones((3, 3)) + y = tf.ones((3, 3)) + out1 = tf.reduce_sum(fn_recompute(x, y)) + out2 = tf.reduce_sum(fn(x, y)) + + grad1 = tf.gradients(out1, [x, y]) + grad2 = tf.gradients(out2, [x, y]) + + with self.test_session() as sess: + outs = sess.run([out1, out2, grad1, grad2]) + self.assertAllClose(outs[0], outs[1]) + for g1, g2 in zip(outs[2], outs[3]): + self.assertAllClose(g1, g2) + + if __name__ == "__main__": tf.test.main() From 77e91f6c3414ae0135eb710d4862aca07fc26c9d Mon Sep 17 00:00:00 2001 From: T2T Team Date: Tue, 19 Sep 2017 14:11:00 -0700 Subject: [PATCH 24/39] Register `lstm_seq2seq` hparams. PiperOrigin-RevId: 169297690 --- tensor2tensor/models/lstm.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/tensor2tensor/models/lstm.py b/tensor2tensor/models/lstm.py index 20475a5a9..f336bd6b4 100644 --- a/tensor2tensor/models/lstm.py +++ b/tensor2tensor/models/lstm.py @@ -266,13 +266,20 @@ def model_fn_body(self, features): @registry.register_hparams -def lstm_attention(): - """hparams for LSTM with attention.""" +def lstm_seq2seq(): + """hparams for LSTM.""" hparams = common_hparams.basic_params1() hparams.batch_size = 1024 hparams.hidden_size = 128 hparams.num_hidden_layers = 2 hparams.initializer = "uniform_unit_scaling" + return hparams + + +@registry.register_hparams +def lstm_attention(): + """hparams for LSTM with attention.""" + hparams = lstm_seq2seq() # Attention hparams.add_hparam("attn_vec_size", hparams.hidden_size) From 12126bd1306a0a7876e617ce664e610ec1e1b22a Mon Sep 17 00:00:00 2001 From: Etienne Pot Date: Tue, 19 Sep 2017 14:18:36 -0700 Subject: [PATCH 25/39] Add flag to profile ops/memory PiperOrigin-RevId: 169299088 --- tensor2tensor/utils/trainer_utils.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tensor2tensor/utils/trainer_utils.py b/tensor2tensor/utils/trainer_utils.py index f2bb62c1f..50cfcc5d0 100644 --- a/tensor2tensor/utils/trainer_utils.py +++ b/tensor2tensor/utils/trainer_utils.py @@ -35,6 +35,7 @@ from tensor2tensor.utils import registry import tensorflow as tf +from tensorflow.contrib.hooks.python.training.profiler_hook import ProfilerHook from tensorflow.contrib.learn.python.learn import learn_runner from tensorflow.python import debug @@ -45,6 +46,8 @@ "If True, logs the contents of the registry and exits.") flags.DEFINE_bool("tfdbg", False, "If True, use the TF debugger CLI on train/eval.") +flags.DEFINE_bool("dbgprofile", False, + "If True, record the timeline for chrome://tracing/.") flags.DEFINE_string("model", "", "Which model to use.") flags.DEFINE_string("hparams_set", "", "Which parameters to use.") flags.DEFINE_string("hparams_range", "", "Parameters range.") @@ -134,6 +137,15 @@ def create_experiment(data_dir, model_name, train_steps, eval_steps, hparams, hook = debug.LocalCLIDebugHook() train_monitors.append(hook) eval_hooks.append(hook) + if FLAGS.dbgprofile: + # Recorded traces can be visualized with chrome://tracing/ + # The memory/tensor lifetime is also profiled + train_monitors.append(ProfilerHook( + save_steps=10, + output_dir=run_config.model_dir, + show_dataflow=True, + show_memory=True, + )) return tf.contrib.learn.Experiment( estimator=estimator, train_input_fn=input_fns[tf.estimator.ModeKeys.TRAIN], From 9d63460a32a6abe44f4caf4a7112a8f3708d2263 Mon Sep 17 00:00:00 2001 From: T2T Team Date: Tue, 19 Sep 2017 14:23:31 -0700 Subject: [PATCH 26/39] Enable fast decoding. PiperOrigin-RevId: 169299895 --- tensor2tensor/models/transformer.py | 5 ++--- tensor2tensor/models/transformer_test.py | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/tensor2tensor/models/transformer.py b/tensor2tensor/models/transformer.py index 918fc8645..9fe0bc5f7 100644 --- a/tensor2tensor/models/transformer.py +++ b/tensor2tensor/models/transformer.py @@ -145,8 +145,7 @@ def model_fn_body(self, features): decoder_self_attention_bias, hparams) - # TODO(llion): Enable fast inference once it's been fully tested. - def x_greedy_infer( + def _greedy_infer( self, features, decode_length, last_position_only=True): """Fast version of greedy decoding. @@ -242,7 +241,7 @@ def symbols_to_logits_fn(ids, i, cache): bias = decoder_self_attention_bias[:, :, i:i+1, :i+1] with tf.variable_scope("body"): - body_outputs = self._data_parallelism( + body_outputs = dp( self.decode, targets, encoder_output[0], diff --git a/tensor2tensor/models/transformer_test.py b/tensor2tensor/models/transformer_test.py index 77e17a494..04c527ac1 100644 --- a/tensor2tensor/models/transformer_test.py +++ b/tensor2tensor/models/transformer_test.py @@ -95,7 +95,7 @@ def testGreedyVsFast(self): features, decode_length, last_position_only=True) greedy_result = tf.squeeze(greedy_result, axis=[2, 3]) - fast_result, _, _ = model.x_greedy_infer(features, decode_length) + fast_result, _, _ = model._greedy_infer(features, decode_length) with self.test_session(): greedy_res = greedy_result.eval() From bc191b54e0e3977ef9016384e2ec4920e660e70c Mon Sep 17 00:00:00 2001 From: Niki Parmar Date: Tue, 19 Sep 2017 15:39:34 -0700 Subject: [PATCH 27/39] Fix formatting in identity output PiperOrigin-RevId: 169312588 --- tensor2tensor/utils/decoding.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/tensor2tensor/utils/decoding.py b/tensor2tensor/utils/decoding.py index fc5f22c1a..664935c94 100644 --- a/tensor2tensor/utils/decoding.py +++ b/tensor2tensor/utils/decoding.py @@ -74,14 +74,18 @@ def log_decode_results(inputs, (problem_name, prediction_idx)) show_and_save_image(inputs / 255., save_path) elif inputs_vocab: - decoded_inputs = inputs_vocab.decode(_save_until_eos(inputs.flatten())) + if identity_output: + decoded_inputs = " ".join(map(str, inputs.flatten())) + else: + decoded_inputs = inputs_vocab.decode(_save_until_eos(inputs.flatten())) + tf.logging.info("Inference results INPUT: %s" % decoded_inputs) decoded_targets = None if identity_output: - decoded_outputs = "".join(map(str, outputs.flatten())) + decoded_outputs = " ".join(map(str, outputs.flatten())) if targets is not None: - decoded_targets = "".join(map(str, targets.flatten())) + decoded_targets = " ".join(map(str, targets.flatten())) else: decoded_outputs = "".join( map(str, targets_vocab.decode(_save_until_eos(outputs.flatten())))) From f07b59f398f8112ea1ff99f120c6d94babb74905 Mon Sep 17 00:00:00 2001 From: Ryan Sepassi Date: Tue, 19 Sep 2017 21:01:28 -0700 Subject: [PATCH 28/39] Fix output shape of TransformerEncoder PiperOrigin-RevId: 169345762 --- tensor2tensor/models/transformer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tensor2tensor/models/transformer.py b/tensor2tensor/models/transformer.py index 9fe0bc5f7..b4f083eca 100644 --- a/tensor2tensor/models/transformer.py +++ b/tensor2tensor/models/transformer.py @@ -309,6 +309,7 @@ def model_fn_body(self, features): 1.0 - hparams.layer_prepostprocess_dropout) encoder_output = transformer_encoder(encoder_input, encoder_self_attention_bias, hparams) + encoder_output = tf.expand_dims(encoder_output, 2) return encoder_output From 21b3b55fa60cefab78581b8c536ad089b2315f46 Mon Sep 17 00:00:00 2001 From: Ryan Sepassi Date: Tue, 19 Sep 2017 21:22:36 -0700 Subject: [PATCH 29/39] SavedModel export and decoding fixes PiperOrigin-RevId: 169347220 --- tensor2tensor/bin/t2t-decoder | 13 +++-- tensor2tensor/utils/data_reader.py | 67 ++++++++++++++++++++--- tensor2tensor/utils/model_builder.py | 12 +++- tensor2tensor/utils/trainer_utils.py | 29 ++++++++-- tensor2tensor/utils/trainer_utils_test.py | 3 +- 5 files changed, 104 insertions(+), 20 deletions(-) diff --git a/tensor2tensor/bin/t2t-decoder b/tensor2tensor/bin/t2t-decoder index 8da8ae5a2..5b5b09555 100644 --- a/tensor2tensor/bin/t2t-decoder +++ b/tensor2tensor/bin/t2t-decoder @@ -46,6 +46,7 @@ import tensorflow as tf flags = tf.flags FLAGS = flags.FLAGS +flags.DEFINE_string("output_dir", "", "Training directory to load from.") flags.DEFINE_string("decode_from_file", None, "Path to decode file") flags.DEFINE_string("decode_to_file", None, "Path prefix to inference output file") @@ -58,6 +59,8 @@ flags.DEFINE_string("t2t_usr_dir", "", "The imported files should contain registrations, " "e.g. @registry.register_model calls, that will then be " "available to the t2t-decoder.") +flags.DEFINE_string("master", "", "Address of TensorFlow master.") +flags.DEFINE_string("schedule", "local_run", "Must be local_run for decoding.") def main(_): @@ -65,16 +68,18 @@ def main(_): usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) trainer_utils.log_registry() trainer_utils.validate_flags() + assert FLAGS.schedule == "local_run" data_dir = os.path.expanduser(FLAGS.data_dir) output_dir = os.path.expanduser(FLAGS.output_dir) hparams = trainer_utils.create_hparams( - FLAGS.hparams_set, FLAGS.problems, data_dir, passed_hparams=FLAGS.hparams) + FLAGS.hparams_set, data_dir, passed_hparams=FLAGS.hparams) + hparams = trainer_utils.add_problem_hparams(hparams, FLAGS.problems) estimator, _ = trainer_utils.create_experiment_components( - hparams=hparams, - output_dir=output_dir, data_dir=data_dir, - model_name=FLAGS.model) + model_name=FLAGS.model, + hparams=hparams, + run_config=trainer_utils.create_run_config(output_dir)) decode_hp = decoding.decode_hparams(FLAGS.decode_hparams) decode_hp.add_hparam("shards", FLAGS.decode_shards) diff --git a/tensor2tensor/utils/data_reader.py b/tensor2tensor/utils/data_reader.py index acf4ae026..4b0541d31 100644 --- a/tensor2tensor/utils/data_reader.py +++ b/tensor2tensor/utils/data_reader.py @@ -156,13 +156,30 @@ def cast_int64_to_int32(features): return f -def feature_placeholders(data_fields): - feature_map = {} - for (field, tp) in data_fields: - if not field.startswith("targets"): - feature_map[field] = tf.placeholder( - dtype=tp, shape=[None] * 4, name=field) - return feature_map +def feature_placeholders(data_fields, data_items_to_decoders): + """Construct Placeholders and run decoders.""" + example = {} + for field, config in data_fields.items(): + if isinstance(config, tf.VarLenFeature): + shape = [None] + else: + shape = config.shape + + example[field] = tf.placeholder(dtype=config.dtype, shape=shape, name=field) + + # Decode + if data_items_to_decoders is None: + data_items_to_decoders = { + field: tf.contrib.slim.tfexample_decoder.Tensor(field) + for field in data_fields + } + + decoded_example = {} + for field, decoder in data_items_to_decoders.items(): + keys_to_tensors = {key: example[key] for key in decoder.keys} + decoded_example[field] = decoder.tensors_to_item(keys_to_tensors) + + return decoded_example def default_example_reading_spec(data_file_pattern): @@ -216,7 +233,7 @@ def read_examples(problem, if data_file_pattern is None: # Create placeholders for input, rather than reading data from disk. - return feature_placeholders(data_fields) + return feature_placeholders(data_fields, data_items_to_decoders) is_training = mode == tf.estimator.ModeKeys.TRAIN dataset = examples_reader( @@ -520,3 +537,37 @@ def get_data_filepatterns(problems, data_dir, mode): else: datasets.append("%s-dev*" % path) return datasets + + +def serving_input_fn(problem, hparams): + """Input fn for serving, starting from Placeholders.""" + data_fields, data_items_to_decoders = problem.example_reading_spec() + + # Feature placeholders that mimic what's on disk + example = feature_placeholders(data_fields, data_items_to_decoders) + + # Preprocess + example = problem.preprocess_examples(example, tf.estimator.ModeKeys.PREDICT, + hparams) + example = cast_int64_to_int32(example) + + # 4-D inputs and space ids + constants = {} + constants["target_space_id"] = tf.constant( + problem.get_hparams().target_space_id) + constants["problem_choice"] = tf.constant(0) + if problem.has_inputs: + while len(example["inputs"].get_shape()) != 4: + example["inputs"] = tf.expand_dims(example["inputs"], axis=-1) + constants["input_space_id"] = tf.constant( + problem.get_hparams().input_space_id) + example.pop("targets") + else: + while len(example["targets"].get_shape()) != 4: + example["targets"] = tf.expand_dims(example["targets"], axis=-1) + + features = constants + features.update(example) + + return tf.estimator.export.ServingInputReceiver( + features=features, receiver_tensors=example) diff --git a/tensor2tensor/utils/model_builder.py b/tensor2tensor/utils/model_builder.py index a0d362035..4a4717bd4 100644 --- a/tensor2tensor/utils/model_builder.py +++ b/tensor2tensor/utils/model_builder.py @@ -182,7 +182,17 @@ def nth_model(n): "problem_choice": batched_problem_choice, } _del_dict_nones(predictions) - return tf.estimator.EstimatorSpec(mode, predictions=predictions) + + export_out = {"outputs": predictions["outputs"]} + if "scores" in predictions: + export_out["scores"] = predictions["scores"] + + return tf.estimator.EstimatorSpec( + mode, + predictions=predictions, + export_outputs={ + "output": tf.estimator.export.PredictOutput(export_out) + }) total_loss, logits = model_output diff --git a/tensor2tensor/utils/trainer_utils.py b/tensor2tensor/utils/trainer_utils.py index 50cfcc5d0..cec1b444d 100644 --- a/tensor2tensor/utils/trainer_utils.py +++ b/tensor2tensor/utils/trainer_utils.py @@ -46,6 +46,8 @@ "If True, logs the contents of the registry and exits.") flags.DEFINE_bool("tfdbg", False, "If True, use the TF debugger CLI on train/eval.") +flags.DEFINE_bool("export_saved_model", False, + "Whether to export a SavedModel for serving.") flags.DEFINE_bool("dbgprofile", False, "If True, record the timeline for chrome://tracing/.") flags.DEFINE_string("model", "", "Which model to use.") @@ -131,6 +133,7 @@ def create_experiment(data_dir, model_name, train_steps, eval_steps, hparams, model_name=model_name, hparams=hparams, run_config=run_config) + train_monitors = [] eval_hooks = [] if FLAGS.tfdbg: @@ -146,6 +149,15 @@ def create_experiment(data_dir, model_name, train_steps, eval_steps, hparams, show_dataflow=True, show_memory=True, )) + + optional_kwargs = {} + if FLAGS.export_saved_model: + assert len(hparams.problem_instances) == 1 + problem = hparams.problem_instances[0] + optional_kwargs["export_strategies"] = [ + make_export_strategy(problem, hparams) + ] + return tf.contrib.learn.Experiment( estimator=estimator, train_input_fn=input_fns[tf.estimator.ModeKeys.TRAIN], @@ -154,7 +166,13 @@ def create_experiment(data_dir, model_name, train_steps, eval_steps, hparams, eval_steps=eval_steps, min_eval_frequency=FLAGS.local_eval_frequency, train_monitors=train_monitors, - eval_hooks=eval_hooks) + eval_hooks=eval_hooks, + **optional_kwargs) + + +def make_export_strategy(problem, hparams): + return tf.contrib.learn.make_export_strategy( + lambda: data_reader.serving_input_fn(problem, hparams), as_text=True) def create_experiment_components(data_dir, model_name, hparams, run_config): @@ -358,10 +376,11 @@ def run(data_dir, model, output_dir, train_steps, eval_steps, schedule): exp.evaluate(delay_secs=0) else: # Perform distributed training/evaluation. - learn_runner.run(experiment_fn=exp_fn, - schedule=schedule, - run_config=run_config, - hparams=hparams) + learn_runner.run( + experiment_fn=exp_fn, + schedule=schedule, + run_config=run_config, + hparams=hparams) def validate_flags(): diff --git a/tensor2tensor/utils/trainer_utils_test.py b/tensor2tensor/utils/trainer_utils_test.py index 1a971ac0c..5e9e31031 100644 --- a/tensor2tensor/utils/trainer_utils_test.py +++ b/tensor2tensor/utils/trainer_utils_test.py @@ -90,9 +90,9 @@ def testHParamsImported(self): def testSingleStep(self): model_name = "transformer" - FLAGS.worker_job = "/job:localhost" data_dir = TrainerUtilsTest.data_dir hparams = trainer_utils.create_hparams("transformer_test", data_dir) + hparams = trainer_utils.add_problem_hparams(hparams, FLAGS.problems) exp = trainer_utils.create_experiment( data_dir=data_dir, model_name=model_name, @@ -110,7 +110,6 @@ def testSingleEvalStepRawSession(self): model_name = "transformer" FLAGS.hparams_set = "transformer_test" FLAGS.problems = "tiny_algo" - FLAGS.worker_job = "/job:localhost" data_dir = "/tmp" # Used only when a vocab file or such like is needed. # Create the problem object, hparams, placeholders, features dict. From 620d6a541478d73b93db33f6c75c5d837523f8d0 Mon Sep 17 00:00:00 2001 From: Ryan Sepassi Date: Tue, 19 Sep 2017 23:06:42 -0700 Subject: [PATCH 30/39] Add Travis build shield to README PiperOrigin-RevId: 169354689 --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index af9778725..e37db796d 100644 --- a/README.md +++ b/README.md @@ -8,6 +8,7 @@ Issues](https://img.shields.io/github/issues/tensorflow/tensor2tensor.svg)](http welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg)](CONTRIBUTING.md) [![Gitter](https://img.shields.io/gitter/room/nwjs/nw.js.svg)](https://gitter.im/tensor2tensor/Lobby) [![License](https://img.shields.io/badge/License-Apache%202.0-brightgreen.svg)](https://opensource.org/licenses/Apache-2.0) +[![Travis](https://img.shields.io/travis/tensorflow/tensor2tensor.svg)]() [T2T](https://github.com/tensorflow/tensor2tensor) is a modular and extensible library and binaries for supervised learning with TensorFlow and with support From 4280f4402ff68213e2d04b502f4b404a8fb0acfb Mon Sep 17 00:00:00 2001 From: Ryan Sepassi Date: Wed, 20 Sep 2017 10:26:31 -0700 Subject: [PATCH 31/39] Rm all refs to local_run in favor of train_and_evaluate PiperOrigin-RevId: 169412526 --- tensor2tensor/bin/t2t-decoder | 5 +-- tensor2tensor/bin/t2t-trainer | 2 +- tensor2tensor/utils/devices.py | 2 +- tensor2tensor/utils/trainer_utils.py | 38 +++++++---------------- tensor2tensor/utils/trainer_utils_test.py | 2 +- 5 files changed, 18 insertions(+), 31 deletions(-) diff --git a/tensor2tensor/bin/t2t-decoder b/tensor2tensor/bin/t2t-decoder index 5b5b09555..d2fe41f2f 100644 --- a/tensor2tensor/bin/t2t-decoder +++ b/tensor2tensor/bin/t2t-decoder @@ -60,7 +60,8 @@ flags.DEFINE_string("t2t_usr_dir", "", "e.g. @registry.register_model calls, that will then be " "available to the t2t-decoder.") flags.DEFINE_string("master", "", "Address of TensorFlow master.") -flags.DEFINE_string("schedule", "local_run", "Must be local_run for decoding.") +flags.DEFINE_string("schedule", "train_and_evaluate", + "Must be train_and_evaluate for decoding.") def main(_): @@ -68,7 +69,7 @@ def main(_): usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) trainer_utils.log_registry() trainer_utils.validate_flags() - assert FLAGS.schedule == "local_run" + assert FLAGS.schedule == "train_and_evaluate" data_dir = os.path.expanduser(FLAGS.data_dir) output_dir = os.path.expanduser(FLAGS.output_dir) diff --git a/tensor2tensor/bin/t2t-trainer b/tensor2tensor/bin/t2t-trainer index 5defbb465..c986522f3 100644 --- a/tensor2tensor/bin/t2t-trainer +++ b/tensor2tensor/bin/t2t-trainer @@ -57,7 +57,7 @@ flags.DEFINE_bool("generate_data", False, "Generate data before training?") flags.DEFINE_integer("eval_steps", 10, "Number of steps in evaluation.") flags.DEFINE_string("output_dir", "", "Base output directory for run.") flags.DEFINE_string("master", "", "Address of TensorFlow master.") -flags.DEFINE_string("schedule", "local_run", +flags.DEFINE_string("schedule", "train_and_evaluate", "Method of tf.contrib.learn.Experiment to run.") diff --git a/tensor2tensor/utils/devices.py b/tensor2tensor/utils/devices.py index d04b73563..d532b6d5f 100644 --- a/tensor2tensor/utils/devices.py +++ b/tensor2tensor/utils/devices.py @@ -109,7 +109,7 @@ def _replica_device_setter(worker_device): ps_tasks=FLAGS.ps_replicas, ps_device=FLAGS.ps_job + "/GPU:0" if FLAGS.ps_gpu > 0 else FLAGS.ps_job) - if FLAGS.schedule == "local_run": + if FLAGS.schedule == "train_and_evaluate": assert not FLAGS.sync datashard_devices = ["gpu:%d" % d for d in _gpu_order(FLAGS.worker_gpu)] if FLAGS.locally_shard_to_cpu or FLAGS.worker_gpu < 1: diff --git a/tensor2tensor/utils/trainer_utils.py b/tensor2tensor/utils/trainer_utils.py index cec1b444d..69d981f7c 100644 --- a/tensor2tensor/utils/trainer_utils.py +++ b/tensor2tensor/utils/trainer_utils.py @@ -334,11 +334,6 @@ def create_run_config(output_dir): def run(data_dir, model, output_dir, train_steps, eval_steps, schedule): """Runs an Estimator locally or distributed. - This function chooses one of two paths to execute: - - 1. Running locally if schedule=="local_run". - 3. Distributed training/evaluation otherwise. - Args: data_dir: The directory the data can be found in. model: The name of the model to use. @@ -358,29 +353,15 @@ def run(data_dir, model, output_dir, train_steps, eval_steps, schedule): run_config = create_run_config(output_dir) hparams = create_hparams( FLAGS.hparams_set, data_dir, passed_hparams=FLAGS.hparams) - if FLAGS.worker_id == 0 and schedule in ["local_run", "train"]: + + if is_chief(): save_metadata(output_dir, hparams) - if schedule == "local_run": - # Run the local demo. - - exp = exp_fn(run_config, hparams) - if exp.train_steps > 0 and exp.eval_steps > 0: - tf.logging.info("Performing local training and evaluation.") - exp.train_and_evaluate() - elif exp.train_steps > 0: - tf.logging.info("Performing local training.") - exp.train() - elif exp.eval_steps > 0: - tf.logging.info("Performing local evaluation.") - exp.evaluate(delay_secs=0) - else: - # Perform distributed training/evaluation. - learn_runner.run( - experiment_fn=exp_fn, - schedule=schedule, - run_config=run_config, - hparams=hparams) + learn_runner.run( + experiment_fn=exp_fn, + schedule=schedule, + run_config=run_config, + hparams=hparams) def validate_flags(): @@ -398,6 +379,11 @@ def validate_flags(): "Using default output_dir=%s.", FLAGS.output_dir) +def is_chief(): + schedules = ["train", "train_and_evaluate"] + return FLAGS.worker_id == 0 and FLAGS.schedule in schedules + + def session_config(): """The TensorFlow Session config to use.""" graph_options = tf.GraphOptions(optimizer_options=tf.OptimizerOptions( diff --git a/tensor2tensor/utils/trainer_utils_test.py b/tensor2tensor/utils/trainer_utils_test.py index 5e9e31031..16a8149f4 100644 --- a/tensor2tensor/utils/trainer_utils_test.py +++ b/tensor2tensor/utils/trainer_utils_test.py @@ -36,7 +36,7 @@ flags = tf.flags FLAGS = tf.flags.FLAGS -flags.DEFINE_string("schedule", "local_run", "") +flags.DEFINE_string("schedule", "train_and_evaluate", "") flags.DEFINE_integer("eval_steps", 10, "Number of steps in evaluation.") flags.DEFINE_string("master", "", "Address of TensorFlow master.") flags.DEFINE_string("output_dir", "", "Base output directory for run.") From 0841742e0b88640999312ecd23a454a49bc04412 Mon Sep 17 00:00:00 2001 From: T2T Team Date: Wed, 20 Sep 2017 17:20:28 -0700 Subject: [PATCH 32/39] Support class modality in fast decoding. PiperOrigin-RevId: 169476287 --- tensor2tensor/models/transformer.py | 9 +++++---- tensor2tensor/utils/t2t_model.py | 6 +++--- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/tensor2tensor/models/transformer.py b/tensor2tensor/models/transformer.py index b4f083eca..7d4ce27be 100644 --- a/tensor2tensor/models/transformer.py +++ b/tensor2tensor/models/transformer.py @@ -172,8 +172,11 @@ def _greedy_infer( inputs = features["inputs"] batch_size = tf.shape(inputs)[0] - # TODO(llion): Support class modality - decode_length = tf.shape(inputs)[1] + decode_length + target_modality = self._problem_hparams.target_modality + if t2t_model.is_class_modality(target_modality): + decode_length = 1 + else: + decode_length = tf.shape(inputs)[1] + decode_length # TODO(llion): Clean up this reshaping logic. inputs = tf.expand_dims(inputs, axis=1) @@ -194,8 +197,6 @@ def _greedy_infer( timing_signal = common_attention.get_timing_signal_1d( decode_length + 1, hparams.hidden_size) - target_modality = self._problem_hparams.target_modality - def preprocess_targets(targets, i): """Performs preprocessing steps on the targets to prepare for the decoder. diff --git a/tensor2tensor/utils/t2t_model.py b/tensor2tensor/utils/t2t_model.py index 6d38a5ba8..3fc110ebf 100644 --- a/tensor2tensor/utils/t2t_model.py +++ b/tensor2tensor/utils/t2t_model.py @@ -44,7 +44,7 @@ def fn_with_timing(*args, **kwargs): return fn_with_timing -def _is_class_modality(mod): +def is_class_modality(mod): # TODO(lukaszkaiser): should be based on type, like CLASS_LABEL, not string. prefix = "class_label_modality_" if len(mod.name) < len(prefix): @@ -198,7 +198,7 @@ def infer(self, # generated sequences, than to see the most likely sequence repeatedly. beam_size = 1 self._hparams.sampling_method = "random" - if _is_class_modality( + if is_class_modality( self._hparams.problems[self._problem_idx].target_modality): beam_size = 1 # No use to run beam-search for a single class. if beam_size == 1: @@ -371,7 +371,7 @@ def infer_step(recent_output, recent_logits, unused_loss): initial_output = tf.slice(initial_output, [0, 0, 0, 0], tf.shape(initial_output)) target_modality = self._hparams.problems[self._problem_idx].target_modality - if _is_class_modality(target_modality): + if is_class_modality(target_modality): decode_length = 1 else: decode_length = tf.shape(features["inputs"])[1] + decode_length From 09f1f17d1f673a57cbf69dbb65176a34be426f6b Mon Sep 17 00:00:00 2001 From: Ryan Sepassi Date: Wed, 20 Sep 2017 18:18:33 -0700 Subject: [PATCH 33/39] Minimally port remaining problems to Problem class PiperOrigin-RevId: 169482949 --- tensor2tensor/data_generators/all_problems.py | 1 + tensor2tensor/data_generators/problem.py | 5 +- .../data_generators/problem_hparams.py | 510 ++++++------------ .../data_generators/problem_hparams_test.py | 50 -- tensor2tensor/models/bluenet_test.py | 3 +- tensor2tensor/models/bytenet_test.py | 3 +- tensor2tensor/models/lstm_test.py | 6 +- tensor2tensor/models/neural_gpu_test.py | 2 +- .../models/transformer_revnet_test.py | 3 +- tensor2tensor/models/transformer_test.py | 3 +- tensor2tensor/models/xception_test.py | 3 +- tensor2tensor/utils/beam_search.py | 1 - tensor2tensor/utils/data_reader.py | 81 +-- tensor2tensor/utils/data_reader_test.py | 2 +- tensor2tensor/utils/trainer_utils.py | 25 +- 15 files changed, 198 insertions(+), 500 deletions(-) delete mode 100644 tensor2tensor/data_generators/problem_hparams_test.py diff --git a/tensor2tensor/data_generators/all_problems.py b/tensor2tensor/data_generators/all_problems.py index 52354704d..5877b541e 100644 --- a/tensor2tensor/data_generators/all_problems.py +++ b/tensor2tensor/data_generators/all_problems.py @@ -29,6 +29,7 @@ from tensor2tensor.data_generators import image from tensor2tensor.data_generators import imdb from tensor2tensor.data_generators import lm1b +from tensor2tensor.data_generators import problem_hparams from tensor2tensor.data_generators import ptb from tensor2tensor.data_generators import snli from tensor2tensor.data_generators import wiki diff --git a/tensor2tensor/data_generators/problem.py b/tensor2tensor/data_generators/problem.py index a006d5627..4ada1d212 100644 --- a/tensor2tensor/data_generators/problem.py +++ b/tensor2tensor/data_generators/problem.py @@ -257,10 +257,9 @@ def get_hparams(self, model_hparams=None): if self._hparams is not None: return self._hparams - assert model_hparams is not None - if self._encoders is None: - self.get_feature_encoders(model_hparams.data_dir) + data_dir = (model_hparams and model_hparams.data_dir) or None + self.get_feature_encoders(data_dir) hp = _default_hparams() ret = self.hparams(hp, model_hparams) diff --git a/tensor2tensor/data_generators/problem_hparams.py b/tensor2tensor/data_generators/problem_hparams.py index 147fc7538..88212b0db 100644 --- a/tensor2tensor/data_generators/problem_hparams.py +++ b/tensor2tensor/data_generators/problem_hparams.py @@ -24,345 +24,185 @@ # Dependency imports +from tensor2tensor.data_generators import problem from tensor2tensor.data_generators import text_encoder from tensor2tensor.layers import modalities # pylint: disable=unused-import from tensor2tensor.utils import registry import tensorflow as tf - -def problem_hparams(problem_name, model_hparams): - """Generate problem hyperparameters based on problem name. - - Args: - problem_name: a string - model_hparams: a tf.contrib.training.HParams - - Returns: - a tf.contrib.training.HParams - """ - base_name, was_reversed, was_copy = parse_problem_name(problem_name) - p = _lookup_problem_hparams_fn(base_name)(model_hparams) - if was_reversed: - _reverse_problem_hparams(p) - if was_copy: - _copy_problem_hparams(p) - return p - - -def parse_problem_name(problem_name): - """Determines if problem_name specifies a copy and/or reversal. - - Args: - problem_name: A string containing a single problem name from FLAGS.problems. - - Returns: - base_name: A string with the base problem name. - was_reversed: A boolean. - was_copy: A boolean. - """ - # Recursively strip tags until we reach a base name. - if problem_name.endswith("_rev"): - base, _, was_copy = parse_problem_name(problem_name[:-4]) - return base, True, was_copy - elif problem_name.endswith("_copy"): - base, was_reversed, _ = parse_problem_name(problem_name[:-5]) - return base, was_reversed, True - return problem_name, False, False - - -def _lookup_problem_hparams_fn(name): - if name not in PROBLEM_HPARAMS_MAP: - map_str = "* " + "\n* ".join(sorted(PROBLEM_HPARAMS_MAP.keys())) - error_msg = "%s not in the supported set of problems:\n%s" % (name, map_str) - raise LookupError(error_msg) - return PROBLEM_HPARAMS_MAP.get(name) - - -def _copy_problem_hparams(p_hparams): - """Use input modality, vocab, and space id for target.""" - p = p_hparams - # Duplicate input modality. - p.target_modality = p.input_modality["inputs"] - # Duplicate input vocabulary. - p.vocabulary["targets"] = p.vocabulary["inputs"] - # Duplicate input space ids. - p.target_space_id = p.input_space_id - # Mark that p was reversed. - p.was_copy = True - - -def _reverse_problem_hparams(p_hparams): - """Swap input/output modalities, vocab, and space ids.""" - p = p_hparams - - # Swap modalities. - input_modality = p.input_modality["inputs"] - target_modality = p.target_modality - p.input_modality["inputs"] = target_modality - p.target_modality = input_modality - - # Swap vocabularies. - input_vocabulary = p.vocabulary["inputs"] - target_vocabulary = p.vocabulary["targets"] - p.vocabulary["inputs"] = target_vocabulary - p.vocabulary["targets"] = input_vocabulary - - # Swap input/target space ids. - input_space_id = p.input_space_id - target_space_id = p.target_space_id - p.input_space_id = target_space_id - p.target_space_id = input_space_id - - # Mark that p was reversed. - p.was_reversed = True - - -def default_problem_hparams(): - """A set of basic model hyperparameters.""" - return tf.contrib.training.HParams( - # Use this parameter to get comparable perplexity numbers with different - # tokenizations. This value should be set to the ratio of the number of - # tokens in the test set according to the tokeization used to the number - # of tokens in the test set in the "official" tokenization. For example, - # if we are using a word-piece based model and we want to compute - # per-word perplexity, then we set loss_multiplier to the number of - # wordpieces per word in the test set. - loss_multiplier=1.0, - - # Use this parameter to allow for larger sequences in the batch. Without - # the use of this parameter, the size of the inner two dimensions will be - # used to judge the sequence length. - batch_size_multiplier=1, - - # To make queues of the right capacity, it's good to know the maximal - # expected batch size, as it can vary a lot. It only affects performance - # of input readers and memory use. The defaults should be safe and fast, - # but decrease if your reader uses a lot of memory and increase if slow. - max_expected_batch_size_per_shard=64, - - # Modalities used to map from input features to a space compatible with - # chosen model architecture. One modality spec (which is a 2-tuple, - # (modality_full_name, vocab_size)) per feature key. modality_full_name is - # a string type:name, e.g. class_label:2d. Leaving off the name uses the - # default modality for that type (e.g. class_label == - # class_label:default). - input_modality={}, - - # Modality used to map from hidden representation to the target space. - # Specified as a modality spec, a 2-tuple described above. - target_modality=None, - - # Identifiers used to tell the model which input/target space will be - # expected. For example, it can tell that we expect French as characters - # as output, or Spanish as sound. An integer with the following semantics: - # 0: Generic / unknown output space (default) - # 1: Image labels - # 2: English characters - # 3: English tokens - # 4: English bpe tokens - # 5: French characters - # 6: French tokens - # 7: German characters - # 8: German tokens - # 9: German bpe tokens - # 10: Digit cipher lexicon 0 - # 11: Digit cipher lexicon 1 - # 12: Audio waveform domain - # 13: Audio spectral domain - # 14: Parse characters - # 15: Parse tokens - # 16: Chinese tokens - # 17: Icelandic characters - # 18: Icelandic tokens - # 19: Icelandic parse tokens - # 20: Macedonian tokens - # 21: Czech tokens - # 22: Czech characters - # Add more above if needed. - input_space_id=0, - target_space_id=0, - - # Vocabulary per feature key. - # a vocabulary converts to/from human-readable strings. - # E.g. {"inputs": text_encoder.ByteTextEncoder(), - # "targets": text_encoder.SubwordTextEncoder("vocab_filename.txt")} - vocabulary={ - "inputs": text_encoder.TextEncoder(), - "targets": text_encoder.TextEncoder() - }, - - # This is a marker to keep track if the problem was reversed or copied. - # Only set automatically, do not override the default. - # - # These tags can be combined in order to perform copies of the input or - # the targets. For instance `problem_copy` will copy the inputs, but - # `problem_rev_copy` will copy the targets. - was_reversed=False, - was_copy=False,) - - -def test_problem_hparams(unused_model_hparams, input_vocab_size, - target_vocab_size): +# TODO(rsepassi): Merge these problems with their data generators. Currenlty +# they only implement the hparams. + + +class AudioTimitProblem(problem.Problem): + """Base class for TIMIT problems.""" + + def example_reading_spec(self): + data_fields = { + "inputs": tf.VarLenFeature(tf.int64), + "audio/sample_count": tf.FixedLenFeature((), tf.int64), + "audio/sample_width": tf.FixedLenFeature((), tf.int64), + "targets": tf.VarLenFeature(tf.int64), + } + return data_fields, None + + def preprocess_examples(self, examples, mode, hparams): + examples = super(AudioTimitProblem, self).preprocess_examples( + examples, mode, hparams) + # Reshape audio to proper shape + sample_count = tf.to_int32(examples.pop("audio/sample_count")) + sample_width = tf.to_int32(examples.pop("audio/sample_width")) + channel_count = 1 + examples["inputs"] = tf.reshape(examples["inputs"], + [sample_count, sample_width, channel_count]) + return examples + + +@registry.register_problem +class AudioTimitCharactersTune(AudioTimitProblem): + """TIMIT to characters.""" + + def feature_encoders(self, _): + return { + "inputs": text_encoder.TextEncoder(), + "targets": text_encoder.ByteTextEncoder(), + } + + def hparams(self, defaults, model_hparams): + hp = defaults + hp.input_modality = { + "inputs": (registry.Modalities.AUDIO, None), + } + hp.target_modality = (registry.Modalities.SYMBOL, 256) + + +@registry.register_problem +class AudioTimitTokens8kTune(AudioTimitProblem): + """TIMIT to tokens.""" + + @property + def target_vocab_size(self): + return 2**13 # 8192 + + def feature_encoders(self, data_dir): + vocab_filename = os.path.join(data_dir, + "vocab.endefr.%d" % self.target_vocab_size) + subtokenizer = text_encoder.SubwordTextEncoder(vocab_filename) + return { + "inputs": text_encoder.TextEncoder(), + "targets": subtokenizer, + } + + def hparams(self, defaults, model_hparams): + hp = defaults + hp.input_modality = { + "inputs": (registry.Modalities.AUDIO, None), + } + hp.target_modality = (registry.Modalities.SYMBOL, + self.get_feature_encoders()["targets"].vocab_size) + hp.batch_size_multiplier = 256 + hp.loss_multiplier = 2.0 + hp.input_space_id = 13 + hp.target_space_id = 3 + + +@registry.register_problem +class AudioTimitTokens8kTest(AudioTimitTokens8kTune): + """TIMIT to tokens.""" + pass + + +@registry.register_problem +class ParsingEnglishPtb8k(problem.Problem): + """Parsing.""" + + @property + def target_vocab_size(self): + return 2**13 # 8192 + + def feature_encoders(self, data_dir): + vocab_filename = os.path.join(data_dir, + "vocab.endefr.%d" % self.target_vocab_size) + subtokenizer = text_encoder.SubwordTextEncoder(vocab_filename) + return { + "inputs": subtokenizer, + "targets": subtokenizer, + } + + def hparams(self, defaults, model_hparams): + hp = defaults + hp.input_modality = { + "inputs": (registry.Modalities.SYMBOL, + self.get_feature_encoders()["inputs"].vocab_size), + } + hp.target_modality = (registry.Modalities.SYMBOL, + self.get_feature_encoders()["targets"].vocab_size) + hp.batch_size_multiplier = 256 + hp.loss_multiplier = 2.0 + hp.input_space_id = 3 + hp.target_space_id = 15 + + +@registry.register_problem +class ParsingEnglishPtb16k(problem.Problem): + """Parsing.""" + + @property + def vocab_prefix(self): + return "wsj" + + @property + def inputs_target_vocab_size(self): + return 2**9 # 512 + + @property + def targets_target_vocab_size(self): + return 2**14 # 16384 + + def feature_encoders(self, data_dir): + source_vocab_filename = os.path.join( + data_dir, + self.vocab_prefix + "_source.vocab.%d" % self.inputs_target_vocab_size) + target_vocab_filename = os.path.join( + data_dir, + self.vocab_prefix + "_target.vocab.%d" % self.targets_target_vocab_size) + source_subtokenizer = text_encoder.SubwordTextEncoder(source_vocab_filename) + target_subtokenizer = text_encoder.SubwordTextEncoder(target_vocab_filename) + return { + "inputs": source_subtokenizer, + "targets": target_subtokenizer, + } + + def hparams(self, defaults, model_hparams): + hp = defaults + hp.input_modality = { + "inputs": (registry.Modalities.SYMBOL, + self.get_feature_encoders()["inputs"].vocab_size), + } + hp.target_modality = (registry.Modalities.SYMBOL, + self.get_feature_encoders()["targets"].vocab_size) + hp.input_space_id = 3 + hp.target_space_id = 15 + + +class TestProblem(problem.Problem): + """Test problem.""" + + def __init__(self, input_vocab_size, target_vocab_size): + super(TestProblem, self).__init__(False, False) + self.input_vocab_size = input_vocab_size + self.target_vocab_size = target_vocab_size + + def hparams(self, defaults, model_hparams): + hp = defaults + hp.input_modality = { + "inputs": (registry.Modalities.SYMBOL, self.input_vocab_size) + } + hp.target_modality = (registry.Modalities.SYMBOL, self.target_vocab_size) + + +def test_problem_hparams(input_vocab_size=None, target_vocab_size=None): """Problem hparams for testing model bodies.""" - p = default_problem_hparams() - p.input_modality = {"inputs": (registry.Modalities.SYMBOL, input_vocab_size)} - p.target_modality = (registry.Modalities.SYMBOL, target_vocab_size) - p.vocabulary = { - "inputs": text_encoder.TextEncoder(), - "targets": text_encoder.TextEncoder() - } - return p - - -def audio_timit_characters(unused_model_hparams): - """English audio transcription benchmark.""" - p = default_problem_hparams() - p.input_modality = { - "inputs": (registry.Modalities.AUDIO, None), - } - p.target_modality = (registry.Modalities.SYMBOL, 256) - p.vocabulary = { - "inputs": text_encoder.TextEncoder(), - "targets": text_encoder.ByteTextEncoder(), - } - p.batch_size_multiplier = 256 - p.loss_multiplier = 2.0 - p.input_space_id = 12 - p.target_space_id = 2 - return p - - -def audio_timit_tokens(model_hparams, wrong_vocab_size): - """English audio transcription benchmark. - - Args: - model_hparams: a tf.contrib.training.HParams - wrong_vocab_size: a number used in the filename indicating the approximate - vocabulary size. This is not to be confused with the actual vocabulary - size. - Returns: - a tf.contrib.training.HParams - """ - p = default_problem_hparams() - # This vocab file must be present within the data directory. - vocab_filename = os.path.join(model_hparams.data_dir, - "vocab.endefr.%d" % wrong_vocab_size) - subtokenizer = text_encoder.SubwordTextEncoder(vocab_filename) - p.input_modality = { - "inputs": (registry.Modalities.AUDIO, None), - } - p.target_modality = (registry.Modalities.SYMBOL, subtokenizer.vocab_size) - p.vocabulary = { - "inputs": text_encoder.TextEncoder(), - "targets": subtokenizer, - } - p.batch_size_multiplier = 256 - p.loss_multiplier = 2.0 - p.input_space_id = 13 - p.target_space_id = 3 - return p - - -def wmt_parsing_characters(model_hparams): - """English to parse tree translation benchmark.""" - del model_hparams # Unused. - p = default_problem_hparams() - p.input_modality = {"inputs": (registry.Modalities.SYMBOL, 256)} - p.target_modality = (registry.Modalities.SYMBOL, 256) - p.vocabulary = { - "inputs": text_encoder.ByteTextEncoder(), - "targets": text_encoder.ByteTextEncoder(), - } - p.loss_multiplier = 2.0 - p.input_space_id = 2 - p.target_space_id = 14 - return p - - -def wmt_parsing_tokens(model_hparams, wrong_vocab_size): - """English to parse tree translation benchmark. - - Args: - model_hparams: a tf.contrib.training.HParams - wrong_vocab_size: a number used in the filename indicating the approximate - vocabulary size. This is not to be confused with the actual vocabulary - size. - Returns: - a tf.contrib.training.HParams - """ - p = default_problem_hparams() - # This vocab file must be present within the data directory. - vocab_filename = os.path.join(model_hparams.data_dir, - "vocab.endefr.%d" % wrong_vocab_size) - subtokenizer = text_encoder.SubwordTextEncoder(vocab_filename) - p.input_modality = { - "inputs": (registry.Modalities.SYMBOL, subtokenizer.vocab_size) - } - p.target_modality = (registry.Modalities.SYMBOL, subtokenizer.vocab_size) - p.vocabulary = { - "inputs": subtokenizer, - "targets": subtokenizer, - } - p.input_space_id = 3 - p.target_space_id = 15 - return p - - -def wsj_parsing_tokens(model_hparams, prefix, wrong_source_vocab_size, - wrong_target_vocab_size): - """English to parse tree translation benchmark. - - Args: - model_hparams: a tf.contrib.training.HParams - prefix: name to use as prefix for vocabulary files. - wrong_source_vocab_size: a number used in the filename indicating the - approximate vocabulary size. This is not to be confused with the actual - vocabulary size. - wrong_target_vocab_size: a number used in the filename indicating the - approximate target vocabulary size. This is not to be confused with the - actual target vocabulary size. - Returns: - a tf.contrib.training.HParams - """ - p = default_problem_hparams() - # This vocab file must be present within the data directory. - source_vocab_filename = os.path.join( - model_hparams.data_dir, - prefix + "_source.vocab.%d" % wrong_source_vocab_size) - target_vocab_filename = os.path.join( - model_hparams.data_dir, - prefix + "_target.vocab.%d" % wrong_target_vocab_size) - source_subtokenizer = text_encoder.SubwordTextEncoder(source_vocab_filename) - target_subtokenizer = text_encoder.SubwordTextEncoder(target_vocab_filename) - p.input_modality = { - "inputs": (registry.Modalities.SYMBOL, source_subtokenizer.vocab_size) - } - p.target_modality = (registry.Modalities.SYMBOL, - target_subtokenizer.vocab_size) - p.vocabulary = { - "inputs": source_subtokenizer, - "targets": target_subtokenizer, - } - p.input_space_id = 3 - p.target_space_id = 15 - return p - - -# Dictionary of named hyperparameter settings for various problems. -# This is only accessed through the problem_hparams function below. -PROBLEM_HPARAMS_MAP = { - "audio_timit_characters_tune": - audio_timit_characters, - "audio_timit_characters_test": - audio_timit_characters, - "audio_timit_tokens_8k_tune": - lambda p: audio_timit_tokens(p, 2**13), - "audio_timit_tokens_8k_test": - lambda p: audio_timit_tokens(p, 2**13), - "parsing_english_ptb8k": - lambda p: wmt_parsing_tokens(p, 2**13), - "parsing_english_ptb16k": - lambda p: wsj_parsing_tokens( # pylint: disable=g-long-lambda - p, "wsj", 2**14, 2**9), -} + p = TestProblem(input_vocab_size, target_vocab_size) + return p.get_hparams() diff --git a/tensor2tensor/data_generators/problem_hparams_test.py b/tensor2tensor/data_generators/problem_hparams_test.py deleted file mode 100644 index df92919ef..000000000 --- a/tensor2tensor/data_generators/problem_hparams_test.py +++ /dev/null @@ -1,50 +0,0 @@ -# coding=utf-8 -# Copyright 2017 The Tensor2Tensor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for tensor2tensor.problem_hparams.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -# Dependency imports - -from tensor2tensor.data_generators import problem_hparams - -import tensorflow as tf - - -class ProblemHparamsTest(tf.test.TestCase): - - def testParseProblemName(self): - problem_name = "base" - self.assertEqual( - problem_hparams.parse_problem_name(problem_name), ("base", False, - False)) - problem_name = "base_rev" - self.assertEqual( - problem_hparams.parse_problem_name(problem_name), ("base", True, False)) - problem_name = "base_copy" - self.assertEqual( - problem_hparams.parse_problem_name(problem_name), ("base", False, True)) - problem_name = "base_copy_rev" - self.assertEqual( - problem_hparams.parse_problem_name(problem_name), ("base", True, True)) - problem_name = "base_rev_copy" - self.assertEqual( - problem_hparams.parse_problem_name(problem_name), ("base", True, True)) - - -if __name__ == "__main__": - tf.test.main() diff --git a/tensor2tensor/models/bluenet_test.py b/tensor2tensor/models/bluenet_test.py index d559fd953..daf87529e 100644 --- a/tensor2tensor/models/bluenet_test.py +++ b/tensor2tensor/models/bluenet_test.py @@ -36,8 +36,7 @@ def testBlueNet(self): x = np.random.random_integers(1, high=vocab_size - 1, size=(3, 5, 1, 1)) y = np.random.random_integers(1, high=vocab_size - 1, size=(3, 1, 1, 1)) hparams = bluenet.bluenet_tiny() - p_hparams = problem_hparams.test_problem_hparams(hparams, vocab_size, - vocab_size) + p_hparams = problem_hparams.test_problem_hparams(vocab_size, vocab_size) with self.test_session() as session: tf.train.get_or_create_global_step() features = { diff --git a/tensor2tensor/models/bytenet_test.py b/tensor2tensor/models/bytenet_test.py index 56f421153..f96d3b999 100644 --- a/tensor2tensor/models/bytenet_test.py +++ b/tensor2tensor/models/bytenet_test.py @@ -36,8 +36,7 @@ def testByteNet(self): x = np.random.random_integers(1, high=vocab_size - 1, size=(3, 5, 1, 1)) y = np.random.random_integers(1, high=vocab_size - 1, size=(3, 6, 1, 1)) hparams = bytenet.bytenet_base() - p_hparams = problem_hparams.test_problem_hparams(hparams, vocab_size, - vocab_size) + p_hparams = problem_hparams.test_problem_hparams(vocab_size, vocab_size) with self.test_session() as session: features = { "inputs": tf.constant(x, dtype=tf.int32), diff --git a/tensor2tensor/models/lstm_test.py b/tensor2tensor/models/lstm_test.py index c1190d016..0d4bc6d80 100644 --- a/tensor2tensor/models/lstm_test.py +++ b/tensor2tensor/models/lstm_test.py @@ -37,8 +37,7 @@ def testLSTMSeq2Seq(self): x = np.random.random_integers(1, high=vocab_size - 1, size=(3, 5, 1, 1)) y = np.random.random_integers(1, high=vocab_size - 1, size=(3, 6, 1, 1)) hparams = common_hparams.basic_params1() - p_hparams = problem_hparams.test_problem_hparams(hparams, vocab_size, - vocab_size) + p_hparams = problem_hparams.test_problem_hparams(vocab_size, vocab_size) with self.test_session() as session: features = { "inputs": tf.constant(x, dtype=tf.int32), @@ -58,8 +57,7 @@ def testLSTMSeq2SeqAttention(self): y = np.random.random_integers(1, high=vocab_size - 1, size=(3, 6, 1, 1)) hparams = lstm.lstm_attention() - p_hparams = problem_hparams.test_problem_hparams(hparams, vocab_size, - vocab_size) + p_hparams = problem_hparams.test_problem_hparams(vocab_size, vocab_size) x = tf.constant(x, dtype=tf.int32) x._shape = tf.TensorShape([None, None, 1, 1]) diff --git a/tensor2tensor/models/neural_gpu_test.py b/tensor2tensor/models/neural_gpu_test.py index 164623699..75149ddd5 100644 --- a/tensor2tensor/models/neural_gpu_test.py +++ b/tensor2tensor/models/neural_gpu_test.py @@ -39,7 +39,7 @@ def testNeuralGPU(self): target_length = input_length input_vocab_size = 9 target_vocab_size = 11 - p_hparams = problem_hparams.test_problem_hparams(hparams, input_vocab_size, + p_hparams = problem_hparams.test_problem_hparams(input_vocab_size, target_vocab_size) inputs = -1 + np.random.random_integers( input_vocab_size, size=(batch_size, input_length, 1, 1)) diff --git a/tensor2tensor/models/transformer_revnet_test.py b/tensor2tensor/models/transformer_revnet_test.py index f9bc8cfb2..f61b88b5b 100644 --- a/tensor2tensor/models/transformer_revnet_test.py +++ b/tensor2tensor/models/transformer_revnet_test.py @@ -46,8 +46,7 @@ def testTransformer(self): target_length = 7 vocab_size = 9 hparams = transformer_revnet_test() - p_hparams = problem_hparams.test_problem_hparams(hparams, vocab_size, - vocab_size) + p_hparams = problem_hparams.test_problem_hparams(vocab_size, vocab_size) hparams.problems = [p_hparams] inputs = -1 + np.random.random_integers( vocab_size, size=(batch_size, input_length, 1, 1)) diff --git a/tensor2tensor/models/transformer_test.py b/tensor2tensor/models/transformer_test.py index 04c527ac1..22848b249 100644 --- a/tensor2tensor/models/transformer_test.py +++ b/tensor2tensor/models/transformer_test.py @@ -44,8 +44,7 @@ def getModel(self, mode=tf.estimator.ModeKeys.TRAIN): hparams.num_heads = 1 hparams.layer_prepostprocess_dropout = 0.0 - p_hparams = problem_hparams.test_problem_hparams( - hparams, VOCAB_SIZE, VOCAB_SIZE) + p_hparams = problem_hparams.test_problem_hparams(VOCAB_SIZE, VOCAB_SIZE) hparams.problems = [p_hparams] inputs = -1 + np.random.random_integers( diff --git a/tensor2tensor/models/xception_test.py b/tensor2tensor/models/xception_test.py index eb4c6db20..9114fb781 100644 --- a/tensor2tensor/models/xception_test.py +++ b/tensor2tensor/models/xception_test.py @@ -36,8 +36,7 @@ def testXception(self): x = np.random.random_integers(1, high=vocab_size - 1, size=(3, 5, 1, 1)) y = np.random.random_integers(1, high=vocab_size - 1, size=(3, 1, 1, 1)) hparams = xception.xception_tiny() - p_hparams = problem_hparams.test_problem_hparams(hparams, vocab_size, - vocab_size) + p_hparams = problem_hparams.test_problem_hparams(vocab_size, vocab_size) with self.test_session() as session: features = { "inputs": tf.constant(x, dtype=tf.int32), diff --git a/tensor2tensor/utils/beam_search.py b/tensor2tensor/utils/beam_search.py index be6c28559..c5e8eb85e 100644 --- a/tensor2tensor/utils/beam_search.py +++ b/tensor2tensor/utils/beam_search.py @@ -107,7 +107,6 @@ def beam_search(symbols_to_logits_fn, eos_id=EOS_ID): """Beam search with length penalties. - Uses an interface specific to the sequence cnn models; Requires a function that can take the currently decoded sybmols and return the logits for the next symbol. The implementation is inspired by https://arxiv.org/abs/1609.08144. diff --git a/tensor2tensor/utils/data_reader.py b/tensor2tensor/utils/data_reader.py index 4b0541d31..08e01ccfb 100644 --- a/tensor2tensor/utils/data_reader.py +++ b/tensor2tensor/utils/data_reader.py @@ -29,8 +29,6 @@ from six.moves import xrange # pylint: disable=redefined-builtin from six.moves import zip # pylint: disable=redefined-builtin -from tensor2tensor.data_generators import problem_hparams -from tensor2tensor.data_generators.problem import preprocess_examples_common from tensor2tensor.utils import registry import tensorflow as tf @@ -128,25 +126,6 @@ def decode_record(record): return dataset -def preprocessing(examples, data_file_pattern): - """Preprocessing of examples.""" - # This function is for obsolete problems only, as we're porting them - # all to the Problem class and its preprocess_examples method. Don't add. - if "audio" in data_file_pattern: - # Reshape audio to proper shape - sample_count = tf.to_int32(examples.pop("audio/sample_count")) - sample_width = tf.to_int32(examples.pop("audio/sample_width")) - channel_count = 1 - examples["inputs"] = tf.reshape(examples["inputs"], - [sample_count, sample_width, channel_count]) - if "wsj" in data_file_pattern: - examples["inputs"] = tf.bitcast(examples["inputs"], tf.int32) - elif "a2q_20161229" in data_file_pattern: - # we forgot the EOS when we preprocessed this data. - examples["targets"] = tf.concat([examples["targets"], [1]], 0) - return examples - - def cast_int64_to_int32(features): f = {} for k, v in six.iteritems(features): @@ -182,54 +161,12 @@ def feature_placeholders(data_fields, data_items_to_decoders): return decoded_example -def default_example_reading_spec(data_file_pattern): - """Example reading spec for problem_hparams problems.""" - # This function is for problems that have yet to be ported to the new Problem - # API. Do not add here. - data_items_to_decoders = None - # Read from image TFRecords if the file has "image" in its name. - if data_file_pattern and "image" in data_file_pattern: - label_key = "image/class/label" - data_fields = { - "image/encoded": tf.FixedLenFeature((), tf.string), - "image/format": tf.FixedLenFeature((), tf.string), - label_key: tf.VarLenFeature(tf.int64) - } - data_items_to_decoders = { - "inputs": - tf.contrib.slim.tfexample_decoder.Image( - image_key="image/encoded", - format_key="image/format", - channels=1 if "mnist" in data_file_pattern else 3), - "targets": - tf.contrib.slim.tfexample_decoder.Tensor(label_key), - } - elif data_file_pattern and "audio" in data_file_pattern: - data_type = tf.int64 if "timit" in data_file_pattern else tf.float32 - data_fields = { - "inputs": tf.VarLenFeature(data_type), - "audio/sample_count": tf.FixedLenFeature((), tf.int64), - "audio/sample_width": tf.FixedLenFeature((), tf.int64), - "targets": tf.VarLenFeature(tf.int64), - } - else: - data_fields = { - "inputs": tf.VarLenFeature(tf.int64), - "targets": tf.VarLenFeature(tf.int64) - } - return data_fields, data_items_to_decoders - - def read_examples(problem, data_file_pattern, capacity, mode=tf.estimator.ModeKeys.TRAIN): """Create Dataset of Example for problem and data_file_pattern.""" - if problem is None: - data_fields, data_items_to_decoders = default_example_reading_spec( - data_file_pattern) - else: - data_fields, data_items_to_decoders = problem.example_reading_spec() + data_fields, data_items_to_decoders = problem.example_reading_spec() if data_file_pattern is None: # Create placeholders for input, rather than reading data from disk. @@ -272,7 +209,7 @@ def input_pipeline(problem, data_file_pattern, capacity, mode, hparams, # reading, parsing, and preprocessing. Use Problem.dataset instead. dataset = read_examples(problem, data_file_pattern, capacity, mode=mode) dataset = dataset.map( - lambda ex: _preprocess(ex, problem, data_file_pattern, hparams, mode), + lambda ex: _preprocess(ex, problem, hparams, mode), num_threads=num_threads) dataset = dataset.filter( lambda ex: example_valid_size(ex, batching_scheme["max_length"])) @@ -302,14 +239,9 @@ def input_pipeline(problem, data_file_pattern, capacity, mode, hparams, return batched_examples -def _preprocess(example, problem, data_file_pattern, hparams, mode): +def _preprocess(example, problem, hparams, mode): """Preprocessing for example.""" - if problem is None: - example = preprocess_examples_common(example, hparams, mode) - example = preprocessing(example, data_file_pattern) - else: - example = problem.preprocess_examples(example, mode, hparams) - + example = problem.preprocess_examples(example, mode, hparams) # We do not want int64s as they are not supported on GPUs. example = cast_int64_to_int32(example) @@ -527,10 +459,7 @@ def get_data_filepatterns(problems, data_dir, mode): """Return the location of a dataset for a given mode.""" datasets = [] for problem in problems.split("-"): - try: - problem = registry.problem(problem).dataset_filename() - except ValueError: - problem, _, _ = problem_hparams.parse_problem_name(problem) + problem = registry.problem(problem).dataset_filename() path = os.path.join(data_dir, problem) if mode == tf.estimator.ModeKeys.TRAIN: datasets.append("%s-train*" % path) diff --git a/tensor2tensor/utils/data_reader_test.py b/tensor2tensor/utils/data_reader_test.py index f03ce6da2..ff01cf07f 100644 --- a/tensor2tensor/utils/data_reader_test.py +++ b/tensor2tensor/utils/data_reader_test.py @@ -143,7 +143,7 @@ def testTrainEvalBehavior(self): def testPreprocess(self): dataset = data_reader.read_examples(self.problem, self.filepatterns[0], 32) examples = dataset.make_one_shot_iterator().get_next() - examples = data_reader._preprocess(examples, self.problem, None, None, None) + examples = data_reader._preprocess(examples, self.problem, None, None) with tf.train.MonitoredSession() as sess: ex_val = sess.run(examples) # problem.preprocess_examples has been run diff --git a/tensor2tensor/utils/trainer_utils.py b/tensor2tensor/utils/trainer_utils.py index 69d981f7c..09c86ca09 100644 --- a/tensor2tensor/utils/trainer_utils.py +++ b/tensor2tensor/utils/trainer_utils.py @@ -26,7 +26,6 @@ from tensor2tensor import models # pylint: disable=unused-import from tensor2tensor.data_generators import all_problems # pylint: disable=unused-import -from tensor2tensor.data_generators import problem_hparams from tensor2tensor.utils import data_reader from tensor2tensor.utils import decoding from tensor2tensor.utils import devices @@ -236,24 +235,12 @@ def add_problem_hparams(hparams, problems): try: problem = registry.problem(problem_name) except LookupError: - problem = None - - if problem is None: - try: - p_hparams = problem_hparams.problem_hparams(problem_name, hparams) - except LookupError: - # The problem is not in the set of registered Problems nor in the old - # set of problem_hparams. - all_problem_names = sorted( - list(problem_hparams.PROBLEM_HPARAMS_MAP) + - registry.list_problems()) - error_lines = [ - "%s not in the set of supported problems:" % problem_name - ] + all_problem_names - error_msg = "\n * ".join(error_lines) - raise LookupError(error_msg) - else: - p_hparams = problem.get_hparams(hparams) + all_problem_names = sorted(registry.list_problems()) + error_lines = ["%s not in the set of supported problems:" % problem_name + ] + all_problem_names + error_msg = "\n * ".join(error_lines) + raise LookupError(error_msg) + p_hparams = problem.get_hparams(hparams) hparams.problem_instances.append(problem) hparams.problems.append(p_hparams) From f191c7864623dc8d130b916411f9d9f866997cc4 Mon Sep 17 00:00:00 2001 From: Ryan Sepassi Date: Thu, 21 Sep 2017 09:37:39 -0700 Subject: [PATCH 34/39] Correct README for decoding PiperOrigin-RevId: 169554635 --- README.md | 3 +- docs/index.md | 9 +- docs/walkthrough.md | 182 +++++++++++++++++++++++++-- tensor2tensor/data_generators/wmt.py | 34 ++--- tensor2tensor/utils/decoding.py | 3 +- 5 files changed, 197 insertions(+), 34 deletions(-) diff --git a/README.md b/README.md index e37db796d..6ef815f4e 100644 --- a/README.md +++ b/README.md @@ -124,8 +124,7 @@ t2t-decoder \ --model=$MODEL \ --hparams_set=$HPARAMS \ --output_dir=$TRAIN_DIR \ - --decode_beam_size=$BEAM_SIZE \ - --decode_alpha=$ALPHA \ + --decode_hparams="beam_size=$BEAM_SIZE,alpha=$ALPHA" \ --decode_from_file=$DECODE_FILE cat $DECODE_FILE.$MODEL.$HPARAMS.beam$BEAM_SIZE.alpha$ALPHA.decodes diff --git a/docs/index.md b/docs/index.md index 9394809b3..3eb7f1c61 100644 --- a/docs/index.md +++ b/docs/index.md @@ -24,11 +24,6 @@ documentation, from basic tutorials to full code documentation. ## Deep Dive -* [Life of an Example](example_life.md): how all parts of T2T are connected and work together +* [Life of an Example](example_life.md): how all parts of T2T are connected and + work together * [Distributed Training](distributed_training.md) - -## Code documentation - -See our -[README](https://github.com/tensorflow/tensor2tensor/blob/master/README.md) -for now, code docs coming. diff --git a/docs/walkthrough.md b/docs/walkthrough.md index 57d7a03f4..6ef815f4e 100644 --- a/docs/walkthrough.md +++ b/docs/walkthrough.md @@ -1,4 +1,4 @@ -# T2T Install and Run Walkthrough +# T2T: Tensor2Tensor Transformers [![PyPI version](https://badge.fury.io/py/tensor2tensor.svg)](https://badge.fury.io/py/tensor2tensor) @@ -8,6 +8,26 @@ Issues](https://img.shields.io/github/issues/tensorflow/tensor2tensor.svg)](http welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg)](CONTRIBUTING.md) [![Gitter](https://img.shields.io/gitter/room/nwjs/nw.js.svg)](https://gitter.im/tensor2tensor/Lobby) [![License](https://img.shields.io/badge/License-Apache%202.0-brightgreen.svg)](https://opensource.org/licenses/Apache-2.0) +[![Travis](https://img.shields.io/travis/tensorflow/tensor2tensor.svg)]() + +[T2T](https://github.com/tensorflow/tensor2tensor) is a modular and extensible +library and binaries for supervised learning with TensorFlow and with support +for sequence tasks. It is actively used and maintained by researchers and +engineers within the Google Brain team. You can read more about Tensor2Tensor in +the recent [Google Research Blog post introducing +it](https://research.googleblog.com/2017/06/accelerating-deep-learning-research.html). + +We're eager to collaborate with you on extending T2T, so please feel +free to [open an issue on +GitHub](https://github.com/tensorflow/tensor2tensor/issues) or +send along a pull request to add your dataset or model. +See [our contribution +doc](CONTRIBUTING.md) for details and our [open +issues](https://github.com/tensorflow/tensor2tensor/issues). +You can chat with us and other users on +[Gitter](https://gitter.im/tensor2tensor/Lobby) and please join our +[Google Group](https://groups.google.com/forum/#!forum/tensor2tensor) to keep up +with T2T announcements. Here is a one-command version that installs tensor2tensor, downloads the data, trains an English-German translation model, and evaluates it: @@ -29,10 +49,28 @@ t2t-decoder \ --problems=translate_ende_wmt32k \ --model=transformer \ --hparams_set=transformer_base_single_gpu \ - --output_dir=~/t2t_train/base + --output_dir=~/t2t_train/base \ --decode_interactive ``` +See the [Walkthrough](#walkthrough) below for more details on each step. + +### Contents + +* [Walkthrough](#walkthrough) +* [Installation](#installation) +* [Features](#features) +* [T2T Overview](#t2t-overview) + * [Datasets](#datasets) + * [Problems and Modalities](#problems-and-modalities) + * [Models](#models) + * [Hyperparameter Sets](#hyperparameter-sets) + * [Trainer](#trainer) +* [Adding your own components](#adding-your-own-components) +* [Adding a dataset](#adding-a-dataset) + +--- + ## Walkthrough Here's a walkthrough training a good English-to-German translation @@ -80,16 +118,13 @@ echo "Goodbye world" >> $DECODE_FILE BEAM_SIZE=4 ALPHA=0.6 -t2t-trainer \ +t2t-decoder \ --data_dir=$DATA_DIR \ --problems=$PROBLEM \ --model=$MODEL \ --hparams_set=$HPARAMS \ --output_dir=$TRAIN_DIR \ - --train_steps=0 \ - --eval_steps=0 \ - --decode_beam_size=$BEAM_SIZE \ - --decode_alpha=$ALPHA \ + --decode_hparams="beam_size=$BEAM_SIZE,alpha=$ALPHA" \ --decode_from_file=$DECODE_FILE cat $DECODE_FILE.$MODEL.$HPARAMS.beam$BEAM_SIZE.alpha$ALPHA.decodes @@ -127,3 +162,136 @@ python -c "from tensor2tensor.models.transformer import Transformer" ``` --- + +## Features + +* Many state of the art and baseline models are built-in and new models can be + added easily (open an issue or pull request!). +* Many datasets across modalities - text, audio, image - available for + generation and use, and new ones can be added easily (open an issue or pull + request for public datasets!). +* Models can be used with any dataset and input mode (or even multiple); all + modality-specific processing (e.g. embedding lookups for text tokens) is done + with `Modality` objects, which are specified per-feature in the dataset/task + specification. +* Support for multi-GPU machines and synchronous (1 master, many workers) and + asynchronous (independent workers synchronizing through a parameter server) + [distributed training](https://github.com/tensorflow/tensor2tensor/tree/master/docs/distributed_training.md). +* Easily swap amongst datasets and models by command-line flag with the data + generation script `t2t-datagen` and the training script `t2t-trainer`. + +--- + +## T2T overview + +### Datasets + +**Datasets** are all standardized on `TFRecord` files with `tensorflow.Example` +protocol buffers. All datasets are registered and generated with the +[data +generator](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/bin/t2t-datagen) +and many common sequence datasets are already available for generation and use. + +### Problems and Modalities + +**Problems** define training-time hyperparameters for the dataset and task, +mainly by setting input and output **modalities** (e.g. symbol, image, audio, +label) and vocabularies, if applicable. All problems are defined either in +[`problem_hparams.py`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/problem_hparams.py) +or are registered with `@registry.register_problem` (run `t2t-datagen` to see +the list of all available problems). +**Modalities**, defined in +[`modality.py`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/utils/modality.py), +abstract away the input and output data types so that **models** may deal with +modality-independent tensors. + +### Models + +**`T2TModel`s** define the core tensor-to-tensor transformation, independent of +input/output modality or task. Models take dense tensors in and produce dense +tensors that may then be transformed in a final step by a **modality** depending +on the task (e.g. fed through a final linear transform to produce logits for a +softmax over classes). All models are imported in the +[`models` subpackage](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/models/__init__.py), +inherit from `T2TModel` - defined in +[`t2t_model.py`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/utils/t2t_model.py) - +and are registered with +[`@registry.register_model`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/utils/registry.py). + +### Hyperparameter Sets + +**Hyperparameter sets** are defined and registered in code with +[`@registry.register_hparams`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/utils/registry.py) +and are encoded in +[`tf.contrib.training.HParams`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/training/python/training/hparam.py) +objects. The `HParams` are available to both the problem specification and the +model. A basic set of hyperparameters are defined in +[`common_hparams.py`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/layers/common_hparams.py) +and hyperparameter set functions can compose other hyperparameter set functions. + +### Trainer + +The **trainer** binary is the main entrypoint for training, evaluation, and +inference. Users can easily switch between problems, models, and hyperparameter +sets by using the `--model`, `--problems`, and `--hparams_set` flags. Specific +hyperparameters can be overridden with the `--hparams` flag. `--schedule` and +related flags control local and distributed training/evaluation +([distributed training documentation](https://github.com/tensorflow/tensor2tensor/tree/master/docs/distributed_training.md)). + +--- + +## Adding your own components + +T2T's components are registered using a central registration mechanism that +enables easily adding new ones and easily swapping amongst them by command-line +flag. You can add your own components without editing the T2T codebase by +specifying the `--t2t_usr_dir` flag in `t2t-trainer`. + +You can do so for models, hyperparameter sets, modalities, and problems. Please +do submit a pull request if your component might be useful to others. + +Here's an example with a new hyperparameter set: + +```python +# In ~/usr/t2t_usr/my_registrations.py + +from tensor2tensor.models import transformer +from tensor2tensor.utils import registry + +@registry.register_hparams +def transformer_my_very_own_hparams_set(): + hparams = transformer.transformer_base() + hparams.hidden_size = 1024 + ... +``` + +```python +# In ~/usr/t2t_usr/__init__.py +from . import my_registrations +``` + +``` +t2t-trainer --t2t_usr_dir=~/usr/t2t_usr --registry_help +``` + +You'll see under the registered HParams your +`transformer_my_very_own_hparams_set`, which you can directly use on the command +line with the `--hparams_set` flag. + +`t2t-datagen` also supports the `--t2t_usr_dir` flag for `Problem` +registrations. + +## Adding a dataset + +To add a new dataset, subclass +[`Problem`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/problem.py) +and register it with `@registry.register_problem`. See +[`TranslateEndeWmt8k`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/wmt.py) +for an example. + +Also see the [data generators +README](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/README.md). + +--- + +*Note: This is not an official Google product.* diff --git a/tensor2tensor/data_generators/wmt.py b/tensor2tensor/data_generators/wmt.py index befb9ac7f..cde0bc9ac 100644 --- a/tensor2tensor/data_generators/wmt.py +++ b/tensor2tensor/data_generators/wmt.py @@ -34,7 +34,6 @@ FLAGS = tf.flags.FLAGS - # End-of-sentence marker. EOS = text_encoder.EOS_ID @@ -186,7 +185,6 @@ def bi_vocabs_token_generator(source_path, # Data-set URLs. - _ENDE_TRAIN_DATASETS = [ [ "http://data.statmt.org/wmt17/translation-task/training-parallel-nc-v12.tgz", # pylint: disable=line-too-long @@ -287,7 +285,6 @@ def bi_vocabs_token_generator(source_path, ], ] - # Generators. @@ -333,8 +330,8 @@ def generator(self, data_dir, tmp_dir, train): with tf.gfile.GFile(token_path, mode="a") as f: f.write("UNK\n") # Add UNK to the vocab. token_vocab = text_encoder.TokenTextEncoder(token_path, replace_oov="UNK") - return token_generator(train_path + ".en", train_path + ".de", - token_vocab, EOS) + return token_generator(train_path + ".en", train_path + ".de", token_vocab, + EOS) @property def input_space_id(self): @@ -360,7 +357,7 @@ def _preprocess_sgm(line, is_sgm): line = line.strip() if line.startswith(""): i = line.index(">") - return line[i+1:-6] # Strip first and last . + return line[i + 1:-6] # Strip first and last . def _compile_data(tmp_dir, datasets, filename): @@ -479,18 +476,24 @@ def targeted_vocab_size(self): def num_shards(self): return 10 # This is a small dataset. + @property + def source_vocab_name(self): + return "vocab.zhen-zh.%d" % self.targeted_vocab_size + + @property + def target_vocab_name(self): + return "vocab.zhen-en.%d" % self.targeted_vocab_size + def generator(self, data_dir, tmp_dir, train): - source_vocab_size = self.targeted_vocab_size - target_vocab_size = self.targeted_vocab_size datasets = _ZHEN_TRAIN_DATASETS if train else _ZHEN_TEST_DATASETS source_datasets = [[item[0], [item[1][0]]] for item in _ZHEN_TRAIN_DATASETS] target_datasets = [[item[0], [item[1][1]]] for item in _ZHEN_TRAIN_DATASETS] source_vocab = generator_utils.get_or_generate_vocab( - data_dir, tmp_dir, "vocab.zhen-zh.%d" % source_vocab_size, - source_vocab_size, source_datasets) + data_dir, tmp_dir, self.source_vocab_name, self.targeted_vocab_size, + source_datasets) target_vocab = generator_utils.get_or_generate_vocab( - data_dir, tmp_dir, "vocab.zhen-en.%d" % target_vocab_size, - target_vocab_size, target_datasets) + data_dir, tmp_dir, self.target_vocab_name, self.targeted_vocab_size, + target_datasets) tag = "train" if train else "dev" data_path = _compile_data(tmp_dir, datasets, "wmt_zhen_tok_%s" % tag) # We generate English->X data by convention, to train reverse translation @@ -508,11 +511,8 @@ def target_space_id(self): return problem.SpaceID.EN_TOK def feature_encoders(self, data_dir): - vocab_size = self.targeted_vocab_size - source_vocab_filename = os.path.join(data_dir, - "vocab.zhen-zh.%d" % vocab_size) - target_vocab_filename = os.path.join(data_dir, - "vocab.zhen-en.%d" % vocab_size) + source_vocab_filename = os.path.join(data_dir, self.source_vocab_name) + target_vocab_filename = os.path.join(data_dir, self.target_vocab_name) source_token = text_encoder.SubwordTextEncoder(source_vocab_filename) target_token = text_encoder.SubwordTextEncoder(target_vocab_filename) return { diff --git a/tensor2tensor/utils/decoding.py b/tensor2tensor/utils/decoding.py index 664935c94..a08947202 100644 --- a/tensor2tensor/utils/decoding.py +++ b/tensor2tensor/utils/decoding.py @@ -138,6 +138,7 @@ def decode_from_dataset(estimator, inputs_vocab = problem_hparams.vocabulary.get("inputs", None) targets_vocab = problem_hparams.vocabulary["targets"] for num_predictions, prediction in enumerate(predictions): + num_predictions += 1 inputs = prediction["inputs"] targets = prediction["targets"] outputs = prediction["outputs"] @@ -181,7 +182,7 @@ def decode_from_dataset(estimator, target_file.write(str(decoded_target) + "\n") if (decode_hp.num_samples >= 0 and - (num_predictions + 1) >= decode_hp.num_samples): + num_predictions >= decode_hp.num_samples): break if decode_to_file: From c996878113b6d283253aef7de4f266484e4b50f6 Mon Sep 17 00:00:00 2001 From: T2T Team Date: Thu, 21 Sep 2017 10:53:17 -0700 Subject: [PATCH 35/39] Reproduces a bug with the SubwordTextEncoder in a test. PiperOrigin-RevId: 169566059 --- .../data_generators/text_encoder_test.py | 29 +++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/tensor2tensor/data_generators/text_encoder_test.py b/tensor2tensor/data_generators/text_encoder_test.py index b55a51bf4..0351d0d2f 100644 --- a/tensor2tensor/data_generators/text_encoder_test.py +++ b/tensor2tensor/data_generators/text_encoder_test.py @@ -107,6 +107,13 @@ def test_reserved_tokens_in_corpus(self): class SubwordTextEncoderTest(tf.test.TestCase): + @classmethod + def setUpClass(cls): + """Make sure the test dir exists and is empty.""" + cls.test_temp_dir = os.path.join(tf.test.get_temp_dir(), "encoder_test") + shutil.rmtree(cls.test_temp_dir, ignore_errors=True) + os.mkdir(cls.test_temp_dir) + def test_encode_decode(self): corpus = ( "This is a corpus of text that provides a bunch of tokens from which " @@ -216,6 +223,28 @@ def test_load_from_file(self): encoder._load_from_file_object(vocab) self.assertEqual(encoder._all_subtoken_strings, correct_vocab) + def test_reserved_token_chars_not_in_alphabet(self): + corpus = "dog" + token_counts = collections.Counter(corpus.split(" ")) + encoder1 = text_encoder.SubwordTextEncoder.build_to_target_size( + 100, token_counts, 2, 100) + filename = os.path.join(self.test_temp_dir, "out.voc") + encoder1.store_to_file(filename) + encoder2 = text_encoder.SubwordTextEncoder(filename=filename) + + for t in text_encoder.RESERVED_TOKENS: + for c in t: + # Verify that encoder1 can encode all reserved token chars. + encoder1.encode(c) + + # TODO(seabass): Implement the fix so that we can remove this assertion. + with self.assertRaises(AssertionError): + for t in text_encoder.RESERVED_TOKENS: + for c in t: + # Verify that encoder2 fails to encode the characters (i.e. + # reproduce the bug). + encoder2.encode(c) + if __name__ == "__main__": tf.test.main() From 8ee83501f149d38b11ef800a00e8f16bb7c661d5 Mon Sep 17 00:00:00 2001 From: Ryan Sepassi Date: Thu, 21 Sep 2017 15:14:14 -0700 Subject: [PATCH 36/39] v1.2.3 PiperOrigin-RevId: 169607663 --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index a84f772b6..331abb78e 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setup( name='tensor2tensor', - version='1.2.2', + version='1.2.3', description='Tensor2Tensor', author='Google Inc.', author_email='no-reply@google.com', From e892dc3cc5a5ef2e6fde5b6569281ac4abc7fa24 Mon Sep 17 00:00:00 2001 From: Ryan Sepassi Date: Thu, 21 Sep 2017 17:20:44 -0700 Subject: [PATCH 37/39] Update example_life.md PiperOrigin-RevId: 169625024 --- docs/example_life.md | 195 ++++++++++++++++-- .../data_generators/cnn_dailymail.py | 2 +- tensor2tensor/data_generators/desc2code.py | 3 +- .../data_generators/gene_expression.py | 10 +- .../data_generators/generator_utils.py | 18 +- tensor2tensor/data_generators/image.py | 64 +++--- tensor2tensor/data_generators/imdb.py | 2 +- tensor2tensor/data_generators/problem.py | 26 +-- .../data_generators/problem_hparams.py | 16 +- tensor2tensor/data_generators/wiki.py | 4 +- tensor2tensor/layers/common_hparams.py | 7 +- tensor2tensor/utils/data_reader.py | 10 +- tensor2tensor/utils/data_reader_test.py | 8 +- 13 files changed, 263 insertions(+), 102 deletions(-) diff --git a/docs/example_life.md b/docs/example_life.md index 2983f5077..f3b18a817 100644 --- a/docs/example_life.md +++ b/docs/example_life.md @@ -9,26 +9,189 @@ welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg)](CO [![Gitter](https://img.shields.io/gitter/room/nwjs/nw.js.svg)](https://gitter.im/tensor2tensor/Lobby) [![License](https://img.shields.io/badge/License-Apache%202.0-brightgreen.svg)](https://opensource.org/licenses/Apache-2.0) -This document show how a training example passes through the T2T pipeline, -and how all its parts are connected to work together. +This doc explains how a training example flows through T2T, from data generation +to training, evaluation, and decoding. It points out the various hooks available +in the `Problem` and `T2TModel` classes and gives an overview of the T2T code +(key functions, files, hyperparameters, etc.). -## The Life of an Example +Some key files and their functions: -A training example passes the following stages in T2T: -* raw input (text from command line or file) -* encoded input after [Problem.feature_encoder](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/data_generators/problem.py#L173) function `encode` is usually a sparse tensor, e.g., a vector of `tf.int32`s -* batched input after [data input pipeline](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/data_reader.py#L242) where the inputs, after [Problem.preprocess_examples](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/data_generators/problem.py#L188) are grouped by their length and made into batches. -* dense input after being processed by a [Modality](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/modality.py#L30) function `bottom`. -* dense output after [T2T.model_fn_body](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/t2t_model.py#L542) -* back to sparse output through [Modality](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/modality.py#L30) function `top`. -* if decoding, back through [Problem.feature_encoder](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/data_generators/problem.py#L173) function `decode` to display on the screen. +* [`trainer_utils.py`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/utils/trainer_utils.py): + Constructs and runs all the main components of the system (the `Problem`, + the `HParams`, the `Estimator`, the `Experiment`, the `input_fn`s and + `model_fn`). +* [`common_hparams.py`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/layers/common_hparams.py): + `basic_params1` serves as the base for all model hyperparameters. Registered + model hparams functions always start with this default set of + hyperparameters. +* [`problem.py`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/problem.py): + Every dataset in T2T subclasses `Problem`. +* [`t2t_model.py`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/utils/t2t_model.py): + Every model in T2T subclasses `T2TModel`. -We go into these phases step by step below. +## Data Generation -## Feature Encoders +The `t2t-datagen` binary is the entrypoint for data generation. It simply looks +up the `Problem` specified by `--problem` and calls +`Problem.generate_data(data_dir, tmp_dir)`. -TODO: describe [Problem.feature_encoder](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/data_generators/problem.py#L173) which is a dict of encoders that have `encode` and `decode` functions. +All `Problem`s are expected to generate 2 sharded `TFRecords` files - 1 for +training and 1 for evaluation - with `tensorflow.Example` protocol buffers. The +expected names of the files are given by `Problem.{training, dev}_filepaths`. +Typically, the features in the `Example` will be `"inputs"` and `"targets"`; +however, some tasks have a different on-disk representation that is converted to +`"inputs"` and `"targets"` online in the input pipeline (e.g. image features are +typically stored with features `"image/encoded"` and `"image/format"` and the +decoding happens in the input pipeline). -## Modalities +For tasks that require a vocabulary, this is also the point at which the +vocabulary is generated and all examples are encoded. -TODO: describe [Modality](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/modality.py#L30) which has `bottom` and `top` but also sharded versions and one for targets. +There are several utility functions in +[`generator_utils`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/generator_utils.py) +that are commonly used by `Problem`s to generate data. Several are highlighted +below: + +* `generate_dataset_and_shuffle`: given 2 generators, 1 for training and 1 for + eval, yielding dictionaries of `>`, will produce sharded and shuffled `TFRecords` files with + `tensorflow.Example` protos. +* `maybe_download`: downloads a file at a URL to the given directory and + filename (see `maybe_download_from_drive` if the URL points to Google + Drive). +* `get_or_generate_vocab_inner`: given a target vocabulary size and a + generator that yields lines or tokens from the dataset, will build a + `SubwordTextEncoder` along with a backing vocabulary file that can be used + to map input strings to lists of ids. + [`SubwordTextEncoder`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/text_encoder.py) + uses word pieces and its encoding is fully invertible. + +## Data Input Pipeline + +Once the data is produced on disk, training, evaluation, and inference (if +decoding from the dataset) consume it by way of T2T input pipeline. This section +will give an overview of that pipeline with specific attention to the various +hooks in the `Problem` class and the model's `HParams` object (typically +registered in the model's file and specified by the `--hparams_set` flag). + +The entire input pipeline is implemented with the new `tf.data.Dataset` API +(previously `tf.contrib.data.Dataset`). + +The key function in the codebase for the input pipeline is +[`data_reader.input_pipeline`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/utils/data_reader.py). +The full input function is built in +[`input_fn_builder.build_input_fn`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/utils/input_fn_builder.py) +(which calls `data_reader.input_pipeline`). + +### Reading and decoding data + +`Problem.dataset_filename` specifies the prefix of the files on disk (they will +be suffixed with `-train` or `-dev` as well as their sharding). + +The features read from the files and their decoding is specified by +`Problem.example_reading_spec`, which returns 2 items: + +1. Dict mapping from on-disk feature name to on-disk types (`VarLenFeature` or + `FixedLenFeature`. +2. Dict mapping output feature name to decoder. This return value is optional + and is only needed for tasks whose features may require additional decoding + (e.g. images). You can find the available decoders in + `tf.contrib.slim.tfexample_decoder`. + +At this point in the input pipeline, the example is a `dict`. + +### Preprocessing + +The read `Example` now runs through `Problem.preprocess_example`, which by +default runs +[`problem.preprocess_example_common`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/problem.py), +which may truncate the inputs/targets or prepend to targets, governed by some +hyperparameters. + +### Batching + +Examples are bucketed by sequence length and then batched out of those buckets. +This significantly improves performance over a naive batching scheme for +variable length sequences because each example in a batch must be padded to +match the example with the maximum length in the batch. + +There are several hyperparameters that affect how examples are batched together: + +* `hp.batch_size`: this is the approximate total number of tokens in the batch + (i.e. for a sequence problem, long sequences will have smaller actual batch + size and short sequences will have a larger actual batch size in order to + generally have an equal number of tokens in the batch). +* `hp.max_length`: sequences with length longer than this will be dropped + during training (and also during eval if `hp.eval_drop_long_sequences` is + `True`). If not set, the maximum length of examples is set to + `hp.batch_size`. +* `hp.batch_size_multiplier`: multiplier for the maximum length +* `hp.min_length_bucket`: example length for the smallest bucket (i.e. the + smallest bucket will bucket examples up to this length). +* `hp.length_bucket_step`: controls how spaced out the length buckets are. + +## Building the Model + +At this point, the input features typically have `"inputs"` and `"targets"`, +each of which is a batched 4-D Tensor (e.g. of shape `[batch_size, +sequence_length, 1, 1]` for text input or `[batch_size, height, width, 3]` for +image input). + +A `T2TModel` is composed of transforms of the input features by `Modality`s, +then the body of the model, then transforms of the model output to predictions +by a `Modality`, and then a loss (during training). + +The `Modality` types for the various input features and for the target are +specified in `Problem.hparams`. A `Modality` is a feature adapter that enables +models to be agnostic to input/output spaces. You can see the various +`Modality`s in +[`modalities.py`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/layers/modalities.py). + +The sketch structure of a T2T model is as follows: + +```python +features = {...} # output from the input pipeline +input_modaly = ... # specified in Problem.hparams +target_modality = ... # specified in Problem.hparams + +transformed_features = {} +transformed_features["inputs"] = input_modality.bottom( + features["inputs"]) +transformed_features["targets"] = target_modality.targets_bottom( + features["targets"]) # for autoregressive models + +body_outputs = model.model_fn_body(transformed_features) + +predictions = target_modality.top(body_outputs, features["targets"]) +loss = target_modality.loss(predictions, features["targets"]) +``` + +Most `T2TModel`s only override `model_fn_body`. + +## Training, Eval, Inference modes + +Both the input function and model functions take a mode in the form of a +`tf.estimator.ModeKeys`, which allows the functions to behave differently in +different modes. + +In training, the model function constructs an optimizer and minimizes the loss. + +In evaluation, the model function constructs the evaluation metrics specified by +`Problem.eval_metrics`. + +In inference, the model function outputs predictions. + +## `Estimator` and `Experiment` + +With the input function and model functions constructed, the actual training +loop and related services (checkpointing, summaries, continuous evaluation, +etc.) are all handled by `Estimator` and `Experiment` objects, constructed in +[`trainer_utils.py`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/utils/trainer_utils.py). + +## Decoding + +* [`decoding.py`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/utils/decoding.py) + +TODO(rsepassi): Explain decoding (interactive, from file, and from dataset) and +`Problem.feature_encoders`. diff --git a/tensor2tensor/data_generators/cnn_dailymail.py b/tensor2tensor/data_generators/cnn_dailymail.py index 93e846a0b..2f8e9cf30 100644 --- a/tensor2tensor/data_generators/cnn_dailymail.py +++ b/tensor2tensor/data_generators/cnn_dailymail.py @@ -129,7 +129,7 @@ def use_train_shards_for_dev(self): def generator(self, data_dir, tmp_dir, _): encoder = generator_utils.get_or_generate_vocab_inner( data_dir, self.vocab_file, self.targeted_vocab_size, - lambda: story_generator(tmp_dir)) + story_generator(tmp_dir)) for story in story_generator(tmp_dir): summary, rest = _story_summary_split(story) encoded_summary = encoder.encode(summary) + [EOS] diff --git a/tensor2tensor/data_generators/desc2code.py b/tensor2tensor/data_generators/desc2code.py index 1e26b000c..174bd8107 100644 --- a/tensor2tensor/data_generators/desc2code.py +++ b/tensor2tensor/data_generators/desc2code.py @@ -195,8 +195,7 @@ def generator_target(): data_dir=data_dir, vocab_filename=self.vocab_target_filename, vocab_size=self.target_vocab_size, - generator_fn=generator_target, - ) + generator=generator_target(),) # Yield the training and testing samples eos_list = [EOS] diff --git a/tensor2tensor/data_generators/gene_expression.py b/tensor2tensor/data_generators/gene_expression.py index 43d5a6702..477e04017 100644 --- a/tensor2tensor/data_generators/gene_expression.py +++ b/tensor2tensor/data_generators/gene_expression.py @@ -159,17 +159,17 @@ def example_reading_spec(self): data_items_to_decoders = None return (data_fields, data_items_to_decoders) - def preprocess_examples(self, examples, mode, unused_hparams): + def preprocess_example(self, example, mode, unused_hparams): del mode # Reshape targets to contain num_output_predictions per output timestep - examples["targets"] = tf.reshape(examples["targets"], - [-1, 1, self.num_output_predictions]) + example["targets"] = tf.reshape(example["targets"], + [-1, 1, self.num_output_predictions]) # Slice off EOS - not needed, and messes up the GeneExpressionConv model # which expects the input length to be a multiple of the target length. - examples["inputs"] = examples["inputs"][:-1] + example["inputs"] = example["inputs"][:-1] - return examples + return example def eval_metrics(self): return [metrics.Metrics.LOG_POISSON, metrics.Metrics.R2] diff --git a/tensor2tensor/data_generators/generator_utils.py b/tensor2tensor/data_generators/generator_utils.py index 3e1086d37..f22e84794 100644 --- a/tensor2tensor/data_generators/generator_utils.py +++ b/tensor2tensor/data_generators/generator_utils.py @@ -300,7 +300,7 @@ def gunzip_file(gz_path, new_path): def get_or_generate_vocab_inner(data_dir, vocab_filename, vocab_size, - generator_fn): + generator): """Inner implementation for vocab generators. Args: @@ -308,7 +308,7 @@ def get_or_generate_vocab_inner(data_dir, vocab_filename, vocab_size, then do not save the vocab even if it doesn't exist. vocab_filename: relative filename where vocab file is stored vocab_size: target size of the vocabulary constructed by SubwordTextEncoder - generator_fn: a generator that produces tokens from the vocabulary + generator: a generator that produces tokens from the vocabulary Returns: A SubwordTextEncoder vocabulary object. @@ -325,7 +325,7 @@ def get_or_generate_vocab_inner(data_dir, vocab_filename, vocab_size, tf.logging.info("Generating vocab file: %s", vocab_filepath) token_counts = defaultdict(int) - for item in generator_fn(): + for item in generator: for tok in tokenizer.encode(text_encoder.native_to_unicode(item)): token_counts[tok] += 1 @@ -382,8 +382,8 @@ def generate(): file_byte_budget -= len(line) yield line - return get_or_generate_vocab_inner( - data_dir, vocab_filename, vocab_size, generator_fn=generate) + return get_or_generate_vocab_inner(data_dir, vocab_filename, vocab_size, + generate()) def get_or_generate_tabbed_vocab(data_dir, tmp_dir, source_filename, @@ -416,8 +416,8 @@ def generate(): part = parts[index].strip() yield part - return get_or_generate_vocab_inner( - data_dir, vocab_filename, vocab_size, generator_fn=generate) + return get_or_generate_vocab_inner(data_dir, vocab_filename, vocab_size, + generate()) def get_or_generate_txt_vocab(data_dir, vocab_filename, vocab_size, @@ -434,8 +434,8 @@ def generate(): for line in source_file: yield line.strip() - return get_or_generate_vocab_inner( - data_dir, vocab_filename, vocab_size, generator_fn=generate) + return get_or_generate_vocab_inner(data_dir, vocab_filename, vocab_size, + generate()) def read_records(filename): diff --git a/tensor2tensor/data_generators/image.py b/tensor2tensor/data_generators/image.py index 64b9d8639..084ef330a 100644 --- a/tensor2tensor/data_generators/image.py +++ b/tensor2tensor/data_generators/image.py @@ -91,19 +91,19 @@ class ImageCeleba(ImageProblem): "Wearing_Hat Wearing_Lipstick Wearing_Necklace Wearing_Necktie Young" ).split() - def preprocess_examples(self, examples, unused_mode, unused_hparams): + def preprocess_example(self, example, unused_mode, unused_hparams): def resize(img, size): return tf.to_int64( tf.image.resize_images(img, [size, size], tf.image.ResizeMethod.AREA)) - inputs = examples["inputs"] + inputs = example["inputs"] # Remove boundaries in CelebA images. Remove 40 pixels each side # vertically and 20 pixels each side horizontally. inputs = tf.image.crop_to_bounding_box(inputs, 40, 20, 218 - 80, 178 - 40) - examples["inputs"] = resize(inputs, 8) - examples["targets"] = resize(inputs, 32) - return examples + example["inputs"] = resize(inputs, 8) + example["targets"] = resize(inputs, 32) + return example def hparams(self, defaults, unused_model_hparams): p = defaults @@ -301,7 +301,7 @@ def generate_data(self, data_dir, tmp_dir, task_id=-1): self.dev_filepaths(data_dir, self.dev_shards, shuffled=False)) -def imagenet_preprocess_examples(examples, mode): +def imagenet_preprocess_example(example, mode): """Preprocessing used for Imagenet and similar problems.""" def preprocess(img): @@ -312,15 +312,15 @@ def preprocess(img): def resize(img): return tf.to_int64(tf.image.resize_images(img, [299, 299])) - inputs = tf.cast(examples["inputs"], tf.int64) + inputs = tf.cast(example["inputs"], tf.int64) if mode == tf.estimator.ModeKeys.TRAIN: - examples["inputs"] = tf.cond( # Preprocess 90% of the time. + example["inputs"] = tf.cond( # Preprocess 90% of the time. tf.less(tf.random_uniform([]), 0.9), lambda img=inputs: preprocess(img), lambda img=inputs: resize(img)) else: - examples["inputs"] = resize(inputs) - return examples + example["inputs"] = resize(inputs) + return example @registry.register_problem @@ -341,8 +341,8 @@ def generate_data(self, data_dir, tmp_dir, task_id=-1): "instructions at https://github.com/tensorflow/models/blob/master" "/inception/README.md#getting-started") - def preprocess_examples(self, examples, mode, _): - return imagenet_preprocess_examples(examples, mode) + def preprocess_example(self, example, mode, _): + return imagenet_preprocess_example(example, mode) @registry.register_problem @@ -366,17 +366,17 @@ def generate_data(self, data_dir, tmp_dir, task_id=-1): "instructions at https://github.com/tensorflow/models/blob/master" "/inception/README.md#getting-started") - def preprocess_examples(self, examples, mode, unused_hparams): + def preprocess_example(self, example, mode, unused_hparams): # Just resize with area. if self._was_reversed: - examples["inputs"] = tf.to_int64( - tf.image.resize_images(examples["inputs"], [32, 32], + example["inputs"] = tf.to_int64( + tf.image.resize_images(example["inputs"], [32, 32], tf.image.ResizeMethod.AREA)) else: - examples = imagenet_preprocess_examples(examples, mode) - examples["inputs"] = tf.to_int64( - tf.image.resize_images(examples["inputs"], [32, 32])) - return examples + example = imagenet_preprocess_example(example, mode) + example["inputs"] = tf.to_int64( + tf.image.resize_images(example["inputs"], [32, 32])) + return example @registry.register_problem @@ -386,17 +386,17 @@ class Img2imgImagenet(ImageProblem): def dataset_filename(self): return "image_imagenet" # Reuse Imagenet data. - def preprocess_examples(self, examples, unused_mode, unused_hparams): + def preprocess_example(self, example, unused_mode, unused_hparams): def resize(img, size): return tf.to_int64( tf.image.resize_images(img, [size, size], tf.image.ResizeMethod.AREA)) - inputs = examples["inputs"] + inputs = example["inputs"] # For Img2Img resize input and output images as desired. - examples["inputs"] = resize(inputs, 8) - examples["targets"] = resize(inputs, 32) - return examples + example["inputs"] = resize(inputs, 8) + example["targets"] = resize(inputs, 32) + return example def hparams(self, defaults, unused_model_hparams): p = defaults @@ -623,11 +623,11 @@ def class_labels(self): "ship", "truck" ] - def preprocess_examples(self, examples, mode, unused_hparams): + def preprocess_example(self, example, mode, unused_hparams): if mode == tf.estimator.ModeKeys.TRAIN: - examples["inputs"] = common_layers.cifar_image_augmentation( - examples["inputs"]) - return examples + example["inputs"] = common_layers.cifar_image_augmentation( + example["inputs"]) + return example def generator(self, data_dir, tmp_dir, is_training): if is_training: @@ -649,8 +649,8 @@ def generator(self, data_dir, tmp_dir, is_training): @registry.register_problem class ImageCifar10Plain(ImageCifar10): - def preprocess_examples(self, examples, mode, unused_hparams): - return examples + def preprocess_example(self, example, mode, unused_hparams): + return example # URLs and filenames for MSCOCO data. @@ -827,8 +827,8 @@ def train_shards(self): def dev_shards(self): return 10 - def preprocess_examples(self, examples, mode, _): - return imagenet_preprocess_examples(examples, mode) + def preprocess_example(self, example, mode, _): + return imagenet_preprocess_example(example, mode) def generator(self, data_dir, tmp_dir, is_training): if is_training: diff --git a/tensor2tensor/data_generators/imdb.py b/tensor2tensor/data_generators/imdb.py index d7eadcd1d..95d728b1e 100644 --- a/tensor2tensor/data_generators/imdb.py +++ b/tensor2tensor/data_generators/imdb.py @@ -79,7 +79,7 @@ def generator(self, data_dir, tmp_dir, train): # Generate vocab encoder = generator_utils.get_or_generate_vocab_inner( data_dir, self.vocab_file, self.targeted_vocab_size, - lambda: self.doc_generator(imdb_dir, "train")) + self.doc_generator(imdb_dir, "train")) # Generate examples dataset = "train" if train else "test" diff --git a/tensor2tensor/data_generators/problem.py b/tensor2tensor/data_generators/problem.py index 4ada1d212..37eee64ab 100644 --- a/tensor2tensor/data_generators/problem.py +++ b/tensor2tensor/data_generators/problem.py @@ -102,19 +102,19 @@ def default_model_hparams(): data_dir=None) -def preprocess_examples_common(examples, hparams, mode): +def preprocess_example_common(example, hparams, mode): """Preprocessing steps common to all models.""" if hparams.max_input_seq_length > 0: - examples["inputs"] = examples["inputs"][:hparams.max_input_seq_length] + example["inputs"] = example["inputs"][:hparams.max_input_seq_length] if hparams.max_target_seq_length > 0: - examples["targets"] = examples["targets"][:hparams.max_target_seq_length] + example["targets"] = example["targets"][:hparams.max_target_seq_length] if hparams.prepend_mode != "none": if mode == tf.estimator.ModeKeys.PREDICT: - examples["partial_targets"] = tf.concat([examples["inputs"], [0]], 0) + example["partial_targets"] = tf.concat([example["inputs"], [0]], 0) else: - examples["targets"] = tf.concat( - [examples["inputs"], [0], examples["targets"]], 0) - return examples + example["targets"] = tf.concat( + [example["inputs"], [0], example["targets"]], 0) + return example class Problem(object): @@ -154,7 +154,7 @@ class Problem(object): * example_reading_spec - Specify the names and types of the features on disk. - Specify tf.contrib.slim.tfexample_decoder - * preprocess_examples(examples, mode) + * preprocess_example(example, mode) - Preprocess the example feature dict from feature name to Tensor or SparseTensor. - Used in training, eval, and inference (specified by mode). @@ -198,8 +198,8 @@ def example_reading_spec(self): data_items_to_decoders = None return (data_fields, data_items_to_decoders) - def preprocess_examples(self, examples, mode, hparams): - return preprocess_examples_common(examples, hparams, mode) + def preprocess_example(self, example, mode, hparams): + return preprocess_example_common(example, hparams, mode) def eval_metrics(self): return [ @@ -310,10 +310,10 @@ def dataset(self, shuffle_files: whether to shuffle input files. Default behavior (i.e. when shuffle_files=None) is to shuffle if mode == TRAIN. hparams: tf.contrib.training.HParams; hparams to be passed to - Problem.preprocess_examples and Problem.hparams. If None, will use a + Problem.preprocess_example and Problem.hparams. If None, will use a default set that is a no-op. preprocess: bool, whether to map the Dataset through - Problem.preprocess_examples. + Problem.preprocess_example. Returns: Dataset containing dict. @@ -366,7 +366,7 @@ def decode_record(record): return dict(zip(decode_items, decoded)) def _preprocess(example): - example = self.preprocess_examples(example, mode, hparams) + example = self.preprocess_example(example, mode, hparams) self.maybe_reverse_features(example) self.maybe_copy_features(example) return example diff --git a/tensor2tensor/data_generators/problem_hparams.py b/tensor2tensor/data_generators/problem_hparams.py index 88212b0db..576a27a79 100644 --- a/tensor2tensor/data_generators/problem_hparams.py +++ b/tensor2tensor/data_generators/problem_hparams.py @@ -47,16 +47,16 @@ def example_reading_spec(self): } return data_fields, None - def preprocess_examples(self, examples, mode, hparams): - examples = super(AudioTimitProblem, self).preprocess_examples( - examples, mode, hparams) + def preprocess_example(self, example, mode, hparams): + example = super(AudioTimitProblem, self).preprocess_example( + example, mode, hparams) # Reshape audio to proper shape - sample_count = tf.to_int32(examples.pop("audio/sample_count")) - sample_width = tf.to_int32(examples.pop("audio/sample_width")) + sample_count = tf.to_int32(example.pop("audio/sample_count")) + sample_width = tf.to_int32(example.pop("audio/sample_width")) channel_count = 1 - examples["inputs"] = tf.reshape(examples["inputs"], - [sample_count, sample_width, channel_count]) - return examples + example["inputs"] = tf.reshape(example["inputs"], + [sample_count, sample_width, channel_count]) + return example @registry.register_problem diff --git a/tensor2tensor/data_generators/wiki.py b/tensor2tensor/data_generators/wiki.py index 30a16817b..a1380c27f 100644 --- a/tensor2tensor/data_generators/wiki.py +++ b/tensor2tensor/data_generators/wiki.py @@ -127,7 +127,7 @@ def use_train_shards_for_dev(self): def generator(self, data_dir, tmp_dir, _): encoder = generator_utils.get_or_generate_vocab_inner( data_dir, self.vocab_file, self.targeted_vocab_size, - lambda: page_generator(tmp_dir, max_docs=10000)) + page_generator(tmp_dir, max_docs=10000)) for page in page_generator(tmp_dir): title = _page_title(page) encoded = encoder.encode(page) + [EOS] @@ -210,7 +210,7 @@ def scramble(self, seq): def generator(self, data_dir, tmp_dir, _): encoder = generator_utils.get_or_generate_vocab_inner( data_dir, self.vocab_file, self.targeted_vocab_size, - lambda: page_generator(tmp_dir, max_docs=1000)) + page_generator(tmp_dir, max_docs=1000)) case_num = 0 for page in page_generator(tmp_dir): encoded = encoder.encode(page) diff --git a/tensor2tensor/layers/common_hparams.py b/tensor2tensor/layers/common_hparams.py index 2e33c9e94..deae14ddc 100644 --- a/tensor2tensor/layers/common_hparams.py +++ b/tensor2tensor/layers/common_hparams.py @@ -126,13 +126,13 @@ def basic_params1(): # The maximum length of "input" sequence. # Sequences longer than this value will be truncated. 0 or negative values # mean there is no maximum or truncation. - # You can change this behavior by overridding preprocess_examples() method + # You can change this behavior by overridding preprocess_example() method # in your problem class. max_input_seq_length=0, # The maximum length of "target" sequence. # Sequences longer than this value will be truncated. 0 or negative values # mean there is no maximum or truncation. - # You can change this behavior by overridding preprocess_examples() method + # You can change this behavior by overridding preprocess_example() method # in your problem class. max_target_seq_length=0, # This flag allows us to optionally treat a seq-to-seq problem @@ -152,8 +152,7 @@ def basic_params1(): # position in the inputs portion can see the # entire inputs portion. This removes the challenge of # autoregressively predicting the inputs portion. - prepend_mode="none", - ) + prepend_mode="none",) class RangedHParams(object): diff --git a/tensor2tensor/utils/data_reader.py b/tensor2tensor/utils/data_reader.py index 08e01ccfb..e88d208ac 100644 --- a/tensor2tensor/utils/data_reader.py +++ b/tensor2tensor/utils/data_reader.py @@ -241,7 +241,7 @@ def input_pipeline(problem, data_file_pattern, capacity, mode, hparams, def _preprocess(example, problem, hparams, mode): """Preprocessing for example.""" - example = problem.preprocess_examples(example, mode, hparams) + example = problem.preprocess_example(example, mode, hparams) # We do not want int64s as they are not supported on GPUs. example = cast_int64_to_int32(example) @@ -316,8 +316,8 @@ def batching_fn(bucket_id, grouped_dataset): if hasattr(dataset, "apply"): # If the Dataset supports dynamic window size, use it. dataset = dataset.apply( - tf.contrib.data.group_by_window( - example_to_bucket_id, batching_fn, None, window_size_fn)) + tf.contrib.data.group_by_window(example_to_bucket_id, batching_fn, + None, window_size_fn)) else: dataset = dataset.group_by_window(example_to_bucket_id, batching_fn, window_size) @@ -476,8 +476,8 @@ def serving_input_fn(problem, hparams): example = feature_placeholders(data_fields, data_items_to_decoders) # Preprocess - example = problem.preprocess_examples(example, tf.estimator.ModeKeys.PREDICT, - hparams) + example = problem.preprocess_example(example, tf.estimator.ModeKeys.PREDICT, + hparams) example = cast_int64_to_int32(example) # 4-D inputs and space ids diff --git a/tensor2tensor/utils/data_reader_test.py b/tensor2tensor/utils/data_reader_test.py index ff01cf07f..4f4d7530d 100644 --- a/tensor2tensor/utils/data_reader_test.py +++ b/tensor2tensor/utils/data_reader_test.py @@ -62,9 +62,9 @@ def example_reading_spec(self): data_items_to_decoders = None return (data_fields, data_items_to_decoders) - def preprocess_examples(self, examples, unused_mode, unused_hparams): - examples["new_field"] = tf.constant([42.42]) - return examples + def preprocess_example(self, example, unused_mode, unused_hparams): + example["new_field"] = tf.constant([42.42]) + return example def generate_test_data(problem, tmp_dir): @@ -146,7 +146,7 @@ def testPreprocess(self): examples = data_reader._preprocess(examples, self.problem, None, None) with tf.train.MonitoredSession() as sess: ex_val = sess.run(examples) - # problem.preprocess_examples has been run + # problem.preprocess_example has been run self.assertAllClose([42.42], ex_val["new_field"]) # int64 has been cast to int32 From 6237729d291d0fd7e2d4a4dfbfc6edcac6b756c4 Mon Sep 17 00:00:00 2001 From: Ryan Sepassi Date: Thu, 21 Sep 2017 17:59:22 -0700 Subject: [PATCH 38/39] Fix travis shield link PiperOrigin-RevId: 169629386 --- README.md | 2 +- docs/walkthrough.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 6ef815f4e..0e97770ba 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ Issues](https://img.shields.io/github/issues/tensorflow/tensor2tensor.svg)](http welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg)](CONTRIBUTING.md) [![Gitter](https://img.shields.io/gitter/room/nwjs/nw.js.svg)](https://gitter.im/tensor2tensor/Lobby) [![License](https://img.shields.io/badge/License-Apache%202.0-brightgreen.svg)](https://opensource.org/licenses/Apache-2.0) -[![Travis](https://img.shields.io/travis/tensorflow/tensor2tensor.svg)]() +[![Travis](https://img.shields.io/travis/tensorflow/tensor2tensor.svg)](https://travis-ci.org/tensorflow/tensor2tensor) [T2T](https://github.com/tensorflow/tensor2tensor) is a modular and extensible library and binaries for supervised learning with TensorFlow and with support diff --git a/docs/walkthrough.md b/docs/walkthrough.md index 6ef815f4e..0e97770ba 100644 --- a/docs/walkthrough.md +++ b/docs/walkthrough.md @@ -8,7 +8,7 @@ Issues](https://img.shields.io/github/issues/tensorflow/tensor2tensor.svg)](http welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg)](CONTRIBUTING.md) [![Gitter](https://img.shields.io/gitter/room/nwjs/nw.js.svg)](https://gitter.im/tensor2tensor/Lobby) [![License](https://img.shields.io/badge/License-Apache%202.0-brightgreen.svg)](https://opensource.org/licenses/Apache-2.0) -[![Travis](https://img.shields.io/travis/tensorflow/tensor2tensor.svg)]() +[![Travis](https://img.shields.io/travis/tensorflow/tensor2tensor.svg)](https://travis-ci.org/tensorflow/tensor2tensor) [T2T](https://github.com/tensorflow/tensor2tensor) is a modular and extensible library and binaries for supervised learning with TensorFlow and with support From 76706efe22b8fc384ed462d9b648ed148cb7f527 Mon Sep 17 00:00:00 2001 From: Ryan Sepassi Date: Thu, 21 Sep 2017 18:31:35 -0700 Subject: [PATCH 39/39] Make output of fn in @recompute_grad a list to avoid trying to concat tuple and list PiperOrigin-RevId: 169632380 --- tensor2tensor/layers/rev_block.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensor2tensor/layers/rev_block.py b/tensor2tensor/layers/rev_block.py index 3dff92c5c..8d1206ee8 100644 --- a/tensor2tensor/layers/rev_block.py +++ b/tensor2tensor/layers/rev_block.py @@ -348,7 +348,7 @@ def _recompute_grad(fn, args): def grad_fn(inputs, variables, outputs, output_grads): del outputs # recompute outputs - outputs = fn(*inputs) + outputs = list(fn(*inputs)) grads = tf.gradients(outputs, inputs + variables, output_grads) grad_inputs = grads[:len(inputs)] grad_vars = grads[len(inputs):]