From acf06b12d28aadd19daa3e2bdbbf8877177ee265 Mon Sep 17 00:00:00 2001 From: bpopeters Date: Wed, 30 Jan 2019 16:51:23 +0000 Subject: [PATCH 1/2] reimplement beam_accum, but with a bug --- onmt/translate/translator.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/onmt/translate/translator.py b/onmt/translate/translator.py index d306c5b9ac..5e089d27a1 100644 --- a/onmt/translate/translator.py +++ b/onmt/translate/translator.py @@ -303,8 +303,8 @@ def translate( if self.dump_beam: import json - json.dump(self.translator.beam_accum, - codecs.open(self.dump_beam, 'w', 'utf-8')) + with codecs.open(self.dump_beam, 'w', 'utf-8') as f: + json.dump(self.beam_accum, f) return all_scores, all_predictions def sample_with_temperature(self, logits, sampling_temp, keep_topk): @@ -833,8 +833,7 @@ def _translate_batch(self, batch, data): select_indices_array = [] # Loop over the batch_size number of beam for j, b in enumerate(beam): - b.advance(out[j, :], - beam_attn.data[j, :, :memory_lengths[j]]) + b.advance(out[j, :], beam_attn.data[j, :, :memory_lengths[j]]) select_indices_array.append( b.get_current_origin() + j * beam_size) select_indices = torch.cat(select_indices_array) @@ -854,6 +853,17 @@ def _translate_batch(self, batch, data): results["scores"].append(scores) results["attention"].append(attn) + if self.beam_accum is not None: + self.beam_accum["beam_parent_ids"].append( + [t.tolist() for t in b.prev_ks]) + self.beam_accum["scores"].append([ + ["%4f" % s for s in t.tolist()] + for t in b.all_scores][1:]) + # ok, so what was the tgt_dict? + self.beam_accum["predicted_ids"].append( + [[vocab.itos[i] for i in t.tolist()] + for t in b.next_ys][1:]) + return results def _score_target(self, batch, memory_bank, src_lengths, data, src_map): From 5d30d9febe43db136ea0b4985733c933e9727fa5 Mon Sep 17 00:00:00 2001 From: bpopeters Date: Wed, 30 Jan 2019 17:31:30 +0000 Subject: [PATCH 2/2] add assert for batch size with dump beam --- onmt/translate/translator.py | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/onmt/translate/translator.py b/onmt/translate/translator.py index 5e089d27a1..df15c175c9 100644 --- a/onmt/translate/translator.py +++ b/onmt/translate/translator.py @@ -120,17 +120,12 @@ def __init__( self.report_score = report_score self.logger = logger - self.use_filter_pred = False - - # for debugging - self.beam_trace = self.dump_beam != "" - self.beam_accum = None - if self.beam_trace: + if self.dump_beam != "": self.beam_accum = { - "predicted_ids": [], - "beam_parent_ids": [], - "scores": [], - "log_probs": []} + "predicted_ids": [], "beam_parent_ids": [], "scores": [] + } + else: + self.beam_accum = None set_random_seed(opt.seed, self.cuda) @@ -185,7 +180,7 @@ def translate( window_size=self.window_size, window_stride=self.window_stride, window=self.window, - use_filter_pred=self.use_filter_pred, + use_filter_pred=False, image_channel_size=self.image_channel_size, ) @@ -561,7 +556,6 @@ def _fast_translate_batch( ): # TODO: support these blacklisted features. assert not self.dump_beam - assert not self.use_filter_pred assert self.block_ngram_repeat == 0 assert self.global_scorer.beta == 0 @@ -761,6 +755,8 @@ def _translate_batch(self, batch, data): # And helper method for reducing verbosity. beam_size = self.beam_size batch_size = batch.batch_size + assert self.beam_accum is None or batch_size == 1, \ + "Beam visualization currently only works with batch_size == 1" tgt_field = self.fields['tgt'][0][1].base_field vocab = tgt_field.vocab