diff --git a/setup.py b/setup.py index 88ed4a4ea..0669ab1a6 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setup( name='tensor2tensor', - version='1.2.6', + version='1.2.7', description='Tensor2Tensor', author='Google Inc.', author_email='no-reply@google.com', diff --git a/tensor2tensor/data_generators/generator_utils.py b/tensor2tensor/data_generators/generator_utils.py index 8ce66dc6e..833717432 100644 --- a/tensor2tensor/data_generators/generator_utils.py +++ b/tensor2tensor/data_generators/generator_utils.py @@ -21,9 +21,9 @@ from collections import defaultdict import gzip -import io import os import random +import stat import tarfile # Dependency imports @@ -190,8 +190,8 @@ def maybe_download(directory, filename, url): print() tf.gfile.Rename(inprogress_filepath, filepath) statinfo = os.stat(filepath) - tf.logging.info("Successfully downloaded %s, %s bytes." % (filename, - statinfo.st_size)) + tf.logging.info("Successfully downloaded %s, %s bytes." % + (filename, statinfo.st_size)) else: tf.logging.info("Not downloading, file already found: %s" % filepath) return filepath @@ -243,7 +243,7 @@ def maybe_download_from_drive(directory, filename, url): print() statinfo = os.stat(filepath) tf.logging.info("Successfully downloaded %s, %s bytes." % (filename, - statinfo.st_size)) + statinfo.st_size)) return filepath @@ -258,8 +258,11 @@ def gunzip_file(gz_path, new_path): tf.logging.info("File %s already exists, skipping unpacking" % new_path) return tf.logging.info("Unpacking %s to %s" % (gz_path, new_path)) + # We may be unpacking into a newly created directory, add write mode. + mode = stat.S_IRWXU or stat.S_IXGRP or stat.S_IRGRP or stat.S_IROTH + os.chmod(os.path.dirname(new_path), mode) with gzip.open(gz_path, "rb") as gz_file: - with io.open(new_path, "wb") as new_file: + with tf.gfile.GFile(new_path, mode="wb") as new_file: for line in gz_file: new_file.write(line) diff --git a/tensor2tensor/data_generators/image.py b/tensor2tensor/data_generators/image.py index e9ae45f01..0c3988bc5 100644 --- a/tensor2tensor/data_generators/image.py +++ b/tensor2tensor/data_generators/image.py @@ -24,6 +24,7 @@ import json import os import random +import struct import tarfile import zipfile @@ -925,3 +926,58 @@ class ImageMsCocoTokens32k(ImageMsCocoTokens8k): @property def targeted_vocab_size(self): return 2**15 # 32768 + + +@registry.register_problem +class OcrTest(Image2TextProblem): + """OCR test problem.""" + + @property + def is_small(self): + return True + + @property + def is_character_level(self): + return True + + @property + def target_space_id(self): + return problem.SpaceID.EN_CHR + + @property + def train_shards(self): + return 1 + + @property + def dev_shards(self): + return 1 + + def preprocess_example(self, example, mode, _): + # Resize from usual size ~1350x60 to 90x4 in this test. + img = example["inputs"] + example["inputs"] = tf.to_int64( + tf.image.resize_images(img, [90, 4], tf.image.ResizeMethod.AREA)) + return example + + def generator(self, data_dir, tmp_dir, is_training): + # In this test problem, we assume that the data is in tmp_dir/ocr/ in + # files names 0.png, 0.txt, 1.png, 1.txt and so on until num_examples. + num_examples = 2 + ocr_dir = os.path.join(tmp_dir, "ocr/") + tf.logging.info("Looking for OCR data in %s." % ocr_dir) + for i in xrange(num_examples): + image_filepath = os.path.join(ocr_dir, "%d.png" % i) + text_filepath = os.path.join(ocr_dir, "%d.txt" % i) + with tf.gfile.Open(text_filepath, "rb") as f: + label = f.read() + with tf.gfile.Open(image_filepath, "rb") as f: + encoded_image_data = f.read() + # In PNG files width and height are stored in these bytes. + width, height = struct.unpack(">ii", encoded_image_data[16:24]) + yield { + "image/encoded": [encoded_image_data], + "image/format": ["png"], + "image/class/label": label.strip(), + "image/height": [height], + "image/width": [width] + } diff --git a/tensor2tensor/data_generators/translate_enfr.py b/tensor2tensor/data_generators/translate_enfr.py index 152d3d963..8076d4792 100644 --- a/tensor2tensor/data_generators/translate_enfr.py +++ b/tensor2tensor/data_generators/translate_enfr.py @@ -34,50 +34,54 @@ # End-of-sentence marker. EOS = text_encoder.EOS_ID -_ENFR_TRAIN_DATASETS = [ +_ENFR_TRAIN_SMALL_DATA = [ [ "https://s3.amazonaws.com/opennmt-trainingdata/baseline-1M-enfr.tgz", ("baseline-1M-enfr/baseline-1M_train.en", "baseline-1M-enfr/baseline-1M_train.fr") ], - # [ - # "http://www.statmt.org/wmt13/training-parallel-commoncrawl.tgz", - # ("commoncrawl.fr-en.en", "commoncrawl.fr-en.fr") - # ], - # [ - # "http://www.statmt.org/wmt13/training-parallel-europarl-v7.tgz", - # ("training/europarl-v7.fr-en.en", "training/europarl-v7.fr-en.fr") - # ], - # [ - # "http://www.statmt.org/wmt14/training-parallel-nc-v9.tgz", - # ("training/news-commentary-v9.fr-en.en", - # "training/news-commentary-v9.fr-en.fr") - # ], - # [ - # "http://www.statmt.org/wmt10/training-giga-fren.tar", - # ("giga-fren.release2.fixed.en.gz", - # "giga-fren.release2.fixed.fr.gz") - # ], - # [ - # "http://www.statmt.org/wmt13/training-parallel-un.tgz", - # ("un/undoc.2000.fr-en.en", "un/undoc.2000.fr-en.fr") - # ], ] -_ENFR_TEST_DATASETS = [ +_ENFR_TEST_SMALL_DATA = [ [ "https://s3.amazonaws.com/opennmt-trainingdata/baseline-1M-enfr.tgz", ("baseline-1M-enfr/baseline-1M_valid.en", "baseline-1M-enfr/baseline-1M_valid.fr") ], - # [ - # "http://data.statmt.org/wmt17/translation-task/dev.tgz", - # ("dev/newstest2013.en", "dev/newstest2013.fr") - # ], +] +_ENFR_TRAIN_LARGE_DATA = [ + [ + "http://www.statmt.org/wmt13/training-parallel-commoncrawl.tgz", + ("commoncrawl.fr-en.en", "commoncrawl.fr-en.fr") + ], + [ + "http://www.statmt.org/wmt13/training-parallel-europarl-v7.tgz", + ("training/europarl-v7.fr-en.en", "training/europarl-v7.fr-en.fr") + ], + [ + "http://www.statmt.org/wmt14/training-parallel-nc-v9.tgz", + ("training/news-commentary-v9.fr-en.en", + "training/news-commentary-v9.fr-en.fr") + ], + [ + "http://www.statmt.org/wmt10/training-giga-fren.tar", + ("giga-fren.release2.fixed.en.gz", + "giga-fren.release2.fixed.fr.gz") + ], + [ + "http://www.statmt.org/wmt13/training-parallel-un.tgz", + ("un/undoc.2000.fr-en.en", "un/undoc.2000.fr-en.fr") + ], +] +_ENFR_TEST_LARGE_DATA = [ + [ + "http://data.statmt.org/wmt17/translation-task/dev.tgz", + ("dev/newstest2013.en", "dev/newstest2013.fr") + ], ] @registry.register_problem -class TranslateEnfrWmt8k(translate.TranslateProblem): +class TranslateEnfrWmtSmall8k(translate.TranslateProblem): """Problem spec for WMT En-Fr translation.""" @property @@ -88,11 +92,18 @@ def targeted_vocab_size(self): def vocab_name(self): return "vocab.enfr" + @property + def use_small_dataset(self): + return True + def generator(self, data_dir, tmp_dir, train): symbolizer_vocab = generator_utils.get_or_generate_vocab( data_dir, tmp_dir, self.vocab_file, self.targeted_vocab_size, - _ENFR_TRAIN_DATASETS) - datasets = _ENFR_TRAIN_DATASETS if train else _ENFR_TEST_DATASETS + _ENFR_TRAIN_SMALL_DATA) + if self.use_small_dataset: + datasets = _ENFR_TRAIN_SMALL_DATA if train else _ENFR_TEST_SMALL_DATA + else: + datasets = _ENFR_TRAIN_LARGE_DATA if train else _ENFR_TEST_LARGE_DATA tag = "train" if train else "dev" data_path = translate.compile_data(tmp_dir, datasets, "wmt_enfr_tok_%s" % tag) @@ -109,7 +120,7 @@ def target_space_id(self): @registry.register_problem -class TranslateEnfrWmt32k(TranslateEnfrWmt8k): +class TranslateEnfrWmtSmall32k(TranslateEnfrWmtSmall8k): @property def targeted_vocab_size(self): @@ -117,7 +128,23 @@ def targeted_vocab_size(self): @registry.register_problem -class TranslateEnfrWmtCharacters(translate.TranslateProblem): +class TranslateEnfrWmt8k(TranslateEnfrWmtSmall8k): + + @property + def use_small_dataset(self): + return False + + +@registry.register_problem +class TranslateEnfrWmt32k(TranslateEnfrWmtSmall32k): + + @property + def use_small_dataset(self): + return False + + +@registry.register_problem +class TranslateEnfrWmtSmallCharacters(translate.TranslateProblem): """Problem spec for WMT En-Fr translation.""" @property @@ -130,7 +157,10 @@ def vocab_name(self): def generator(self, data_dir, tmp_dir, train): character_vocab = text_encoder.ByteTextEncoder() - datasets = _ENFR_TRAIN_DATASETS if train else _ENFR_TEST_DATASETS + if self.use_small_dataset: + datasets = _ENFR_TRAIN_SMALL_DATA if train else _ENFR_TEST_SMALL_DATA + else: + datasets = _ENFR_TRAIN_LARGE_DATA if train else _ENFR_TEST_LARGE_DATA tag = "train" if train else "dev" data_path = translate.compile_data(tmp_dir, datasets, "wmt_enfr_chr_%s" % tag) @@ -144,3 +174,11 @@ def input_space_id(self): @property def target_space_id(self): return problem.SpaceID.FR_CHR + + +@registry.register_problem +class TranslateEnfrWmtCharacters(TranslateEnfrWmtSmallCharacters): + + @property + def use_small_dataset(self): + return False diff --git a/tensor2tensor/data_generators/translate_enzh.py b/tensor2tensor/data_generators/translate_enzh.py index 6b0f36c23..0ee3bfd08 100644 --- a/tensor2tensor/data_generators/translate_enzh.py +++ b/tensor2tensor/data_generators/translate_enzh.py @@ -35,9 +35,13 @@ # End-of-sentence marker. EOS = text_encoder.EOS_ID + +# End-of-sentence marker. +EOS = text_encoder.EOS_ID + # This is far from being the real WMT17 task - only toyset here -# you need to register to get UN data and CWT data -# also by convention this is EN to ZH - use translate_enzh_wmt8k_rev for ZH to EN task +# you need to register to get UN data and CWT data. Also, by convention, +# this is EN to ZH - use translate_enzh_wmt8k_rev for ZH to EN task _ENZH_TRAIN_DATASETS = [[("http://data.statmt.org/wmt17/translation-task/" "training-parallel-nc-v12.tgz"), ("training/news-commentary-v12.zh-en.en", diff --git a/tensor2tensor/layers/common_attention.py b/tensor2tensor/layers/common_attention.py index 2178e6fe5..cf7ef9115 100644 --- a/tensor2tensor/layers/common_attention.py +++ b/tensor2tensor/layers/common_attention.py @@ -2958,15 +2958,20 @@ def pad_and_reshape(x): @expert_utils.add_var_scope() def multihead_self_attention_reduced( - x, factor, nonlinearity, reduction_type, multihead_params): + x, + factor, + multihead_params, + nonlinearity="none", + reduction_type="conv", +): """Reduce the length dimension by compressing with conv. Args: x (tf.Tensor): float32 of shape [batch, length, depth] factor (int): compression factor for the memory sequence + multihead_params (dict): parameters for multihead attention nonlinearity (str): Add some non-linearity after the memory block reduction_type (str): type of compression - multihead_params (dict): parameters for multihead attention Returns: (tf.Tensor): float32 of shape [batch, length, depth] diff --git a/tensor2tensor/layers/common_hparams.py b/tensor2tensor/layers/common_hparams.py index d2d8bb2e5..c8ba0d03c 100644 --- a/tensor2tensor/layers/common_hparams.py +++ b/tensor2tensor/layers/common_hparams.py @@ -116,12 +116,15 @@ def basic_params1(): # If set to True, drop sequences longer than max_length during eval. # This affects the validity of the evaluation metrics. eval_drop_long_sequences=int(False), + # TODO(lukaszkaiser): these parameters should probably be set elsewhere. # in SymbolModality, share the output embeddings and the softmax # variables. # You can also share the input embeddings with the output embeddings # by using a problem_hparams that uses the same modality object for # the input_modality and target_modality. shared_embedding_and_softmax_weights=int(False), + # In SymbolModality, skip the top layer, assume we're providing logits. + symbol_modality_skip_top=int(False), # For each feature for which you want to override the default input # modality, add an entry to this semicolon-separated string. Entries are # formatted "feature_name:modality_type:modality_name", e.g. diff --git a/tensor2tensor/layers/common_layers.py b/tensor2tensor/layers/common_layers.py index 08fd2f56b..7089529c8 100644 --- a/tensor2tensor/layers/common_layers.py +++ b/tensor2tensor/layers/common_layers.py @@ -326,7 +326,7 @@ def conv_internal(conv_fn, inputs, filters, kernel_size, **kwargs): raise ValueError("Inputs to conv must have statically known rank 4. " "Shape: " + str(static_shape)) # Add support for left padding. - if "padding" in kwargs and kwargs["padding"] == "LEFT": + if kwargs.get("padding") == "LEFT": dilation_rate = (1, 1) if "dilation_rate" in kwargs: dilation_rate = kwargs["dilation_rate"] @@ -344,15 +344,9 @@ def conv_internal(conv_fn, inputs, filters, kernel_size, **kwargs): def conv2d_kernel(kernel_size_arg, name_suffix): """Call conv2d but add suffix to name.""" - if "name" in kwargs: - original_name = kwargs["name"] - name = kwargs.pop("name") + "_" + name_suffix - else: - original_name = None - name = "conv_" + name_suffix - original_force2d = None - if "force2d" in kwargs: - original_force2d = kwargs.pop("force2d") + name = "{}_{}".format(kwargs.get("name", "conv"), name_suffix) + original_name = kwargs.pop("name", None) + original_force2d = kwargs.pop("force2d", None) result = conv_fn(inputs, filters, kernel_size_arg, name=name, **kwargs) if original_name is not None: kwargs["name"] = original_name # Restore for other calls. @@ -1483,8 +1477,22 @@ def padded_cross_entropy(logits, return tf.reduce_sum(xent * weights), tf.reduce_sum(weights) -def smoothing_cross_entropy(logits, labels, vocab_size, confidence): - """Cross entropy with label smoothing to limit over-confidence.""" +def smoothing_cross_entropy(logits, labels, vocab_size, confidence, + gaussian=False): + """Cross entropy with label smoothing to limit over-confidence. + + Args: + logits: Tensor of size [batch_size, ?, ?, ?, vocab_size] + labels: Tensor of size [batch_size, ?, ?, ?] + vocab_size: Tensor representing the size of the vocabulary. + confidence: Used to determine on and off values for label smoothing. + If `gaussian` is true, `confidence` is the variance to the gaussian + distribution. + gaussian: Uses a gaussian distribution for label smoothing + + Returns: + + """ with tf.name_scope("smoothing_cross_entropy", [logits, labels]): # Low confidence is given to all non-true labels, uniformly. low_confidence = (1.0 - confidence) / tf.to_float(vocab_size - 1) @@ -1492,12 +1500,23 @@ def smoothing_cross_entropy(logits, labels, vocab_size, confidence): # We subtract it just for readability, makes no difference on learning. normalizing = -(confidence * tf.log(confidence) + tf.to_float( vocab_size - 1) * low_confidence * tf.log(low_confidence + 1e-20)) - # Soft targets. - soft_targets = tf.one_hot( - tf.cast(labels, tf.int32), - depth=vocab_size, - on_value=confidence, - off_value=low_confidence) + + if gaussian: + labels = tf.cast(labels, tf.float32) + + normal_dist = tf.distributions.Normal(loc=labels, scale=confidence) + # Locations to evaluate the probability distributions. + soft_targets = normal_dist.prob(tf.cast(tf.range(vocab_size), tf.float32) + [:, None, None, None, None]) + # Reordering soft_targets from [vocab_size, batch_size, ?, ?, ?] to match + # logits: [batch_size, ?, ?, ?, vocab_size] + soft_targets = tf.transpose(soft_targets, perm=[1, 2, 3, 4, 0]) + else: + soft_targets = tf.one_hot( + tf.cast(labels, tf.int32), + depth=vocab_size, + on_value=confidence, + off_value=low_confidence) xentropy = tf.nn.softmax_cross_entropy_with_logits( logits=logits, labels=soft_targets) return xentropy - normalizing diff --git a/tensor2tensor/layers/modalities.py b/tensor2tensor/layers/modalities.py index a29aa93b1..9e0f73045 100644 --- a/tensor2tensor/layers/modalities.py +++ b/tensor2tensor/layers/modalities.py @@ -115,6 +115,8 @@ def top(self, body_output, _): else: scope_name = "softmax" reuse = False + if self._model_hparams.symbol_modality_skip_top: + return tf.expand_dims(body_output, 3) with tf.variable_scope(scope_name, reuse=reuse): var = self._get_weights() if (self._model_hparams.factored_logits and diff --git a/tensor2tensor/layers/modalities_test.py b/tensor2tensor/layers/modalities_test.py index 93dda6d09..7421a7e07 100644 --- a/tensor2tensor/layers/modalities_test.py +++ b/tensor2tensor/layers/modalities_test.py @@ -40,6 +40,7 @@ def testSymbolModalityInputs(self): symbol_modality_num_shards=4, hidden_size=hidden_size, multiply_embedding_mode="sqrt_depth", + symbol_modality_skip_top=0, shared_embedding_and_softmax_weights=0) x = -1 + np.random.random_integers( vocab_size, size=(batch_size, length, 1, 1)) @@ -65,6 +66,7 @@ def testSymbolModalityTargets(self): symbol_modality_num_shards=4, hidden_size=hidden_size, label_smoothing=0.2, + symbol_modality_skip_top=0, shared_embedding_and_softmax_weights=0, factored_logits=0, mode=tf.estimator.ModeKeys.TRAIN) @@ -99,6 +101,7 @@ def testSymbolModalityTargetsFactored(self): symbol_modality_num_shards=4, hidden_size=hidden_size, label_smoothing=0.2, + symbol_modality_skip_top=0, shared_embedding_and_softmax_weights=0, factored_logits=1, mode=tf.estimator.ModeKeys.TRAIN) diff --git a/tensor2tensor/models/__init__.py b/tensor2tensor/models/__init__.py index f5fafe706..f4c8a9a82 100644 --- a/tensor2tensor/models/__init__.py +++ b/tensor2tensor/models/__init__.py @@ -36,9 +36,11 @@ from tensor2tensor.models import shake_shake from tensor2tensor.models import slicenet from tensor2tensor.models import transformer +from tensor2tensor.models import transformer_adv from tensor2tensor.models import transformer_alternative from tensor2tensor.models import transformer_moe from tensor2tensor.models import transformer_revnet +from tensor2tensor.models import transformer_sketch from tensor2tensor.models import transformer_vae from tensor2tensor.models import xception # pylint: enable=unused-import diff --git a/tensor2tensor/models/lstm.py b/tensor2tensor/models/lstm.py index 2f5475276..c3e378359 100644 --- a/tensor2tensor/models/lstm.py +++ b/tensor2tensor/models/lstm.py @@ -19,8 +19,6 @@ from __future__ import division from __future__ import print_function -import collections - # Dependency imports from tensor2tensor.layers import common_hparams @@ -29,7 +27,6 @@ from tensor2tensor.utils import t2t_model import tensorflow as tf -from tensorflow.python.util import nest def lstm(inputs, hparams, train, name, initial_state=None): @@ -60,22 +57,28 @@ def dropout_lstm_cell(): input_keep_prob=1.0 - hparams.dropout * tf.to_float(train)) layers = [dropout_lstm_cell() for _ in range(hparams.num_hidden_layers)] - AttentionMechanism = (tf.contrib.seq2seq.LuongAttention if hparams.attention_mechanism == "luong" - else tf.contrib.seq2seq.BahdanauAttention) - attention_mechanism = AttentionMechanism(hparams.hidden_size, encoder_outputs) - + if hparams.attention_mechanism == "luong": + attention_mechanism_class = tf.contrib.seq2seq.LuongAttention + elif hparams.attention_mechanism == "bahdanau": + attention_mechanism_class = tf.contrib.seq2seq.BahdanauAttention + else: + raise ValueError("Unknown hparams.attention_mechanism = %s, must be " + "luong or bahdanu." % hparams.attention_mechanism) + attention_mechanism = attention_mechanism_class( + hparams.hidden_size, encoder_outputs) + cell = tf.contrib.seq2seq.AttentionWrapper( tf.nn.rnn_cell.MultiRNNCell(layers), [attention_mechanism]*hparams.num_heads, attention_layer_size=[hparams.attention_layer_size]*hparams.num_heads, - output_attention=(hparams.output_attention==1)) + output_attention=(hparams.output_attention == 1)) - batch_size = inputs.get_shape()[0].value if batch_size is None: batch_size = tf.shape(inputs)[0] - initial_state = cell.zero_state(batch_size, tf.float32).clone(cell_state=initial_state) + initial_state = cell.zero_state(batch_size, tf.float32).clone( + cell_state=initial_state) with tf.variable_scope(name): output, state = tf.nn.dynamic_rnn( @@ -84,11 +87,11 @@ def dropout_lstm_cell(): initial_state=initial_state, dtype=tf.float32, time_major=False) - + # For multi-head attention project output back to hidden size if hparams.output_attention == 1 and hparams.num_heads > 1: output = tf.layers.dense(output, hparams.hidden_size) - + return output, state @@ -131,6 +134,9 @@ def lstm_seq2seq_internal_attention(inputs, targets, hparams, train): class LSTMSeq2seq(t2t_model.T2TModel): def model_fn_body(self, features): + # TODO(lukaszkaiser): investigate this issue and repair. + if self._hparams.initializer == "orthogonal": + raise ValueError("LSTM models fail with orthogonal initializer.") train = self._hparams.mode == tf.estimator.ModeKeys.TRAIN return lstm_seq2seq_internal(features["inputs"], features["targets"], self._hparams, train) @@ -140,6 +146,9 @@ def model_fn_body(self, features): class LSTMSeq2seqAttention(t2t_model.T2TModel): def model_fn_body(self, features): + # TODO(lukaszkaiser): investigate this issue and repair. + if self._hparams.initializer == "orthogonal": + raise ValueError("LSTM models fail with orthogonal initializer.") train = self._hparams.mode == tf.estimator.ModeKeys.TRAIN return lstm_seq2seq_internal_attention( features["inputs"], features["targets"], self._hparams, train) @@ -155,11 +164,11 @@ def lstm_seq2seq(): hparams.initializer = "uniform_unit_scaling" hparams.initializer_gain = 1.0 hparams.weight_decay = 0.0 - return hparams + def lstm_attention_base(): - """ Base attention params. """ + """Base attention params.""" hparams = lstm_seq2seq() hparams.add_hparam("attention_layer_size", hparams.hidden_size) hparams.add_hparam("output_attention", int(True)) @@ -169,33 +178,37 @@ def lstm_attention_base(): @registry.register_hparams def lstm_bahdanau_attention(): - """hparams for LSTM with bahdanau attention.""" + """Hparams for LSTM with bahdanau attention.""" hparams = lstm_attention_base() hparams.add_hparam("attention_mechanism", "bahdanau") return hparams + @registry.register_hparams def lstm_luong_attention(): - """hparams for LSTM with luong attention.""" + """Hparams for LSTM with luong attention.""" hparams = lstm_attention_base() hparams.add_hparam("attention_mechanism", "luong") return hparams + @registry.register_hparams def lstm_attention(): - """ For backwards compatibility, Defaults to bahdanau """ + """For backwards compatibility, defaults to bahdanau.""" return lstm_bahdanau_attention() + @registry.register_hparams def lstm_bahdanau_attention_multi(): - """ Multi-head Luong attention """ + """Multi-head Bahdanu attention.""" hparams = lstm_bahdanau_attention() hparams.num_heads = 4 return hparams + @registry.register_hparams def lstm_luong_attention_multi(): - """ Multi-head Luong attention """ + """Multi-head Luong attention.""" hparams = lstm_luong_attention() hparams.num_heads = 4 - return hparams \ No newline at end of file + return hparams diff --git a/tensor2tensor/models/lstm_test.py b/tensor2tensor/models/lstm_test.py index 0d4bc6d80..b8be74f23 100644 --- a/tensor2tensor/models/lstm_test.py +++ b/tensor2tensor/models/lstm_test.py @@ -24,7 +24,6 @@ import numpy as np from tensor2tensor.data_generators import problem_hparams -from tensor2tensor.layers import common_hparams from tensor2tensor.models import lstm import tensorflow as tf @@ -36,7 +35,7 @@ def testLSTMSeq2Seq(self): vocab_size = 9 x = np.random.random_integers(1, high=vocab_size - 1, size=(3, 5, 1, 1)) y = np.random.random_integers(1, high=vocab_size - 1, size=(3, 6, 1, 1)) - hparams = common_hparams.basic_params1() + hparams = lstm.lstm_seq2seq() p_hparams = problem_hparams.test_problem_hparams(vocab_size, vocab_size) with self.test_session() as session: features = { diff --git a/tensor2tensor/models/transformer.py b/tensor2tensor/models/transformer.py index 9a090e40f..1d8603687 100644 --- a/tensor2tensor/models/transformer.py +++ b/tensor2tensor/models/transformer.py @@ -64,24 +64,21 @@ def encode(self, inputs, target_space, hparams): encoder_input, self_attention_bias, encoder_decoder_attention_bias = ( transformer_prepare_encoder(inputs, target_space, hparams)) - encoder_input = tf.nn.dropout( - encoder_input, 1.0 - hparams.layer_prepostprocess_dropout) + encoder_input = tf.nn.dropout(encoder_input, + 1.0 - hparams.layer_prepostprocess_dropout) - encoder_output = transformer_encoder( - encoder_input, - self_attention_bias, - hparams) + encoder_output = transformer_encoder(encoder_input, self_attention_bias, + hparams) return encoder_output, encoder_decoder_attention_bias - def decode( - self, - decoder_input, - encoder_output, - encoder_decoder_attention_bias, - decoder_self_attention_bias, - hparams, - cache=None): + def decode(self, + decoder_input, + encoder_output, + encoder_decoder_attention_bias, + decoder_self_attention_bias, + hparams, + cache=None): """Decode Transformer outputs from encoder representation. Args: @@ -129,11 +126,12 @@ def model_fn_body(self, features): """ hparams = self._hparams - inputs = features["inputs"] - - target_space = features["target_space_id"] - encoder_output, encoder_decoder_attention_bias = self.encode( - inputs, target_space, hparams) + inputs = features.get("inputs") + encoder_output, encoder_decoder_attention_bias = (None, None) + if inputs is not None: + target_space = features["target_space_id"] + encoder_output, encoder_decoder_attention_bias = self.encode( + inputs, target_space, hparams) targets = features["targets"] targets = common_layers.flatten4d3d(targets) @@ -141,15 +139,11 @@ def model_fn_body(self, features): decoder_input, decoder_self_attention_bias = transformer_prepare_decoder( targets, hparams) - return self.decode( - decoder_input, - encoder_output, - encoder_decoder_attention_bias, - decoder_self_attention_bias, - hparams) + return self.decode(decoder_input, encoder_output, + encoder_decoder_attention_bias, + decoder_self_attention_bias, hparams) - def _greedy_infer( - self, features, decode_length, last_position_only=True): + def _greedy_infer(self, features, decode_length, last_position_only=True): """Fast version of greedy decoding. Args: @@ -185,18 +179,16 @@ def _beam_decode(self, features, decode_length, beam_size, top_beams, Returns: samples: an integer `Tensor`. Top samples from the beam search """ - return self._fast_decode( - features, decode_length, last_position_only, beam_size, top_beams, - alpha) - - def _fast_decode( - self, - features, - decode_length, - last_position_only=True, - beam_size=1, - top_beams=1, - alpha=1.0): + return self._fast_decode(features, decode_length, last_position_only, + beam_size, top_beams, alpha) + + def _fast_decode(self, + features, + decode_length, + last_position_only=True, + beam_size=1, + top_beams=1, + alpha=1.0): """Fast decoding. Implements both greedy and beam search decoding, uses beam search iff @@ -277,12 +269,10 @@ def preprocess_targets(targets, i): # TODO(llion): Explain! Is this even needed? targets = tf.cond( - tf.equal(i, 0), - lambda: tf.zeros_like(targets), - lambda: targets) + tf.equal(i, 0), lambda: tf.zeros_like(targets), lambda: targets) if hparams.pos == "timing": - targets += timing_signal[:, i:i+1] + targets += timing_signal[:, i:i + 1] return targets decoder_self_attention_bias = ( @@ -297,17 +287,12 @@ def symbols_to_logits_fn(ids, i, cache): targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3) targets = preprocess_targets(targets, i) - bias = decoder_self_attention_bias[:, :, i:i+1, :i+1] + bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1] with tf.variable_scope("body"): - body_outputs = dp( - self.decode, - targets, - cache["encoder_output"], - cache["encoder_decoder_attention_bias"], - bias, - hparams, - cache) + body_outputs = dp(self.decode, targets, cache["encoder_output"], + cache["encoder_decoder_attention_bias"], bias, + hparams, cache) with tf.variable_scope(target_modality.name): logits = target_modality.top_sharded(body_outputs, None, dp)[0] @@ -322,7 +307,8 @@ def symbols_to_logits_fn(ids, i, cache): "layer_%d" % layer: { "k": tf.zeros([batch_size, 0, key_channels]), "v": tf.zeros([batch_size, 0, value_channels]), - } for layer in range(num_layers) + } + for layer in range(num_layers) } # Set 2nd dim to None since it's not invariant in the tf.while_loop @@ -342,19 +328,25 @@ def symbols_to_logits_fn(ids, i, cache): vocab_size = target_modality.top_dimensionality initial_ids = tf.zeros([batch_size], dtype=tf.int32) decoded_ids, _ = beam_search.beam_search( - symbols_to_logits_fn, initial_ids, beam_size, decode_length, - vocab_size, alpha, states=cache) + symbols_to_logits_fn, + initial_ids, + beam_size, + decode_length, + vocab_size, + alpha, + states=cache) if top_beams == 1: decoded_ids = decoded_ids[:, 0, 1:] else: decoded_ids = decoded_ids[:, :top_beams, 1:] else: # Greedy + def inner_loop(i, next_id, decoded_ids, cache): logits, cache = symbols_to_logits_fn(next_id, i, cache) next_id = tf.expand_dims(tf.argmax(logits, axis=-1), axis=1) decoded_ids = tf.concat([decoded_ids, next_id], axis=1) - return i+1, next_id, decoded_ids, cache + return i + 1, next_id, decoded_ids, cache decoded_ids = tf.zeros([batch_size, 0], dtype=tf.int64) next_id = tf.zeros([batch_size, 1], dtype=tf.int64) @@ -384,8 +376,8 @@ def model_fn_body(self, features): inputs = common_layers.flatten4d3d(inputs) - (encoder_input, encoder_self_attention_bias, - _) = (transformer_prepare_encoder(inputs, target_space, hparams)) + (encoder_input, encoder_self_attention_bias, _) = ( + transformer_prepare_encoder(inputs, target_space, hparams)) encoder_input = tf.nn.dropout(encoder_input, 1.0 - hparams.layer_prepostprocess_dropout) @@ -396,29 +388,6 @@ def model_fn_body(self, features): return encoder_output -@registry.register_model -class TransformerDecoder(t2t_model.T2TModel): - """Transformer, decoder only.""" - - def model_fn_body(self, features): - hparams = self._hparams - targets = features["targets"] - - targets = common_layers.flatten4d3d(targets) - - (decoder_input, decoder_self_attention_bias) = transformer_prepare_decoder( - targets, hparams) - - decoder_input = tf.nn.dropout(decoder_input, - 1.0 - hparams.layer_prepostprocess_dropout) - - decoder_output = transformer_decoder( - decoder_input, None, decoder_self_attention_bias, None, hparams) - decoder_output = tf.expand_dims(decoder_output, 2) - - return decoder_output - - def transformer_prepare_encoder(inputs, target_space, hparams): """Prepare one shard of the model for the encoder. @@ -574,12 +543,12 @@ def transformer_decoder(decoder_input, with tf.variable_scope("encdec_attention"): # TODO(llion): Add caching. y = common_attention.multihead_attention( - common_layers.layer_preprocess(x, hparams), - encoder_output, - encoder_decoder_attention_bias, + common_layers.layer_preprocess( + x, hparams), encoder_output, encoder_decoder_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, - hparams.hidden_size, hparams.num_heads, + hparams.hidden_size, + hparams.num_heads, hparams.attention_dropout) x = common_layers.layer_postprocess(x, y, hparams) with tf.variable_scope("ffn"): @@ -1056,3 +1025,19 @@ def transformer_relative_big(): hparams.self_attention_type = "dot_product_relative" hparams.max_relative_position = 20 return hparams + + +@registry.register_hparams +def transformer_tpu(): + """HParams for Transformer model on TPU.""" + hparams = transformer_base() + hparams.use_pad_remover = int(False) # where op not supported + hparams.optimizer = "TrueAdam" + hparams.learning_rate = 0.2 + + # Inputs + # Each example in the batch will be of (padded) length hparams.max_length + hparams.max_length = 64 + hparams.tpu_batch_size_per_shard = 16 + + return hparams diff --git a/tensor2tensor/models/transformer_adv.py b/tensor2tensor/models/transformer_adv.py new file mode 100644 index 000000000..2a12aa389 --- /dev/null +++ b/tensor2tensor/models/transformer_adv.py @@ -0,0 +1,229 @@ +# coding=utf-8 +# Copyright 2017 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Adversarial Transformer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# Dependency imports + +from tensor2tensor.layers import common_layers +from tensor2tensor.models import transformer +from tensor2tensor.models import transformer_vae +from tensor2tensor.utils import registry +from tensor2tensor.utils import t2t_model + +import tensorflow as tf + + +def encode(x, x_space, hparams, name): + """Transformer preparations and encoder.""" + with tf.variable_scope(name): + (encoder_input, encoder_self_attention_bias, + ed) = transformer.transformer_prepare_encoder(x, x_space, hparams) + encoder_input = tf.nn.dropout(encoder_input, 1.0 - hparams.dropout) + return transformer.transformer_encoder( + encoder_input, encoder_self_attention_bias, hparams), ed + + +def decode(encoder_output, encoder_decoder_attention_bias, targets, + hparams, name, reuse=False): + """Transformer decoder.""" + with tf.variable_scope(name, reuse=reuse): + targets = common_layers.flatten4d3d(targets) + + decoder_input, decoder_self_bias = transformer.transformer_prepare_decoder( + targets, hparams) + + decoder_input = tf.nn.dropout(decoder_input, + 1.0 - hparams.layer_prepostprocess_dropout) + + decoder_output = transformer.transformer_decoder( + decoder_input, + encoder_output, + decoder_self_bias, + encoder_decoder_attention_bias, + hparams) + + # Expand since t2t expects 4d tensors. + return tf.expand_dims(decoder_output, axis=2) + + +def reverse_gradient(x, delta=1.0): + return tf.stop_gradient((1.0 + delta) * x) - delta * x + + +def adversary(embedded, inputs, hparams, name, reuse=False): + with tf.variable_scope(name, reuse=reuse): + h0, i0 = common_layers.pad_to_same_length( + embedded, inputs, final_length_divisible_by=16) + h0 = tf.concat([h0, tf.expand_dims(i0, axis=2)], axis=-1) + h0 = tf.layers.dense(h0, hparams.hidden_size, name="io") + h1 = transformer_vae.compress(h0, None, False, hparams, "compress1") + h2 = transformer_vae.compress(h1, None, False, hparams, "compress2") + res_dense = tf.reduce_mean(h2, axis=[1, 2]) + res_single = tf.squeeze(tf.layers.dense(res_dense, 1), axis=-1) + return tf.nn.sigmoid(res_single) + + +def softmax_embed(x, embedding, batch_size, hparams): + """Softmax x and embed.""" + x = tf.reshape(tf.nn.softmax(x), [-1, 34*1024]) + x = tf.matmul(x, embedding) + return tf.reshape(x, [batch_size, -1, 1, hparams.hidden_size]) + + +def adv_transformer_internal(inputs, targets, target_space, hparams): + """Adversarial Transformer, main step used for training.""" + with tf.variable_scope("adv_transformer"): + batch_size = tf.shape(targets)[0] + targets = tf.reshape(targets, [batch_size, -1, 1]) + embedding = tf.get_variable("embedding", [34*1024, hparams.hidden_size]) + targets_emb = tf.gather(embedding, targets) + + # Noisy embedded targets. + targets_noisy = tf.one_hot(targets, 34*1024) + noise_val = hparams.noise_val + targets_noisy += tf.random_uniform(tf.shape(targets_noisy), + minval=-noise_val, maxval=noise_val) + targets_emb_noisy = softmax_embed( + targets_noisy, embedding, batch_size, hparams) + + # Encoder. + if inputs is not None: + inputs_emb = common_layers.flatten4d3d(inputs) + inputs, ed = encode(inputs_emb, target_space, hparams, "input_enc") + else: + ed = None + + # Masking. + masking = common_layers.inverse_lin_decay(60000) + masking *= common_layers.inverse_exp_decay(20000) # Not much at start. + masking -= tf.random_uniform([]) * 0.4 + mask = tf.less(masking, tf.random_uniform(tf.shape(targets))) + mask = tf.expand_dims(tf.to_float(mask), 3) + noise = tf.random_uniform(tf.shape(targets_emb)) + targets_emb = mask * targets_emb + (1.0 - mask) * noise + + # Decoder. + res_dec = decode(inputs, ed, targets_emb, hparams, "decoder") + res = tf.layers.dense(res_dec, 34*1024, name="res_sm") + res_emb = softmax_embed(res, embedding, batch_size, hparams) + + # Extra steps. + extra_step_prob = masking * 0.6 + if hparams.mode != tf.estimator.ModeKeys.TRAIN: + extra_step_prob = 1.0 + for _ in xrange(hparams.extra_steps): + def another_step(emb): + res_dec = decode(inputs, ed, emb, hparams, "decoder", reuse=True) + res = tf.layers.dense(res_dec, 34*1024, name="res_sm", reuse=True) + return softmax_embed(res, embedding, batch_size, hparams), res + res_emb, res = tf.cond(tf.less(tf.random_uniform([]), extra_step_prob), + lambda e=res_emb: another_step(e), + lambda: (res_emb, res)) + + # Adversary. + delta = masking * hparams.delta_max + true_logit = adversary(tf.stop_gradient(targets_emb_noisy), + tf.stop_gradient(inputs + inputs_emb), + hparams, "adversary") + gen_logit = adversary(reverse_gradient(res_emb, delta), + tf.stop_gradient(inputs + inputs_emb), + hparams, "adversary", reuse=True) + losses = {"adv": gen_logit - true_logit} + res = tf.stop_gradient(masking * res) + (1.0 - masking) * res + return res, losses + + +@registry.register_model +class TransformerAdv(t2t_model.T2TModel): + """Adversarial Transformer.""" + + def model_fn_body(self, features): + inputs = features.get("inputs", None) + return adv_transformer_internal( + inputs, features["targets_raw"], + features["target_space_id"], self._hparams) + + def infer(self, features=None, decode_length=50, beam_size=1, top_beams=1, + last_position_only=False, alpha=0.0): + """Produce predictions from the model.""" + if not features: + features = {} + inputs_old = None + if "inputs" in features and len(features["inputs"].shape) < 4: + inputs_old = features["inputs"] + features["inputs"] = tf.expand_dims(features["inputs"], 2) + + # Create an initial targets tensor. + if "partial_targets" in features: + initial_output = tf.convert_to_tensor(features["partial_targets"]) + else: + batch_size = tf.shape(features["inputs"])[0] + length = tf.shape(features["inputs"])[1] + initial_output = tf.zeros((batch_size, 2 * length, 1, 1), dtype=tf.int64) + + features["targets"] = initial_output + sharded_logits, _ = self.model_fn( + features, False, last_position_only=last_position_only) + sharded_samples = self._data_parallelism(tf.argmax, sharded_logits, 4) + samples = tf.concat(sharded_samples, 0) + + # More steps. + how_many_more_steps = 5 + for _ in xrange(how_many_more_steps): + with tf.variable_scope(tf.get_variable_scope(), reuse=True): + features["targets"] = samples + sharded_logits, _ = self.model_fn( + features, False, last_position_only=last_position_only) + sharded_samples = self._data_parallelism(tf.argmax, sharded_logits, 4) + samples = tf.concat(sharded_samples, 0) + + if inputs_old is not None: # Restore to not confuse Estimator. + features["inputs"] = inputs_old + return samples + + +@registry.register_hparams +def transformer_adv_small(): + """Set of hyperparameters.""" + hparams = transformer.transformer_small() + hparams.batch_size = 2048 + hparams.learning_rate_warmup_steps = 4000 + hparams.num_hidden_layers = 3 + hparams.hidden_size = 384 + hparams.filter_size = 2048 + hparams.label_smoothing = 0.0 + hparams.weight_decay = 0.1 + hparams.symbol_modality_skip_top = int(True) + hparams.add_hparam("num_compress_steps", 2) + hparams.add_hparam("extra_steps", 0) + hparams.add_hparam("noise_val", 0.3) + hparams.add_hparam("delta_max", 2.0) + return hparams + + +@registry.register_hparams +def transformer_adv_base(): + """Set of hyperparameters.""" + hparams = transformer_adv_small() + hparams.batch_size = 1024 + hparams.hidden_size = 512 + hparams.filter_size = 4096 + hparams.num_hidden_layers = 6 + return hparams diff --git a/tensor2tensor/models/transformer_moe.py b/tensor2tensor/models/transformer_moe.py index c8a32a667..014a390c6 100644 --- a/tensor2tensor/models/transformer_moe.py +++ b/tensor2tensor/models/transformer_moe.py @@ -21,9 +21,9 @@ from __future__ import division from __future__ import print_function -# Dependency imports +import functools -from six.moves import xrange # pylint: disable=redefined-builtin +# Dependency imports from tensor2tensor.layers import common_attention from tensor2tensor.layers import common_hparams @@ -36,11 +36,43 @@ import tensorflow as tf +# The transformer architecture can be defined using the layer_types hparams. +# If not defined, the default types and num_hidden_layers are used as fallback +# values. +# +# Examples of usage: +# "a/a/a/a/a/a": Original base transformer (6 encoder and decoder layers of +# multihead full attention) +# "a/a/a-moe/a": 4 layers with 1 moe at layer 3 +# "loc/red/loc/red": Alternate between local and memory compressed attention +# "a/a/a#": Encoder only model (3 layers) +# "#a/a/a": Decoder only model (3 layers) +# "a/a-moe#a/a/a": Encoder (2 layers with 1 moe), decoder (3 layers) +# Note that all combinaisons are not necessarily possibles (some attention +# types are not necessarily compatible with the encoder, or can't accept certain +# types of masking) + +SEP_ENCODEC = "#" +SEP_LAYER = "/" +SEP_FF = "-" + + +def partial(fct, *args, **kwargs): + """Wrapper around functools.partial for Python 2 compatibility with wraps.""" + new_fct = functools.partial(fct, *args, **kwargs) + new_fct = functools.wraps(fct)(new_fct) + return new_fct + + @registry.register_model class TransformerMoe(t2t_model.T2TModel): """Attention net. See file docstring.""" + @expert_utils.add_var_scope("transformer_moe") def model_fn_body_sharded(self, sharded_features): + + # ========= Prepare the input and target ========= + hparams = self._hparams dp = self._data_parallelism targets = sharded_features["targets"] @@ -50,10 +82,10 @@ def model_fn_body_sharded(self, sharded_features): inputs = dp(common_layers.flatten4d3d, inputs) targets = dp(common_layers.flatten4d3d, targets) - def preprocess(x): + def dp_preprocess(x): return dp(common_layers.layer_preprocess, x, hparams) - def postprocess(x, y): + def dp_postprocess(x, y): return dp(common_layers.layer_postprocess, x, y, hparams) (encoder_input, encoder_self_attention_bias, @@ -66,98 +98,299 @@ def postprocess(x, y): 1.0 - hparams.layer_prepostprocess_dropout) decoder_input = dp(tf.nn.dropout, decoder_input, 1.0 - hparams.layer_prepostprocess_dropout) - extra_loss = 0 + cache = dict(extra_loss=0) moe_hidden_sizes = [int(s) for s in hparams.moe_hidden_sizes.split(",")] expert_fn = expert_utils.ffn_expert_fn( hparams.hidden_size, moe_hidden_sizes, hparams.hidden_size) + + # ========= Define some utils decorators ========= + + def prepostprocess(fct): + """Add pre and post processing.""" + # WARNING: Should be applied after dp (pre/post-process use dp and + # can be applied to function which doesn't use dp) + @functools.wraps(fct) + def decorated(x, *args, **kwargs): + x = dp_preprocess(x) + y = fct(x, *args, **kwargs) + return dp_postprocess(x, y) + return decorated + + def dp_wrapper(fct): + """Encapsulate the function in a data parallelism object.""" + @functools.wraps(fct) + def decorated(*args, **kwargs): + return dp(fct, *args, **kwargs) + return decorated + + def add_kwargs( + fct, + enco_kwargs=None, + deco_kwargs=None, + endeco_kwargs=None, # Enco-deco attention: overwrite deco_kwargs + ): + """Allow to have different arguments for the encoder and decoder.""" + # WARNING: If this decorator is applied before dp_wrapper, the kwargs + # may not be correctly dipatched across the devices. + @functools.wraps(fct) + def decorated(*args, **kwargs): + current_scope = tf.contrib.framework.get_name_scope() + if "/encoder/" in current_scope: + kwargs.update(enco_kwargs or {}) + elif "/decoder/" in current_scope: + kwargs.update(deco_kwargs or {}) + if "/att_ende_" in current_scope: + kwargs.update(endeco_kwargs or {}) + return fct(*args, **kwargs) + return decorated + + def capture_extra_loss(fct, loss_coef=1.0): + """Capture the additional loss.""" + @functools.wraps(fct) + def decorated(*args, **kwargs): + y, loss = fct(*args, **kwargs) + cache["extra_loss"] += loss * loss_coef + return y + return decorated + + def remove_kwargs(fct, extra_params): + """Remove some unused parameters.""" + @functools.wraps(fct) + def decorated(*args, **kwargs): + for k in extra_params: # Remove the extra params + kwargs.pop(k, None) + return fct(*args, **kwargs) + return decorated + + # def pad_remover(fct): + # """Remove/restore the padding on the input.""" + # @functools.wraps(fct) + # def decorated(x, *args, **kwargs): + # x = pad_remover.remove(x) + # x = fct(x, *args, **kwargs) + # x = pad_remover.restore(x) + # return x + # return decorated + + # ========= Define the available layers ========= + total_key_depth = hparams.attention_key_channels or hparams.hidden_size + total_value_depth = hparams.attention_value_channels or hparams.hidden_size + + # Multi-head full attention layer + multihead_attention = partial( + common_attention.multihead_attention, + total_key_depth=total_key_depth, + total_value_depth=total_value_depth, + output_depth=hparams.hidden_size, + num_heads=hparams.num_heads, + dropout_rate=hparams.attention_dropout, + ) + multihead_attention = dp_wrapper(multihead_attention) + multihead_attention = add_kwargs( # After dp to correctly dispatch kwargs + multihead_attention, + enco_kwargs={"bias": encoder_self_attention_bias}, + deco_kwargs={"bias": decoder_self_attention_bias}, + endeco_kwargs={"bias": encoder_decoder_attention_bias}, + ) + multihead_attention = prepostprocess(multihead_attention) + + # Local attention layer + # Reuse same parameters as multihead_attention (dp and pre/post-processing + # already applied) + # Only works for self attention. Always mask the future. + local_attention = partial( + multihead_attention, + block_length=hparams.attention_loc_block_length, + attention_type="local_mask_right", + ) + + # Memory-compressed multihead self attention layer + # Only works for self attention. Always mask the future. + compressed_attention = partial( + common_attention.multihead_self_attention_reduced, + factor=hparams.attention_red_factor, + nonlinearity=hparams.attention_red_nonlinearity, + reduction_type=hparams.attention_red_type, + multihead_params=dict( + total_key_depth=total_key_depth, + total_value_depth=total_value_depth, + num_heads=hparams.num_heads, + dropout_rate=hparams.attention_dropout, + ) + ) + compressed_attention = remove_kwargs( + compressed_attention, ["memory_antecedent"]) + compressed_attention = dp_wrapper(compressed_attention) + compressed_attention = prepostprocess(compressed_attention) + + # Mixture of expert layer + distributed_moe = partial( + expert_utils.distributed_moe, + dp, + self._ps_devices, + train=hparams.mode == tf.estimator.ModeKeys.TRAIN, + input_size=hparams.hidden_size, + expert_fn=expert_fn, + num_experts=hparams.moe_num_experts, + k=hparams.moe_k, + loss_coef=hparams.moe_loss_coef + ) + distributed_moe = capture_extra_loss(distributed_moe) + distributed_moe = prepostprocess(distributed_moe) + + # FC layer + conv_hidden_relu = partial( + common_layers.conv_hidden_relu, + hidden_size=hparams.filter_size, + output_size=hparams.hidden_size, + dropout=hparams.relu_dropout, + ) + conv_hidden_relu = dp_wrapper(conv_hidden_relu) + conv_hidden_relu = prepostprocess(conv_hidden_relu) + + # Separable convolution layer + # Reuse conv_hidden_relu (dp and pre/post-processing already applied) + # Mask the future for the decoder only + sep_conv_relu = partial( + conv_hidden_relu, + # Parameters copied from the transformer model, could add hparams + kernel_size=(3, 1), + second_kernel_size=(31, 1), + ) + sep_conv_relu = add_kwargs( + sep_conv_relu, + enco_kwargs={"padding": "SAME"}, + deco_kwargs={"padding": "LEFT"}, # Mask future for decoder + ) + + # This dictionary contains the list of all available layers + available_layers = dict( + # Attention layers + a=multihead_attention, # Standard multihead full attention + loc=local_attention, # Local attention + red=compressed_attention, # Memory-compressed attention + mem=None, # Memory efficient + # Feed-forward layers + moe=distributed_moe, # Mixture of expert layer + sep=sep_conv_relu, # Separable convolution + fc=conv_hidden_relu, # Fully connected + ) + + def extract_layer_types(layer_types): + """Parse the layer string. + + Args: + layer_types (str): String containing the network architecture. See + top file comment for examples of format. + + Returns: + list[tuple[str, str]]: Encoder layers: list of (attention, feed-forward) + list[tuple[str, str, str]]: Decoder layers: list of (self-attention, + enc-dec attention, feed-forward) + """ + # If the architecture has not explicitly been set, we just construct a + # standard transformer with the fallback values + if not layer_types: + layer_types = SEP_LAYER.join( + [hparams.default_att] * hparams.num_hidden_layers) + + # If encoder not explicitly defined, the encoder will have the same + # structure as the decoder + layer_types = layer_types.split(SEP_ENCODEC) + if len(layer_types) == 1: + layer_types *= 2 + + # Some models don't need the encoder (ex: language modeling) + # TODO(epot): What are the other conditions (has_input ?) + if hparams.prepend_mode != "none": + layer_types[0] = "" + + # Extend the blocks and fill them with the default values if not specified + final_layers = ([], []) + for i, blocks_str in enumerate(layer_types): + for blocks_str in blocks_str.split(SEP_LAYER): + if not blocks_str: + continue + blocks_list = blocks_str.split(SEP_FF) + # Eventually use the fallback values for the layer_types. If the + # encoder is empty, do not use the enco-deco attention. + self_att = blocks_list[0] or hparams.default_att + ende_att = hparams.default_att if layer_types[0] else "_" + ff = hparams.default_ff + if len(blocks_list) > 1: + ff = blocks_list[-1] + if len(blocks_list) == 3: + ende_att = blocks_list[1] + if i == 0: # Encoder + blocks_tuple = (self_att, ff) + elif i == 1: # Decoder + blocks_tuple = (self_att, ende_att, ff) + final_layers[i].append(blocks_tuple) + + return final_layers + + # ========= Construct the transformer encoder and decoder ========= + + encoder_layers, decoder_layers = extract_layer_types(hparams.layer_types) + + # Display the encoder-decoder architecture + def print_layer(name, layers): + tf.logging.info("{} architecture:".format(name)) + for i, l in enumerate(layers): + tf.logging.info(" * Layer {}: {}".format(i, " - ".join(l))) + print_layer("Encoder", encoder_layers) + print_layer("Decoder", decoder_layers) + + encoder_outputs = [] + x = encoder_input - for layer in xrange(hparams.num_hidden_layers): - with tf.variable_scope("encoder_layer_%d" % layer): - with tf.variable_scope("encoder_self_attention"): - y = dp( - common_attention.multihead_attention, - preprocess(x), - None, - encoder_self_attention_bias, - hparams.attention_key_channels or hparams.hidden_size, - hparams.attention_value_channels or hparams.hidden_size, - hparams.hidden_size, - hparams.num_heads, - hparams.attention_dropout) - x = postprocess(x, y) - with tf.variable_scope("ffn"): - if str(layer) in hparams.moe_layers_encoder.split(","): - y, loss = expert_utils.distributed_moe( - dp, - self._ps_devices, - preprocess(x), - hparams.mode == tf.estimator.ModeKeys.TRAIN, - input_size=hparams.hidden_size, - expert_fn=expert_fn, - num_experts=hparams.moe_num_experts, - k=hparams.moe_k, - loss_coef=hparams.moe_loss_coef) - extra_loss += loss - else: - y = dp( - common_layers.conv_hidden_relu, - preprocess(x), - hparams.filter_size, - hparams.hidden_size, - dropout=hparams.relu_dropout) - x = postprocess(x, y) - encoder_output = preprocess(x) + with tf.variable_scope("encoder"): + for layer_num, block_types in enumerate(encoder_layers): + # Each encoder layers is composed of two blocks: + # * self-attention block + # * feed-forward block + att_type, ff_type = block_types + with tf.variable_scope("layer_{}".format(layer_num)): + with tf.variable_scope("att_{}".format(att_type)): + x = available_layers[att_type]( + x, + memory_antecedent=None, + ) + with tf.variable_scope("ff_{}".format(ff_type)): + x = available_layers[ff_type](x) + encoder_outputs.append(x) + if encoder_outputs: + encoder_outputs[-1] = dp_preprocess(x) + x = decoder_input - for layer in xrange(hparams.num_hidden_layers): - with tf.variable_scope("decoder_layer_%d" % layer): - with tf.variable_scope("decoder_self_attention"): - y = dp( - common_attention.multihead_attention, - preprocess(x), - None, - decoder_self_attention_bias, - hparams.attention_key_channels or hparams.hidden_size, - hparams.attention_value_channels or hparams.hidden_size, - hparams.hidden_size, - hparams.num_heads, - hparams.attention_dropout) - x = postprocess(x, y) - with tf.variable_scope("encoder_decoder_attention"): - y = dp( - common_attention.multihead_attention, - preprocess(x), - encoder_output, - encoder_decoder_attention_bias, - hparams.attention_key_channels or hparams.hidden_size, - hparams.attention_value_channels or hparams.hidden_size, - hparams.hidden_size, - hparams.num_heads, - hparams.attention_dropout) - x = postprocess(x, y) - with tf.variable_scope("ffn"): - if str(layer) in hparams.moe_layers_decoder.split(","): - y, loss = expert_utils.distributed_moe( - dp, - self._ps_devices, - preprocess(x), - hparams.mode == tf.estimator.ModeKeys.TRAIN, - input_size=hparams.hidden_size, - expert_fn=expert_fn, - num_experts=hparams.moe_num_experts, - k=hparams.moe_k, - loss_coef=hparams.moe_loss_coef) - extra_loss += loss - else: - y = dp( - common_layers.conv_hidden_relu, - preprocess(x), - hparams.filter_size, - hparams.hidden_size, - dropout=hparams.relu_dropout) - x = postprocess(x, y) - x = preprocess(x) + with tf.variable_scope("decoder"): + for layer_num, block_types in enumerate(decoder_layers): + # Each decoder layers is composed of three blocks: + # * self-attention block + # * enco-deco attention block (optional) + # * feed-forward block + self_att_type, att_ende_type, ff_type = block_types + with tf.variable_scope("layer_{}".format(layer_num)): + with tf.variable_scope("self_att_{}".format(self_att_type)): + x = available_layers[self_att_type]( + x, + memory_antecedent=None, + ) + with tf.variable_scope("att_ende_{}".format(att_ende_type)): + # Only add the enco-deco attention layer if there is an encoder + if encoder_outputs: + x = available_layers[att_ende_type]( + x, + memory_antecedent=encoder_outputs[-1], + ) + with tf.variable_scope("ff_{}".format(ff_type)): + x = available_layers[ff_type](x) + # If normalization is done in layer_preprocess, then it should also be + # done on the output, since the output can grow very large, being the sum + # of a whole stack of unnormalized layer outputs. + x = dp_preprocess(x) decoder_output = dp(tf.expand_dims, x, 2) - return decoder_output, extra_loss + return decoder_output, cache["extra_loss"] @registry.register_hparams @@ -185,6 +418,9 @@ def transformer_moe_base(): hparams.num_sampled_classes = 0 hparams.label_smoothing = 0.0 hparams.shared_embedding_and_softmax_weights = int(True) + # According to noam, ("n", "da") seems better for harder-to-learn models + hparams.layer_preprocess_sequence = "n" + hparams.layer_postprocess_sequence = "da" hparams.add_hparam("filter_size", 2048) # Add new ones like this. # attention-related flags @@ -192,8 +428,11 @@ def transformer_moe_base(): hparams.add_hparam("attention_key_channels", 0) hparams.add_hparam("attention_value_channels", 0) hparams.add_hparam("ffn_layer", "conv_hidden_relu") - hparams.add_hparam("parameter_attention_key_channels", 0) - hparams.add_hparam("parameter_attention_value_channels", 0) + # Other attention types params + hparams.add_hparam("attention_loc_block_length", 256) + hparams.add_hparam("attention_red_factor", 3) + hparams.add_hparam("attention_red_type", "conv") + hparams.add_hparam("attention_red_nonlinearity", "none") # All hyperparameters ending in "dropout" are automatically set to 0.0 # when not in training mode. hparams.add_hparam("attention_dropout", 0.0) @@ -201,28 +440,54 @@ def transformer_moe_base(): hparams.add_hparam("pos", "timing") # timing, none hparams.add_hparam("nbr_decoder_problems", 1) hparams.add_hparam("proximity_bias", int(False)) - # FLAGS RELATED TO MIXTURE-OF-EXPERTS - # comma-separated list of layer numbers. - # At each of these layers, we replace the ffn with a mixture of experts. - hparams.add_hparam("moe_layers_encoder", "2") - hparams.add_hparam("moe_layers_decoder", "2") + + # Decoder layers type. If set, num_decoder_layers parameter will be ignored + # and the number of decoder layer will be deduced from the string + # See top file comment for example of usage + hparams.add_hparam("layer_types", "") + # Default attention type (ex: a, loc, red,...) and feed-forward type (ex: fc, + # sep, moe,...) + hparams.add_hparam("default_att", "a") + hparams.add_hparam("default_ff", "fc") + return hparams @registry.register_hparams -def transformer_no_moe(): - """Without the mixture of experts (for comparison).""" +def transformer_moe_8k(): + """Hyper parameters specifics for long sequence generation.""" hparams = transformer_moe_base() - hparams.moe_layers_encoder = "" - hparams.moe_layers_decoder = "" + + hparams.batch_size = 8192 + hparams.max_length = 0 # max_length == batch_size + hparams.eval_drop_long_sequences = int(True) + hparams.min_length_bucket = 256 # Avoid cyclic problems for big batches + + hparams.default_ff = "sep" + hparams.hidden_size = 1024 + return hparams @registry.register_hparams -def transformer_moe_1b(): - """1-billion parameter model - requires multi-gpu sync training.""" - hparams = transformer_moe_base() - hparams.moe_n1 = 128 - hparams.moe_layers_encoder = "1,3" - hparams.moe_layers_decoder = "1,3" +def transformer_moe_12k(): + """Hyper parameters specifics for long sequence generation.""" + hparams = transformer_moe_8k() + hparams.batch_size = 12000 + # At 12k, the softmax become the memory bottleneck + hparams.factored_logit = int(True) return hparams + + +@registry.register_hparams +def transformer_moe_prepend_8k(): + """Model which formulate a seq2seq problem as language modeling.""" + hparams = transformer_moe_8k() + hparams.prepend_mode = "prepend_inputs_masked_attention", + hparams.eval_drop_long_sequences = int(False), + hparams.max_input_seq_length = 7500, + hparams.layer_types = "loc/red/loc-moe/red/loc" + hparams.moe_num_experts = 256 + return hparams + + diff --git a/tensor2tensor/models/transformer_sketch.py b/tensor2tensor/models/transformer_sketch.py new file mode 100644 index 000000000..b7bd9b1ef --- /dev/null +++ b/tensor2tensor/models/transformer_sketch.py @@ -0,0 +1,162 @@ +# coding=utf-8 +# Copyright 2017 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Transformer Sketch for im2sketch problems. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# Dependency imports + +from tensor2tensor.layers import common_hparams +from tensor2tensor.models import transformer +from tensor2tensor.models import transformer_vae +from tensor2tensor.models.transformer import transformer_base +from tensor2tensor.models.transformer import transformer_n_da +from tensor2tensor.models.transformer import transformer_small +from tensor2tensor.utils import registry + + +@registry.register_model +class TransformerSketch(transformer.Transformer): + """Transformer with strided convolutions.""" + + def encode(self, inputs, target_space, hparams): + """Add two layers strided convolutions ontop of encode.""" + hparams.num_compress_steps = 2 + compressed_inputs = transformer_vae.compress(inputs, c=None, is_2d=True, + hparams=hparams, + name="convolutions") + + return super(TransformerSketch, self).encode( + compressed_inputs, target_space, hparams) + + +@registry.register_hparams +def transformer_sketch(): + """Basic transformer_sketch hparams.""" + hparams = transformer_n_da() + hparams.batch_size = 2048 + hparams.max_length = 784 + hparams.clip_grad_norm = 5. + hparams.learning_rate_decay_scheme = "noam" + hparams.learning_rate = 0.2 + hparams.learning_rate_warmup_steps = 10000 + hparams.num_hidden_layers = 6 + hparams.initializer = "orthogonal" + hparams.sampling_method = "random" + return hparams + + +@registry.register_hparams +def transformer_base_sketch(): + """Parameters based on base.""" + hparams = transformer_base() + hparams.batch_size = 2048 + hparams.max_length = 784 + hparams.clip_grad_norm = 5. + hparams.learning_rate_decay_scheme = "noam" + hparams.learning_rate_warmup_steps = 8000 + hparams.learning_rate = 0.2 + hparams.num_hidden_layers = 6 + hparams.initializer = "orthogonal" + hparams.sampling_method = "random" + return hparams + + +@registry.register_hparams +def transformer_small_sketch(): + """Modified transformer_small.""" + hparams = transformer_small() + hparams.batch_size = 2048 + hparams.max_length = 784 + hparams.clip_grad_norm = 5. + hparams.learning_rate_decay_scheme = "noam" + hparams.learning_rate = 0.1 + hparams.initializer = "orthogonal" + hparams.sampling_method = "random" + hparams.learning_rate_warmup_steps = 10000 + return hparams + + +@registry.register_hparams +def transformer_sketch_2layer(): + hparams = transformer_sketch() + hparams.num_hidden_layers = 2 + return hparams + + +@registry.register_hparams +def transformer_sketch_4layer(): + hparams = transformer_sketch() + hparams.num_hidden_layers = 4 + return hparams + + +@registry.register_hparams +def transformer_sketch_6layer(): + hparams = transformer_sketch() + hparams.num_hidden_layers = 6 + return hparams + + +@registry.register_ranged_hparams("transformer_sketch_ranged") +def transformer_sketch_ranged(rhp): + """Range of hparams for vizier.""" + + hparams = transformer_sketch() + common_hparams.fill_ranged_hparams_from_hparams(hparams, rhp) + + rhp.set_categorical("ffn_layer", + ["conv_hidden_relu_with_sepconv", "conv_hidden_relu"]) + rhp.set_discrete("batch_size", [1024, 2048, 4096]) + rhp.set_discrete("num_hidden_layers", [2, 3, 4, 5, 6]) + rhp.set_discrete("hidden_size", [32, 64, 128, 256, 512, 1024], + scale=rhp.LOG_SCALE) + rhp.set_discrete("kernel_height", [1, 3, 5, 7]) + rhp.set_discrete("kernel_width", [1, 3, 5, 7]) + rhp.set_discrete("compress_steps", [0, 1, 2]) + rhp.set_float("dropout", 0.0, 0.5) + rhp.set_float("weight_decay", 1e-4, .03, scale=rhp.LOG_SCALE) + rhp.set_float("label_smoothing", 0.0, 0.2) + rhp.set_float("clip_grad_norm", 0.01, 8.0, scale=rhp.LOG_SCALE) + rhp.set_float("learning_rate", 0.1, 1.0, scale=rhp.LOG_SCALE) + rhp.set_categorical("initializer", + ["uniform", "orthogonal", "uniform_unit_scaling"]) + rhp.set_float("initializer_gain", 0.5, 3.5) + rhp.set_categorical("learning_rate_decay_scheme", + ["none", "sqrt", "noam", "exp10k"]) + rhp.set_float("optimizer_adam_epsilon", 1e-7, 1e-2, scale=rhp.LOG_SCALE) + rhp.set_float("optimizer_adam_beta1", 0.8, 0.9) + rhp.set_float("optimizer_adam_beta2", 0.995, 0.999) + rhp.set_categorical("optimizer", [ + "Adam", "Adagrad", "Momentum", "RMSProp", "SGD", "YellowFin"]) + + +@registry.register_hparams +def transformer_opt(): + """Parameters that work better.""" + hparams = transformer_sketch() + hparams.batch_size = 1024 + hparams.learning_rate = 0.28 + hparams.num_hidden_layers = 3 + hparams.dropout = 0.35 + hparams.ffn_layer = "conv_hidden_relu_with_sepconv" + hparams.hidden_size = 128 + hparams.initializer_gain = 2.6 + hparams.weight_decay = 0. + return hparams diff --git a/tensor2tensor/models/transformer_vae.py b/tensor2tensor/models/transformer_vae.py index 67ec86ef5..d936ce72f 100644 --- a/tensor2tensor/models/transformer_vae.py +++ b/tensor2tensor/models/transformer_vae.py @@ -128,7 +128,7 @@ def dae(x, hparams, name): steps = hparams.kl_warmup_steps gumbel_samples *= common_layers.inverse_exp_decay(steps // 5) * 0.5 temperature = 1.2 - common_layers.inverse_lin_decay(steps) - # 30% of the time keep reasonably high temperature to keep learning. + # 10% of the time keep reasonably high temperature to keep learning. temperature = tf.cond(tf.less(tf.random_uniform([]), 0.9), lambda: temperature, lambda: tf.random_uniform([], minval=0.5, maxval=1.0)) @@ -216,6 +216,84 @@ def kmeans(x, means, hparams, name): return x_means_hot, tf.reduce_mean(kl) # * 10.0 +def bit_to_int(x_bit, nbits): + """Turn x_bit representing numbers bitwise (lower-endian) to int tensor.""" + x_l = tf.stop_gradient(tf.reshape(x_bit, [-1, nbits])) + x_labels = [] + for i in range(nbits): + x_labels.append(x_l[:, i] * 2**i) + res = sum(x_labels) + return tf.to_int32(tf.reshape(res, tf.shape(x_bit)[:-1])) + + +def int_to_bit(x_int, nbits): + """Turn x_int representing numbers into a bitwise (lower-endian) tensor.""" + x_l = tf.expand_dims(x_int, axis=-1) + x_labels = [] + for i in range(nbits): + x_labels.append(tf.floormod(tf.floordiv(x_l, 2**i), 2)) + res = tf.concat(x_labels, axis=-1) + return tf.to_float(res) + + +def bottleneck(x, hparams, filter_size, name): + """Bottleneck.""" + def embed1(x): + if hparams.bottleneck_kind == "semhash": + c = int_to_bit(x, c_size) + h1a = tf.layers.dense(c, filter_size, name="vch1a") + h1b = tf.layers.dense(1.0 - c, filter_size, name="vch1b") + return h1a + h1b + elif hparams.bottleneck_kind == "gumbel-softmax": + hot = tf.one_hot(x, hparams.v_size) + with tf.variable_scope(name, reuse=True): + return tf.layers.dense(hot, hparams.hidden_size, name="dae_dense") + + def embed(x): + with tf.variable_scope(name, reuse=True): + h1 = embed1(x) + h2 = tf.layers.dense(tf.nn.relu(h1), filter_size, name="vch2") + res = tf.layers.dense(tf.nn.relu(h2), hparams.hidden_size, name="vcfin") + return res + + with tf.variable_scope(name): + c_size = hparams.c_size + l = tf.constant(0.0) + if hparams.bottleneck_kind == "dense": + c = tf.layers.dense(x, c_size, name="vcc") + h1 = tf.layers.dense(c, filter_size, name="vch1") + if hparams.bottleneck_kind == "semhash": + c = tf.layers.dense(x, c_size, name="vcc") + y_clean = common_layers.saturating_sigmoid(c) + tf.summary.histogram("y_clean", tf.reshape(y_clean, [-1])) + # l = tf.reduce_mean(y_clean * (1.0 - y_clean)) + if hparams.noise_dev > 0 and hparams.mode == tf.estimator.ModeKeys.TRAIN: + dev = hparams.noise_dev + noise = tf.truncated_normal(tf.shape(c), mean=0.0, stddev=dev) + y = common_layers.saturating_sigmoid(c + noise) + else: + y = y_clean + d = tf.to_float(tf.less(0.5, y)) + y_discrete = tf.stop_gradient(d) + y - tf.stop_gradient(y) + pd = common_layers.inverse_exp_decay(hparams.startup_steps * 2) + pd *= hparams.d_mix + pd = pd if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0 + c = tf.cond(tf.less(tf.random_uniform([]), pd), + lambda: y_discrete, lambda: y) + h1a = tf.layers.dense(c, filter_size, name="vch1a") + h1b = tf.layers.dense(1.0 - c, filter_size, name="vch1b") + h1 = h1a + h1b + dx = tf.to_int32(tf.stop_gradient(d)) + c = bit_to_int(dx, c_size) + if hparams.bottleneck_kind == "gumbel-softmax": + _, hot, l = dae(x, hparams, name) + c = tf.argmax(hot, axis=-1) + h1 = tf.layers.dense(hot, hparams.hidden_size, name="dae_dense") + h2 = tf.layers.dense(tf.nn.relu(h1), filter_size, name="vch2") + res = tf.layers.dense(tf.nn.relu(h2), hparams.hidden_size, name="vcfin") + return res, c, l, embed + + def compress(x, c, is_2d, hparams, name): """Compress.""" with tf.variable_scope(name): @@ -272,6 +350,32 @@ def decode(cond_vec, cond_add, gold, c, ed, hparams, name): return transformer.transformer_decoder(decoder_input, c, bias, ed, hparams) +def decode_transformer(encoder_output, + encoder_decoder_attention_bias, + targets, + hparams, + name): + """Original Transformer decoder.""" + with tf.variable_scope(name): + targets = common_layers.flatten4d3d(targets) + + decoder_input, decoder_self_bias = transformer.transformer_prepare_decoder( + targets, hparams) + + decoder_input = tf.nn.dropout(decoder_input, + 1.0 - hparams.layer_prepostprocess_dropout) + + decoder_output = transformer.transformer_decoder( + decoder_input, + encoder_output, + decoder_self_bias, + encoder_decoder_attention_bias, + hparams) + + # Expand since t2t expects 4d tensors. + return tf.expand_dims(decoder_output, axis=2) + + def expand_batch(x, mul): """Expand on batch by mul times.""" cx = tf.expand_dims(x, axis=1) @@ -298,18 +402,6 @@ def ae_compress(x, is_2d, hparams, name, reuse=None): hot, loss = bit_vae(cur, hparams, "bvae") else: hot, loss, _, _ = vae(cur, hparams.z_size, "vae") - # Do a second level vae with some probability. - if hparams.z_size2 > 0: - prob_z2 = common_layers.inverse_exp_decay(hparams.startup_steps*2) * 0.8 - if hparams.mode != tf.contrib.learn.ModeKeys.TRAIN: - prob_z2 = 1.0 - def vae2(): - hot2, loss2, _, _ = vae(hot, hparams.z_size2, "vae2") - ret = tf.layers.dense(hot2, hparams.z_size) - return mix(ret, hot, hparams.startup_steps * 2), loss2 - hot, loss2 = tf.cond(tf.less(tf.random_uniform([]), prob_z2), - vae2, lambda: (hot, tf.constant(0.0))) - loss += loss2 * 0.1 return cur, hot, loss if hparams.use_gumbel_softmax: _, hot, loss = dae(cur, hparams, "dae") @@ -389,90 +481,127 @@ def ffn(x, hparams, name): return common_layers.layer_postprocess(x, y, hparams) -def ae_transformer_internal(inputs, targets, target_space, hparams): +def multinomial_sample(x, vocab_size, temperature): + """Multinomial sampling from a n-dimensional tensor.""" + samples = tf.multinomial(tf.reshape(x, [-1, vocab_size]) / temperature, 1) + reshaped_samples = tf.reshape(samples, tf.shape(x)[:-1]) + return tf.to_int32(reshaped_samples) + + +def ae_latent_sample(t_c, inputs, ed, embed, iters, hparams): + """Sample from the latent space in the autoencoder.""" + t_pred = decode_transformer(inputs, ed, t_c, hparams, "extra") + t_pred = tf.layers.dense(t_pred, 2**16, name="extra_logits") + t_bit = multinomial_sample(t_pred, 2**16, hparams.sampling_temp) + for i in xrange(iters): + t_bit_prev = t_bit + with tf.variable_scope(tf.get_variable_scope(), reuse=True): + t_c = embed(t_bit) + t_pred = decode_transformer(inputs, ed, t_c, hparams, "extra") + t_pred = tf.layers.dense(t_pred, 2**16, name="extra_logits") + t_bit = multinomial_sample(t_pred, 2**16, hparams.sampling_temp) + t_bit = tf.concat([t_bit_prev[:, :(i+1), :], + t_bit[:, (i+1):, :]], axis=1) + return t_bit + + +def ae_transformer_internal(inputs, targets, target_space, hparams, + beam_size, cache=None): """AE Transformer, main step used for training.""" + hparams.z_size = hparams.hidden_size with tf.variable_scope("ae_transformer"): # Prepare inputs, targets, k. - k = 2**hparams.num_compress_steps - _, targets = common_layers.pad_to_same_length( - targets, targets, final_length_divisible_by=k) - inputs = common_layers.flatten4d3d(inputs) - inputs, ed = encode(inputs, target_space, hparams, "input_enc") - - # Compress and ae. - ae, hot, kl = ae_compress(targets, hparams.is_2d, hparams, "ae") - tf.summary.histogram("hot", tf.reshape(tf.argmax(hot, axis=-1), [-1])) - emb = ae_embed(hot, hparams, "ae", reuse=True) - - # Compress context and run autoregressive decoder on emb-hot. - if hparams.do_vae: - reconstruct_loss = 0.0 + orig_targets = targets + batch_size = tf.shape(orig_targets)[0] + targets = tf.reshape(targets, [batch_size, -1, 1, hparams.hidden_size]) + k = hparams.num_compress_steps + + # Encoder. + if inputs is not None: + inputs = common_layers.flatten4d3d(inputs) + inputs, ed = encode(inputs, target_space, hparams, "input_enc") + else: + ed = None + + # Autoencoding. + losses = {"vc": tf.constant(0.0), "sm": tf.constant(0.0)} + latent_len = hparams.latent_length + if hparams.do_ae: + targets_pad, _ = common_layers.pad_to_same_length( + targets, targets, final_length_divisible_by=latent_len * 2**k) + targets_c = compress(targets_pad, None, False, hparams, "compress") + targets_c = targets_c[:, :latent_len, :, :] + if hparams.mode != tf.estimator.ModeKeys.PREDICT: + # Compress and bottleneck. + t_c, t_bit, vc_loss, _ = bottleneck(targets_c, hparams, 2*2048, "vc") + tf.summary.histogram("bit0", tf.reshape(t_bit[:, 0, :], [-1])) + pc = common_layers.inverse_exp_decay(hparams.startup_steps) * 0.95 + pc = pc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0 + cond = tf.less(tf.random_uniform([]), pc) + t_c = tf.cond(cond, lambda: t_c, lambda: targets_c) + losses["vc"] = vc_loss * tf.to_float(cond) + # Extra loss predicting latent code from input. + t_pred = decode_transformer( + inputs, ed, tf.stop_gradient(t_c), hparams, "extra") + t_pred = tf.layers.dense(t_pred, 2**16, name="extra_logits") + losses["sm"] = tf.nn.sparse_softmax_cross_entropy_with_logits( + labels=t_bit, logits=t_pred) + losses["sm"] = tf.reduce_mean(losses["sm"]) * 0.2 * tf.to_float(cond) + else: + _, _, _, embed = bottleneck(targets_c, hparams, 2*2048, "vc") + t_c = tf.zeros_like(targets_c) + if cache is None: + cache = ae_latent_sample(t_c, inputs, ed, embed, 3, hparams) + cache = cache[0, :, :] + cache = tf.reshape(cache, [1, latent_len, 1]) + cache = tf.tile(cache, [beam_size, 1, 1]) + t_c = embed(cache) + # Postprocess. + pos = tf.get_variable("pos", [1, latent_len + 1, 1, hparams.hidden_size]) + t_c = tf.pad(t_c, [[0, 0], [1, 0], [0, 0], [0, 0]]) + pos + targets = tf.concat([tf.reverse(t_c, [1]), targets], axis=1) else: - emb_flat = tf.expand_dims(common_layers.flatten4d3d(emb), axis=2) - emb_flat = tf.stop_gradient(emb_flat) - dec_c = decode(None, None, emb_flat, inputs, ed, hparams, "dgold") - dec_c = tf.reshape(dec_c, tf.shape(emb)) - c_z = tf.layers.dense(dec_c, hparams.v_size, name="mask_context") - reconstruct_loss = tf.nn.softmax_cross_entropy_with_logits( - labels=hot, logits=c_z) - # If not training, use the predicted z instead of the autoregressive one. - if hparams.mode == tf.estimator.ModeKeys.PREDICT: - hot = tf.one_hot(tf.argmax(c_z, axis=-1), hparams.v_size) - - # Decompress, pass for ae loss. - z = ae_decompress(emb, ae, targets, hparams.is_2d, hparams, "ae") - if not (hparams.use_gumbel_softmax and hparams.softmax_k > 0): - kl *= common_layers.inverse_exp_decay(int(hparams.startup_steps * 0.8), - min_value=0.0001) - reconstruct_loss *= common_layers.inverse_exp_decay(hparams.startup_steps) - losses = {"kl": kl, "reconstruction": reconstruct_loss * 0.1} - return z, losses + targets = tf.pad(targets, [[0, 0], [latent_len + 1, 0], [0, 0], [0, 0]]) + + res = decode_transformer(inputs, ed, targets, hparams, "decoder") + res = res[:, latent_len + 1:, :, :] + return res, losses, cache @registry.register_model class TransformerAE(t2t_model.T2TModel): + """Autoencoder-augmented Transformer.""" + + @property + def has_input(self): + return self._problem_hparams.input_modality def model_fn_body(self, features): - return ae_transformer_internal( - features["inputs"], features["targets"], features["target_space_id"], - self._hparams) - - def infer(self, features=None, decode_length=50, beam_size=1, top_beams=1, - last_position_only=False, alpha=0.0): - """A inference method, see T2TModel.""" - if not features: - features = {} - inputs_old = None - if "inputs" in features and len(features["inputs"].shape) < 4: - inputs_old = features["inputs"] - features["inputs"] = tf.expand_dims(features["inputs"], 2) - - # Create an initial targets tensor. - if "partial_targets" in features: - initial_output = tf.convert_to_tensor(features["partial_targets"]) - else: - batch_size = tf.shape(features["inputs"])[0] - initial_output = tf.zeros((batch_size, 1, 1, 1), dtype=tf.int64) - - features["targets"] = initial_output - sharded_logits, _ = self.model_fn( - features, False, last_position_only=last_position_only) - sharded_samples = self._data_parallelism(tf.argmax, sharded_logits, 4) - samples = tf.concat(sharded_samples, 0) - - # More steps. - how_many_more_steps = 2 - for _ in xrange(how_many_more_steps): - with tf.variable_scope(tf.get_variable_scope(), reuse=True): - features["targets"] = samples - sharded_logits, _ = self.model_fn( - features, False, last_position_only=last_position_only) - sharded_samples = self._data_parallelism(tf.argmax, sharded_logits, 4) - samples = tf.concat(sharded_samples, 0) - - if inputs_old is not None: # Restore to not confuse Estimator. - features["inputs"] = inputs_old - return samples + inputs = features["inputs"] if "inputs" in features else None + if self._hparams.drop_inputs: + inputs = None + reuse = "cache_raw" in features + beam_size = self._decode_hparams.beam_size + with tf.variable_scope(tf.get_variable_scope(), reuse=reuse): + res, loss, _ = ae_transformer_internal( + inputs, features["targets"], features["target_space_id"], + self._hparams, beam_size, features.get("cache_raw", None)) + return res, loss + + def prepare_features_for_infer(self, features): + if not self._hparams.do_ae: + return features + beam_size = self._decode_hparams.beam_size + inputs = tf.zeros([beam_size, 1, 1, self._hparams.hidden_size]) + inputs = inputs if "inputs" in features else None + if self._hparams.drop_inputs or not self.has_input: + inputs = None + targets = tf.zeros([beam_size, 1, 1, self._hparams.hidden_size]) + with tf.variable_scope("body"): + _, _, cache = ae_transformer_internal( + inputs, targets, features["target_space_id"], + self._hparams, beam_size) + features["cache_raw"] = cache @registry.register_hparams @@ -481,12 +610,24 @@ def transformer_ae_small(): hparams = transformer.transformer_small() hparams.batch_size = 2048 hparams.learning_rate_warmup_steps = 4000 + hparams.num_hidden_layers = 3 + hparams.hidden_size = 384 + hparams.filter_size = 2048 + hparams.label_smoothing = 0.0 + hparams.add_hparam("c_size", 16) + hparams.add_hparam("latent_length", 4) + hparams.add_hparam("noise_dev", 1.0) + hparams.add_hparam("d_mix", 0.5) + # Bottleneck kinds supported: dense, semhash, gumbel-softmax. + hparams.add_hparam("bottleneck_kind", "semhash") + hparams.add_hparam("do_ae", int(True)) + hparams.add_hparam("drop_inputs", int(False)) hparams.add_hparam("z_size", 128) - hparams.add_hparam("z_size2", 0) - hparams.add_hparam("v_size", 1024*32) - hparams.add_hparam("num_compress_steps", 4) - hparams.add_hparam("kl_warmup_steps", 60000) - hparams.add_hparam("startup_steps", 30000) + hparams.add_hparam("v_size", 1024*64) + hparams.add_hparam("max_context_length", 64) + hparams.add_hparam("num_compress_steps", 3) + hparams.add_hparam("kl_steps", 35000) + hparams.add_hparam("startup_steps", 10000) hparams.add_hparam("kmeans_lr_factor", 0.002) hparams.add_hparam("z_dropout", 0.1) hparams.add_hparam("is_2d", 0) @@ -515,6 +656,7 @@ def transformer_ae_cifar(): hparams.is_2d = 1 hparams.learning_rate_warmup_steps = 8000 hparams.learning_rate = 0.2 + hparams.ffn_layer = "conv_hidden_relu_with_sepconv" return hparams @@ -522,11 +664,8 @@ def transformer_ae_cifar(): def transformer_ae_base(): """Set of hyperparameters.""" hparams = transformer_ae_small() + hparams.batch_size = 1024 hparams.hidden_size = 512 - hparams.filter_size = 2048 - hparams.attention_dropout = 0.0 - hparams.relu_dropout = 0.0 - hparams.dropout = 0.0 - hparams.num_hidden_layers = 4 - hparams.z_size = 256 + hparams.filter_size = 4096 + hparams.num_hidden_layers = 6 return hparams diff --git a/tensor2tensor/tpu/tpu_trainer.py b/tensor2tensor/tpu/tpu_trainer.py index 8cda597d4..d9b20ee75 100644 --- a/tensor2tensor/tpu/tpu_trainer.py +++ b/tensor2tensor/tpu/tpu_trainer.py @@ -23,7 +23,7 @@ # Dependency imports from tensor2tensor import models # pylint: disable=unused-import -from tensor2tensor.data_generators import all_problems # pylint: disable=unused-import +from tensor2tensor import problems # pylint: disable=unused-import from tensor2tensor.tpu import tpu_trainer_lib as lib from tensor2tensor.utils import trainer_utils @@ -35,7 +35,7 @@ flags.DEFINE_integer("tpu_num_shards", 8, "Number of tpu shards.") flags.DEFINE_string("output_dir", "", "Base output directory for run.") flags.DEFINE_string("master", "", "Address of TensorFlow master.") -flags.DEFINE_integer("eval_steps", 10, "Number of steps in evaluation.") +flags.DEFINE_integer("eval_steps", 200, "Number of steps in evaluation.") flags.DEFINE_integer("iterations_per_loop", 1000, "Number of iterations in a TPU training loop.") @@ -63,14 +63,29 @@ def main(unused_argv): batch_size=hparams.tpu_batch_size_per_shard * FLAGS.tpu_num_shards, log_device_placement=FLAGS.log_device_placement, iterations_per_loop=FLAGS.iterations_per_loop) - if FLAGS.train_steps: - estimator.train( - lambda params: input_fn(tf.estimator.ModeKeys.TRAIN, params), - steps=FLAGS.train_steps) - if FLAGS.eval_steps: + + if not FLAGS.train_steps: + assert FLAGS.eval_steps estimator.evaluate( lambda params: input_fn(tf.estimator.ModeKeys.EVAL, params), steps=FLAGS.eval_steps) + return + + num_rounds = FLAGS.train_steps // FLAGS.local_eval_frequency + steps_per_round = [FLAGS.local_eval_frequency] * num_rounds + remainder = FLAGS.train_steps % FLAGS.local_eval_frequency + if remainder: + steps_per_round.append(remainder) + + for num_steps in steps_per_round: + estimator.train( + lambda params: input_fn(tf.estimator.ModeKeys.TRAIN, params), + steps=num_steps) + if FLAGS.eval_steps: + estimator.evaluate( + lambda params: input_fn(tf.estimator.ModeKeys.EVAL, params), + steps=FLAGS.eval_steps) + tf.logging.info("Training and evaluation complete.") if __name__ == "__main__": diff --git a/tensor2tensor/tpu/tpu_trainer_lib.py b/tensor2tensor/tpu/tpu_trainer_lib.py index dca9f4de9..85a3cdf42 100644 --- a/tensor2tensor/tpu/tpu_trainer_lib.py +++ b/tensor2tensor/tpu/tpu_trainer_lib.py @@ -24,12 +24,10 @@ from __future__ import print_function import copy -import math # Dependency imports from tensor2tensor.layers import common_layers -from tensor2tensor.models import transformer from tensor2tensor.utils import data_reader from tensor2tensor.utils import metrics from tensor2tensor.utils import model_builder @@ -39,6 +37,17 @@ from tensorflow.python.util import nest +def create_dummy_vars(): + """Dummy vars for restore to work when not using TPU codepath.""" + with tf.variable_scope("losses_avg"): + with tf.variable_scope("problem_0"): + for var_name in ["total", "extra", "training"]: + tf.get_variable( + "%s_loss" % var_name, initializer=100.0, trainable=False) + with tf.variable_scope("train_stats"): + tf.get_variable("problem_0_steps", initializer=0, trainable=False) + + def get_input_fn(data_dir, problem, hparams): """Get basic T2T input fn.""" @@ -60,46 +69,11 @@ def input_fn(mode, params): }, } - def decode_record(record): - """Serialized Example to dict of .""" - data_fields, _ = problem.example_reading_spec() - decoded = tf.parse_single_example(record, features=data_fields) - decoded["inputs"] = decoded["inputs"].values - decoded["targets"] = decoded["targets"].values - return decoded - - data_files = tf.contrib.slim.parallel_reader.get_data_files( - problem.filepattern(data_dir, mode)) - dataset = tf.data.TFRecordDataset(data_files) - dataset = dataset.map(decode_record, num_parallel_calls=num_threads) - - def _preprocess(example, problem, hparams, mode): - example = problem.preprocess_example(example, mode, hparams) - # We do not want int64s as they are not supported on TPUs. - example = data_reader.cast_int64_to_int32(example) - return example - - dataset = dataset.map( - lambda ex: _preprocess(ex, problem, hparams, mode), - num_parallel_calls=num_threads) - def _valid_size(example): return data_reader.example_valid_size( example, batching_scheme["min_length"], batching_scheme["max_length"]) - dataset = dataset.filter(_valid_size) - if is_training: - dataset = dataset.shuffle(100) - # TODO(rsepassi): In eval mode, should not repeat - dataset = dataset.repeat(None) - dataset = data_reader.padded_batch(dataset, batch_size, - batching_scheme["padded_shapes"]) - - if not is_training: - dataset = dataset.map( - lambda f: pad_batch(f, batch_size), num_parallel_calls=num_threads) - - def shape_def(example): + def define_shapes(example): """Set the right shapes for the features.""" inputs = example["inputs"] targets = example["targets"] @@ -123,7 +97,22 @@ def shape_def(example): return example - dataset = dataset.map(shape_def, num_parallel_calls=num_threads) + dataset = problem.dataset( + mode=mode, data_dir=data_dir, num_threads=num_threads, hparams=hparams) + dataset = dataset.map( + data_reader.cast_int64_to_int32, num_threads=num_threads) + dataset = dataset.filter(_valid_size) + if is_training: + dataset = dataset.shuffle(100) + # TODO(rsepassi): In eval mode, should not repeat. Do so because TPU seems + # to crash if it runs out of data during eval. + dataset = dataset.repeat(None) + dataset = data_reader.padded_batch(dataset, batch_size, + batching_scheme["padded_shapes"]) + if not is_training: + dataset = dataset.map( + lambda f: pad_batch(f, batch_size), num_parallel_calls=num_threads) + dataset = dataset.map(define_shapes, num_parallel_calls=num_threads) dataset = dataset.prefetch(1) features = dataset.make_one_shot_iterator().get_next() @@ -155,6 +144,9 @@ def get_model_fn(model, hp, use_tpu=True): def model_fn(features, labels, mode, params, config): """Model fn.""" del params + del config + create_dummy_vars() + hparams = copy.deepcopy(hp) problem_hp = hparams.problems[0] orig_features = features @@ -168,9 +160,12 @@ def model_fn(features, labels, mode, params, config): # Transform features transformed_features = {} if input_modality is not None: - transformed_features["inputs"] = input_modality.bottom(features["inputs"]) - transformed_features["targets"] = target_modality.targets_bottom( - features["targets"]) + with tf.variable_scope(input_modality.name): + transformed_features["inputs"] = input_modality.bottom( + features["inputs"]) + with tf.variable_scope(target_modality.name): + transformed_features["targets"] = target_modality.targets_bottom( + features["targets"]) transformed_features["problem_choice"] = tf.constant(0) transformed_features["input_space_id"] = tf.constant( problem_hp.input_space_id) @@ -178,17 +173,19 @@ def model_fn(features, labels, mode, params, config): problem_hp.target_space_id) # Model construction - outputs = model_class.model_fn_body(transformed_features) - logits = target_modality.top(outputs, labels) + with tf.variable_scope("body"): + outputs = model_class.model_fn_body(transformed_features) + with tf.variable_scope(target_modality.name): + logits = target_modality.top(outputs, labels) - # Ensure the length is known statically - shape = [None] * logits.get_shape().ndims - shape[1] = hparams.max_length - logits.set_shape(logits.get_shape().merge_with(shape)) + # Ensure the length is known statically + shape = [None] * logits.get_shape().ndims + shape[1] = hparams.max_length + logits.set_shape(logits.get_shape().merge_with(shape)) - # Loss - loss_num, loss_den = target_modality.loss(logits, labels) - loss = loss_num / tf.maximum(1.0, loss_den) + # Loss + loss_num, loss_den = target_modality.loss(logits, labels) + loss = loss_num / tf.maximum(1.0, loss_den) if mode == tf.estimator.ModeKeys.EVAL: problem = hp.problem_instances[0] @@ -202,10 +199,7 @@ def model_fn(features, labels, mode, params, config): assert mode == tf.estimator.ModeKeys.TRAIN # Learning rate - num_shards = config.tpu_config.num_shards - lr = hparams.learning_rate * model_builder.learning_rate_decay( - hparams, num_worker_replicas=num_shards) - lr /= math.sqrt(float(num_shards)) + lr = hparams.learning_rate * model_builder.learning_rate_decay(hparams) # Optimizer opt = model_builder.ConditionalOptimizer(hparams.optimizer, lr, hparams) @@ -313,19 +307,3 @@ def make_estimator(model_fn, config=run_config, train_batch_size=batch_size, eval_batch_size=batch_size * 2) - - -@registry.register_hparams -def transformer_tpu(): - """HParams for Transformer model on TPU.""" - hp = transformer.transformer_base() - hp.use_pad_remover = int(False) # where op not supported - hp.optimizer = "TrueAdam" - hp.learning_rate = 0.4 - - # Inputs - # Each example in the batch will be of (padded) length hp.max_length - hp.max_length = 64 - hp.tpu_batch_size_per_shard = 20 - - return hp diff --git a/tensor2tensor/utils/devices.py b/tensor2tensor/utils/devices.py index 9fa322985..e296394da 100644 --- a/tensor2tensor/utils/devices.py +++ b/tensor2tensor/utils/devices.py @@ -118,8 +118,8 @@ def _replica_device_setter(worker_device): if FLAGS.locally_shard_to_cpu or FLAGS.worker_gpu < 1: datashard_devices += ["cpu:0"] caching_devices = None - elif FLAGS.sync: - assert FLAGS.ps_replicas > 0 + elif FLAGS.sync and FLAGS.ps_replicas > 0: + # compute on ps datashard_devices = [ _replica_device_setter(d) for d in ps_devices(all_workers=all_workers) ] @@ -131,7 +131,8 @@ def _replica_device_setter(worker_device): else: caching_devices = None else: - # old fashioned async - compute on worker + # compute on worker - this is either a single-worker setup or asynchronous + # with parameter servers. if FLAGS.worker_gpu > 1: datashard_devices = [ _replica_device_setter(FLAGS.worker_job + "/GPU:%d" % d) diff --git a/tensor2tensor/utils/model_builder.py b/tensor2tensor/utils/model_builder.py index 44a6f5208..ef362ed90 100644 --- a/tensor2tensor/utils/model_builder.py +++ b/tensor2tensor/utils/model_builder.py @@ -108,7 +108,8 @@ def nth_model(n): hparams.problems[n], n, dp, - devices.ps_devices(all_workers=True)) + devices.ps_devices(all_workers=True), + decode_hparams=decode_hparams) if mode == tf.estimator.ModeKeys.PREDICT: return model_class.infer( features, diff --git a/tensor2tensor/utils/t2t_model.py b/tensor2tensor/utils/t2t_model.py index 85f339511..07f4622d6 100644 --- a/tensor2tensor/utils/t2t_model.py +++ b/tensor2tensor/utils/t2t_model.py @@ -66,7 +66,8 @@ def __init__(self, problem_hparams, problem_idx=0, data_parallelism=None, - ps_devices=None): + ps_devices=None, + decode_hparams=None): """Create a T2TModel. Args: @@ -77,6 +78,7 @@ def __init__(self, data_parallelism: a expert_utils.parallelism (specifies devices for data parallelism). ps_devices: a list of devices to be used for experts + decode_hparams: a hyperparameter object with decoding parameters. Returns: a T2TModel @@ -103,6 +105,7 @@ def __init__(self, tf.logging.info("Unsetting shared_embedding_and_softmax_weights.") hparams.shared_embedding_and_softmax_weights = 0 self._hparams = hparams + self._decode_hparams = copy.copy(decode_hparams) self._data_parallelism = data_parallelism self._num_datashards = data_parallelism.n self._ps_devices = ps_devices @@ -146,6 +149,10 @@ def _create_modalities(self, problem_hparams, hparams): def has_input(self): return self._problem_hparams.input_modality + def prepare_features_for_infer(self, features): + """Called before inference to allow adding infer-specific features.""" + pass + def eval_autoregressive(self, features=None, decode_length=50, @@ -195,11 +202,11 @@ def infer(self, """ # TODO(rsepassi): Make decoding work with real-valued model outputs # (i.e. if the target modality is RealModality). - if not self.has_input: - # since there is no input, it is more interesting to see randomly - # generated sequences, than to see the most likely sequence repeatedly. - beam_size = 1 - self._hparams.sampling_method = "random" + self.prepare_features_for_infer(features) + if not self.has_input and beam_size > 1: + tf.logging.warn("Beam searching for a model with no inputs.") + if not self.has_input and self._hparams.sampling_method != "random": + tf.logging.warn("Non-random sampling for a model with no inputs.") if is_class_modality( self._hparams.problems[self._problem_idx].target_modality): beam_size = 1 # No use to run beam-search for a single class. @@ -540,6 +547,7 @@ def model_fn(self, features, skip=False, last_position_only=False): ] all_previous_modalities.extend(previous_modalities) do_reuse = input_modality.name in all_previous_modalities + transformed_features[key + "_raw"] = sharded_features[key] with tf.variable_scope(input_modality.name, reuse=do_reuse): transformed_features[key] = input_modality.bottom_sharded( sharded_features[key], dp) @@ -547,8 +555,13 @@ def model_fn(self, features, skip=False, last_position_only=False): # Target space id just gets copied to every shard. if "target_space_id" in features: - transformed_features["target_space_id"] = [features["target_space_id"] - ] * self._num_datashards + transformed_features["target_space_id"] = [ + features["target_space_id"]] * self._num_datashards + + # For features without a modality ending in "_raw", we pass them raw. + for key, feature in sharded_features.items(): + if key not in transformed_features and key.endswith("_raw"): + transformed_features[key] = feature # Targets are transformed by the autoregressive part of the modality previous_tgt_modalities = [ @@ -564,7 +577,7 @@ def model_fn(self, features, skip=False, last_position_only=False): sharded_features["targets"], dp) # Allows later access to pre-embedding raw targets. - transformed_features["raw_targets"] = sharded_features["targets"] + transformed_features["targets_raw"] = sharded_features["targets"] # Construct the model body. with tf.variable_scope("body", reuse=self._problem_idx > 0): diff --git a/tensor2tensor/utils/trainer_utils.py b/tensor2tensor/utils/trainer_utils.py index e90e2dd10..57d45fb50 100644 --- a/tensor2tensor/utils/trainer_utils.py +++ b/tensor2tensor/utils/trainer_utils.py @@ -63,6 +63,19 @@ flags.DEFINE_string("data_dir", None, "Directory with training data.") flags.DEFINE_integer("train_steps", 250000, "The number of steps to run training for.") +flags.DEFINE_string("eval_early_stopping_metric", "loss", + "If --schedule=train_and_evaluate and " + "--eval_early_stopping_steps is not None, then stop when " + "--eval_early_stopping_metric has not decreased for " + "--eval_early_stopping_steps") +flags.DEFINE_integer("eval_early_stopping_steps", None, + "If --schedule=train_and_evaluate and " + "--eval_early_stopping_steps is not None, then stop when " + "--eval_early_stopping_metric has not decreased for " + "--eval_early_stopping_steps") +flags.DEFINE_bool("eval_early_stopping_metric_minimize", True, + "Whether to check for the early stopping metric going down " + "or up.") flags.DEFINE_bool("eval_run_autoregressive", False, "Run eval autoregressively where we condition on previous" "generated output instead of the actual target.") @@ -148,7 +161,20 @@ def create_experiment(data_dir, model_name, train_steps, eval_steps, hparams, save_steps=10, output_dir=run_config.model_dir, show_dataflow=True, - show_memory=True,)) + show_memory=True, + )) + if FLAGS.schedule == "train_and_evaluate": + if FLAGS.local_eval_frequency: + train_monitors.append( + tf.contrib.learn.monitors.ValidationMonitor( + input_fn=input_fns[tf.estimator.ModeKeys.EVAL], + eval_steps=eval_steps, + every_n_steps=FLAGS.local_eval_frequency, + hooks=eval_hooks, + early_stopping_rounds=FLAGS.eval_early_stopping_steps, + early_stopping_metric=FLAGS.eval_early_stopping_metric, + early_stopping_metric_minimize=FLAGS. + eval_early_stopping_metric_minimize)) optional_kwargs = {} if FLAGS.export_saved_model: @@ -164,7 +190,6 @@ def create_experiment(data_dir, model_name, train_steps, eval_steps, hparams, eval_input_fn=input_fns[tf.estimator.ModeKeys.EVAL], train_steps=train_steps, eval_steps=eval_steps, - min_eval_frequency=FLAGS.local_eval_frequency, train_monitors=train_monitors, eval_hooks=eval_hooks, eval_delay_secs=0, @@ -378,8 +403,9 @@ def is_chief(): def session_config(): """The TensorFlow Session config to use.""" - graph_options = tf.GraphOptions(optimizer_options=tf.OptimizerOptions( - opt_level=tf.OptimizerOptions.L1, do_function_inlining=False)) + graph_options = tf.GraphOptions( + optimizer_options=tf.OptimizerOptions( + opt_level=tf.OptimizerOptions.L1, do_function_inlining=False)) if FLAGS.experimental_optimize_placement: rewrite_options = tf.RewriterConfig(optimize_tensor_layout=True)