Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

edit based on ava_nmt #3

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions ava_nmt/build_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@
import sequencing as sq
from sequencing import MODE, TIME_MAJOR

def build_vocab(vocab_file, embedding_dim, delimiter=' '):
# construct vocab
with open(vocab_file, 'r') as f:
symbols = [s[:-1] for s in f.readlines()]
vocab = sq.Vocab(symbols, embedding_dim, delimiter)
return vocab

def build_source_char_inputs(src_vocab, src_data_file,
batch_size, buffer_size=16,
mode=MODE.TRAIN):
Expand Down Expand Up @@ -75,7 +82,7 @@ def _source_generator():
for li, l in enumerate(lines)],
dtype=numpy.int32)

src_sample_matrix_np = numpy.zeros((num_lines, max_word_length,
src_sample_matrix_np = numpy.zeros((num_lines, max_word_length,
src_len_np.max()),
dtype=numpy.float32)

Expand Down Expand Up @@ -166,7 +173,7 @@ def _parallel_generator():
tf.logging.info('Read from head ......')
src_data.close()
trg_data.close()

# shuf and reopen
tf.logging.info('Shuffling ......')
subprocess.call(['./shuffle_data.sh', src_data_file, trg_data_file])
Expand Down Expand Up @@ -220,7 +227,7 @@ def _parallel_generator():
trg_len_np = numpy.asarray([len(l[1]) for l in lines],
dtype=numpy.int32)

src_sample_matrix_np = numpy.zeros((num_lines, max_word_length,
src_sample_matrix_np = numpy.zeros((num_lines, max_word_length,
src_len_np.max()),
dtype=numpy.float32)

Expand Down Expand Up @@ -271,4 +278,3 @@ def _parallel_generator():
yield read_buffer.pop()

return _parallel_generator()

12 changes: 5 additions & 7 deletions ava_nmt/build_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def cross_entropy_sequence_loss(logits, targets, sequence_length):
losses = entropy_losses * loss_mask
# losses.shape: T * B
# sequence_length: B
total_loss_avg = tf.reduce_sum(losses) / batch_size
total_loss_avg = tf.reduce_sum(losses) / batch_size

return total_loss_avg

Expand Down Expand Up @@ -200,8 +200,8 @@ def build_attention_model(params, src_vocab, trg_vocab,
source_char_embedded = source_char_embedding_table(source_ids)

# encode char to word
char_encoder = sq.StackRNNEncoder(params['char_encoder'],
params['attention_key_size']['char'],
char_encoder = sq.StackRNNEncoder(params['encoder'],
params['encoder']['attention_key_size'],
name='char_rnn',
mode=mode)

Expand All @@ -227,14 +227,14 @@ def build_attention_model(params, src_vocab, trg_vocab,
char_attention_length = char_encoded_representation.attention_length

encoder = sq.StackBidirectionalRNNEncoder(params['encoder'],
params['attention_key_size']['word'],
params['encoder']['attention_key_size'],
name='stack_rnn',
mode=mode)
encoded_representation = encoder.encode(source_embedded, source_word_seq_length)
attention_keys = encoded_representation.attention_keys
attention_values = encoded_representation.attention_values
attention_length = encoded_representation.attention_length
encoder_final_states_bw = encoded_representation.final_state[-1][-1].h
encoder_final_states_bw = encoded_representation.final_state[-1][-1]

# feedback
if mode == MODE.RL:
Expand Down Expand Up @@ -412,5 +412,3 @@ def build_attention_model(params, src_vocab, trg_vocab,
sequence_length=target_seq_length)
return decoder_output, total_loss_avg, total_loss_avg, \
tf.to_float(0.), tf.to_float(0.)


8 changes: 4 additions & 4 deletions ava_nmt/nmt_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import numpy
import tensorflow as tf

import config
import config_example
from build_inputs import build_source_char_inputs
from build_model import build_attention_model
from sequencing import TIME_MAJOR, MODE, optimistic_restore
Expand Down Expand Up @@ -95,7 +95,7 @@ def infer(src_vocab, src_data_file, trg_vocab,
decoder_output_eval.beam_ids,
decoder_final_state.log_probs],
feed_dict=feed_dict)

src_len_np = current_input_dict['src_len']
data_batch_size = len(src_len_np)

Expand Down Expand Up @@ -139,7 +139,7 @@ def infer(src_vocab, src_data_file, trg_vocab,
if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO)

all_configs = [i for i in dir(config) if i.startswith('config_')]
all_configs = [i for i in dir(config_example) if i.startswith('config_')]

parser = argparse.ArgumentParser(description='Sequencing Training ...')
parser.add_argument('--config', choices=all_configs,
Expand All @@ -153,7 +153,7 @@ def infer(src_vocab, src_data_file, trg_vocab,
default='test.out')

args = parser.parse_args()
training_configs = getattr(config, args.config)()
training_configs = getattr(config_example, args.config)()

test_src_file = args.test_src if args.test_src else training_configs.test_src_file

Expand Down
15 changes: 9 additions & 6 deletions ava_nmt/nmt_train.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

import tensorflow as tf

import config
import config_example
from build_inputs import build_parallel_char_inputs
from build_model import build_attention_model
from sequencing import MODE, TIME_MAJOR, optimistic_restore
Expand Down Expand Up @@ -68,7 +68,7 @@ def train(src_vocab, src_data_file, trg_vocab, trg_data_file,

# attention model for training
_, total_loss_avg, entropy_loss_avg, reward_loss_rmse, reward_predicted = \
build_attention_model(params, src_vocab, trg_vocab,
build_attention_model(params, src_vocab, trg_vocab,
source_placeholders,
target_placeholders,
mode=mode,
Expand Down Expand Up @@ -148,7 +148,7 @@ def train(src_vocab, src_data_file, trg_vocab, trg_data_file,
reward_predicted, global_step_tensor],
feed_dict=feed_dict)
train_writer.add_summary(summary, global_step)

if numpy.isnan(gradients_norm_np):
print(gradients_norm_np, gradients_np)
break
Expand All @@ -172,6 +172,9 @@ def train(src_vocab, src_data_file, trg_vocab, trg_data_file,
for i in range(10):
pids = predicted_ids_np[:, i].tolist()
if TIME_MAJOR:
print(TIME_MAJOR)
print("shape",predicted_ids_np.shape)
print("pid is", pids)
sids = current_input_dict['src'][:, i].tolist()
tids = current_input_dict['trg'][:, i].tolist()
else:
Expand All @@ -189,7 +192,7 @@ def train(src_vocab, src_data_file, trg_vocab, trg_data_file,
if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO)

all_configs = [i for i in dir(config) if i.startswith('config_')]
all_configs = [i for i in dir(config_example) if i.startswith('config_')]

parser = argparse.ArgumentParser(description='Sequencing Training ...')
parser.add_argument('--config', choices=all_configs,
Expand All @@ -200,8 +203,8 @@ def train(src_vocab, src_data_file, trg_vocab, trg_data_file,

args = parser.parse_args()

training_configs = getattr(config, args.config)()

training_configs = getattr(config_example, args.config)()
print("training_configs.params is", training_configs.params)
if args.mode == 'rl':
mode = MODE.RL
else:
Expand Down
16 changes: 8 additions & 8 deletions sequencing/data/vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,18 @@ def __init__(self, tokens, embedding_dim, delimiter=' ', vocab_size=None,
self.unk_token, self.unk_id = unk_token, 0
self.bos_token, self.bos_id = bos_token, 1
self.eos_token, self.eos_id = eos_token, 2

self.token_to_id_dict = {token: token_id + 3 for token_id, token in
self.space_token, self.space_id = " ", 4
self.token_to_id_dict = {token: token_id + 4 for token_id, token in
enumerate(tokens[:vocab_size])}
self.token_to_id_dict.update(**{self.unk_token: self.unk_id,
self.bos_token: self.bos_id,
self.eos_token: self.eos_id})
self.eos_token: self.eos_id,
self.space_token:self.space_id})

self.id_to_token_dict = {v: k for k, v in self.token_to_id_dict.items()}
self.vocab_size = vocab_size + 3
if self.delimiter == '':
self.space_id = self.token_to_id_dict[' ']
self.vocab_size = vocab_size + 4
#if self.delimiter == '':
self.space_id = self.token_to_id_dict[' ']

def _map_token_to_id_with_unk(self, token):
try:
Expand All @@ -55,7 +56,7 @@ def _token_to_id(self, tokens):
return list(map(self._map_token_to_id_with_unk, tokens))

def string_to_ids(self, token_string, bos=False):
if self.delimiter:
if self.delimiter: # delimiter=" "
tokens = token_string.strip().split(self.delimiter)
else:
# delimiter is '', character-level
Expand Down Expand Up @@ -84,4 +85,3 @@ def build_vocab(vocab_file, embedding_dim, delimiter=' ',
symbols = [s.split('\t')[0] for s in f.readlines()]
vocab = Vocab(symbols, embedding_dim, delimiter, vocab_size)
return vocab