From 09cab54daabff77ed3a08e0512e2937d6638aee9 Mon Sep 17 00:00:00 2001 From: Ryan Sepassi Date: Wed, 21 Jun 2017 09:38:07 -0700 Subject: [PATCH 1/6] Enable users to register components without editing codebase PiperOrigin-RevId: 159703423 --- .gitignore | 5 +++++ README.md | 38 +++++++++++++++++++++++++++++++++++ setup.py | 2 +- tensor2tensor/bin/t2t-trainer | 30 ++++++++++++++++++++++++++- 4 files changed, 73 insertions(+), 2 deletions(-) create mode 100644 .gitignore diff --git a/.gitignore b/.gitignore new file mode 100644 index 000000000..09f934869 --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +# Compiled python modules. +*.pyc + +# Python egg metadata, regenerated from source files by setuptools. +/*.egg-info diff --git a/README.md b/README.md index f13ed0343..aacfb9095 100644 --- a/README.md +++ b/README.md @@ -191,6 +191,44 @@ related flags control local and distributed training/evaluation --- +## Adding your own components + +T2T's components are registered using a central registration mechanism that +enables easily adding new ones and easily swapping amongst them by command-line +flag. You can add your own components without editing the T2T codebase by +specifying the `--t2t_usr_dir` flag in `t2t-trainer`. + +You can currently do so for models, hyperparameter sets, and modalities. Please +do submit a pull request if your component might be useful to others. + +Here's an example with a new hyperparameter set: + +```python +# In ~/usr/t2t_usr/my_registrations.py + +from tensor2tensor.models import transformer +from tensor2tensor.utils import registry + +@registry.register_hparams +def transformer_my_very_own_hparams_set(): + hparams = transformer.transformer_base() + hparams.hidden_size = 1024 + ... +``` + +```python +# In ~/usr/t2t_usr/__init__.py +import my_registrations +``` + +``` +t2t-trainer --t2t_usr_dir=~/usr/t2t_usr --registry_help +``` + +You'll see under the registered HParams your +`transformer_my_very_own_hparams_set`, which you can directly use on the command +line with the `--hparams_set` flag. + ## Adding a dataset See the [data generators diff --git a/setup.py b/setup.py index d31734dc2..d20c5bc33 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setup( name='tensor2tensor', - version='1.0.4', + version='1.0.3', description='Tensor2Tensor', author='Google Inc.', author_email='no-reply@google.com', diff --git a/tensor2tensor/bin/t2t-trainer b/tensor2tensor/bin/t2t-trainer index 9fe799e1f..92f671826 100644 --- a/tensor2tensor/bin/t2t-trainer +++ b/tensor2tensor/bin/t2t-trainer @@ -29,17 +29,45 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import importlib +import os +import sys + # Dependency imports from tensor2tensor.utils import trainer_utils as utils import tensorflow as tf -FLAGS = tf.flags.FLAGS +flags = tf.flags +FLAGS = flags.FLAGS + +flags.DEFINE_string("t2t_usr_dir", "", + "Path to a Python module that will be imported. The " + "__init__.py file should include the necessary imports. " + "The imported files should contain registrations, " + "e.g. @registry.register_model calls, that will then be " + "available to the t2t-trainer.") + + +def import_usr_dir(): + """Import module at FLAGS.t2t_usr_dir, if provided.""" + if not FLAGS.t2t_usr_dir: + return + dir_path = os.path.expanduser(FLAGS.t2t_usr_dir) + if dir_path[-1] == "/": + dir_path = dir_path[:-1] + containing_dir, module_name = os.path.split(dir_path) + tf.logging.info("Importing user module %s from path %s", module_name, + containing_dir) + sys.path.insert(0, containing_dir) + importlib.import_module(module_name) + sys.path.pop(0) def main(_): tf.logging.set_verbosity(tf.logging.INFO) + import_usr_dir() utils.log_registry() utils.validate_flags() utils.run( From 01787ca8a53e96c56eb6826443c5a12a29e9209a Mon Sep 17 00:00:00 2001 From: Ryan Sepassi Date: Wed, 21 Jun 2017 12:22:25 -0700 Subject: [PATCH 2/6] Add leakr dictionary to prevent internal stuff leaking out PiperOrigin-RevId: 159726434 --- tensor2tensor/data_generators/lm_example.py | 2 +- tensor2tensor/utils/expert_utils.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/tensor2tensor/data_generators/lm_example.py b/tensor2tensor/data_generators/lm_example.py index 7c4a42cec..d8a76baeb 100644 --- a/tensor2tensor/data_generators/lm_example.py +++ b/tensor2tensor/data_generators/lm_example.py @@ -38,7 +38,7 @@ tar xvf 1-billion-word-language-modeling-benchmark-r13output.tar.gz # replace oov words with UNK -./blaze-bin/third_party/py/tensor2tensor/data_generators/replace_oov \ +$BINARYDIR/replace_oov \ --vocab_file=$DATADIR/vocab-2016-09-10.txt \ --in_filepattern=\ $DATADIR/1-billion-word-language-modeling-benchmark-r13output/\ diff --git a/tensor2tensor/utils/expert_utils.py b/tensor2tensor/utils/expert_utils.py index 8d3d1d50c..0bd69599d 100644 --- a/tensor2tensor/utils/expert_utils.py +++ b/tensor2tensor/utils/expert_utils.py @@ -1212,7 +1212,6 @@ def SampledSoftmaxLoss(features, sampler, num_classes, target_classes, Args: features: a Tensor with shape [batch_size, hidden_size] sampler: a candidate sampler object - (see learning/brain/google/python/ops/candidate_sampling.py) num_classes: an integer target_classes: an integer Tensor with shape [batch_size] target_params: a Tensor with shape [batch_size, hidden_size] @@ -1261,7 +1260,6 @@ def ParallelSampledSoftmaxLoss(params, target_classes: A list of num_datashards integer Tensors each with shape [batch_size_i] sampler: a candidate sampler object - (see learning/brain/google/python/ops/candidate_sampling.py) num_classes: an Integer data_parallelism: a Parallelism object target_weights: an optional list of num_datashards Tensors each with From b368f7c3ea582d0517708cc354832d307432903a Mon Sep 17 00:00:00 2001 From: Ryan Sepassi Date: Wed, 21 Jun 2017 12:44:24 -0700 Subject: [PATCH 3/6] Add gitter links PiperOrigin-RevId: 159729147 --- README.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/README.md b/README.md index aacfb9095..770931cf4 100644 --- a/README.md +++ b/README.md @@ -6,6 +6,7 @@ version](https://badge.fury.io/py/tensor2tensor.svg)](https://badge.fury.io/py/t Issues](https://img.shields.io/github/issues/tensorflow/tensor2tensor.svg)](https://github.com/tensorflow/tensor2tensor/issues) [![Contributions welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg)](CONTRIBUTING.md) +[![Gitter](https://img.shields.io/gitter/room/nwjs/nw.js.svg)](https://gitter.im/tensor2tensor/Lobby) [![License](https://img.shields.io/badge/License-Apache%202.0-brightgreen.svg)](https://opensource.org/licenses/Apache-2.0) [T2T](https://github.com/tensorflow/tensor2tensor) is a modular and extensible @@ -22,6 +23,8 @@ send along a pull request to add your data-set or model. See [our contribution doc](CONTRIBUTING.md) for details and our [open issues](https://github.com/tensorflow/tensor2tensor/issues). +And chat with us and other users on +[Gitter](https://gitter.im/tensor2tensor/Lobby). --- From a8463f53b76847f5692d9acf64914d5285b7214d Mon Sep 17 00:00:00 2001 From: Noam Shazeer Date: Wed, 21 Jun 2017 12:44:28 -0700 Subject: [PATCH 4/6] Added parameter-attention option to transformer model. PiperOrigin-RevId: 159729158 --- tensor2tensor/models/common_attention.py | 69 ++++++++++++++++++++++++ tensor2tensor/models/transformer.py | 66 ++++++++++++++++++----- 2 files changed, 123 insertions(+), 12 deletions(-) diff --git a/tensor2tensor/models/common_attention.py b/tensor2tensor/models/common_attention.py index c89ae18c2..6d3d5d27c 100644 --- a/tensor2tensor/models/common_attention.py +++ b/tensor2tensor/models/common_attention.py @@ -398,3 +398,72 @@ def multihead_attention(query_antecedent, x = combine_heads(x) x = common_layers.conv1d(x, output_depth, 1, name="output_transform") return x + + +def parameter_attention(x, + total_key_depth, + total_value_depth, + output_depth, + memory_rows, + num_heads, + dropout_rate, + name=None): + """Attention over parameters. + + We use the same multi-headed attention as in the other layers, but the memory + keys and values are model parameters. There are no linear transformation + on the keys or values. + + We are also a bit more careful about memory usage, since the number of + memory positions may be very large. + + Args: + x: a Tensor with shape [batch, length_q, channels] + total_key_depth: an integer + total_value_depth: an integer + output_depth: an integer + memory_rows: an integer + num_heads: an integer dividing total_key_depth and total_value_depth + dropout_rate: a floating point number + name: an optional string + + Returns: + A Tensor. + """ + with tf.variable_scope(name, default_name="parameter_attention", + values=[x]): + head_size_k = total_key_depth // num_heads + head_size_v = total_value_depth // num_heads + var_shape_k = [num_heads, memory_rows, head_size_k] + var_shape_v = [num_heads, memory_rows, head_size_v] + k = tf.get_variable( + "k", var_shape_k, + initializer=tf.random_normal_initializer( + 0, output_depth ** -0.5)) * (num_heads ** 0.5) + v = tf.get_variable( + "v", var_shape_v, + initializer=tf.random_normal_initializer( + 0, output_depth ** -0.5)) * (output_depth ** 0.5) + batch_size = tf.shape(x)[0] + length = tf.shape(x)[1] + q = common_layers.conv1d(x, total_key_depth, 1, name="q_transform") + if dropout_rate: + # This is a cheaper form of attention dropout where we use to use + # the same dropout decisions across batch elemets and query positions, + # but different decisions across heads and memory positions. + v = tf.nn.dropout(v, 1.0 - dropout_rate, + noise_shape=[num_heads, memory_rows, 1]) + # query is [batch, length, hidden_size] + # reshape and transpose it to [heads, batch * length, head_size] + q = tf.reshape(q, [batch_size, length, num_heads, head_size_k]) + q = tf.transpose(q, [2, 0, 1, 3]) + q = tf.reshape(q, [num_heads, batch_size * length, head_size_k]) + weights = tf.matmul(q, k, transpose_b=True) + weights = tf.nn.softmax(weights) + y = tf.matmul(weights, v) + y = tf.reshape(y, [num_heads, batch_size, length, head_size_v]) + y = tf.transpose(y, [1, 2, 0, 3]) + y = tf.reshape(y, [batch_size, length, total_value_depth]) + y.set_shape([None, None, total_value_depth]) + y = common_layers.conv1d(y, output_depth, 1, name="output_transform") + return y diff --git a/tensor2tensor/models/transformer.py b/tensor2tensor/models/transformer.py index 2c88cb045..264e0570d 100644 --- a/tensor2tensor/models/transformer.py +++ b/tensor2tensor/models/transformer.py @@ -167,12 +167,7 @@ def transformer_encoder(encoder_input, hparams.attention_dropout, summaries=summaries, name="encoder_self_attention")) - x = residual_fn(x, - common_layers.conv_hidden_relu( - x, - hparams.filter_size, - hparams.hidden_size, - dropout=hparams.relu_dropout)) + x = residual_fn(x, transformer_ffn_layer(x, hparams)) return x @@ -231,15 +226,40 @@ def transformer_decoder(decoder_input, hparams.attention_dropout, summaries=summaries, name="encdec_attention")) - x = residual_fn(x, - common_layers.conv_hidden_relu( - x, - hparams.filter_size, - hparams.hidden_size, - dropout=hparams.relu_dropout)) + x = residual_fn(x, transformer_ffn_layer(x, hparams)) return x +def transformer_ffn_layer(x, hparams): + """Feed-forward layer in the transformer. + + Args: + x: a Tensor of shape [batch_size, length, hparams.hidden_size] + hparams: hyperparmeters for model + + Returns: + a Tensor of shape [batch_size, length, hparams.hidden_size] + """ + if hparams.ffn_layer == "conv_hidden_relu": + return common_layers.conv_hidden_relu( + x, + hparams.filter_size, + hparams.hidden_size, + dropout=hparams.relu_dropout) + elif hparams.ffn_layer == "parameter_attention": + return common_attention.parameter_attention( + x, + hparams.parameter_attention_key_channels or hparams.hidden_size, + hparams.parameter_attention_value_channels or hparams.hidden_size, + hparams.hidden_size, + hparams.filter_size, + hparams.num_heads, + hparams.attention_dropout) + else: + assert hparams.ffn_layer == "none" + return x + + @registry.register_hparams def transformer_base(): """Set of hyperparameters.""" @@ -268,6 +288,9 @@ def transformer_base(): hparams.add_hparam("num_heads", 8) hparams.add_hparam("attention_key_channels", 0) hparams.add_hparam("attention_value_channels", 0) + hparams.add_hparam("ffn_layer", "conv_hidden_relu") + hparams.add_hparam("parameter_attention_key_channels", 0) + hparams.add_hparam("parameter_attention_value_channels", 0) hparams.add_hparam("attention_dropout", 0.0) hparams.add_hparam("relu_dropout", 0.0) hparams.add_hparam("pos", "timing") # timing, none @@ -492,6 +515,25 @@ def transformer_big_dr2(): return hparams +@registry.register_hparams +def transformer_parameter_attention_a(): + hparams = transformer_base() + hparams.ffn_layer = "parameter_attention" + hparams.filter_size = 1536 + return hparams + + +@registry.register_hparams +def transformer_parameter_attention_b(): + hparams = transformer_base() + hparams.ffn_layer = "parameter_attention" + hparams.filter_size = 512 + hparams.parameter_attention_key_channels = 1024 + hparams.parameter_attention_value_channels = 1024 + hparams.num_heads = 16 + return hparams + + @registry.register_ranged_hparams("transformer_big_single_gpu") def transformer_range1(rhp): """Small range of hyperparameters.""" From 0fad2909b72ee6bf05eb184d0aff6d81fa93a192 Mon Sep 17 00:00:00 2001 From: Noam Shazeer Date: Wed, 21 Jun 2017 15:36:00 -0700 Subject: [PATCH 5/6] updated image transformer. now combines channels to have only 1024 positions for rev-cifar instead of 3072. PiperOrigin-RevId: 159754350 --- tensor2tensor/models/common_attention.py | 38 +++++++++++++++--------- tensor2tensor/models/modalities.py | 6 ++-- 2 files changed, 27 insertions(+), 17 deletions(-) diff --git a/tensor2tensor/models/common_attention.py b/tensor2tensor/models/common_attention.py index 6d3d5d27c..e9f3081d4 100644 --- a/tensor2tensor/models/common_attention.py +++ b/tensor2tensor/models/common_attention.py @@ -271,10 +271,14 @@ def attention_image_summary(attn, image_shapes=None): Args: attn: a Tensor with shape [batch, num_heads, query_length, memory_length] - image_shapes: optional quadruple of integer scalars. + image_shapes: optional tuple of integer scalars. If the query positions and memory positions represent the - pixels of a flattened image, then pass in their dimensions: + pixels of flattened images, then pass in their dimensions: (query_rows, query_cols, memory_rows, memory_cols). + If the query positions and memory positions represent the + pixels x channels of flattened images, then pass in their dimensions: + (query_rows, query_cols, query_channels, + memory_rows, memory_cols, memory_channels). """ num_heads = attn.get_shape().as_list()[1] # [batch, query_length, memory_length, num_heads] @@ -286,10 +290,20 @@ def attention_image_summary(attn, image_shapes=None): image = split_last_dimension(image, 3) image = tf.reduce_max(image, 4) if image_shapes is not None: - q_rows, q_cols, m_rows, m_cols = list(image_shapes) - image = tf.reshape(image, [-1, q_rows, q_cols, m_rows, m_cols, 3]) - image = tf.transpose(image, [0, 1, 3, 2, 4, 5]) - image = tf.reshape(image, [-1, q_rows * m_rows, q_cols * m_cols, 3]) + if len(image_shapes) == 4: + q_rows, q_cols, m_rows, m_cols = list(image_shapes) + image = tf.reshape(image, [-1, q_rows, q_cols, m_rows, m_cols, 3]) + image = tf.transpose(image, [0, 1, 3, 2, 4, 5]) + image = tf.reshape(image, [-1, q_rows * m_rows, q_cols * m_cols, 3]) + else: + assert len(image_shapes) == 6 + q_rows, q_cols, q_channnels, m_rows, m_cols, m_channels = list( + image_shapes) + image = tf.reshape(image, [-1, q_rows, q_cols, q_channnels, + m_rows, m_cols, m_channels, 3]) + image = tf.transpose(image, [0, 1, 4, 3, 2, 5, 6, 7]) + image = tf.reshape(image, [-1, q_rows * m_rows * q_channnels, + q_cols * m_cols * m_channels, 3]) tf.summary.image("attention", image, max_outputs=1) @@ -310,10 +324,8 @@ def dot_product_attention(q, bias: bias Tensor (see attention_bias()) dropout_rate: a floating point number summaries: a boolean - image_shapes: optional quadruple of integer scalars for image summary. - If the query positions and memory positions represent the - pixels of a flattened image, then pass in their dimensions: - (query_rows, query_cols, memory_rows, memory_cols). + image_shapes: optional tuple of integer scalars. + see comments for attention_image_summary() name: an optional string Returns: @@ -356,10 +368,8 @@ def multihead_attention(query_antecedent, num_heads: an integer dividing total_key_depth and total_value_depth dropout_rate: a floating point number summaries: a boolean - image_shapes: optional quadruple of integer scalars for image summary. - If the query positions and memory positions represent the - pixels of a flattened image, then pass in their dimensions: - (query_rows, query_cols, memory_rows, memory_cols). + image_shapes: optional tuple of integer scalars. + see comments for attention_image_summary() name: an optional string Returns: diff --git a/tensor2tensor/models/modalities.py b/tensor2tensor/models/modalities.py index 0593189f0..fd9fb4432 100644 --- a/tensor2tensor/models/modalities.py +++ b/tensor2tensor/models/modalities.py @@ -441,8 +441,8 @@ class IdentityModality(modality.Modality): def targets_dimensionality(self): return self._vocab_size - def inputs_bottom_simple(self, inputs): - return tf.to_float(inputs) + def bottom(self, x): + return tf.to_float(x) - def targets_top_simple(self, body_output, _): + def top(self, body_output, _): return body_output From 8195f345a9a60f81ba8d6947a776d0dfad51ef5a Mon Sep 17 00:00:00 2001 From: Ryan Sepassi Date: Wed, 21 Jun 2017 17:13:48 -0700 Subject: [PATCH 6/6] Remove tensorflow dependency from setup.py to enable cpu installs and bump version to 1.0.5 PiperOrigin-RevId: 159767222 --- README.md | 7 +++++++ setup.py | 7 +++++-- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 770931cf4..69ad66ddc 100644 --- a/README.md +++ b/README.md @@ -98,7 +98,14 @@ cat $DECODE_FILE.$MODEL.$HPARAMS.beam$BEAM_SIZE.alpha$ALPHA.decodes ## Installation ``` +# Assumes tensorflow or tensorflow-gpu installed pip install tensor2tensor + +# Installs with tensorflow-gpu requirement +pip install tensor2tensor[tensorflow_gpu] + +# Installs with tensorflow (cpu) requirement +pip install tensor2tensor[tensorflow] ``` Binaries: diff --git a/setup.py b/setup.py index d20c5bc33..a2d541a30 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setup( name='tensor2tensor', - version='1.0.3', + version='1.0.5', description='Tensor2Tensor', author='Google Inc.', author_email='no-reply@google.com', @@ -17,8 +17,11 @@ 'numpy', 'sympy', 'six', - 'tensorflow-gpu>=1.2.0rc1', ], + extras_require={ + 'tensorflow': ['tensorflow>=1.2.0rc1'], + 'tensorflow_gpu': ['tensorflow-gpu>=1.2.0rc1'], + }, classifiers=[ 'Development Status :: 4 - Beta', 'Intended Audience :: Developers',