From 8612f5ec854ea06306b97560a40c5fc3e70b613a Mon Sep 17 00:00:00 2001 From: Thai Chau Truong Date: Tue, 7 Nov 2023 22:34:45 -0500 Subject: [PATCH 1/3] Add script for n_best parameter in topp/topk --- onmt/tests/test_greedy_search.py | 2 +- onmt/translate/greedy_search.py | 10 +++++++++- onmt/translate/translator.py | 1 + 3 files changed, 11 insertions(+), 2 deletions(-) diff --git a/onmt/tests/test_greedy_search.py b/onmt/tests/test_greedy_search.py index fd4291e506..c790592b1e 100644 --- a/onmt/tests/test_greedy_search.py +++ b/onmt/tests/test_greedy_search.py @@ -374,7 +374,7 @@ def test_returns_correct_scores_non_deterministic_beams(self): samp.update_finished() self.assertEqual( - [score for score, _, _ in samp.hypotheses[batch_sz - 1][-1:]], + [score for score, _, _ in samp.hypotheses[batch_sz - 1][:1]], [valid_score_dist_2[0] / temp], ) diff --git a/onmt/translate/greedy_search.py b/onmt/translate/greedy_search.py index 8a5707ffa8..0cc58e4e37 100644 --- a/onmt/translate/greedy_search.py +++ b/onmt/translate/greedy_search.py @@ -98,6 +98,8 @@ class GreedySearch(DecodeStrategy): eos (int): See base. unk (int): See base. start (int): See base. + n_best (int): Don't stop until at least this many beams have + reached EOS. batch_size (int): See base. global_scorer (onmt.translate.GNMTGlobalScorer): Scorer instance. min_length (int): See base. @@ -123,6 +125,7 @@ def __init__( eos, unk, start, + n_best, batch_size, global_scorer, min_length, @@ -157,6 +160,7 @@ def __init__( self.keep_topp = keep_topp self.topk_scores = None self.beam_size = beam_size + self.n_best = n_best def initialize( self, enc_out, src_len, src_map=None, device=None, target_prefix=None @@ -265,10 +269,14 @@ def update_finished(self): else [] ) self.hypotheses[b_orig].append((score, pred, attention)) + if len(self.hypotheses[b_orig]) >= 2: + self.hypotheses[b_orig] = sorted( + self.hypotheses[b_orig], key=lambda x: x[0], reverse=True + ) self.done = self.is_finished.all() if self.done: for b in range(self.batch_size): - best_hyp = sorted(self.hypotheses[b], key=lambda x: x[0], reverse=True) + best_hyp = self.hypotheses[b][: self.n_best] for score, pred, attn in best_hyp: self.scores[b].append(score) self.predictions[b].append(pred) diff --git a/onmt/translate/translator.py b/onmt/translate/translator.py index be6903bcaf..1f2af5044b 100644 --- a/onmt/translate/translator.py +++ b/onmt/translate/translator.py @@ -810,6 +810,7 @@ def translate_batch(self, batch, attn_debug): eos=self._tgt_eos_idx, unk=self._tgt_unk_idx, start=self._tgt_start_with, + n_best=self.n_best, batch_size=len(batch["srclen"]), global_scorer=self.global_scorer, min_length=self.min_length, From 155f6b3e226f42619c96aac75858e0cd62fe7bf4 Mon Sep 17 00:00:00 2001 From: Thai Chau Truong Date: Thu, 9 Nov 2023 01:23:44 -0500 Subject: [PATCH 2/3] Update unittests --- onmt/tests/test_greedy_search.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/onmt/tests/test_greedy_search.py b/onmt/tests/test_greedy_search.py index c790592b1e..d645740055 100644 --- a/onmt/tests/test_greedy_search.py +++ b/onmt/tests/test_greedy_search.py @@ -46,6 +46,7 @@ def test_doesnt_predict_eos_if_shorter_than_min_len(self): 2, 3, 1, + 1, batch_sz, GlobalScorerStub(), min_length, @@ -100,6 +101,7 @@ def test_returns_correct_scores_deterministic(self): 2, 3, 1, + 1, batch_sz, GlobalScorerStub(), 0, @@ -186,6 +188,7 @@ def test_returns_correct_scores_non_deterministic(self): 2, 3, 1, + 1, batch_sz, GlobalScorerStub(), 0, @@ -297,6 +300,7 @@ def test_returns_correct_scores_non_deterministic_beams(self): 2, 3, 1, + 1, batch_sz, GlobalScorerStub(), 0, @@ -419,6 +423,7 @@ def test_returns_correct_scores_non_deterministic_topp(self): 2, 3, 1, + 1, batch_sz, GlobalScorerStub(), 0, From 677a5c41075313dcd29841323031199a37ae9b62 Mon Sep 17 00:00:00 2001 From: Thai Chau Truong Date: Thu, 9 Nov 2023 01:32:17 -0500 Subject: [PATCH 3/3] Add n_best parameter for GreedySearchLM class --- onmt/translate/translator.py | 1 + 1 file changed, 1 insertion(+) diff --git a/onmt/translate/translator.py b/onmt/translate/translator.py index 1f2af5044b..99a835e1aa 100644 --- a/onmt/translate/translator.py +++ b/onmt/translate/translator.py @@ -1010,6 +1010,7 @@ def translate_batch(self, batch, attn_debug): eos=self._tgt_eos_idx, unk=self._tgt_unk_idx, start=self._tgt_start_with, + n_best=self.n_best, batch_size=len(batch["srclen"]), global_scorer=self.global_scorer, min_length=self.min_length,