diff --git a/include/ctranslate2/decoding.h b/include/ctranslate2/decoding.h index a9e99176c..62c5e226a 100644 --- a/include/ctranslate2/decoding.h +++ b/include/ctranslate2/decoding.h @@ -30,6 +30,8 @@ namespace ctranslate2 { const bool return_scores = false, const bool return_attention = false, const size_t num_hypotheses = 1, + const bool include_eos_in_scores = true, + const bool include_eos_in_hypotheses = true, const std::vector>& logits_processors = {}, const std::vector>* prefix_ids = nullptr) const = 0; }; @@ -54,6 +56,8 @@ namespace ctranslate2 { const bool return_scores = false, const bool return_attention = false, const size_t num_hypotheses = 1, + const bool include_eos_in_scores = true, + const bool include_eos_in_hypotheses = true, const std::vector>& logits_processors = {}, const std::vector>* prefix_ids = nullptr) const override; @@ -101,6 +105,8 @@ namespace ctranslate2 { const bool return_scores = false, const bool return_attention = false, const size_t num_hypotheses = 1, + const bool include_eos_in_scores = true, + const bool include_eos_in_hypotheses = true, const std::vector>& logits_processors = {}, const std::vector>* prefix_ids = nullptr) const override; @@ -124,10 +130,11 @@ namespace ctranslate2 { size_t sampling_topk = 1; float sampling_temperature = 1; size_t num_hypotheses = 1; + bool include_eos_in_scores = true; + bool include_eos_in_hypotheses = true; bool return_scores = false; bool return_attention = false; bool return_alternatives = false; - bool return_prefix = true; float min_alternative_expansion_prob = 0; std::vector disable_ids; std::vector disable_ids_begin; diff --git a/src/decoding.cc b/src/decoding.cc index 4c36b5e8e..a51aa4de0 100644 --- a/src/decoding.cc +++ b/src/decoding.cc @@ -159,20 +159,22 @@ namespace ctranslate2 { static std::vector build_hypothesis(const StorageView& history, const dim_t batch, - const dim_t beam) { - const auto length = history.dim(-1); + const dim_t beam, + const bool ignore_last) { + const auto length = history.dim(-1) - dim_t(ignore_last); const auto* ids = history.index({batch, beam, 0}); return std::vector(ids, ids + length); } static std::vector> build_attention(const StorageView& history, const dim_t batch, - const dim_t beam) { + const dim_t beam, + const bool ignore_last) { if (!history) return {}; const auto source_length = history.dim(-1); - const auto target_length = history.dim(-2); + const auto target_length = history.dim(-2) - dim_t(ignore_last); std::vector> attention; attention.reserve(target_length); @@ -385,6 +387,8 @@ namespace ctranslate2 { const bool return_scores, const bool return_attention, const size_t num_hypotheses, + const bool include_eos_in_scores, + const bool include_eos_in_hypotheses, const std::vector>& logits_processors, const std::vector>* prefix_ids) const { PROFILE("beam_search"); @@ -480,12 +484,18 @@ namespace ctranslate2 { } // Multiply by the current beam log probs. + StorageView topk_scores_prev(dtype); if (topk_scores) { DEVICE_AND_TYPE_DISPATCH(log_probs.device(), log_probs.dtype(), primitives::add_depth_broadcast(topk_scores.to(device).data(), log_probs.data(), topk_scores.size(), log_probs.size())); + + if (!include_eos_in_scores) { + topk_scores_prev = topk_scores; + topk_scores_prev.reshape({cur_batch_size, _beam_size}); + } } // Flatten the probs into a list of candidates. @@ -549,11 +559,19 @@ namespace ctranslate2 { if (k == 0) top_beam_finished[i] = true; + bool ignore_last_score = false; + bool ignore_last_token = false; + if (last_id == end_id) { + ignore_last_score = !include_eos_in_scores; + ignore_last_token = !include_eos_in_hypotheses; + } + // Register this hypothesis. - result.scores.emplace_back(topk_scores.scalar_at({i, k})); - result.hypotheses.emplace_back(build_hypothesis(alive_seq, i, k)); + const StorageView& scores = ignore_last_score ? topk_scores_prev : topk_scores; + result.scores.emplace_back(scores.scalar_at({i, k})); + result.hypotheses.emplace_back(build_hypothesis(alive_seq, i, k, ignore_last_token)); if (alive_attention) - result.attention.emplace_back(build_attention(alive_attention, i, k)); + result.attention.emplace_back(build_attention(alive_attention, i, k, ignore_last_token)); // Move another active beam to this position. for (dim_t j = secondary_candidates_offset; j < num_candidates; ++j) { @@ -668,6 +686,8 @@ namespace ctranslate2 { const bool return_scores, const bool return_attention, const size_t num_hypotheses, + const bool include_eos_in_scores, + const bool include_eos_in_hypotheses, const std::vector>& logits_processors, const std::vector>* prefix_ids) const { const dim_t batch_size = start_ids.size(); @@ -693,6 +713,8 @@ namespace ctranslate2 { /*return_scores=*/true, return_attention, /*num_hypotheses=*/1, + include_eos_in_scores, + include_eos_in_hypotheses, logits_processors, prefix_ids ? &repeat_prefix_ids : nullptr); @@ -789,12 +811,17 @@ namespace ctranslate2 { const size_t batch_id = batch_offset[i]; const dim_t prefix_length = prefix_ids ? prefix_ids->at(batch_id).size() : 0; - results[batch_id].hypotheses[0].push_back(word_id); - if (return_scores) - results[batch_id].scores[0] += best_probs.scalar_at({i, 0}); - if (attention_step) { - const auto* attn = attention_step.index({i, 0}); - results[batch_id].attention[0].emplace_back(attn, attn + attention_step.dim(-1)); + if (word_id != end_id || include_eos_in_hypotheses) { + results[batch_id].hypotheses[0].push_back(word_id); + if (attention_step) { + const auto* attn = attention_step.index({i, 0}); + results[batch_id].attention[0].emplace_back(attn, attn + attention_step.dim(-1)); + } + } + + if (word_id != end_id || include_eos_in_scores) { + if (return_scores) + results[batch_id].scores[0] += best_probs.scalar_at({i, 0}); } const bool is_finished = ((word_id == end_id && step >= prefix_length) @@ -1033,6 +1060,8 @@ namespace ctranslate2 { /*return_scores=*/true, options.return_attention, options.num_hypotheses, + options.include_eos_in_scores, + options.include_eos_in_hypotheses, logits_processors)[0]; start_ids.clear(); @@ -1084,6 +1113,8 @@ namespace ctranslate2 { options.return_scores, options.return_attention, /*num_hypotheses=*/1, + options.include_eos_in_scores, + options.include_eos_in_hypotheses, logits_processors); // Update the result with the suffix decoding. @@ -1174,6 +1205,8 @@ namespace ctranslate2 { options.return_scores, options.return_attention, options.num_hypotheses, + options.include_eos_in_scores, + options.include_eos_in_hypotheses, logits_processors, prefix_ids.empty() ? nullptr : &prefix_ids); } @@ -1182,28 +1215,11 @@ namespace ctranslate2 { auto& result = results[b]; for (size_t i = 0; i < result.hypotheses.size(); ++i) { - // Remove EOS token. - while (result.hypotheses[i].back() == end_id) { - result.hypotheses[i].pop_back(); - if (!result.attention.empty()) - result.attention[i].pop_back(); - } - // Restore original word ids. if (decoder.output_layer_is_updated()) { for (auto& id : result.hypotheses[i]) id = decoder.to_original_word_id(id); } - - // Remove the prefix if configured. - const size_t prefix_length = start_tokens[b].size() - 1; - if (!options.return_prefix && prefix_length > 0) { - result.hypotheses[i].erase(result.hypotheses[i].begin(), - result.hypotheses[i].begin() + prefix_length); - if (!result.attention.empty()) - result.attention[i].erase(result.attention[i].begin(), - result.attention[i].begin() + prefix_length); - } } } diff --git a/src/models/language_model.cc b/src/models/language_model.cc index 2d8394f2b..dc6ca323e 100644 --- a/src/models/language_model.cc +++ b/src/models/language_model.cc @@ -162,6 +162,12 @@ namespace ctranslate2 { for (size_t i = 0; i < results.size(); ++i) { auto& result = results[i]; + // Remove EOS token. + for (auto& sequence : result.hypotheses) { + while (!sequence.empty() && sequence.back() == end_id) + sequence.pop_back(); + } + // Forward the start token to the output if it is not the special BOS token. if (!start_ids[i].empty() && start_ids[i][0] != vocabulary.bos_id()) { for (auto& sequence : result.hypotheses) diff --git a/src/models/sequence_to_sequence.cc b/src/models/sequence_to_sequence.cc index 3df3b4b6c..f31490fc9 100644 --- a/src/models/sequence_to_sequence.cc +++ b/src/models/sequence_to_sequence.cc @@ -374,6 +374,16 @@ namespace ctranslate2 { for (size_t i = 0; i < batch_size; ++i) { DecodingResult& result = results[i]; + + // Remove EOS token. + for (size_t h = 0; h < result.hypotheses.size(); ++h) { + while (!result.hypotheses[h].empty() && result.hypotheses[h].back() == end_id) { + result.hypotheses[h].pop_back(); + if (!result.attention.empty()) + result.attention[h].pop_back(); + } + } + auto hypotheses = target_vocabulary.to_tokens(result.hypotheses); if (!result.attention.empty()) { diff --git a/src/models/whisper.cc b/src/models/whisper.cc index 0e1c16ea4..26b9fdb17 100644 --- a/src/models/whisper.cc +++ b/src/models/whisper.cc @@ -229,7 +229,8 @@ namespace ctranslate2 { decoding_options.sampling_temperature = options.sampling_temperature; decoding_options.num_hypotheses = options.num_hypotheses; decoding_options.return_scores = options.return_scores; - decoding_options.return_prefix = false; + decoding_options.include_eos_in_scores = options.beam_size > 1; + decoding_options.include_eos_in_hypotheses = false; for (const auto& id : _model->config["suppress_ids"]) decoding_options.disable_ids.push_back(id); for (const auto& id : _model->config["suppress_ids_begin"]) @@ -433,6 +434,7 @@ namespace ctranslate2 { } else { // cannot be normal text tokens for (size_t i = 0; i < _eot_id; ++i) disable_tokens.add(batch_id, i); + check_timestamps_prob_for_batch.push_back(batch_id); } } else { check_timestamps_prob_for_batch.push_back(batch_id); @@ -443,7 +445,7 @@ namespace ctranslate2 { const size_t token = sequences.at({batch_id, t}); if (token >= _timestamp_begin_id) { - for (size_t i = _timestamp_begin_id; i <= token; ++i) + for (size_t i = _timestamp_begin_id; i < token; ++i) disable_tokens.add(batch_id, i); break; }