Skip to content

Commit

Permalink
Fix some incorrectness in Whisper decoding (#1081)
Browse files Browse the repository at this point in the history
  • Loading branch information
guillaumekln authored Feb 13, 2023
1 parent 7ee6a34 commit af0b976
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 33 deletions.
9 changes: 8 additions & 1 deletion include/ctranslate2/decoding.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ namespace ctranslate2 {
const bool return_scores = false,
const bool return_attention = false,
const size_t num_hypotheses = 1,
const bool include_eos_in_scores = true,
const bool include_eos_in_hypotheses = true,
const std::vector<std::shared_ptr<LogitsProcessor>>& logits_processors = {},
const std::vector<std::vector<size_t>>* prefix_ids = nullptr) const = 0;
};
Expand All @@ -54,6 +56,8 @@ namespace ctranslate2 {
const bool return_scores = false,
const bool return_attention = false,
const size_t num_hypotheses = 1,
const bool include_eos_in_scores = true,
const bool include_eos_in_hypotheses = true,
const std::vector<std::shared_ptr<LogitsProcessor>>& logits_processors = {},
const std::vector<std::vector<size_t>>* prefix_ids = nullptr) const override;

Expand Down Expand Up @@ -101,6 +105,8 @@ namespace ctranslate2 {
const bool return_scores = false,
const bool return_attention = false,
const size_t num_hypotheses = 1,
const bool include_eos_in_scores = true,
const bool include_eos_in_hypotheses = true,
const std::vector<std::shared_ptr<LogitsProcessor>>& logits_processors = {},
const std::vector<std::vector<size_t>>* prefix_ids = nullptr) const override;

Expand All @@ -124,10 +130,11 @@ namespace ctranslate2 {
size_t sampling_topk = 1;
float sampling_temperature = 1;
size_t num_hypotheses = 1;
bool include_eos_in_scores = true;
bool include_eos_in_hypotheses = true;
bool return_scores = false;
bool return_attention = false;
bool return_alternatives = false;
bool return_prefix = true;
float min_alternative_expansion_prob = 0;
std::vector<size_t> disable_ids;
std::vector<size_t> disable_ids_begin;
Expand Down
76 changes: 46 additions & 30 deletions src/decoding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -159,20 +159,22 @@ namespace ctranslate2 {

static std::vector<size_t> build_hypothesis(const StorageView& history,
const dim_t batch,
const dim_t beam) {
const auto length = history.dim(-1);
const dim_t beam,
const bool ignore_last) {
const auto length = history.dim(-1) - dim_t(ignore_last);
const auto* ids = history.index<int32_t>({batch, beam, 0});
return std::vector<size_t>(ids, ids + length);
}

static std::vector<std::vector<float>> build_attention(const StorageView& history,
const dim_t batch,
const dim_t beam) {
const dim_t beam,
const bool ignore_last) {
if (!history)
return {};

const auto source_length = history.dim(-1);
const auto target_length = history.dim(-2);
const auto target_length = history.dim(-2) - dim_t(ignore_last);

std::vector<std::vector<float>> attention;
attention.reserve(target_length);
Expand Down Expand Up @@ -385,6 +387,8 @@ namespace ctranslate2 {
const bool return_scores,
const bool return_attention,
const size_t num_hypotheses,
const bool include_eos_in_scores,
const bool include_eos_in_hypotheses,
const std::vector<std::shared_ptr<LogitsProcessor>>& logits_processors,
const std::vector<std::vector<size_t>>* prefix_ids) const {
PROFILE("beam_search");
Expand Down Expand Up @@ -480,12 +484,18 @@ namespace ctranslate2 {
}

// Multiply by the current beam log probs.
StorageView topk_scores_prev(dtype);
if (topk_scores) {
DEVICE_AND_TYPE_DISPATCH(log_probs.device(), log_probs.dtype(),
primitives<D>::add_depth_broadcast(topk_scores.to(device).data<T>(),
log_probs.data<T>(),
topk_scores.size(),
log_probs.size()));

if (!include_eos_in_scores) {
topk_scores_prev = topk_scores;
topk_scores_prev.reshape({cur_batch_size, _beam_size});
}
}

// Flatten the probs into a list of candidates.
Expand Down Expand Up @@ -549,11 +559,19 @@ namespace ctranslate2 {
if (k == 0)
top_beam_finished[i] = true;

bool ignore_last_score = false;
bool ignore_last_token = false;
if (last_id == end_id) {
ignore_last_score = !include_eos_in_scores;
ignore_last_token = !include_eos_in_hypotheses;
}

// Register this hypothesis.
result.scores.emplace_back(topk_scores.scalar_at<float>({i, k}));
result.hypotheses.emplace_back(build_hypothesis(alive_seq, i, k));
const StorageView& scores = ignore_last_score ? topk_scores_prev : topk_scores;
result.scores.emplace_back(scores.scalar_at<float>({i, k}));
result.hypotheses.emplace_back(build_hypothesis(alive_seq, i, k, ignore_last_token));
if (alive_attention)
result.attention.emplace_back(build_attention(alive_attention, i, k));
result.attention.emplace_back(build_attention(alive_attention, i, k, ignore_last_token));

// Move another active beam to this position.
for (dim_t j = secondary_candidates_offset; j < num_candidates; ++j) {
Expand Down Expand Up @@ -668,6 +686,8 @@ namespace ctranslate2 {
const bool return_scores,
const bool return_attention,
const size_t num_hypotheses,
const bool include_eos_in_scores,
const bool include_eos_in_hypotheses,
const std::vector<std::shared_ptr<LogitsProcessor>>& logits_processors,
const std::vector<std::vector<size_t>>* prefix_ids) const {
const dim_t batch_size = start_ids.size();
Expand All @@ -693,6 +713,8 @@ namespace ctranslate2 {
/*return_scores=*/true,
return_attention,
/*num_hypotheses=*/1,
include_eos_in_scores,
include_eos_in_hypotheses,
logits_processors,
prefix_ids ? &repeat_prefix_ids : nullptr);

Expand Down Expand Up @@ -789,12 +811,17 @@ namespace ctranslate2 {
const size_t batch_id = batch_offset[i];
const dim_t prefix_length = prefix_ids ? prefix_ids->at(batch_id).size() : 0;

results[batch_id].hypotheses[0].push_back(word_id);
if (return_scores)
results[batch_id].scores[0] += best_probs.scalar_at<float>({i, 0});
if (attention_step) {
const auto* attn = attention_step.index<float>({i, 0});
results[batch_id].attention[0].emplace_back(attn, attn + attention_step.dim(-1));
if (word_id != end_id || include_eos_in_hypotheses) {
results[batch_id].hypotheses[0].push_back(word_id);
if (attention_step) {
const auto* attn = attention_step.index<float>({i, 0});
results[batch_id].attention[0].emplace_back(attn, attn + attention_step.dim(-1));
}
}

if (word_id != end_id || include_eos_in_scores) {
if (return_scores)
results[batch_id].scores[0] += best_probs.scalar_at<float>({i, 0});
}

const bool is_finished = ((word_id == end_id && step >= prefix_length)
Expand Down Expand Up @@ -1033,6 +1060,8 @@ namespace ctranslate2 {
/*return_scores=*/true,
options.return_attention,
options.num_hypotheses,
options.include_eos_in_scores,
options.include_eos_in_hypotheses,
logits_processors)[0];

start_ids.clear();
Expand Down Expand Up @@ -1084,6 +1113,8 @@ namespace ctranslate2 {
options.return_scores,
options.return_attention,
/*num_hypotheses=*/1,
options.include_eos_in_scores,
options.include_eos_in_hypotheses,
logits_processors);

// Update the result with the suffix decoding.
Expand Down Expand Up @@ -1174,6 +1205,8 @@ namespace ctranslate2 {
options.return_scores,
options.return_attention,
options.num_hypotheses,
options.include_eos_in_scores,
options.include_eos_in_hypotheses,
logits_processors,
prefix_ids.empty() ? nullptr : &prefix_ids);
}
Expand All @@ -1182,28 +1215,11 @@ namespace ctranslate2 {
auto& result = results[b];

for (size_t i = 0; i < result.hypotheses.size(); ++i) {
// Remove EOS token.
while (result.hypotheses[i].back() == end_id) {
result.hypotheses[i].pop_back();
if (!result.attention.empty())
result.attention[i].pop_back();
}

// Restore original word ids.
if (decoder.output_layer_is_updated()) {
for (auto& id : result.hypotheses[i])
id = decoder.to_original_word_id(id);
}

// Remove the prefix if configured.
const size_t prefix_length = start_tokens[b].size() - 1;
if (!options.return_prefix && prefix_length > 0) {
result.hypotheses[i].erase(result.hypotheses[i].begin(),
result.hypotheses[i].begin() + prefix_length);
if (!result.attention.empty())
result.attention[i].erase(result.attention[i].begin(),
result.attention[i].begin() + prefix_length);
}
}
}

Expand Down
6 changes: 6 additions & 0 deletions src/models/language_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,12 @@ namespace ctranslate2 {
for (size_t i = 0; i < results.size(); ++i) {
auto& result = results[i];

// Remove EOS token.
for (auto& sequence : result.hypotheses) {
while (!sequence.empty() && sequence.back() == end_id)
sequence.pop_back();
}

// Forward the start token to the output if it is not the special BOS token.
if (!start_ids[i].empty() && start_ids[i][0] != vocabulary.bos_id()) {
for (auto& sequence : result.hypotheses)
Expand Down
10 changes: 10 additions & 0 deletions src/models/sequence_to_sequence.cc
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,16 @@ namespace ctranslate2 {

for (size_t i = 0; i < batch_size; ++i) {
DecodingResult& result = results[i];

// Remove EOS token.
for (size_t h = 0; h < result.hypotheses.size(); ++h) {
while (!result.hypotheses[h].empty() && result.hypotheses[h].back() == end_id) {
result.hypotheses[h].pop_back();
if (!result.attention.empty())
result.attention[h].pop_back();
}
}

auto hypotheses = target_vocabulary.to_tokens(result.hypotheses);

if (!result.attention.empty()) {
Expand Down
6 changes: 4 additions & 2 deletions src/models/whisper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,8 @@ namespace ctranslate2 {
decoding_options.sampling_temperature = options.sampling_temperature;
decoding_options.num_hypotheses = options.num_hypotheses;
decoding_options.return_scores = options.return_scores;
decoding_options.return_prefix = false;
decoding_options.include_eos_in_scores = options.beam_size > 1;
decoding_options.include_eos_in_hypotheses = false;
for (const auto& id : _model->config["suppress_ids"])
decoding_options.disable_ids.push_back(id);
for (const auto& id : _model->config["suppress_ids_begin"])
Expand Down Expand Up @@ -433,6 +434,7 @@ namespace ctranslate2 {
} else { // cannot be normal text tokens
for (size_t i = 0; i < _eot_id; ++i)
disable_tokens.add(batch_id, i);
check_timestamps_prob_for_batch.push_back(batch_id);
}
} else {
check_timestamps_prob_for_batch.push_back(batch_id);
Expand All @@ -443,7 +445,7 @@ namespace ctranslate2 {
const size_t token = sequences.at<int32_t>({batch_id, t});

if (token >= _timestamp_begin_id) {
for (size_t i = _timestamp_begin_id; i <= token; ++i)
for (size_t i = _timestamp_begin_id; i < token; ++i)
disable_tokens.add(batch_id, i);
break;
}
Expand Down

0 comments on commit af0b976

Please sign in to comment.