Skip to content

Commit

Permalink
Add script for n_best parameter in topp/topk (#2509)
Browse files Browse the repository at this point in the history
* Add script for n_best parameter in topp/topk

Co-authored-by: Thai Chau Truong <tctruong@dom_softissimo.lan>
  • Loading branch information
PC91 and Thai Chau Truong authored Nov 12, 2023
1 parent c5c84af commit f3059a5
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 2 deletions.
7 changes: 6 additions & 1 deletion onmt/tests/test_greedy_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def test_doesnt_predict_eos_if_shorter_than_min_len(self):
2,
3,
1,
1,
batch_sz,
GlobalScorerStub(),
min_length,
Expand Down Expand Up @@ -100,6 +101,7 @@ def test_returns_correct_scores_deterministic(self):
2,
3,
1,
1,
batch_sz,
GlobalScorerStub(),
0,
Expand Down Expand Up @@ -186,6 +188,7 @@ def test_returns_correct_scores_non_deterministic(self):
2,
3,
1,
1,
batch_sz,
GlobalScorerStub(),
0,
Expand Down Expand Up @@ -297,6 +300,7 @@ def test_returns_correct_scores_non_deterministic_beams(self):
2,
3,
1,
1,
batch_sz,
GlobalScorerStub(),
0,
Expand Down Expand Up @@ -374,7 +378,7 @@ def test_returns_correct_scores_non_deterministic_beams(self):

samp.update_finished()
self.assertEqual(
[score for score, _, _ in samp.hypotheses[batch_sz - 1][-1:]],
[score for score, _, _ in samp.hypotheses[batch_sz - 1][:1]],
[valid_score_dist_2[0] / temp],
)

Expand Down Expand Up @@ -419,6 +423,7 @@ def test_returns_correct_scores_non_deterministic_topp(self):
2,
3,
1,
1,
batch_sz,
GlobalScorerStub(),
0,
Expand Down
10 changes: 9 additions & 1 deletion onmt/translate/greedy_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ class GreedySearch(DecodeStrategy):
eos (int): See base.
unk (int): See base.
start (int): See base.
n_best (int): Don't stop until at least this many beams have
reached EOS.
batch_size (int): See base.
global_scorer (onmt.translate.GNMTGlobalScorer): Scorer instance.
min_length (int): See base.
Expand All @@ -123,6 +125,7 @@ def __init__(
eos,
unk,
start,
n_best,
batch_size,
global_scorer,
min_length,
Expand Down Expand Up @@ -157,6 +160,7 @@ def __init__(
self.keep_topp = keep_topp
self.topk_scores = None
self.beam_size = beam_size
self.n_best = n_best

def initialize(
self, enc_out, src_len, src_map=None, device=None, target_prefix=None
Expand Down Expand Up @@ -265,10 +269,14 @@ def update_finished(self):
else []
)
self.hypotheses[b_orig].append((score, pred, attention))
if len(self.hypotheses[b_orig]) >= 2:
self.hypotheses[b_orig] = sorted(
self.hypotheses[b_orig], key=lambda x: x[0], reverse=True
)
self.done = self.is_finished.all()
if self.done:
for b in range(self.batch_size):
best_hyp = sorted(self.hypotheses[b], key=lambda x: x[0], reverse=True)
best_hyp = self.hypotheses[b][: self.n_best]
for score, pred, attn in best_hyp:
self.scores[b].append(score)
self.predictions[b].append(pred)
Expand Down
2 changes: 2 additions & 0 deletions onmt/translate/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,6 +810,7 @@ def translate_batch(self, batch, attn_debug):
eos=self._tgt_eos_idx,
unk=self._tgt_unk_idx,
start=self._tgt_start_with,
n_best=self.n_best,
batch_size=len(batch["srclen"]),
global_scorer=self.global_scorer,
min_length=self.min_length,
Expand Down Expand Up @@ -1009,6 +1010,7 @@ def translate_batch(self, batch, attn_debug):
eos=self._tgt_eos_idx,
unk=self._tgt_unk_idx,
start=self._tgt_start_with,
n_best=self.n_best,
batch_size=len(batch["srclen"]),
global_scorer=self.global_scorer,
min_length=self.min_length,
Expand Down

0 comments on commit f3059a5

Please sign in to comment.