diff --git a/.gitignore b/.gitignore old mode 100755 new mode 100644 index 80f48a2fd..c9dd3db88 --- a/.gitignore +++ b/.gitignore @@ -1,14 +1,16 @@ # Compiled python modules. *.pyc +# Byte-compiled +_pycache__/ + # Python egg metadata, regenerated from source files by setuptools. /*.egg-info -# PyPI distribution artificats +# PyPI distribution artifacts. build/ dist/ # Sublime project files *.sublime-project *.sublime-workspace - diff --git a/tensor2tensor/data_generators/text_encoder.py b/tensor2tensor/data_generators/text_encoder.py old mode 100755 new mode 100644 index 6cfa9e740..7b00a85d2 --- a/tensor2tensor/data_generators/text_encoder.py +++ b/tensor2tensor/data_generators/text_encoder.py @@ -28,11 +28,20 @@ # Dependency imports import six +from six import PY2 from six.moves import xrange # pylint: disable=redefined-builtin from tensor2tensor.data_generators import tokenizer import tensorflow as tf + +# Conversion between Unicode and UTF-8, if required (on Python2) +_native_to_unicode = (lambda s: s.decode("utf-8")) if PY2 else (lambda s: s) + + +_unicode_to_native = (lambda s: s.encode("utf-8")) if PY2 else (lambda s: s) + + # Reserved tokens for things like padding and EOS symbols. PAD = "" EOS = "" @@ -162,15 +171,36 @@ def _load_vocab_from_file(self, filename): class SubwordTextEncoder(TextEncoder): - """Class for breaking tokens into subtokens. + """Class for invertibly encoding text using a limited vocabulary. - Invertibly encodes a string as a sequence of subtokens from a limited + Invertibly encodes a native string as a sequence of subtokens from a limited vocabulary. A SubwordTextEncoder is built from a corpus (so it is tailored to the text in the corpus), and stored to a file. See text_encoder_build_subword.py. It can then be loaded and used to encode/decode any text. + + Encoding has four phases: + + 1. Tokenize into a list of tokens. Each token is a unicode string of either + all alphanumeric characters or all non-alphanumeric characters. We drop + tokens consisting of a single space that are between two alphanumeric + tokens. + + 2. Escape each token. This escapes away special and out-of-vocabulary + characters, and makes sure that each token ends with an underscore, and + has no other underscores. + + 3. Represent each escaped token as a the concatenation of a list of subtokens + from the limited vocabulary. Subtoken selection is done greedily from + beginning to end. That is, we construct the list in order, always picking + the longest subtoken in our vocabulary that matches a prefix of the + remaining portion of the encoded token. + + 4. Concatenate these lists. This concatenation is invertible due to the + fact that the trailing underscores indicate when one list is finished. + """ def __init__(self, filename=None, num_reserved_ids=2): @@ -182,24 +212,26 @@ def __init__(self, filename=None, num_reserved_ids=2): super(SubwordTextEncoder, self).__init__(num_reserved_ids=num_reserved_ids) def encode(self, raw_text): - """Converts a string to a list of subtoken ids. + """Converts a native string to a list of subtoken ids. Args: - raw_text: a string. + raw_text: a native string. Returns: a list of integers in the range [0, vocab_size) """ - return self._tokens_to_subtokens(self._tokenizer.encode(raw_text)) + return self._tokens_to_subtokens(self._tokenizer.encode( + _native_to_unicode(raw_text))) def decode(self, subtokens): - """Converts a sequence of subtoken ids to a string. + """Converts a sequence of subtoken ids to a native string. Args: subtokens: a list of integers in the range [0, vocab_size) Returns: - a string + a native string """ - return self._tokenizer.decode(self._subtokens_to_tokens(subtokens)) + return _unicode_to_native(self._tokenizer.decode( + self._subtokens_to_tokens(subtokens))) @property def vocab_size(self): @@ -239,8 +271,8 @@ def subtoken_to_subtoken_string(self, subtoken): if subtoken_string: return subtoken_string if 0 <= subtoken < self._num_reserved_ids: - return "%s_" % RESERVED_TOKENS[subtoken] - return "ID%d_" % subtoken + return u"%s_" % RESERVED_TOKENS[subtoken] + return u"ID%d_" % subtoken def _escaped_token_to_subtokens(self, escaped_token): """Converts an escaped token string to a list of subtokens. @@ -260,27 +292,11 @@ def _escaped_token_to_subtokens(self, escaped_token): if subtoken != -1: break end -= 1 - if end > pos: - ret.append(subtoken) - pos = end - else: - # No subtoken in the vocabulary matches escaped_token[pos]. - # This can happen if the token contains a Unicode character - # that did not occur in the vocabulary training set. - # The id self.vocab_size - 1 is decoded as Unicode uFFFD, - # REPLACEMENT_CHARACTER. - ret.append(self.vocab_size - 1) - # Ensure that the outer loop continues - pos += 1 - return ret + assert end > pos + ret.append(subtoken) + pos = end - @classmethod - def alphabet(cls, token_counts): - """Return the set of Unicode characters that appear in the tokens.""" - alphabet_set = set() - for token in six.iterkeys(token_counts): - alphabet_set |= set(token) - return alphabet_set + return ret @classmethod def build_to_target_size(cls, @@ -304,23 +320,21 @@ def build_to_target_size(cls, Returns: a SubwordTextEncoder instance. """ - # Calculate the alphabet, i.e. the set of all Unicode characters - # that appear in the tokens. - alphabet_set = cls.alphabet(token_counts) - tf.logging.info("Alphabet contains %d characters" % len(alphabet_set)) - def bisect(min_val, max_val): + """Bisection to find the right size.""" present_count = (max_val + min_val) // 2 tf.logging.info("Trying min_count %d" % present_count) subtokenizer = cls() - subtokenizer.build_from_token_counts(token_counts, alphabet_set, + subtokenizer.build_from_token_counts(token_counts, present_count, num_iterations) if min_val >= max_val or subtokenizer.vocab_size == target_size: return subtokenizer + if subtokenizer.vocab_size > target_size: other_subtokenizer = bisect(present_count + 1, max_val) else: other_subtokenizer = bisect(min_val, present_count - 1) + if (abs(other_subtokenizer.vocab_size - target_size) < abs(subtokenizer.vocab_size - target_size)): return other_subtokenizer @@ -330,17 +344,29 @@ def bisect(min_val, max_val): def build_from_token_counts(self, token_counts, - alphabet_set, min_count, num_iterations=4): """Train a SubwordTextEncoder based on a dictionary of word counts. Args: token_counts: a dictionary of Unicode strings to int. - alphabet_set: the set of Unicode characters that appear in the tokens. min_count: an integer - discard subtokens with lower counts. num_iterations: an integer. how many iterations of refinement. """ + # first determine the alphabet to include all characters with count at + # least min_count in the dataset. + char_counts = defaultdict(int) + for token, count in six.iteritems(token_counts): + for c in token: + char_counts[c] += count + self._alphabet = set() + for c, count in six.iteritems(char_counts): + if count >= min_count: + self._alphabet.add(c) + # Make sure all characters needed for escaping are included + for c in u"\\_;0123456789": + self._alphabet.add(c) + # We build iteratively. On each iteration, we segment all the words, # then count the resulting potential subtokens, keeping the ones # with high enough counts for our new vocabulary. @@ -364,43 +390,36 @@ def build_from_token_counts(self, for end in xrange(start + 1, len(escaped_token) + 1): subtoken_string = escaped_token[start:end] counts[subtoken_string] += count + # Make sure all characters needed for escaping are included + for c in self._alphabet: + counts[c] += min_count # Array of sets of candidate subtoken strings, by length len_to_subtoken_strings = [] for subtoken_string, count in six.iteritems(counts): lsub = len(subtoken_string) - # All subtoken strings of length 1 are automatically included - # later, so we don't need to consider them here - if count < min_count or lsub <= 1: - continue - # Add this subtoken string to its length set - while len(len_to_subtoken_strings) <= lsub: - len_to_subtoken_strings.append(set()) - len_to_subtoken_strings[lsub].add(subtoken_string) + if count >= min_count: + # Add this subtoken string to its length set + while len(len_to_subtoken_strings) <= lsub: + len_to_subtoken_strings.append(set()) + len_to_subtoken_strings[lsub].add(subtoken_string) new_subtoken_strings = [] # consider the candidates longest to shortest, so that if we accept # a longer subtoken string, we can decrement the counts of its prefixes. - for subtoken_strings in reversed(len_to_subtoken_strings[2:]): + for lsub in reversed(range(1, len(len_to_subtoken_strings))): + subtoken_strings = len_to_subtoken_strings[lsub] for subtoken_string in subtoken_strings: count = counts[subtoken_string] - if count < min_count: - continue - new_subtoken_strings.append((count, subtoken_string)) - for l in xrange(1, len(subtoken_string)): - counts[subtoken_string[:l]] -= count - # Sort what we've got so far in decreasing order by count + if count >= min_count: + new_subtoken_strings.append((count, subtoken_string)) + for l in xrange(1, lsub): + counts[subtoken_string[:l]] -= count + # Sort in decreasing order by count new_subtoken_strings.sort(reverse=True) - # Add the alphabet set at the end of the vocabulary list - for char in alphabet_set: - new_subtoken_strings.append((0, char)) - # Also include the Unicode REPLACEMENT CHARACTER to use - # when encountering previously unseen Unicode characters - # in the input (i.e. input external to the tokenizer training - # set, which may thus contain characters not in the alphabet_set). - # This must be the last entry in the subtoken vocabulary list. - new_subtoken_strings.append((0, u"\uFFFD")) # Now we have a candidate vocabulary + old_alphabet = self._alphabet self._init_from_list([u""] * self._num_reserved_ids + [p[1] for p in new_subtoken_strings]) + assert old_alphabet == self._alphabet tf.logging.info("vocab_size = %d" % self.vocab_size) original = "This sentence was encoded by the SubwordTextEncoder." @@ -423,46 +442,77 @@ def _init_from_list(self, subtoken_strings): self._all_subtoken_strings = subtoken_strings self._subtoken_string_to_id = { s: i for i, s in enumerate(subtoken_strings) if s} + self._alphabet = set([c for c in subtoken_strings if len(c) == 1]) def _load_from_file(self, filename): """Load from a file.""" subtoken_strings = [] with tf.gfile.Open(filename) as f: for line in f: - if six.PY2: - subtoken_strings.append(line.strip()[1:-1].decode("utf-8")) - else: - subtoken_strings.append(line.strip()[1:-1]) + subtoken_strings.append(_native_to_unicode(line.strip()[1:-1])) self._init_from_list(subtoken_strings) def store_to_file(self, filename): with tf.gfile.Open(filename, "w") as f: for subtoken_string in self._all_subtoken_strings: - if six.PY2: - f.write("'" + subtoken_string.encode("utf-8") + "'\n") - else: - f.write("'" + subtoken_string + "'\n") + f.write("'" + _unicode_to_native(subtoken_string) + "'\n") def _escape_token(self, token): - r"""Translate '\'->'\\' and '_'->'\u', then append '_'. + r"""Escape away underscores and OOV characters and append '_'. + + This allows the token to be experessed as the concatenation of a list + of subtokens from the vocabulary. The underscore acts as a sentinel + which allows us to invertibly concatenate multiple such lists. Args: - token: a string + token: a unicode string Returns: - escaped_token: a string + escaped_token: a unicode string """ - return token.replace("\\", "\\\\").replace("_", "\\u") + "_" + token = token.replace("\\", "\\\\").replace("_", "\\u") + "_" + ret = u"" + for c in token: + if c in self._alphabet: + ret += c + else: + ret += u"\\%d;" % ord(c) + return ret def _unescape_token(self, escaped_token): - r"""Remove '_' from end, then translate '\\'->'\' and '\u'->'_'. + r"""Inverse of _escape_token(). Args: - escaped_token: a string + escaped_token: a unicode string Returns: - token: a string + token: a unicode string """ - assert escaped_token[-1] == "_" - return escaped_token[:-1].replace("\\u", "_").replace("\\\\", "\\") + ret = u"" + escaped_token = escaped_token[:-1] + pos = 0 + while pos < len(escaped_token): + c = escaped_token[pos] + if c == "\\": + pos += 1 + c = escaped_token[pos] + if c == u"u": + ret += u"_" + pos += 1 + elif c == "\\": + ret += u"_" + pos += 1 + else: + semicolon_pos = escaped_token.find(u";", pos) + if semicolon_pos == -1: + continue + try: + ret += unichr(int(escaped_token[pos:semicolon_pos])) + pos = semicolon_pos + 1 + except (ValueError, OverflowError) as _: + pass + else: + ret += c + pos += 1 + return ret @classmethod def get_token_counts(cls, text_filepattern, corpus_max_lines): @@ -474,7 +524,7 @@ def get_token_counts(cls, text_filepattern, corpus_max_lines): with tf.gfile.Open(text_filename) as f: for line in f: # The tokenizer updates token_counts in encode() - tok.encode(line.strip()) + tok.encode(_native_to_unicode(line.strip())) lines_read += 1 if corpus_max_lines > 0 and lines_read > corpus_max_lines: return tok.token_counts diff --git a/tensor2tensor/data_generators/text_encoder_build_subword.py b/tensor2tensor/data_generators/text_encoder_build_subword.py index 9b8da9364..659e9da14 100644 --- a/tensor2tensor/data_generators/text_encoder_build_subword.py +++ b/tensor2tensor/data_generators/text_encoder_build_subword.py @@ -59,8 +59,7 @@ def main(unused_argv): raise ValueError('Must provide --corpus_filepattern') token_counts = text_encoder.SubwordTextEncoder.get_token_counts( FLAGS.corpus_filepattern, FLAGS.corpus_max_lines) - alphabet_set = text_encoder.SubwordTextEncoder.alphabet(token_counts) - gs.build_from_token_counts(token_counts, alphabet_set, + gs.build_from_token_counts(token_counts, FLAGS.min_count, FLAGS.num_iterations) gs.store_to_file(FLAGS.output_fn) diff --git a/tensor2tensor/data_generators/tokenizer.py b/tensor2tensor/data_generators/tokenizer.py index 0eaea4f58..8490ead19 100644 --- a/tensor2tensor/data_generators/tokenizer.py +++ b/tensor2tensor/data_generators/tokenizer.py @@ -14,32 +14,29 @@ """A simple invertible tokenizer. -Converts from a raw string to a list of tokens (represented as Unicode strings). +Converts from a unicode string to a list of tokens +(represented as Unicode strings). This tokenizer has the following desirable properties: - It is invertible. - - Punctuation is broken away from adjacent letters. + - Alphanumeric characters are broken away from non-alphanumeric characters. - A single space between words does not produce an extra token. - The full Unicode punctuation and separator set is recognized. The tokenization algorithm is as follows: -0. We classify the input characters into "word characters" and - "separator characters". Separator characters are defined as the union of - Unicode punctuation and separators/white space. All other characters are - "word characters". - -1. Split the text into a list of tokens, splitting at every boundary of a - "word character" and a "separator character". This produces a list which - alternates between "word tokens" (strings of word codepoints) and - "separator tokens" (strings of of separator/punctuation codepoints). +1. Split the text into a list of tokens, splitting at every boundary of an + alphanumeric character and a non-alphanumeric character. This produces + a list which alternates between "alphanumeric tokens" + (strings of alphanumeric characters) and "non-alphanumeric tokens" + (strings of of non-alphanumeric characters). 2. Remove every token consisting of a single space, unless it is the very first or very last token in the list. These tokens are now - implied by the fact that there are two adjacent word tokens. + implied by the fact that there are two adjacent alphanumeric tokens. -e.g. "Dude - that's so cool." - -> ["Dude", " - ", "that", "'", "s", "so", "cool", "."] +e.g. u"Dude - that's so cool." + -> [u"Dude", u" - ", u"that", u"'", u"s", u"so", u"cool", u"."] """ from __future__ import absolute_import @@ -47,87 +44,66 @@ from __future__ import print_function from collections import defaultdict -import re import sys import unicodedata # Dependency imports -from six import PY2 from six import unichr # pylint: disable=redefined-builtin from six.moves import xrange # pylint: disable=redefined-builtin -# Regular expression that matches Unicode whitespace characters -# (including ASCII whitespace) as defined in the Python run-time library -_RE_WHITESPACE = re.compile(r"^\s$", re.UNICODE) - - -# Set of Unicode whitespace code points -UNICODE_WHITESPACE = set(unichr(i) for i in xrange(sys.maxunicode) - if _RE_WHITESPACE.match(unichr(i))) - - -# Set of Unicode punctuation code points -UNICODE_PUNCTUATION = set(unichr(i) for i in xrange(sys.maxunicode) - if unicodedata.category(unichr(i)).startswith("P")) - - -# Conversion between Unicode and UTF-8, if required (on Python2) -_decode_string = (lambda s: s.decode("utf-8")) if PY2 else (lambda s: s) - - -_encode_string = (lambda s: s.encode("utf-8")) if PY2 else (lambda s: s) - - class Tokenizer(object): """Vocab for breaking words into Unicode wordpieces. """ - _SEPARATOR_CHAR_SET = UNICODE_WHITESPACE | UNICODE_PUNCTUATION + # This set contains all letter and number characters. + _ALPHANUMERIC_CHAR_SET = set( + unichr(i) for i in xrange(sys.maxunicode) + if (unicodedata.category(unichr(i)).startswith("L") or + unicodedata.category(unichr(i)).startswith("N"))) def __init__(self): self.token_counts = defaultdict(int) - def encode(self, raw_text): - """Encode a raw string as a list of tokens. + def encode(self, text): + """Encode a unicode string as a list of tokens. Args: - raw_text: a (Python2 or Python3 native) string + text: a unicode string Returns: a list of tokens as Unicode strings """ - if not raw_text: + if not text: return [] ret = [] token_start = 0 - unicode_text = _decode_string(raw_text) # Classify each character in the input string - is_sep = [c in self._SEPARATOR_CHAR_SET for c in unicode_text] - for pos in xrange(1, len(unicode_text)): - if is_sep[pos] != is_sep[pos - 1]: - token = unicode_text[token_start:pos] + is_alnum = [c in self._ALPHANUMERIC_CHAR_SET for c in text] + for pos in xrange(1, len(text)): + if is_alnum[pos] != is_alnum[pos - 1]: + token = text[token_start:pos] if token != u" " or token_start == 0: ret.append(token) self.token_counts[token] += 1 token_start = pos - final_token = unicode_text[token_start:] + final_token = text[token_start:] ret.append(final_token) self.token_counts[final_token] += 1 return ret def decode(self, tokens): - """Decode a list of tokens to a string. + """Decode a list of tokens to a unicode string. Args: tokens: a list of Unicode strings Returns: - a (Python2 or Python3 native) string + a unicode string """ ret = u"" - is_word = [t[0] not in self._SEPARATOR_CHAR_SET for t in tokens] + token_is_alnum = [t[0] in self._ALPHANUMERIC_CHAR_SET for t in tokens] for i, token in enumerate(tokens): - if i > 0 and is_word[i - 1] and is_word[i]: + if i > 0 and token_is_alnum[i - 1] and token_is_alnum[i]: ret += u" " ret += token - return _encode_string(ret) + return ret diff --git a/tensor2tensor/data_generators/tokenizer_test.py b/tensor2tensor/data_generators/tokenizer_test.py index 70c7d31eb..766630ba3 100644 --- a/tensor2tensor/data_generators/tokenizer_test.py +++ b/tensor2tensor/data_generators/tokenizer_test.py @@ -23,7 +23,6 @@ # Dependency imports -import six from six.moves import xrange # pylint: disable=redefined-builtin from tensor2tensor.data_generators import tokenizer @@ -35,29 +34,28 @@ class TokenizerTest(tf.test.TestCase): def testEncode(self): t = tokenizer.Tokenizer() self.assertEqual( - t.encode("Dude - that's so cool."), - ["Dude", " - ", "that", "'", "s", "so", "cool", "."]) - # TODO(lukaszkaiser): make it work again with Unicode. - # self.assertEqual( - # t.encode("Łukasz est né en 1981."), - # ["Łukasz", "est", "né", "en", "1981", "."]) + t.encode(u"Dude - that's so cool."), + [u"Dude", u" - ", u"that", u"'", u"s", u"so", u"cool", u"."]) self.assertEqual( - t.encode(" Spaces at the ends "), - [" ", "Spaces", "at", "the", "ends", " "]) - self.assertEqual(t.encode("802.11b"), ["802", ".", "11b"]) - self.assertEqual(t.encode("two. \nlines"), ["two", ". \n", "lines"]) + t.encode(u"Łukasz est né en 1981."), + [u"Łukasz", u"est", u"né", u"en", u"1981", u"."]) + self.assertEqual( + t.encode(u" Spaces at the ends "), + [u" ", u"Spaces", u"at", u"the", u"ends", u" "]) + self.assertEqual(t.encode(u"802.11b"), [u"802", u".", u"11b"]) + self.assertEqual(t.encode(u"two. \nlines"), [u"two", u". \n", u"lines"]) def testDecode(self): t = tokenizer.Tokenizer() self.assertEqual( - t.decode(["Dude", " - ", "that", "'", "s", "so", "cool", "."]), - "Dude - that's so cool.") + t.decode([u"Dude", u" - ", u"that", u"'", u"s", u"so", u"cool", u"."]), + u"Dude - that's so cool.") def testInvertibilityOnRandomStrings(self): t = tokenizer.Tokenizer() random.seed(123) - for _ in xrange(0): # TODO(lukaszkaiser): make it work again with Unicode. - s = "".join([six.int2byte(random.randint(0, 255)) for _ in xrange(10)]) + for _ in xrange(1000): + s = u"".join([unichr(random.randint(0, 65535)) for _ in xrange(10)]) self.assertEqual(s, t.decode(t.encode(s))) diff --git a/tensor2tensor/models/bluenet.py b/tensor2tensor/models/bluenet.py index 19bed2032..8f4c89eac 100644 --- a/tensor2tensor/models/bluenet.py +++ b/tensor2tensor/models/bluenet.py @@ -77,7 +77,8 @@ def run_binary_modules(modules, cur1, cur2, hparams): """Run binary modules.""" selection_var = tf.get_variable("selection", [len(modules)], initializer=tf.zeros_initializer()) - inv_t = 100.0 * common_layers.inverse_exp_decay(100000, min_value=0.01) + inv_t = 100.0 * common_layers.inverse_exp_decay( + hparams.anneal_until, min_value=0.01) selected_weights = tf.nn.softmax(selection_var * inv_t) all_res = [modules[n](cur1, cur2, hparams) for n in xrange(len(modules))] all_res = tf.concat([tf.expand_dims(r, axis=0) for r in all_res], axis=0) @@ -89,7 +90,8 @@ def run_unary_modules_basic(modules, cur, hparams): """Run unary modules.""" selection_var = tf.get_variable("selection", [len(modules)], initializer=tf.zeros_initializer()) - inv_t = 100.0 * common_layers.inverse_exp_decay(100000, min_value=0.01) + inv_t = 100.0 * common_layers.inverse_exp_decay( + hparams.anneal_until, min_value=0.01) selected_weights = tf.nn.softmax(selection_var * inv_t) all_res = [modules[n](cur, hparams) for n in xrange(len(modules))] all_res = tf.concat([tf.expand_dims(r, axis=0) for r in all_res], axis=0) @@ -109,7 +111,8 @@ def run_unary_modules_sample(modules, cur, hparams, k): lambda: tf.zeros_like(cur), lambda i=n: modules[i](cur, hparams)) for n in xrange(len(modules))] - inv_t = 100.0 * common_layers.inverse_exp_decay(100000, min_value=0.01) + inv_t = 100.0 * common_layers.inverse_exp_decay( + hparams.anneal_until, min_value=0.01) selected_weights = tf.nn.softmax(selection_var * inv_t - 1e9 * (1.0 - to_run)) all_res = tf.concat([tf.expand_dims(r, axis=0) for r in all_res], axis=0) res = all_res * tf.reshape(selected_weights, [-1, 1, 1, 1, 1]) @@ -122,6 +125,14 @@ def run_unary_modules(modules, cur, hparams): return run_unary_modules_sample(modules, cur, hparams, 4) +def batch_deviation(x): + """Average deviation of the batch.""" + x_mean = tf.reduce_mean(x, axis=[0], keep_dims=True) + x_variance = tf.reduce_mean( + tf.square(x - x_mean), axis=[0], keep_dims=True) + return tf.reduce_mean(tf.sqrt(x_variance)) + + @registry.register_model class BlueNet(t2t_model.T2TModel): @@ -153,14 +164,15 @@ def run_unary(x, name): with tf.variable_scope("conv"): x = run_unary_modules(conv_modules, x, hparams) x.set_shape(x_shape) - return x + return tf.nn.dropout(x, 1.0 - hparams.dropout), batch_deviation(x) - cur1, cur2 = inputs, inputs + cur1, cur2, extra_loss = inputs, inputs, 0.0 cur_shape = inputs.get_shape() for i in xrange(hparams.num_hidden_layers): with tf.variable_scope("layer_%d" % i): - cur1 = run_unary(cur1, "unary1") - cur2 = run_unary(cur2, "unary2") + cur1, loss1 = run_unary(cur1, "unary1") + cur2, loss2 = run_unary(cur2, "unary2") + extra_loss += (loss1 + loss2) / float(hparams.num_hidden_layers) with tf.variable_scope("binary1"): next1 = run_binary_modules(binary_modules, cur1, cur2, hparams) next1.set_shape(cur_shape) @@ -169,7 +181,9 @@ def run_unary(x, name): next2.set_shape(cur_shape) cur1, cur2 = next1, next2 - return cur1 + anneal = common_layers.inverse_exp_decay(hparams.anneal_until) + extra_loss *= hparams.batch_deviation_loss_factor * anneal + return cur1, extra_loss @registry.register_hparams @@ -185,7 +199,7 @@ def bluenet_base(): hparams.num_hidden_layers = 8 hparams.kernel_height = 3 hparams.kernel_width = 3 - hparams.learning_rate_decay_scheme = "exp50k" + hparams.learning_rate_decay_scheme = "exp10k" hparams.learning_rate = 0.05 hparams.learning_rate_warmup_steps = 3000 hparams.initializer_gain = 1.0 @@ -196,6 +210,8 @@ def bluenet_base(): hparams.optimizer_adam_beta1 = 0.85 hparams.optimizer_adam_beta2 = 0.997 hparams.add_hparam("imagenet_use_2d", True) + hparams.add_hparam("anneal_until", 40000) + hparams.add_hparam("batch_deviation_loss_factor", 0.001) return hparams diff --git a/tensor2tensor/models/common_layers.py b/tensor2tensor/models/common_layers.py index 36d9b0b51..2e2b74268 100644 --- a/tensor2tensor/models/common_layers.py +++ b/tensor2tensor/models/common_layers.py @@ -135,9 +135,14 @@ def image_augmentation(images, do_colors=False): def cifar_image_augmentation(images): """Image augmentation suitable for CIFAR-10/100. - As described in https://arxiv.org/pdf/1608.06993v3.pdf (page 5).""" - images = tf.image.resize_image_with_crop_or_pad( - images, 40, 40) + As described in https://arxiv.org/pdf/1608.06993v3.pdf (page 5). + + Args: + images: a Tensor. + Returns: + Tensor of the same shape as images. + """ + images = tf.image.resize_image_with_crop_or_pad(images, 40, 40) images = tf.random_crop(images, [32, 32, 3]) images = tf.image.random_flip_left_right(images) return images diff --git a/tensor2tensor/utils/data_reader.py b/tensor2tensor/utils/data_reader.py index 88b45db9d..d09787ae4 100644 --- a/tensor2tensor/utils/data_reader.py +++ b/tensor2tensor/utils/data_reader.py @@ -203,12 +203,10 @@ def preprocess(img): lambda img=inputs: resize(img)) else: examples["inputs"] = tf.to_int64(resize(inputs)) - elif ("image_cifar10" in data_file_pattern - and mode == tf.contrib.learn.ModeKeys.TRAIN): + and mode == tf.contrib.learn.ModeKeys.TRAIN): examples["inputs"] = common_layers.cifar_image_augmentation( examples["inputs"]) - elif "audio" in data_file_pattern: # Reshape audio to proper shape sample_count = tf.to_int32(examples.pop("audio/sample_count"))