Skip to content

Commit

Permalink
sync : whisper.cpp (metal soft max fix + example prints)
Browse files Browse the repository at this point in the history
  • Loading branch information
ggerganov committed Dec 8, 2023
1 parent c57aa8e commit 95cdaf9
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 12 deletions.
26 changes: 15 additions & 11 deletions examples/whisper/whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5028,6 +5028,7 @@ int whisper_full_with_state(
// basically don't process anything that is less than 1.0s
// see issue #39: https://github.com/ggerganov/whisper.cpp/issues/39
if (seek_end < seek_start + (params.speed_up ? 50 : 100)) {
WHISPER_PRINT_DEBUG("%s: input is too short - %d ms < 1000 ms\n", __func__, (seek_end - seek_start)*10);
return 0;
}

Expand Down Expand Up @@ -5455,6 +5456,7 @@ int whisper_full_with_state(

// do not allow to go back in time
if (has_ts && seek_delta > seek_delta_new && result_len < i) {
WHISPER_PRINT_DEBUG("%s: decoder %d: failed due to seek_delta (%d > %d)\n", __func__, j, seek_delta, seek_delta_new);
failed = true; // TODO: maybe this is not a failure ?
continue;
}
Expand Down Expand Up @@ -5483,6 +5485,7 @@ int whisper_full_with_state(
if (seek + seek_delta + 100 >= seek_end) {
result_len = i + 1;
} else {
WHISPER_PRINT_DEBUG("%s: decoder %d failed (result_len = 0)\n", __func__, j);
failed = true;
continue;
}
Expand All @@ -5493,6 +5496,7 @@ int whisper_full_with_state(
seek_delta = 100*WHISPER_CHUNK_SIZE;
}

WHISPER_PRINT_DEBUG("%s: decoder %d completed\n", __func__, j);
completed = true;
continue;
}
Expand All @@ -5508,6 +5512,7 @@ int whisper_full_with_state(
// sometimes, the decoding can get stuck in a repetition loop
// this is an attempt to mitigate such cases - we flag the decoding as failed and use a fallback strategy
if (i == n_max - 1 && (result_len == 0 || seek_delta < 100*WHISPER_CHUNK_SIZE/2)) {
WHISPER_PRINT_DEBUG("%s: decoder %d: failed due to repetition loop\n", __func__, j);
failed = true;
continue;
}
Expand Down Expand Up @@ -5651,28 +5656,27 @@ int whisper_full_with_state(
WHISPER_PRINT_DEBUG("%s: best decoder = %d\n", __func__, best_decoder_id);
}

bool success = true;

// was the decoding successful for the current temperature?
// do fallback only if:
// - we are not at the last temperature
// - we are not at the end of the audio (3 sec)
if (it != (int) temperatures.size() - 1 &&
seek_end - seek > 10*WHISPER_CHUNK_SIZE) {
bool success = true;

if (it != (int) temperatures.size() - 1) {
const auto & decoder = state->decoders[best_decoder_id];

if (decoder.failed || decoder.sequence.avg_logprobs < params.logprob_thold) {
WHISPER_PRINT_DEBUG("%s: failed due to avg_logprobs %8.5f < %8.5f\n", __func__, decoder.sequence.avg_logprobs, params.logprob_thold);
success = false;
state->n_fail_p++;
}
}

if (success) {
//for (auto & token : ctx->decoders[best_decoder_id].sequence.tokens) {
// WHISPER_PRINT_DEBUG("%s: token = %d, p = %6.3f, pt = %6.3f, ts = %s, str = %s\n", __func__, token.id, token.p, token.pt, ctx->vocab.id_to_token.at(token.tid).c_str(), ctx->vocab.id_to_token.at(token.id).c_str());
//}
if (success) {
//for (auto & token : ctx->decoders[best_decoder_id].sequence.tokens) {
// WHISPER_PRINT_DEBUG("%s: token = %d, p = %6.3f, pt = %6.3f, ts = %s, str = %s\n", __func__, token.id, token.p, token.pt, ctx->vocab.id_to_token.at(token.tid).c_str(), ctx->vocab.id_to_token.at(token.id).c_str());
//}

break;
}
break;
}

WHISPER_PRINT_DEBUG("\n%s: failed to decode with temperature = %.2f\n", __func__, t_cur);
Expand Down
6 changes: 5 additions & 1 deletion src/ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,8 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
ggml_metal_log_callback(level, buffer, ggml_metal_log_user_data);
} else {
char* buffer2 = malloc(len+1);
va_end(args);
va_start(args, format);
vsnprintf(buffer2, len+1, format, args);
buffer2[len] = 0;
ggml_metal_log_callback(level, buffer2, ggml_metal_log_user_data);
Expand Down Expand Up @@ -1193,7 +1195,9 @@ void ggml_metal_graph_compute(
const float scale = ((float *) dst->op_params)[0];

[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
if (id_src1) {
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
}
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
Expand Down

0 comments on commit 95cdaf9

Please sign in to comment.