Skip to content
This repository has been archived by the owner on Jul 7, 2023. It is now read-only.

Commit

Permalink
Merge pull request #29 from rsepassi/push
Browse files Browse the repository at this point in the history
Push 1.0.6
  • Loading branch information
lukaszkaiser authored Jun 23, 2017
2 parents 9d04261 + 2f4d5b7 commit 204b359
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 11 deletions.
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,20 @@ issues](https://github.com/tensorflow/tensor2tensor/issues).
And chat with us and other users on
[Gitter](https://gitter.im/tensor2tensor/Lobby).

### Contents

* [Walkthrough](#walkthrough)
* [Installation](#installation)
* [Features](#features)
* [T2T Overview](#t2t-overview)
* [Datasets](#datasets)
* [Problems and Modalities](#problems-and-modalities)
* [Models](#models)
* [Hyperparameter Sets](#hyperparameter-sets)
* [Trainer](#trainer)
* [Adding your own components](#adding-your-own-components)
* [Adding a dataset](#adding-a-dataset)

---

## Walkthrough
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setup(
name='tensor2tensor',
version='1.0.5',
version='1.0.6',
description='Tensor2Tensor',
author='Google Inc.',
author_email='[email protected]',
Expand Down
13 changes: 7 additions & 6 deletions tensor2tensor/data_generators/generator_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,15 +127,15 @@ def generate_files(generator,


def download_report_hook(count, block_size, total_size):
"""Report hook for download progress
"""Report hook for download progress.
Args:
count: current block number
block_size: block size
total_size: total size
"""
percent = int(count*block_size*100/total_size)
print("\r%d%%" % percent + ' completed', end='\r')
percent = int(count * block_size * 100 / total_size)
print("\r%d%%" % percent + " completed", end="\r")


def maybe_download(directory, filename, url):
Expand All @@ -155,11 +155,12 @@ def maybe_download(directory, filename, url):
filepath = os.path.join(directory, filename)
if not tf.gfile.Exists(filepath):
tf.logging.info("Downloading %s to %s" % (url, filepath))
filepath, _ = urllib.urlretrieve(url, filepath,
reporthook=download_report_hook)

inprogress_filepath = filepath + ".incomplete"
inprogress_filepath, _ = urllib.urlretrieve(url, inprogress_filepath,
reporthook=download_report_hook)
# Print newline to clear the carriage return from the download progress
print()
tf.gfile.Rename(inprogress_filepath, filepath)
statinfo = os.stat(filepath)
tf.logging.info("Succesfully downloaded %s, %s bytes." % (filename,
statinfo.st_size))
Expand Down
7 changes: 5 additions & 2 deletions tensor2tensor/models/common_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1079,6 +1079,7 @@ def conv_hidden_relu(inputs,
hidden_size,
output_size,
kernel_size=(1, 1),
second_kernel_size=(1, 1),
summaries=True,
dropout=0.0,
**kwargs):
Expand All @@ -1090,7 +1091,8 @@ def conv_hidden_relu(inputs,
inputs = tf.expand_dims(inputs, 2)
else:
is_3d = False
h = conv(
conv_f1 = conv if kernel_size == (1, 1) else separable_conv
h = conv_f1(
inputs,
hidden_size,
kernel_size,
Expand All @@ -1103,7 +1105,8 @@ def conv_hidden_relu(inputs,
tf.summary.histogram("hidden_density_logit",
relu_density_logit(
h, list(range(inputs.shape.ndims - 1))))
ret = conv(h, output_size, (1, 1), name="conv2", **kwargs)
conv_f2 = conv if second_kernel_size == (1, 1) else separable_conv
ret = conv_f2(h, output_size, second_kernel_size, name="conv2", **kwargs)
if is_3d:
ret = tf.squeeze(ret, 2)
return ret
Expand Down
10 changes: 9 additions & 1 deletion tensor2tensor/models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,15 @@ def transformer_ffn_layer(x, hparams):
hparams.filter_size,
hparams.num_heads,
hparams.attention_dropout)
elif hparams.ffn_layer == "conv_hidden_relu_with_sepconv":
return common_layers.conv_hidden_relu(
x,
hparams.filter_size,
hparams.hidden_size,
kernel_size=(3, 1),
second_kernel_size=(31, 1),
padding="LEFT",
dropout=hparams.relu_dropout)
else:
assert hparams.ffn_layer == "none"
return x
Expand Down Expand Up @@ -342,7 +351,6 @@ def transformer_parsing_base():
hparams.learning_rate_warmup_steps = 16000
hparams.hidden_size = 1024
hparams.learning_rate = 0.05
hparams.residual_dropout = 0.1
hparams.shared_embedding_and_softmax_weights = int(False)
return hparams

Expand Down
4 changes: 3 additions & 1 deletion tensor2tensor/utils/data_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,14 +100,16 @@ def examples_queue(data_sources,
with tf.name_scope("examples_queue"):
# Read serialized examples using slim parallel_reader.
num_epochs = None if training else 1
data_files = tf.contrib.slim.parallel_reader.get_data_files(data_sources)
num_readers = min(4 if training else 1, len(data_files))
_, example_serialized = tf.contrib.slim.parallel_reader.parallel_read(
data_sources,
tf.TFRecordReader,
num_epochs=num_epochs,
shuffle=training,
capacity=2 * capacity,
min_after_dequeue=capacity,
num_readers=4 if training else 1)
num_readers=num_readers)

if data_items_to_decoders is None:
data_items_to_decoders = {
Expand Down

0 comments on commit 204b359

Please sign in to comment.