From 7fb6cdb47d51e4d550a9766f878622536f6736d1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Hernandez?= Date: Wed, 22 Jan 2020 15:25:18 +0100 Subject: [PATCH 01/16] create build_generator function --- onmt/model_builder.py | 52 ++++++++++++++++++++++++++----------------- 1 file changed, 32 insertions(+), 20 deletions(-) diff --git a/onmt/model_builder.py b/onmt/model_builder.py index 43cd9731c8..7a8539ea50 100644 --- a/onmt/model_builder.py +++ b/onmt/model_builder.py @@ -42,11 +42,14 @@ def build_embeddings(opt, text_field, for_encoder=True): word_padding_idx, feat_pad_indices = pad_indices[0], pad_indices[1:] num_embs = [len(f.vocab) for _, f in text_field] + print("NUM EMBS", num_embs) num_word_embeddings, num_feat_embeddings = num_embs[0], num_embs[1:] fix_word_vecs = opt.fix_word_vecs_enc if for_encoder \ else opt.fix_word_vecs_dec + print("FIELD", text_field.fields) + emb = Embeddings( word_vec_size=emb_dim, position_encoding=opt.position_encoding, @@ -87,6 +90,34 @@ def build_decoder(opt, embeddings): else opt.decoder_type return str2dec[dec_type].from_opt(opt, embeddings) +def build_generator(model_opt, fields): + # print(fields['tgt'].fields) + gen_sizes = [len(field[1].vocab) for field in fields['tgt'].fields] + print(gen_sizes) + exit() + if not model_opt.copy_attn: + if model_opt.generator_function == "sparsemax": + gen_func = onmt.modules.sparse_activations.LogSparsemax(dim=-1) + else: + gen_func = nn.LogSoftmax(dim=-1) + generator = nn.Sequential( + nn.Linear(model_opt.dec_rnn_size, + len(fields["tgt"].base_field.vocab)), + Cast(torch.float32), + gen_func + ) + if model_opt.share_decoder_embeddings: + generator[0].weight = decoder.embeddings.word_lut.weight + else: + tgt_base_field = fields["tgt"].base_field + vocab_size = len(tgt_base_field.vocab) + pad_idx = tgt_base_field.vocab.stoi[tgt_base_field.pad_token] + generator = CopyGenerator(model_opt.dec_rnn_size, vocab_size, pad_idx) + if model_opt.share_decoder_embeddings: + generator.linear.weight = decoder.embeddings.word_lut.weight + + return generator + def load_test_model(opt, model_path=None): if model_path is None: @@ -172,26 +203,7 @@ def build_base_model(model_opt, fields, gpu, checkpoint=None, gpu_id=None): model = onmt.models.NMTModel(encoder, decoder) # Build Generator. - if not model_opt.copy_attn: - if model_opt.generator_function == "sparsemax": - gen_func = onmt.modules.sparse_activations.LogSparsemax(dim=-1) - else: - gen_func = nn.LogSoftmax(dim=-1) - generator = nn.Sequential( - nn.Linear(model_opt.dec_rnn_size, - len(fields["tgt"].base_field.vocab)), - Cast(torch.float32), - gen_func - ) - if model_opt.share_decoder_embeddings: - generator[0].weight = decoder.embeddings.word_lut.weight - else: - tgt_base_field = fields["tgt"].base_field - vocab_size = len(tgt_base_field.vocab) - pad_idx = tgt_base_field.vocab.stoi[tgt_base_field.pad_token] - generator = CopyGenerator(model_opt.dec_rnn_size, vocab_size, pad_idx) - if model_opt.share_decoder_embeddings: - generator.linear.weight = decoder.embeddings.word_lut.weight + generator = build_generator(model_opt, fields) # Load the model states from checkpoint or initialize them. if checkpoint is not None: From 44bb96248749126b2eab9a2255d57808d2d6d3ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Hernandez?= Date: Wed, 22 Jan 2020 16:27:10 +0100 Subject: [PATCH 02/16] create Generator class --- onmt/model_builder.py | 12 +++--------- onmt/modules/__init__.py | 3 ++- onmt/modules/generator.py | 27 +++++++++++++++++++++++++++ 3 files changed, 32 insertions(+), 10 deletions(-) create mode 100644 onmt/modules/generator.py diff --git a/onmt/model_builder.py b/onmt/model_builder.py index 7a8539ea50..4de5c70723 100644 --- a/onmt/model_builder.py +++ b/onmt/model_builder.py @@ -13,7 +13,8 @@ from onmt.decoders import str2dec -from onmt.modules import Embeddings, VecEmbedding, CopyGenerator +from onmt.modules import Embeddings, VecEmbedding, \ + Generator, CopyGenerator from onmt.modules.util_class import Cast from onmt.utils.misc import use_gpu from onmt.utils.logging import logger @@ -93,19 +94,12 @@ def build_decoder(opt, embeddings): def build_generator(model_opt, fields): # print(fields['tgt'].fields) gen_sizes = [len(field[1].vocab) for field in fields['tgt'].fields] - print(gen_sizes) - exit() if not model_opt.copy_attn: if model_opt.generator_function == "sparsemax": gen_func = onmt.modules.sparse_activations.LogSparsemax(dim=-1) else: gen_func = nn.LogSoftmax(dim=-1) - generator = nn.Sequential( - nn.Linear(model_opt.dec_rnn_size, - len(fields["tgt"].base_field.vocab)), - Cast(torch.float32), - gen_func - ) + generator = Generator(model_opt.rnn_size, gen_sizes, gen_func) if model_opt.share_decoder_embeddings: generator[0].weight = decoder.embeddings.word_lut.weight else: diff --git a/onmt/modules/__init__.py b/onmt/modules/__init__.py index 763ac8448a..646217fa69 100644 --- a/onmt/modules/__init__.py +++ b/onmt/modules/__init__.py @@ -3,6 +3,7 @@ from onmt.modules.gate import context_gate_factory, ContextGate from onmt.modules.global_attention import GlobalAttention from onmt.modules.conv_multi_step_attention import ConvMultiStepAttention +from onmt.modules.generator import Generator from onmt.modules.copy_generator import CopyGenerator, CopyGeneratorLoss, \ CopyGeneratorLossCompute from onmt.modules.multi_headed_attn import MultiHeadedAttention @@ -13,6 +14,6 @@ __all__ = ["Elementwise", "context_gate_factory", "ContextGate", "GlobalAttention", "ConvMultiStepAttention", "CopyGenerator", - "CopyGeneratorLoss", "CopyGeneratorLossCompute", + "Generator", "CopyGeneratorLoss", "CopyGeneratorLossCompute", "MultiHeadedAttention", "Embeddings", "PositionalEncoding", "WeightNormConv2d", "AverageAttention", "VecEmbedding"] diff --git a/onmt/modules/generator.py b/onmt/modules/generator.py new file mode 100644 index 0000000000..e4f8663899 --- /dev/null +++ b/onmt/modules/generator.py @@ -0,0 +1,27 @@ +""" Onmt NMT Model base class definition """ +import torch +import torch.nn as nn + +from torch.nn.modules.module import _addindent + +from onmt.modules.util_class import Cast + + +class Generator(nn.Module): + def __init__(self, rnn_size, sizes, gen_func): + super(Generator, self).__init__() + self.generators = nn.ModuleList() + for i, size in enumerate(sizes): + self.generators.append( + nn.Sequential( + nn.Linear(rnn_size, + size), + Cast(torch.float32), + gen_func + )) + + def forward(self, dec_out): + return [generator(dec_out) for generator in self.generators] + + def __getitem__(self, i): + return self.generators[0][i] From 772cee9636f5a6b6f81c092189520cfee3a36dec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Hernandez?= Date: Wed, 22 Jan 2020 16:51:27 +0100 Subject: [PATCH 03/16] remove some prints/comments --- onmt/model_builder.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/onmt/model_builder.py b/onmt/model_builder.py index 4de5c70723..330e033ef2 100644 --- a/onmt/model_builder.py +++ b/onmt/model_builder.py @@ -43,14 +43,11 @@ def build_embeddings(opt, text_field, for_encoder=True): word_padding_idx, feat_pad_indices = pad_indices[0], pad_indices[1:] num_embs = [len(f.vocab) for _, f in text_field] - print("NUM EMBS", num_embs) num_word_embeddings, num_feat_embeddings = num_embs[0], num_embs[1:] fix_word_vecs = opt.fix_word_vecs_enc if for_encoder \ else opt.fix_word_vecs_dec - print("FIELD", text_field.fields) - emb = Embeddings( word_vec_size=emb_dim, position_encoding=opt.position_encoding, @@ -92,7 +89,6 @@ def build_decoder(opt, embeddings): return str2dec[dec_type].from_opt(opt, embeddings) def build_generator(model_opt, fields): - # print(fields['tgt'].fields) gen_sizes = [len(field[1].vocab) for field in fields['tgt'].fields] if not model_opt.copy_attn: if model_opt.generator_function == "sparsemax": From db356cfe490f81fb12422c6cb3200c061067d44d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Hernandez?= Date: Thu, 23 Jan 2020 11:50:14 +0100 Subject: [PATCH 04/16] create OnmtBatch class --- onmt/inputters/inputter.py | 30 +++++++++++++++++++++++++----- 1 file changed, 25 insertions(+), 5 deletions(-) diff --git a/onmt/inputters/inputter.py b/onmt/inputters/inputter.py index d4c43342c0..dcc41773e7 100644 --- a/onmt/inputters/inputter.py +++ b/onmt/inputters/inputter.py @@ -582,6 +582,26 @@ def _pool(data, batch_size, batch_size_fn, batch_size_multiple, yield b +class OnmtBatch(torchtext.data.Batch): + def __init__(self, data=None, dataset=None, device=None): + super(OnmtBatch, self).__init__(data, dataset, device) + # we need to shift target features if needed + if self.tgt.size(-1) > 1: + # tokens: [ len x batch x 1] + tokens = self.tgt[:,:,0].unsqueeze(-1) + # feats: [ len x batch x num_feats ] + feats = self.tgt[:,:,1:] + # shift feats one step to the right + feats = torch.cat(( + feats[-1,:,:].unsqueeze(0), + feats[:-1,:,:] + )) + # build back target tensor + self.tgt = torch.cat(( + tokens, + feats + ), dim=-1) + class OrderedIterator(torchtext.data.Iterator): def __init__(self, @@ -627,7 +647,7 @@ def __iter__(self): """ Extended version of the definition in torchtext.data.Iterator. Added yield_raw_example behaviour to yield a torchtext.data.Example - instead of a torchtext.data.Batch object. + instead of an OnmtBatch object. """ while True: self.init_epoch() @@ -648,7 +668,7 @@ def __iter__(self): if self.yield_raw_example: yield minibatch[0] else: - yield torchtext.data.Batch( + yield OnmtBatch( minibatch, self.dataset, self.device) @@ -709,9 +729,9 @@ def __iter__(self): self.random_shuffler, self.pool_factor): minibatch = sorted(minibatch, key=self.sort_key, reverse=True) - yield torchtext.data.Batch(minibatch, - self.iterables[0].dataset, - self.device) + yield OnmtBatch(minibatch, + self.iterables[0].dataset, + self.device) class DatasetLazyIter(object): From 797bef6c1998d5bf3d571791a200b99adad92f9f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Hernandez?= Date: Thu, 23 Jan 2020 15:27:07 +0100 Subject: [PATCH 05/16] handle no tgt case in OnmtBatch --- onmt/inputters/inputter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onmt/inputters/inputter.py b/onmt/inputters/inputter.py index dcc41773e7..e8d210dfff 100644 --- a/onmt/inputters/inputter.py +++ b/onmt/inputters/inputter.py @@ -586,7 +586,7 @@ class OnmtBatch(torchtext.data.Batch): def __init__(self, data=None, dataset=None, device=None): super(OnmtBatch, self).__init__(data, dataset, device) # we need to shift target features if needed - if self.tgt.size(-1) > 1: + if hasattr(self, 'tgt') and self.tgt.size(-1) > 1: # tokens: [ len x batch x 1] tokens = self.tgt[:,:,0].unsqueeze(-1) # feats: [ len x batch x num_feats ] From ef21fde4d27f3a8ec18ee3f5de821f2553c05d4e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Hernandez?= Date: Thu, 23 Jan 2020 17:13:10 +0100 Subject: [PATCH 06/16] greedy decoding approximately working --- onmt/translate/decode_strategy.py | 14 ++++++++++---- onmt/translate/greedy_search.py | 32 ++++++++++++++++++++++++++----- onmt/translate/translation.py | 25 +++++++++++++++++------- onmt/translate/translator.py | 10 ++++++++-- 4 files changed, 63 insertions(+), 18 deletions(-) diff --git a/onmt/translate/decode_strategy.py b/onmt/translate/decode_strategy.py index 828e8a4d51..b78c5b6d87 100644 --- a/onmt/translate/decode_strategy.py +++ b/onmt/translate/decode_strategy.py @@ -83,7 +83,8 @@ def __init__(self, pad, bos, eos, batch_size, parallel_paths, self.done = False - def initialize(self, memory_bank, src_lengths, src_map=None, device=None): + def initialize(self, memory_bank, src_lengths, num_features, + src_map=None, device=None): """DecodeStrategy subclasses should override :func:`initialize()`. `initialize` should be called before all actions. @@ -91,20 +92,25 @@ def initialize(self, memory_bank, src_lengths, src_map=None, device=None): """ if device is None: device = torch.device('cpu') + # initialize to [ batch*beam x num_feats x 1] self.alive_seq = torch.full( - [self.batch_size * self.parallel_paths, 1], self.bos, + [self.batch_size * self.parallel_paths, num_features, 1], self.bos, dtype=torch.long, device=device) self.is_finished = torch.zeros( [self.batch_size, self.parallel_paths], dtype=torch.uint8, device=device) + # initialize features (we need to know num_features) + self.features = [[[] for _ in range(num_features - 1)] + for _ in range(self.batch_size)] return None, memory_bank, src_lengths, src_map def __len__(self): - return self.alive_seq.shape[1] + return self.alive_seq.shape[-1] def ensure_min_length(self, log_probs): if len(self) <= self.min_length: - log_probs[:, self.eos] = -1e20 + for probs in log_probs: + probs[:, self.eos] = -1e20 def ensure_max_length(self): # add one to account for BOS. Don't account for EOS because hitting diff --git a/onmt/translate/greedy_search.py b/onmt/translate/greedy_search.py index 8ebef32e15..354e34a232 100644 --- a/onmt/translate/greedy_search.py +++ b/onmt/translate/greedy_search.py @@ -91,7 +91,7 @@ def __init__(self, pad, bos, eos, batch_size, min_length, self.keep_topk = keep_topk self.topk_scores = None - def initialize(self, memory_bank, src_lengths, src_map=None, device=None): + def initialize(self, memory_bank, src_lengths, num_features, src_map=None, device=None): """Initialize for decoding.""" fn_map_state = None @@ -104,7 +104,7 @@ def initialize(self, memory_bank, src_lengths, src_map=None, device=None): self.memory_lengths = src_lengths super(GreedySearch, self).initialize( - memory_bank, src_lengths, src_map, device) + memory_bank, src_lengths, num_features, src_map, device) self.select_indices = torch.arange( self.batch_size, dtype=torch.long, device=device) self.original_batch_idx = torch.arange( @@ -113,7 +113,7 @@ def initialize(self, memory_bank, src_lengths, src_map=None, device=None): @property def current_predictions(self): - return self.alive_seq[:, -1] + return self.alive_seq[:, :, -1] @property def batch_offset(self): @@ -131,6 +131,19 @@ def advance(self, log_probs, attn): to 1.) attn (FloatTensor): Shaped ``(1, B, inp_seq_len)``. """ + # print("PROBS", [toto.size() for toto in log_probs]) + # print("ALIVE", self.alive_seq.size()) + + # we need to get the feature first + if len(log_probs) > 1: + features_id = [] + for logits in log_probs[1:]: + features_id.append(logits.topk(1, dim=-1)[1]) + features_id = torch.cat(features_id, dim=-1) + else: + features_id = None + # keep only log probs for tokens + log_probs = log_probs[0] self.ensure_min_length(log_probs) self.block_ngram_repeats(log_probs) @@ -139,7 +152,12 @@ def advance(self, log_probs, attn): self.is_finished = topk_ids.eq(self.eos) - self.alive_seq = torch.cat([self.alive_seq, topk_ids], -1) + if features_id is not None: + topk_ids = torch.cat(( + topk_ids, features_id + ), dim=-1) + + self.alive_seq = torch.cat([self.alive_seq, topk_ids.unsqueeze(-1)], -1) if self.return_attention: if self.alive_attn is None: self.alive_attn = attn @@ -147,6 +165,7 @@ def advance(self, log_probs, attn): self.alive_attn = torch.cat([self.alive_attn, attn], 0) self.ensure_max_length() + def update_finished(self): """Finalize scores and predictions.""" # shape: (sum(~ self.is_finished), 1) @@ -154,7 +173,10 @@ def update_finished(self): for b in finished_batches.view(-1): b_orig = self.original_batch_idx[b] self.scores[b_orig].append(self.topk_scores[b, 0]) - self.predictions[b_orig].append(self.alive_seq[b, 1:]) + self.predictions[b_orig].append(self.alive_seq[b, 0, 1:]) + # check on first item of the batch ot get num_features + for i in range(len(self.features[0])): + self.features[b_orig][i].append(self.alive_seq[b, 1+i, 1:]) self.attention[b_orig].append( self.alive_attn[:, b, :self.memory_lengths[b]] if self.alive_attn is not None else []) diff --git a/onmt/translate/translation.py b/onmt/translate/translation.py index 21eeb91e96..fca2ef1109 100644 --- a/onmt/translate/translation.py +++ b/onmt/translate/translation.py @@ -33,15 +33,25 @@ def __init__(self, data, fields, n_best=1, replace_unk=False, self.phrase_table = phrase_table self.has_tgt = has_tgt - def _build_target_tokens(self, src, src_vocab, src_raw, pred, attn): - tgt_field = dict(self.fields)["tgt"].base_field + def _build_target_tokens(self, src, src_vocab, src_raw, pred, all_feats, attn): + # feats need do be shifted back one step to the left + all_feats = [list(feat[1:]) + [feat[0]] for feat in all_feats] # TODO find a better way + tgt_fields = dict(self.fields)["tgt"] + tgt_field = tgt_fields.base_field vocab = tgt_field.vocab + feats_vocabs = [field.vocab for name, field in tgt_fields.fields[1:]] tokens = [] - for tok in pred: + for tok_feats in zip(pred, *all_feats): + tok = tok_feats[0] if tok < len(vocab): - tokens.append(vocab.itos[tok]) + token = vocab.itos[tok] else: - tokens.append(src_vocab.itos[tok - len(vocab)]) + token = src_vocab.itos[tok - len(vocab)] + if len(tok_feats) > 1: + feats = tok_feats[1:] + for feat, fv in zip(feats, feats_vocabs): + token += u"│" + fv.itos[feat] + tokens.append(token) if tokens[-1] == tgt_field.eos_token: tokens = tokens[:-1] break @@ -63,8 +73,9 @@ def from_batch(self, translation_batch): len(translation_batch["predictions"])) batch_size = batch.batch_size - preds, pred_score, attn, align, gold_score, indices = list(zip( + preds, feats, pred_score, attn, align, gold_score, indices = list(zip( *sorted(zip(translation_batch["predictions"], + translation_batch["features"], translation_batch["scores"], translation_batch["attention"], translation_batch["alignment"], @@ -96,7 +107,7 @@ def from_batch(self, translation_batch): pred_sents = [self._build_target_tokens( src[:, b] if src is not None else None, src_vocab, src_raw, - preds[b][n], attn[b][n]) + preds[b][n], feats[b][n], attn[b][n]) for n in range(self.n_best)] gold_sent = None if tgt is not None: diff --git a/onmt/translate/translator.py b/onmt/translate/translator.py index 7fe4c5c09b..94de29666c 100644 --- a/onmt/translate/translator.py +++ b/onmt/translate/translator.py @@ -646,6 +646,7 @@ def _translate_batch_with_strategy( results = { "predictions": None, + "features": None, "scores": None, "attention": None, "batch": batch, @@ -654,15 +655,17 @@ def _translate_batch_with_strategy( enc_states, batch_size, src)} # (2) prep decode_strategy. Possibly repeat src objects. + num_features = batch.src[0].size(-1) src_map = batch.src_map if use_src_map else None fn_map_state, memory_bank, memory_lengths, src_map = \ - decode_strategy.initialize(memory_bank, src_lengths, src_map) + decode_strategy.initialize(memory_bank, src_lengths, + num_features, src_map) if fn_map_state is not None: self.model.decoder.map_state(fn_map_state) # (3) Begin decoding step by step: for step in range(decode_strategy.max_length): - decoder_input = decode_strategy.current_predictions.view(1, -1, 1) + decoder_input = decode_strategy.current_predictions.view(1, -1, num_features) log_probs, attn = self._decode_and_generate( decoder_input, @@ -674,6 +677,8 @@ def _translate_batch_with_strategy( step=step, batch_offset=decode_strategy.batch_offset) + # print("PROBS", [item.size() for item in log_probs]) + # Note: we may have probs over several features decode_strategy.advance(log_probs, attn) any_finished = decode_strategy.is_finished.any() if any_finished: @@ -702,6 +707,7 @@ def _translate_batch_with_strategy( results["scores"] = decode_strategy.scores results["predictions"] = decode_strategy.predictions + results["features"] = decode_strategy.features results["attention"] = decode_strategy.attention if self.report_align: results["alignment"] = self._align_forward( From 345dd1de7c678ea1befff2fdf8b2f244aca13583 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Hernandez?= Date: Thu, 23 Jan 2020 17:34:20 +0100 Subject: [PATCH 07/16] remove some print --- onmt/translate/translator.py | 1 - 1 file changed, 1 deletion(-) diff --git a/onmt/translate/translator.py b/onmt/translate/translator.py index 94de29666c..06369ccf90 100644 --- a/onmt/translate/translator.py +++ b/onmt/translate/translator.py @@ -677,7 +677,6 @@ def _translate_batch_with_strategy( step=step, batch_offset=decode_strategy.batch_offset) - # print("PROBS", [item.size() for item in log_probs]) # Note: we may have probs over several features decode_strategy.advance(log_probs, attn) any_finished = decode_strategy.is_finished.any() From 92c2f26bc7578d2a2224550bdda8eece99df30e5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Hernandez?= Date: Fri, 24 Jan 2020 10:44:47 +0100 Subject: [PATCH 08/16] use several criterions for loss compute, adapt training --- onmt/model_builder.py | 5 ++-- onmt/trainer.py | 6 ++--- onmt/utils/loss.py | 59 ++++++++++++++++++++++++++++--------------- 3 files changed, 44 insertions(+), 26 deletions(-) diff --git a/onmt/model_builder.py b/onmt/model_builder.py index 330e033ef2..c58dc312b5 100644 --- a/onmt/model_builder.py +++ b/onmt/model_builder.py @@ -88,7 +88,7 @@ def build_decoder(opt, embeddings): else opt.decoder_type return str2dec[dec_type].from_opt(opt, embeddings) -def build_generator(model_opt, fields): +def build_generator(model_opt, fields, decoder): gen_sizes = [len(field[1].vocab) for field in fields['tgt'].fields] if not model_opt.copy_attn: if model_opt.generator_function == "sparsemax": @@ -96,6 +96,7 @@ def build_generator(model_opt, fields): else: gen_func = nn.LogSoftmax(dim=-1) generator = Generator(model_opt.rnn_size, gen_sizes, gen_func) + # TODO this can't work with target features ??? if model_opt.share_decoder_embeddings: generator[0].weight = decoder.embeddings.word_lut.weight else: @@ -193,7 +194,7 @@ def build_base_model(model_opt, fields, gpu, checkpoint=None, gpu_id=None): model = onmt.models.NMTModel(encoder, decoder) # Build Generator. - generator = build_generator(model_opt, fields) + generator = build_generator(model_opt, fields, decoder) # Load the model states from checkpoint or initialize them. if checkpoint is not None: diff --git a/onmt/trainer.py b/onmt/trainer.py index 4328ca52ea..8024e732c8 100644 --- a/onmt/trainer.py +++ b/onmt/trainer.py @@ -31,10 +31,10 @@ def build_trainer(opt, device_id, model, fields, optim, model_saver=None): used to save the model """ - tgt_field = dict(fields)["tgt"].base_field - train_loss = onmt.utils.loss.build_loss_compute(model, tgt_field, opt) + tgt_fields = dict(fields)["tgt"] + train_loss = onmt.utils.loss.build_loss_compute(model, tgt_fields, opt) valid_loss = onmt.utils.loss.build_loss_compute( - model, tgt_field, opt, train=False) + model, tgt_fields, opt, train=False) trunc_size = opt.truncated_decoder # Badly named... shard_size = opt.max_generator_batches if opt.model_dtype == 'fp32' else 0 diff --git a/onmt/utils/loss.py b/onmt/utils/loss.py index c48f0d3d21..65a84b3e40 100644 --- a/onmt/utils/loss.py +++ b/onmt/utils/loss.py @@ -12,7 +12,7 @@ from onmt.modules.sparse_activations import LogSparsemax -def build_loss_compute(model, tgt_field, opt, train=True): +def build_loss_compute(model, tgt_fields, opt, train=True): """ Returns a LossCompute subclass which wraps around an nn.Module subclass (such as nn.NLLLoss) which defines the loss criterion. The LossCompute @@ -22,7 +22,7 @@ def build_loss_compute(model, tgt_field, opt, train=True): for when using a copy mechanism. """ device = torch.device("cuda" if onmt.utils.misc.use_gpu(opt) else "cpu") - + tgt_field = tgt_fields.base_field padding_idx = tgt_field.vocab.stoi[tgt_field.pad_token] unk_idx = tgt_field.vocab.stoi[tgt_field.unk_token] @@ -31,33 +31,42 @@ def build_loss_compute(model, tgt_field, opt, train=True): "order to use --lambda_coverage != 0" if opt.copy_attn: - criterion = onmt.modules.CopyGeneratorLoss( + criterions = [onmt.modules.CopyGeneratorLoss( len(tgt_field.vocab), opt.copy_attn_force, unk_index=unk_idx, ignore_index=padding_idx - ) + )] elif opt.label_smoothing > 0 and train: - criterion = LabelSmoothingLoss( + criterions = [LabelSmoothingLoss( opt.label_smoothing, len(tgt_field.vocab), ignore_index=padding_idx - ) + )] elif isinstance(model.generator[-1], LogSparsemax): - criterion = SparsemaxLoss(ignore_index=padding_idx, reduction='sum') + criterions = [SparsemaxLoss(ignore_index=padding_idx, reduction='sum')] else: - criterion = nn.NLLLoss(ignore_index=padding_idx, reduction='sum') + criterions = [nn.NLLLoss(ignore_index=padding_idx, reduction='sum')] + + # we need to add as many additional criterion as we have features + for field in tgt_fields.fields[1:]: + padding_idx = field[1].vocab.stoi[field[1].pad_token] + criterions.append( + nn.NLLLoss(ignore_index=padding_idx, reduction='sum') + ) # if the loss function operates on vectors of raw logits instead of # probabilities, only the first part of the generator needs to be # passed to the NMTLossCompute. At the moment, the only supported # loss function of this kind is the sparsemax loss. - use_raw_logits = isinstance(criterion, SparsemaxLoss) + use_raw_logits = isinstance(criterions[0], SparsemaxLoss) + # TODO make this compatible with target features !!! loss_gen = model.generator[0] if use_raw_logits else model.generator if opt.copy_attn: + # TODO make this compatible with target features... compute = onmt.modules.CopyGeneratorLossCompute( - criterion, loss_gen, tgt_field.vocab, opt.copy_loss_by_seqlength, + criterions, loss_gen, tgt_field.vocab, opt.copy_loss_by_seqlength, lambda_coverage=opt.lambda_coverage ) else: compute = NMTLossCompute( - criterion, loss_gen, lambda_coverage=opt.lambda_coverage, + criterions, loss_gen, lambda_coverage=opt.lambda_coverage, lambda_align=opt.lambda_align) compute.to(device) @@ -83,14 +92,15 @@ class LossComputeBase(nn.Module): normalzation (str): normalize by "sents" or "tokens" """ - def __init__(self, criterion, generator): + def __init__(self, criterions, generator): super(LossComputeBase, self).__init__() - self.criterion = criterion + # We may have several criterions in the case of target word features + self.criterions = criterions self.generator = generator @property def padding_idx(self): - return self.criterion.ignore_index + return self.criterions[0].ignore_index def _make_shard_state(self, batch, output, range_, attns=None): """ @@ -178,7 +188,8 @@ def _stats(self, loss, scores, target): Returns: :obj:`onmt.utils.Statistics` : statistics for this batch. """ - pred = scores.max(1)[1] + # TODO we need to add some stats for features + pred = scores[0].max(1)[1] non_padding = target.ne(self.padding_idx) num_correct = pred.eq(target).masked_select(non_padding).sum().item() num_non_padding = non_padding.sum().item() @@ -214,7 +225,7 @@ def forward(self, output, target): output (FloatTensor): batch_size x n_classes target (LongTensor): batch_size """ - model_prob = self.one_hot.repeat(target.size(0), 1) + model_prob = self.one_hot.repeat(target.size(0), 1).to(target.device) model_prob.scatter_(1, target.unsqueeze(1), self.confidence) model_prob.masked_fill_((target == self.ignore_index).unsqueeze(1), 0) @@ -226,9 +237,9 @@ class NMTLossCompute(LossComputeBase): Standard NMT Loss Computation. """ - def __init__(self, criterion, generator, normalization="sents", + def __init__(self, criterions, generator, normalization="sents", lambda_coverage=0.0, lambda_align=0.0): - super(NMTLossCompute, self).__init__(criterion, generator) + super(NMTLossCompute, self).__init__(criterions, generator) self.lambda_coverage = lambda_coverage self.lambda_align = lambda_align @@ -237,6 +248,9 @@ def _make_shard_state(self, batch, output, range_, attns=None): "output": output, "target": batch.tgt[range_[0] + 1: range_[1], :, 0], } + if batch.tgt.size(-1) > 1: + shard_state["features"] = [batch.tgt[range_[0] + 1: range_[1], :, i+1] + for i in range(batch.tgt.size(-1) - 1)] if self.lambda_coverage != 0.0: coverage = attns.get("coverage", None) std = attns.get("std", None) @@ -275,15 +289,18 @@ def _make_shard_state(self, batch, output, range_, attns=None): }) return shard_state - def _compute_loss(self, batch, output, target, std_attn=None, + def _compute_loss(self, batch, output, target, features, std_attn=None, coverage_attn=None, align_head=None, ref_align=None): bottled_output = self._bottle(output) scores = self.generator(bottled_output) gtruth = target.view(-1) - - loss = self.criterion(scores, gtruth) + loss = self.criterions[0](scores[0], gtruth) + if features is not None: + for score, crit, feat in zip(scores[1:], self.criterions[1:], features): + truth = feat.view(-1) + loss += crit(score, truth) if self.lambda_coverage != 0.0: coverage_loss = self._compute_coverage_loss( std_attn=std_attn, coverage_attn=coverage_attn) From c27a921603ebe187b486b853d47791420fedfe93 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Hernandez?= Date: Fri, 24 Jan 2020 18:14:27 +0100 Subject: [PATCH 09/16] fix preprocess multi corpus, translation eos condition --- onmt/bin/preprocess.py | 7 +++---- onmt/translate/greedy_search.py | 2 -- onmt/translate/translation.py | 2 +- 3 files changed, 4 insertions(+), 7 deletions(-) diff --git a/onmt/bin/preprocess.py b/onmt/bin/preprocess.py index 1d48a20109..699d6c456d 100755 --- a/onmt/bin/preprocess.py +++ b/onmt/bin/preprocess.py @@ -247,10 +247,9 @@ def preprocess(opt): src_nfeats = 0 tgt_nfeats = 0 - for src, tgt in zip(opt.train_src, opt.train_tgt): - src_nfeats += count_features(src) if opt.data_type == 'text' \ - else 0 - tgt_nfeats += count_features(tgt) # tgt always text so far + src_nfeats = count_features(opt.train_src[0]) if opt.data_type == 'text' \ + else 0 + tgt_nfeats = count_features(opt.train_tgt[0]) # tgt always text so far logger.info(" * number of source features: %d." % src_nfeats) logger.info(" * number of target features: %d." % tgt_nfeats) diff --git a/onmt/translate/greedy_search.py b/onmt/translate/greedy_search.py index 354e34a232..ca8b3936c2 100644 --- a/onmt/translate/greedy_search.py +++ b/onmt/translate/greedy_search.py @@ -131,8 +131,6 @@ def advance(self, log_probs, attn): to 1.) attn (FloatTensor): Shaped ``(1, B, inp_seq_len)``. """ - # print("PROBS", [toto.size() for toto in log_probs]) - # print("ALIVE", self.alive_seq.size()) # we need to get the feature first if len(log_probs) > 1: diff --git a/onmt/translate/translation.py b/onmt/translate/translation.py index fca2ef1109..5af33e9593 100644 --- a/onmt/translate/translation.py +++ b/onmt/translate/translation.py @@ -52,7 +52,7 @@ def _build_target_tokens(self, src, src_vocab, src_raw, pred, all_feats, attn): for feat, fv in zip(feats, feats_vocabs): token += u"│" + fv.itos[feat] tokens.append(token) - if tokens[-1] == tgt_field.eos_token: + if token.split(u"│")[0] == tgt_field.eos_token: tokens = tokens[:-1] break if self.replace_unk and attn is not None and src is not None: From f4481d9928fc4a806259ce879abbf3f21a91a9ef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Hernandez?= Date: Tue, 28 Jan 2020 18:32:30 +0100 Subject: [PATCH 10/16] refactor build_generator and Generator class, allow share_decoder_embeddings --- onmt/model_builder.py | 36 +++++++++++++++++------------- onmt/modules/generator.py | 46 +++++++++++++++++++++++++++++++-------- 2 files changed, 58 insertions(+), 24 deletions(-) diff --git a/onmt/model_builder.py b/onmt/model_builder.py index c58dc312b5..abce9e570b 100644 --- a/onmt/model_builder.py +++ b/onmt/model_builder.py @@ -90,22 +90,28 @@ def build_decoder(opt, embeddings): def build_generator(model_opt, fields, decoder): gen_sizes = [len(field[1].vocab) for field in fields['tgt'].fields] - if not model_opt.copy_attn: - if model_opt.generator_function == "sparsemax": - gen_func = onmt.modules.sparse_activations.LogSparsemax(dim=-1) - else: - gen_func = nn.LogSoftmax(dim=-1) - generator = Generator(model_opt.rnn_size, gen_sizes, gen_func) - # TODO this can't work with target features ??? - if model_opt.share_decoder_embeddings: - generator[0].weight = decoder.embeddings.word_lut.weight + if model_opt.share_decoder_embeddings: + rnn_sizes = ([model_opt.rnn_size - (model_opt.feat_vec_size * (len(gen_sizes) -1) )] + + [model_opt.feat_vec_size] * (len(gen_sizes) - 1)) else: - tgt_base_field = fields["tgt"].base_field - vocab_size = len(tgt_base_field.vocab) - pad_idx = tgt_base_field.vocab.stoi[tgt_base_field.pad_token] - generator = CopyGenerator(model_opt.dec_rnn_size, vocab_size, pad_idx) - if model_opt.share_decoder_embeddings: - generator.linear.weight = decoder.embeddings.word_lut.weight + rnn_sizes = [model_opt.rnn_size] * len(gen_sizes) + + if model_opt.generator_function == "sparsemax": + gen_func = onmt.modules.sparse_activations.LogSparsemax(dim=-1) + else: + gen_func = nn.LogSoftmax(dim=-1) + + tgt_base_field = fields["tgt"].base_field + pad_idx = tgt_base_field.vocab.stoi[tgt_base_field.pad_token] + generator = Generator(rnn_sizes, gen_sizes, gen_func, + shared=model_opt.share_decoder_embeddings, + copy_attn=model_opt.copy_attn, + pad_idx=pad_idx) + + if model_opt.share_decoder_embeddings: + # share the weights + for gen, emb in zip(generator.generators, decoder.embeddings.emb_luts): + gen[0].weight = emb.weight return generator diff --git a/onmt/modules/generator.py b/onmt/modules/generator.py index e4f8663899..737ad90da2 100644 --- a/onmt/modules/generator.py +++ b/onmt/modules/generator.py @@ -6,22 +6,50 @@ from onmt.modules.util_class import Cast +from onmt.modules.copy_generator import CopyGenerator + class Generator(nn.Module): - def __init__(self, rnn_size, sizes, gen_func): + def __init__(self, rnn_sizes, gen_sizes, gen_func, shared=False, copy_attn=False, pad_idx=None): super(Generator, self).__init__() self.generators = nn.ModuleList() - for i, size in enumerate(sizes): + self.shared = shared + self.rnn_sizes = rnn_sizes + self.gen_sizes = gen_sizes + + def simple_generator(rnn_size, gen_size, gen_func): + return nn.Sequential( + nn.Linear(rnn_size, gen_size), + Cast(torch.float32), + gen_func) + + # create first generator + if copy_attn: + self.generators.append( + CopyGenerator(rnn_sizes[0], gen_sizes[0], pad_idx)) + else: self.generators.append( - nn.Sequential( - nn.Linear(rnn_size, - size), - Cast(torch.float32), - gen_func - )) + simple_generator(rnn_sizes[0], gen_sizes[0], gen_func)) + + # additional generators for features + for rnn_size, gen_size in zip(rnn_sizes[1:], gen_sizes[1:]): + self.generators.append( + simple_generator(rnn_size, gen_size, gen_func)) + def forward(self, dec_out): - return [generator(dec_out) for generator in self.generators] + # if shared_decoder_embeddings, we slice the decoder output + if self.shared: + outs = [] + offset = 0 + for generator, s in zip(self.generators, self.rnn_sizes): + sliced_dec_out = dec_out[:,offset:offset+s] + out = generator(sliced_dec_out) + offset += s + outs.append(out) + return outs + else: + return [generator(dec_out) for generator in self.generators] def __getitem__(self, i): return self.generators[0][i] From fe2dc10dbae15776b6d6e9ef5ae290bd702b87b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Hernandez?= Date: Tue, 28 Jan 2020 18:38:59 +0100 Subject: [PATCH 11/16] default features to None in compute_loss --- onmt/utils/loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onmt/utils/loss.py b/onmt/utils/loss.py index 65a84b3e40..df4b3175ad 100644 --- a/onmt/utils/loss.py +++ b/onmt/utils/loss.py @@ -289,7 +289,7 @@ def _make_shard_state(self, batch, output, range_, attns=None): }) return shard_state - def _compute_loss(self, batch, output, target, features, std_attn=None, + def _compute_loss(self, batch, output, target, features=None, std_attn=None, coverage_attn=None, align_head=None, ref_align=None): bottled_output = self._bottle(output) From e83f487baaf9ce44961aa168340842460acb9c19 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Hernandez?= Date: Tue, 28 Jan 2020 19:10:35 +0100 Subject: [PATCH 12/16] introduce feat_no_time_shift flag, fix translation for no feat case --- onmt/inputters/inputter.py | 52 +++++++++++++++++++++-------------- onmt/opts.py | 4 +++ onmt/translate/translation.py | 16 +++++++---- onmt/translate/translator.py | 10 +++++-- 4 files changed, 53 insertions(+), 29 deletions(-) diff --git a/onmt/inputters/inputter.py b/onmt/inputters/inputter.py index e8d210dfff..05d916e71a 100644 --- a/onmt/inputters/inputter.py +++ b/onmt/inputters/inputter.py @@ -583,24 +583,25 @@ def _pool(data, batch_size, batch_size_fn, batch_size_multiple, class OnmtBatch(torchtext.data.Batch): - def __init__(self, data=None, dataset=None, device=None): + def __init__(self, data=None, dataset=None, device=None, feat_no_time_shift=False): super(OnmtBatch, self).__init__(data, dataset, device) # we need to shift target features if needed - if hasattr(self, 'tgt') and self.tgt.size(-1) > 1: - # tokens: [ len x batch x 1] - tokens = self.tgt[:,:,0].unsqueeze(-1) - # feats: [ len x batch x num_feats ] - feats = self.tgt[:,:,1:] - # shift feats one step to the right - feats = torch.cat(( - feats[-1,:,:].unsqueeze(0), - feats[:-1,:,:] - )) - # build back target tensor - self.tgt = torch.cat(( - tokens, - feats - ), dim=-1) + if not(feat_no_time_shift): + if hasattr(self, 'tgt') and self.tgt.size(-1) > 1: + # tokens: [ len x batch x 1] + tokens = self.tgt[:,:,0].unsqueeze(-1) + # feats: [ len x batch x num_feats ] + feats = self.tgt[:,:,1:] + # shift feats one step to the right + feats = torch.cat(( + feats[-1,:,:].unsqueeze(0), + feats[:-1,:,:] + )) + # build back target tensor + self.tgt = torch.cat(( + tokens, + feats + ), dim=-1) class OrderedIterator(torchtext.data.Iterator): @@ -610,12 +611,14 @@ def __init__(self, pool_factor=1, batch_size_multiple=1, yield_raw_example=False, + feat_no_time_shift=False, **kwargs): super(OrderedIterator, self).__init__(dataset, batch_size, **kwargs) self.batch_size_multiple = batch_size_multiple self.yield_raw_example = yield_raw_example self.dataset = dataset self.pool_factor = pool_factor + self.feat_no_time_shift = feat_no_time_shift def create_batches(self): if self.train: @@ -671,7 +674,8 @@ def __iter__(self): yield OnmtBatch( minibatch, self.dataset, - self.device) + self.device, + feat_no_time_shift=self.feat_no_time_shift) if not self.repeat: return @@ -703,6 +707,7 @@ def __init__(self, self.sort_key = temp_dataset.sort_key self.random_shuffler = RandomShuffler() self.pool_factor = opt.pool_factor + self.feat_no_time_shift = opt.feat_no_time_shift del temp_dataset def _iter_datasets(self): @@ -731,7 +736,8 @@ def __iter__(self): minibatch = sorted(minibatch, key=self.sort_key, reverse=True) yield OnmtBatch(minibatch, self.iterables[0].dataset, - self.device) + self.device, + feat_no_time_shift=self.feat_no_time_shift) class DatasetLazyIter(object): @@ -749,7 +755,8 @@ class DatasetLazyIter(object): def __init__(self, dataset_paths, fields, batch_size, batch_size_fn, batch_size_multiple, device, is_train, pool_factor, - repeat=True, num_batches_multiple=1, yield_raw_example=False): + repeat=True, num_batches_multiple=1, feat_no_time_shift=False, + yield_raw_example=False): self._paths = dataset_paths self.fields = fields self.batch_size = batch_size @@ -761,6 +768,7 @@ def __init__(self, dataset_paths, fields, batch_size, batch_size_fn, self.num_batches_multiple = num_batches_multiple self.yield_raw_example = yield_raw_example self.pool_factor = pool_factor + self.feat_no_time_shift = feat_no_time_shift def _iter_dataset(self, path): logger.info('Loading dataset from %s' % path) @@ -778,7 +786,8 @@ def _iter_dataset(self, path): sort=False, sort_within_batch=True, repeat=False, - yield_raw_example=self.yield_raw_example + yield_raw_example=self.yield_raw_example, + shiftfeat_no_time_=self.feat_no_time_shift ) for batch in cur_iter: self.dataset = cur_iter.dataset @@ -872,7 +881,8 @@ def build_dataset_iter(corpus_type, fields, opt, is_train=True, multi=False): opt.pool_factor, repeat=not opt.single_pass, num_batches_multiple=max(opt.accum_count) * opt.world_size, - yield_raw_example=multi) + yield_raw_example=multi, + feat_no_time_shift=opt.feat_no_time_shift) def build_dataset_iter_multiple(train_shards, fields, opt): diff --git a/onmt/opts.py b/onmt/opts.py index af47f79836..0a023deb17 100644 --- a/onmt/opts.py +++ b/onmt/opts.py @@ -56,6 +56,10 @@ def model_opts(parser): help="If -feat_merge_size is not set, feature " "embedding sizes will be set to N^feat_vec_exponent " "where N is the number of values the feature takes.") + group.add('--feat_no_time_shift', '-feat_no_time_shift', + action='store_true', + help="If set, do not shift the target features one step " + "to the right.") # Encoder-Decoder Options group = parser.add_argument_group('Model- Encoder-Decoder') diff --git a/onmt/translate/translation.py b/onmt/translate/translation.py index 5af33e9593..290eba5466 100644 --- a/onmt/translate/translation.py +++ b/onmt/translate/translation.py @@ -23,7 +23,7 @@ class TranslationBuilder(object): """ def __init__(self, data, fields, n_best=1, replace_unk=False, - has_tgt=False, phrase_table=""): + has_tgt=False, phrase_table="", feat_no_time_shift=False): self.data = data self.fields = fields self._has_text_src = isinstance( @@ -32,16 +32,22 @@ def __init__(self, data, fields, n_best=1, replace_unk=False, self.replace_unk = replace_unk self.phrase_table = phrase_table self.has_tgt = has_tgt + self.feat_no_time_shift = feat_no_time_shift - def _build_target_tokens(self, src, src_vocab, src_raw, pred, all_feats, attn): + def _build_target_tokens(self, src, src_vocab, src_raw, pred, attn, all_feats=None): # feats need do be shifted back one step to the left - all_feats = [list(feat[1:]) + [feat[0]] for feat in all_feats] # TODO find a better way + if all_feats is not None: + if self.feat_no_time_shift: + all_feats = [list(feat[1:]) + [feat[0]] for feat in all_feats] + pred_iter = zip(pred, *all_feats) + else: + pred_iter = [(item,) for item in pred] tgt_fields = dict(self.fields)["tgt"] tgt_field = tgt_fields.base_field vocab = tgt_field.vocab feats_vocabs = [field.vocab for name, field in tgt_fields.fields[1:]] tokens = [] - for tok_feats in zip(pred, *all_feats): + for tok_feats in pred_iter: tok = tok_feats[0] if tok < len(vocab): token = vocab.itos[tok] @@ -107,7 +113,7 @@ def from_batch(self, translation_batch): pred_sents = [self._build_target_tokens( src[:, b] if src is not None else None, src_vocab, src_raw, - preds[b][n], feats[b][n], attn[b][n]) + preds[b][n], attn[b][n], feats[b][n] if len(feats[0]) > 0 else None) for n in range(self.n_best)] gold_sent = None if tgt is not None: diff --git a/onmt/translate/translator.py b/onmt/translate/translator.py index 06369ccf90..34f2468057 100644 --- a/onmt/translate/translator.py +++ b/onmt/translate/translator.py @@ -130,7 +130,8 @@ def __init__( report_align=False, report_score=True, logger=None, - seed=-1): + seed=-1, + feat_no_time_shift=False): self.model = model self.fields = fields tgt_field = dict(self.fields)["tgt"].base_field @@ -187,6 +188,8 @@ def __init__( self.use_filter_pred = False self._filter_pred = None + self.feat_no_time_shift = feat_no_time_shift + # for debugging self.beam_trace = self.dump_beam != "" self.beam_accum = None @@ -259,7 +262,8 @@ def from_opt( report_align=report_align, report_score=report_score, logger=logger, - seed=opt.seed) + seed=opt.seed, + feat_no_time_shift=vars(opt).get("feat_no_time_shift", False)) def _log(self, msg): if self.logger: @@ -334,7 +338,7 @@ def translate( xlation_builder = onmt.translate.TranslationBuilder( data, self.fields, self.n_best, self.replace_unk, tgt, - self.phrase_table + self.phrase_table, self.feat_no_time_shift ) # Statistics From bf0e8243bda0140a0481eb27629da2966529aeff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Hernandez?= Date: Tue, 28 Jan 2020 19:11:39 +0100 Subject: [PATCH 13/16] fix indent --- onmt/translate/translation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onmt/translate/translation.py b/onmt/translate/translation.py index 290eba5466..c0bf9179b8 100644 --- a/onmt/translate/translation.py +++ b/onmt/translate/translation.py @@ -39,7 +39,7 @@ def _build_target_tokens(self, src, src_vocab, src_raw, pred, attn, all_feats=No if all_feats is not None: if self.feat_no_time_shift: all_feats = [list(feat[1:]) + [feat[0]] for feat in all_feats] - pred_iter = zip(pred, *all_feats) + pred_iter = zip(pred, *all_feats) else: pred_iter = [(item,) for item in pred] tgt_fields = dict(self.fields)["tgt"] From 342786005b57218bed97e39622de7e134ae17d3c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Hernandez?= Date: Tue, 28 Jan 2020 19:13:19 +0100 Subject: [PATCH 14/16] fix typo --- onmt/inputters/inputter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onmt/inputters/inputter.py b/onmt/inputters/inputter.py index 05d916e71a..18b33d1f33 100644 --- a/onmt/inputters/inputter.py +++ b/onmt/inputters/inputter.py @@ -787,7 +787,7 @@ def _iter_dataset(self, path): sort_within_batch=True, repeat=False, yield_raw_example=self.yield_raw_example, - shiftfeat_no_time_=self.feat_no_time_shift + feat_no_time_shift=self.feat_no_time_shift ) for batch in cur_iter: self.dataset = cur_iter.dataset From 5e2363c298197833835dc370500692230364080c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Hernandez?= Date: Tue, 28 Jan 2020 19:21:32 +0100 Subject: [PATCH 15/16] fix some flake --- onmt/inputters/inputter.py | 13 +++++++------ onmt/model_builder.py | 8 ++++---- onmt/modules/generator.py | 8 +++----- onmt/translate/greedy_search.py | 8 +++++--- onmt/translate/translation.py | 6 ++++-- onmt/translate/translator.py | 6 ++++-- onmt/utils/loss.py | 13 ++++++++----- 7 files changed, 35 insertions(+), 27 deletions(-) diff --git a/onmt/inputters/inputter.py b/onmt/inputters/inputter.py index 18b33d1f33..ab61fce38c 100644 --- a/onmt/inputters/inputter.py +++ b/onmt/inputters/inputter.py @@ -583,19 +583,20 @@ def _pool(data, batch_size, batch_size_fn, batch_size_multiple, class OnmtBatch(torchtext.data.Batch): - def __init__(self, data=None, dataset=None, device=None, feat_no_time_shift=False): + def __init__(self, data=None, dataset=None, + device=None, feat_no_time_shift=False): super(OnmtBatch, self).__init__(data, dataset, device) # we need to shift target features if needed if not(feat_no_time_shift): if hasattr(self, 'tgt') and self.tgt.size(-1) > 1: # tokens: [ len x batch x 1] - tokens = self.tgt[:,:,0].unsqueeze(-1) + tokens = self.tgt[:, :, 0].unsqueeze(-1) # feats: [ len x batch x num_feats ] - feats = self.tgt[:,:,1:] + feats = self.tgt[:, :, 1:] # shift feats one step to the right feats = torch.cat(( - feats[-1,:,:].unsqueeze(0), - feats[:-1,:,:] + feats[-1, :, :].unsqueeze(0), + feats[:-1, :, :] )) # build back target tensor self.tgt = torch.cat(( @@ -603,8 +604,8 @@ def __init__(self, data=None, dataset=None, device=None, feat_no_time_shift=Fals feats ), dim=-1) -class OrderedIterator(torchtext.data.Iterator): +class OrderedIterator(torchtext.data.Iterator): def __init__(self, dataset, batch_size, diff --git a/onmt/model_builder.py b/onmt/model_builder.py index abce9e570b..7cacd9e4b3 100644 --- a/onmt/model_builder.py +++ b/onmt/model_builder.py @@ -13,9 +13,7 @@ from onmt.decoders import str2dec -from onmt.modules import Embeddings, VecEmbedding, \ - Generator, CopyGenerator -from onmt.modules.util_class import Cast +from onmt.modules import Embeddings, VecEmbedding, Generator from onmt.utils.misc import use_gpu from onmt.utils.logging import logger from onmt.utils.parse import ArgumentParser @@ -88,10 +86,12 @@ def build_decoder(opt, embeddings): else opt.decoder_type return str2dec[dec_type].from_opt(opt, embeddings) + def build_generator(model_opt, fields, decoder): gen_sizes = [len(field[1].vocab) for field in fields['tgt'].fields] if model_opt.share_decoder_embeddings: - rnn_sizes = ([model_opt.rnn_size - (model_opt.feat_vec_size * (len(gen_sizes) -1) )] + rnn_sizes = ([model_opt.rnn_size - + (model_opt.feat_vec_size * (len(gen_sizes) - 1))] + [model_opt.feat_vec_size] * (len(gen_sizes) - 1)) else: rnn_sizes = [model_opt.rnn_size] * len(gen_sizes) diff --git a/onmt/modules/generator.py b/onmt/modules/generator.py index 737ad90da2..33479a77cc 100644 --- a/onmt/modules/generator.py +++ b/onmt/modules/generator.py @@ -2,15 +2,14 @@ import torch import torch.nn as nn -from torch.nn.modules.module import _addindent - from onmt.modules.util_class import Cast from onmt.modules.copy_generator import CopyGenerator class Generator(nn.Module): - def __init__(self, rnn_sizes, gen_sizes, gen_func, shared=False, copy_attn=False, pad_idx=None): + def __init__(self, rnn_sizes, gen_sizes, gen_func, + shared=False, copy_attn=False, pad_idx=None): super(Generator, self).__init__() self.generators = nn.ModuleList() self.shared = shared @@ -36,14 +35,13 @@ def simple_generator(rnn_size, gen_size, gen_func): self.generators.append( simple_generator(rnn_size, gen_size, gen_func)) - def forward(self, dec_out): # if shared_decoder_embeddings, we slice the decoder output if self.shared: outs = [] offset = 0 for generator, s in zip(self.generators, self.rnn_sizes): - sliced_dec_out = dec_out[:,offset:offset+s] + sliced_dec_out = dec_out[:, offset:offset+s] out = generator(sliced_dec_out) offset += s outs.append(out) diff --git a/onmt/translate/greedy_search.py b/onmt/translate/greedy_search.py index ca8b3936c2..e732e44ec3 100644 --- a/onmt/translate/greedy_search.py +++ b/onmt/translate/greedy_search.py @@ -91,7 +91,8 @@ def __init__(self, pad, bos, eos, batch_size, min_length, self.keep_topk = keep_topk self.topk_scores = None - def initialize(self, memory_bank, src_lengths, num_features, src_map=None, device=None): + def initialize(self, memory_bank, src_lengths, num_features, + src_map=None, device=None): """Initialize for decoding.""" fn_map_state = None @@ -155,7 +156,9 @@ def advance(self, log_probs, attn): topk_ids, features_id ), dim=-1) - self.alive_seq = torch.cat([self.alive_seq, topk_ids.unsqueeze(-1)], -1) + self.alive_seq = torch.cat([ + self.alive_seq, + topk_ids.unsqueeze(-1)], -1) if self.return_attention: if self.alive_attn is None: self.alive_attn = attn @@ -163,7 +166,6 @@ def advance(self, log_probs, attn): self.alive_attn = torch.cat([self.alive_attn, attn], 0) self.ensure_max_length() - def update_finished(self): """Finalize scores and predictions.""" # shape: (sum(~ self.is_finished), 1) diff --git a/onmt/translate/translation.py b/onmt/translate/translation.py index c0bf9179b8..21f9a64647 100644 --- a/onmt/translate/translation.py +++ b/onmt/translate/translation.py @@ -34,7 +34,8 @@ def __init__(self, data, fields, n_best=1, replace_unk=False, self.has_tgt = has_tgt self.feat_no_time_shift = feat_no_time_shift - def _build_target_tokens(self, src, src_vocab, src_raw, pred, attn, all_feats=None): + def _build_target_tokens(self, src, src_vocab, src_raw, + pred, attn, all_feats=None): # feats need do be shifted back one step to the left if all_feats is not None: if self.feat_no_time_shift: @@ -113,7 +114,8 @@ def from_batch(self, translation_batch): pred_sents = [self._build_target_tokens( src[:, b] if src is not None else None, src_vocab, src_raw, - preds[b][n], attn[b][n], feats[b][n] if len(feats[0]) > 0 else None) + preds[b][n], attn[b][n], + feats[b][n] if len(feats[0]) > 0 else None) for n in range(self.n_best)] gold_sent = None if tgt is not None: diff --git a/onmt/translate/translator.py b/onmt/translate/translator.py index 34f2468057..42ad98c8f8 100644 --- a/onmt/translate/translator.py +++ b/onmt/translate/translator.py @@ -662,14 +662,16 @@ def _translate_batch_with_strategy( num_features = batch.src[0].size(-1) src_map = batch.src_map if use_src_map else None fn_map_state, memory_bank, memory_lengths, src_map = \ - decode_strategy.initialize(memory_bank, src_lengths, + decode_strategy.initialize( + memory_bank, src_lengths, num_features, src_map) if fn_map_state is not None: self.model.decoder.map_state(fn_map_state) # (3) Begin decoding step by step: for step in range(decode_strategy.max_length): - decoder_input = decode_strategy.current_predictions.view(1, -1, num_features) + decoder_input = decode_strategy.current_predictions\ + .view(1, -1, num_features) log_probs, attn = self._decode_and_generate( decoder_input, diff --git a/onmt/utils/loss.py b/onmt/utils/loss.py index df4b3175ad..164d7f2128 100644 --- a/onmt/utils/loss.py +++ b/onmt/utils/loss.py @@ -249,8 +249,9 @@ def _make_shard_state(self, batch, output, range_, attns=None): "target": batch.tgt[range_[0] + 1: range_[1], :, 0], } if batch.tgt.size(-1) > 1: - shard_state["features"] = [batch.tgt[range_[0] + 1: range_[1], :, i+1] - for i in range(batch.tgt.size(-1) - 1)] + shard_state["features"] = [ + batch.tgt[range_[0] + 1: range_[1], :, i+1] + for i in range(batch.tgt.size(-1) - 1)] if self.lambda_coverage != 0.0: coverage = attns.get("coverage", None) std = attns.get("std", None) @@ -289,8 +290,9 @@ def _make_shard_state(self, batch, output, range_, attns=None): }) return shard_state - def _compute_loss(self, batch, output, target, features=None, std_attn=None, - coverage_attn=None, align_head=None, ref_align=None): + def _compute_loss(self, batch, output, target, features=None, + std_attn=None, coverage_attn=None, + align_head=None, ref_align=None): bottled_output = self._bottle(output) @@ -298,7 +300,8 @@ def _compute_loss(self, batch, output, target, features=None, std_attn=None, gtruth = target.view(-1) loss = self.criterions[0](scores[0], gtruth) if features is not None: - for score, crit, feat in zip(scores[1:], self.criterions[1:], features): + for score, crit, feat in zip(scores[1:], + self.criterions[1:], features): truth = feat.view(-1) loss += crit(score, truth) if self.lambda_coverage != 0.0: From 81fd85d707ea5715c25a6dc357289950271af2ac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Hernandez?= Date: Thu, 30 Jan 2020 10:47:03 +0100 Subject: [PATCH 16/16] adapt beam search, fix some structure for greedy/beam compat --- onmt/translate/beam_search.py | 47 +++++++++++++++++++++++-------- onmt/translate/decode_strategy.py | 6 ++-- onmt/translate/greedy_search.py | 5 ++-- onmt/translate/translation.py | 2 +- 4 files changed, 43 insertions(+), 17 deletions(-) diff --git a/onmt/translate/beam_search.py b/onmt/translate/beam_search.py index 9e9c89f563..d715b01545 100644 --- a/onmt/translate/beam_search.py +++ b/onmt/translate/beam_search.py @@ -93,11 +93,11 @@ def __init__(self, beam_size, batch_size, pad, bos, eos, n_best, not stepwise_penalty and self.global_scorer.has_cov_pen) self._cov_pen = self.global_scorer.has_cov_pen - def initialize(self, memory_bank, src_lengths, src_map=None, device=None): + def initialize(self, memory_bank, src_lengths, num_features, + src_map=None, device=None): """Initialize for decoding. Repeat src objects `beam_size` times. """ - def fn_map_state(state, dim): return tile(state, self.beam_size, dim=dim) @@ -115,7 +115,7 @@ def fn_map_state(state, dim): self.memory_lengths = tile(src_lengths, self.beam_size) super(BeamSearch, self).initialize( - memory_bank, self.memory_lengths, src_map, device) + memory_bank, self.memory_lengths, num_features, src_map, device) self.best_scores = torch.full( [self.batch_size], -1e10, dtype=torch.float, device=device) self._beam_offset = torch.arange( @@ -135,7 +135,7 @@ def fn_map_state(state, dim): @property def current_predictions(self): - return self.alive_seq[:, -1] + return self.alive_seq[:, :, -1] @property def current_backptr(self): @@ -148,6 +148,19 @@ def batch_offset(self): return self._batch_offset def advance(self, log_probs, attn): + # we need to get the features first + if len(log_probs) > 1: + # we take top 1 for feats + features_id = [] + for logits in log_probs[1:]: + features_id.append(logits.topk(1, dim=-1)[1]) + features_id = torch.cat(features_id, dim=-1) + else: + features_id = None + + # keep only log probs for tokens + log_probs = log_probs[0] + vocab_size = log_probs.size(-1) # using integer division to get an integer _B without casting @@ -174,7 +187,7 @@ def advance(self, log_probs, attn): curr_scores = log_probs / length_penalty # Avoid any direction that would repeat unwanted ngrams - self.block_ngram_repeats(curr_scores) + self.block_ngram_repeats(curr_scores) # TODO check compat with feats # Flatten probs into a list of possibilities. curr_scores = curr_scores.reshape(_B, self.beam_size * vocab_size) @@ -192,10 +205,18 @@ def advance(self, log_probs, attn): self.select_indices = self._batch_index.view(_B * self.beam_size) self.topk_ids.fmod_(vocab_size) # resolve true word ids + # Concatenate topk_ids for tokens and feats. + if features_id is not None: + topk_ids = torch.cat(( + self.topk_ids.view(_B * self.beam_size, 1), + features_id), dim=1) + else: + topk_ids = self.topk_ids.view(_B * self.beam_size, 1) + # Append last prediction. self.alive_seq = torch.cat( [self.alive_seq.index_select(0, self.select_indices), - self.topk_ids.view(_B * self.beam_size, 1)], -1) + topk_ids.unsqueeze(-1)], -1) self.maybe_update_forbidden_tokens() @@ -239,7 +260,7 @@ def update_finished(self): # it's faster to not move this back to the original device self.is_finished = self.is_finished.to('cpu') self.top_beam_finished |= self.is_finished[:, 0].eq(1) - predictions = self.alive_seq.view(_B_old, self.beam_size, step) + predictions = self.alive_seq.view(_B_old, self.beam_size, -1, step) attention = ( self.alive_attn.view( step - 1, _B_old, self.beam_size, self.alive_attn.size(-1)) @@ -256,9 +277,12 @@ def update_finished(self): self.best_scores[b] = s self.hypotheses[b].append(( self.topk_scores[i, j], - predictions[i, j, 1:], # Ignore start_token. + predictions[i, j, 0, 1:], # Ignore start_token. attention[:, i, j, :self.memory_lengths[i]] - if attention is not None else None)) + if attention is not None else None, + [predictions[i, 0, 1+k, 1:] + for k in range(self.num_features)] + if predictions.size(-2) > 1 else None)) # End condition is the top beam finished and we can return # n_best hypotheses. if self.ratio > 0: @@ -271,13 +295,14 @@ def update_finished(self): if finish_flag and len(self.hypotheses[b]) >= self.n_best: best_hyp = sorted( self.hypotheses[b], key=lambda x: x[0], reverse=True) - for n, (score, pred, attn) in enumerate(best_hyp): + for n, (score, pred, attn, feats) in enumerate(best_hyp): if n >= self.n_best: break self.scores[b].append(score) self.predictions[b].append(pred) # ``(batch, n_best,)`` self.attention[b].append( attn if attn is not None else []) + self.features[b].append(feats if feats is not None else []) else: non_finished_batch.append(i) non_finished = torch.tensor(non_finished_batch) @@ -297,7 +322,7 @@ def update_finished(self): self._batch_index = self._batch_index.index_select(0, non_finished) self.select_indices = self._batch_index.view(_B_new * self.beam_size) self.alive_seq = predictions.index_select(0, non_finished) \ - .view(-1, self.alive_seq.size(-1)) + .view(-1, self.alive_seq.size(-2), self.alive_seq.size(-1)) self.topk_scores = self.topk_scores.index_select(0, non_finished) self.topk_ids = self.topk_ids.index_select(0, non_finished) if self.alive_attn is not None: diff --git a/onmt/translate/decode_strategy.py b/onmt/translate/decode_strategy.py index b78c5b6d87..1bff35d4eb 100644 --- a/onmt/translate/decode_strategy.py +++ b/onmt/translate/decode_strategy.py @@ -68,6 +68,8 @@ def __init__(self, pad, bos, eos, batch_size, parallel_paths, self.predictions = [[] for _ in range(batch_size)] self.scores = [[] for _ in range(batch_size)] self.attention = [[] for _ in range(batch_size)] + # initialize features + self.features = [[] for _ in range(batch_size)] self.alive_attn = None @@ -99,9 +101,7 @@ def initialize(self, memory_bank, src_lengths, num_features, self.is_finished = torch.zeros( [self.batch_size, self.parallel_paths], dtype=torch.uint8, device=device) - # initialize features (we need to know num_features) - self.features = [[[] for _ in range(num_features - 1)] - for _ in range(self.batch_size)] + self.num_features = num_features - 1 # tokens are not features return None, memory_bank, src_lengths, src_map def __len__(self): diff --git a/onmt/translate/greedy_search.py b/onmt/translate/greedy_search.py index e732e44ec3..331b64858f 100644 --- a/onmt/translate/greedy_search.py +++ b/onmt/translate/greedy_search.py @@ -175,8 +175,9 @@ def update_finished(self): self.scores[b_orig].append(self.topk_scores[b, 0]) self.predictions[b_orig].append(self.alive_seq[b, 0, 1:]) # check on first item of the batch ot get num_features - for i in range(len(self.features[0])): - self.features[b_orig][i].append(self.alive_seq[b, 1+i, 1:]) + self.features[b_orig] = [[]] + for i in range(self.num_features): + self.features[b_orig][0].append(self.alive_seq[b, 1+i, 1:]) self.attention[b_orig].append( self.alive_attn[:, b, :self.memory_lengths[b]] if self.alive_attn is not None else []) diff --git a/onmt/translate/translation.py b/onmt/translate/translation.py index 21f9a64647..1838a98940 100644 --- a/onmt/translate/translation.py +++ b/onmt/translate/translation.py @@ -38,7 +38,7 @@ def _build_target_tokens(self, src, src_vocab, src_raw, pred, attn, all_feats=None): # feats need do be shifted back one step to the left if all_feats is not None: - if self.feat_no_time_shift: + if not(self.feat_no_time_shift): all_feats = [list(feat[1:]) + [feat[0]] for feat in all_feats] pred_iter = zip(pred, *all_feats) else: