Skip to content
This repository has been archived by the owner on Jul 7, 2023. It is now read-only.

Commit

Permalink
Merge pull request #708 from rsepassi/push
Browse files Browse the repository at this point in the history
v1.5.7
  • Loading branch information
lukaszkaiser authored Apr 13, 2018
2 parents c4ca5a4 + 95aeb11 commit 120315c
Show file tree
Hide file tree
Showing 54 changed files with 2,401 additions and 718 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setup(
name='tensor2tensor',
version='1.5.6',
version='1.5.7',
description='Tensor2Tensor',
author='Google Inc.',
author_email='[email protected]',
Expand Down
2 changes: 1 addition & 1 deletion tensor2tensor/bin/t2t_datagen.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@

import numpy as np

from tensor2tensor import problems as problems_lib # pylint: disable=unused-import
from tensor2tensor.data_generators import algorithmic_math
from tensor2tensor.data_generators import all_problems # pylint: disable=unused-import
from tensor2tensor.data_generators import audio
from tensor2tensor.data_generators import generator_utils
from tensor2tensor.data_generators import snli
Expand Down
3 changes: 2 additions & 1 deletion tensor2tensor/bin/t2t_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ def create_decode_hparams():

def decode(estimator, hparams, decode_hp):
if FLAGS.decode_interactive:
decoding.decode_interactively(estimator, hparams, decode_hp, checkpoint_path=FLAGS.checkpoint_path)
decoding.decode_interactively(estimator, hparams, decode_hp,
checkpoint_path=FLAGS.checkpoint_path)
elif FLAGS.decode_from_file:
decoding.decode_from_file(estimator, FLAGS.decode_from_file, hparams,
decode_hp, FLAGS.decode_to_file,
Expand Down
95 changes: 95 additions & 0 deletions tensor2tensor/bin/t2t_distill.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# coding=utf-8
# Copyright 2018 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

r"""Perform distillation for a teacher to student.
This script is intended to be used with --model=distillation. See the model for
example hyperparameters and usage.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os

# Dependency imports

from tensor2tensor import models # pylint: disable=unused-import
from tensor2tensor import problems as problems_lib # pylint: disable=unused-import
from tensor2tensor.bin import t2t_trainer
from tensor2tensor.utils import cloud_mlengine
from tensor2tensor.utils import flags as t2t_flags # pylint: disable=unused-import
from tensor2tensor.utils import trainer_lib
from tensor2tensor.utils import usr_dir

import tensorflow as tf

flags = tf.flags
FLAGS = flags.FLAGS


def main(argv):
tf.logging.set_verbosity(tf.logging.INFO)
trainer_lib.set_random_seed(FLAGS.random_seed)
usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
t2t_trainer.log_registry()

if FLAGS.cloud_mlengine:
return cloud_mlengine.launch()

if FLAGS.generate_data:
t2t_trainer.generate_data()

if cloud_mlengine.job_dir():
FLAGS.output_dir = cloud_mlengine.job_dir()

if argv:
t2t_trainer.set_hparams_from_args(argv[1:])

with t2t_trainer.maybe_cloud_tpu():
root_output_dir = FLAGS.output_dir

# Train Teacher ============
hparams = t2t_trainer.create_hparams()
hparams.distill_phase = "train"
teacher_dir = os.path.join(root_output_dir, "teacher")
FLAGS.output_dir = teacher_dir

exp_fn = t2t_trainer.create_experiment_fn()
run_config = t2t_trainer.create_run_config(hparams)
exp = exp_fn(run_config, hparams)
if t2t_trainer.is_chief():
t2t_trainer.save_metadata(hparams)
t2t_trainer.execute_schedule(exp)
# ==========================
# Train Student ============
hparams = t2t_trainer.create_hparams()
hparams.add_hparam("teacher_dir", teacher_dir)
hparams.distill_phase = "distill"
student_dir = os.path.join(root_output_dir, "student")
FLAGS.output_dir = student_dir

exp_fn = t2t_trainer.create_experiment_fn()
run_config = t2t_trainer.create_run_config(hparams)
exp = exp_fn(run_config, hparams)

if t2t_trainer.is_chief():
t2t_trainer.save_metadata(hparams)
t2t_trainer.execute_schedule(exp)
# ==========================


if __name__ == "__main__":
tf.app.run()
4 changes: 2 additions & 2 deletions tensor2tensor/data_generators/algorithmic_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,8 +355,8 @@ def generate_calculus_integrate_sample(vlist, ops, min_depth, max_depth,
# functions: Dict of special function names. Maps human readable string names to
# single char names used in flist.
# ops: Dict mapping op symbols (chars) to ExprOp instances.
# solve_ops: Encodes rules for how to algebraically cancel out each operation. See
# doc-string for `algebra_inverse_solve`.
# solve_ops: Encodes rules for how to algebraically cancel out each operation.
# See doc-string for `algebra_inverse_solve`.
# int_encoder: Function that maps a string to a list of tokens. Use this to
# encode an expression to feed into a model.
# int_decoder: Function that maps a list of tokens to a string. Use this to
Expand Down
85 changes: 43 additions & 42 deletions tensor2tensor/data_generators/all_problems.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,47 +18,48 @@
from __future__ import division
from __future__ import print_function

# pylint: disable=unused-import
from tensor2tensor.data_generators import algorithmic
from tensor2tensor.data_generators import algorithmic_math
from tensor2tensor.data_generators import audio
from tensor2tensor.data_generators import celeba
from tensor2tensor.data_generators import cifar
from tensor2tensor.data_generators import cipher
from tensor2tensor.data_generators import cnn_dailymail
from tensor2tensor.data_generators import desc2code
from tensor2tensor.data_generators import fsns
from tensor2tensor.data_generators import gym
from tensor2tensor.data_generators import ice_parsing
from tensor2tensor.data_generators import imagenet
from tensor2tensor.data_generators import imdb
from tensor2tensor.data_generators import librispeech
from tensor2tensor.data_generators import lm1b
from tensor2tensor.data_generators import mnist
from tensor2tensor.data_generators import mscoco
from tensor2tensor.data_generators import multinli
from tensor2tensor.data_generators import ocr
from tensor2tensor.data_generators import problem_hparams
from tensor2tensor.data_generators import ptb
from tensor2tensor.data_generators import snli
from tensor2tensor.data_generators import squad
from tensor2tensor.data_generators import translate_encs
from tensor2tensor.data_generators import translate_ende
from tensor2tensor.data_generators import translate_enfr
from tensor2tensor.data_generators import translate_enmk
from tensor2tensor.data_generators import translate_envi
from tensor2tensor.data_generators import translate_enzh
from tensor2tensor.data_generators import twentybn
from tensor2tensor.data_generators import wiki
from tensor2tensor.data_generators import wsj_parsing
import importlib


# Problem modules that require optional dependencies
# pylint: disable=g-import-not-at-top
try:
# Requires h5py
from tensor2tensor.data_generators import gene_expression
except ImportError:
pass
# pylint: enable=g-import-not-at-top
# pylint: enable=unused-import
modules = [
"tensor2tensor.data_generators.algorithmic",
"tensor2tensor.data_generators.algorithmic_math",
"tensor2tensor.data_generators.audio",
"tensor2tensor.data_generators.celeba",
"tensor2tensor.data_generators.cifar",
"tensor2tensor.data_generators.cipher",
"tensor2tensor.data_generators.cnn_dailymail",
"tensor2tensor.data_generators.desc2code",
"tensor2tensor.data_generators.fsns",
"tensor2tensor.data_generators.gene_expression",
"tensor2tensor.data_generators.gym",
"tensor2tensor.data_generators.ice_parsing",
"tensor2tensor.data_generators.imagenet",
"tensor2tensor.data_generators.imdb",
"tensor2tensor.data_generators.librispeech",
"tensor2tensor.data_generators.lm1b",
"tensor2tensor.data_generators.mnist",
"tensor2tensor.data_generators.mscoco",
"tensor2tensor.data_generators.multinli",
"tensor2tensor.data_generators.ocr",
"tensor2tensor.data_generators.problem_hparams",
"tensor2tensor.data_generators.ptb",
"tensor2tensor.data_generators.snli",
"tensor2tensor.data_generators.squad",
"tensor2tensor.data_generators.translate_encs",
"tensor2tensor.data_generators.translate_ende",
"tensor2tensor.data_generators.translate_enfr",
"tensor2tensor.data_generators.translate_enmk",
"tensor2tensor.data_generators.translate_envi",
"tensor2tensor.data_generators.translate_enzh",
"tensor2tensor.data_generators.twentybn",
"tensor2tensor.data_generators.wiki",
"tensor2tensor.data_generators.wsj_parsing",
]


for module in modules:
try:
importlib.import_module(module)
except ImportError as error:
print("Did not import module: %s; Cause: %s" % (module, str(error)))
114 changes: 112 additions & 2 deletions tensor2tensor/data_generators/cifar.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,17 @@ def cifar_generator(cifar_version, tmp_dir, training, how_many, start_from=0):
test_files = _CIFAR10_TEST_FILES
prefix = _CIFAR10_PREFIX
image_size = _CIFAR10_IMAGE_SIZE
elif cifar_version == "cifar100":
label_key = "labels"
elif cifar_version == "cifar100" or cifar_version == "cifar20":
url = _CIFAR100_URL
train_files = _CIFAR100_TRAIN_FILES
test_files = _CIFAR100_TEST_FILES
prefix = _CIFAR100_PREFIX
image_size = _CIFAR100_IMAGE_SIZE
if cifar_version == "cifar100":
label_key = "fine_labels"
else:
label_key = "coarse_labels"

_get_cifar(tmp_dir, url)
data_files = train_files if training else test_files
Expand All @@ -97,7 +102,7 @@ def cifar_generator(cifar_version, tmp_dir, training, how_many, start_from=0):
all_images.extend([
np.squeeze(images[j]).transpose((1, 2, 0)) for j in xrange(num_images)
])
labels = data["labels" if cifar_version == "cifar10" else "fine_labels"]
labels = data[label_key]
all_labels.extend([labels[j] for j in xrange(num_images)])
return image_utils.image_generator(
all_images[start_from:start_from + how_many],
Expand Down Expand Up @@ -417,3 +422,108 @@ def hparams(self, defaults, unused_model_hparams):
p.max_expected_batch_size_per_shard = 4
p.input_space_id = 1
p.target_space_id = 1


@registry.register_problem
class ImageCifar20Tune(mnist.ImageMnistTune):
"""Cifar-20 Tune."""

@property
def num_classes(self):
return 20

@property
def num_channels(self):
return 3

@property
def class_labels(self):
return [
"aquatic mammals",
"fish",
"flowers",
"food containers",
"fruit and vegetables",
"household electrical devices",
"household furniture",
"insects",
"large carnivores",
"large man-made outdoor things",
"large natural outdoor scenes",
"large omnivores and herbivores",
"medium-sized mammals",
"non-insect invertebrates",
"people",
"reptiles",
"small mammals",
"trees",
"vehicles 1",
"vehicles 2",
]

def preprocess_example(self, example, mode, unused_hparams):
image = example["inputs"]
image.set_shape([_CIFAR100_IMAGE_SIZE, _CIFAR100_IMAGE_SIZE, 3])
if mode == tf.estimator.ModeKeys.TRAIN:
image = image_utils.cifar_image_augmentation(image)
if not self._was_reversed:
image = tf.image.per_image_standardization(image)
example["inputs"] = image
return example

def generator(self, data_dir, tmp_dir, is_training):
if is_training:
return cifar_generator("cifar20", tmp_dir, True, 48000)
else:
return cifar_generator("cifar20", tmp_dir, True, 2000, 48000)


@registry.register_problem
class ImageCifar20(ImageCifar20Tune):

def generator(self, data_dir, tmp_dir, is_training):
if is_training:
return cifar_generator("cifar20", tmp_dir, True, 50000)
else:
return cifar_generator("cifar20", tmp_dir, False, 10000)


@registry.register_problem
class ImageCifar20Plain(ImageCifar20):

def preprocess_example(self, example, mode, unused_hparams):
image = example["inputs"]
image.set_shape([_CIFAR100_IMAGE_SIZE, _CIFAR100_IMAGE_SIZE, 3])
if not self._was_reversed:
image = tf.image.per_image_standardization(image)
example["inputs"] = image
return example


@registry.register_problem
class ImageCifar20PlainGen(ImageCifar20Plain):
"""CIFAR-20 32x32 for image generation without standardization preprep."""

def dataset_filename(self):
return "image_cifar20_plain" # Reuse CIFAR-20 plain data.

def preprocess_example(self, example, mode, unused_hparams):
example["inputs"].set_shape([_CIFAR100_IMAGE_SIZE, _CIFAR100_IMAGE_SIZE, 3])
example["inputs"] = tf.to_int64(example["inputs"])
return example


@registry.register_problem
class ImageCifar20Plain8(ImageCifar20):
"""CIFAR-20 rescaled to 8x8 for output: Conditional image generation."""

def dataset_filename(self):
return "image_cifar20_plain" # Reuse CIFAR-20 plain data.

def preprocess_example(self, example, mode, unused_hparams):
image = example["inputs"]
image = image_utils.resize_by_area(image, 8)
if not self._was_reversed:
image = tf.image.per_image_standardization(image)
example["inputs"] = image
return example
Loading

0 comments on commit 120315c

Please sign in to comment.