Skip to content

Commit

Permalink
Revert "Fix support of English-only Whisper models (#1080)" (#1082)
Browse files Browse the repository at this point in the history
This reverts commit 71bc055.
  • Loading branch information
guillaumekln authored Feb 13, 2023
1 parent dfd4230 commit 7ee6a34
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 108 deletions.
4 changes: 0 additions & 4 deletions include/ctranslate2/models/whisper.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,6 @@ namespace ctranslate2 {

WhisperReplica(const std::shared_ptr<const WhisperModel>& model);

bool is_multilingual() const;

std::vector<WhisperGenerationResult>
generate(const StorageView& features,
const std::vector<std::vector<std::string>>& prompts,
Expand All @@ -110,8 +108,6 @@ namespace ctranslate2 {
public:
using ReplicaPool::ReplicaPool;

bool is_multilingual() const;

std::vector<std::future<WhisperGenerationResult>>
generate(StorageView features,
std::vector<std::vector<std::string>> prompts,
Expand Down
5 changes: 0 additions & 5 deletions include/ctranslate2/replica_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,11 +146,6 @@ namespace ctranslate2 {
}

protected:
const Replica& get_first_replica() const {
auto& worker = static_cast<ReplicaWorker<Replica>&>(_thread_pool->get_worker(0));
return worker.replica();
}

template <typename Result, typename Func>
std::vector<std::future<Result>>
post_examples(const std::vector<Example>& examples,
Expand Down
10 changes: 0 additions & 10 deletions python/cpp/whisper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,6 @@ namespace ctranslate2 {
public:
using ReplicaPoolHelper::ReplicaPoolHelper;

bool is_multilingual() const {
return _pool->is_multilingual();
}

std::variant<std::vector<models::WhisperGenerationResult>,
std::vector<AsyncResult<models::WhisperGenerationResult>>>
generate(StorageViewWrapper features,
Expand Down Expand Up @@ -101,9 +97,6 @@ namespace ctranslate2 {
https://github.com/openai/whisper
)pbdoc")

.def_property_readonly("is_multilingual", &WhisperWrapper::is_multilingual,
"Returns ``True`` if this model is multilingual.")

.def(py::init<const std::string&, const std::string&, const std::variant<int, std::vector<int>>&, const StringOrMap&, size_t, size_t, long, py::object>(),
py::arg("model_path"),
py::arg("device")="cpu",
Expand Down Expand Up @@ -193,9 +186,6 @@ namespace ctranslate2 {
Returns:
For each batch, a list of pairs (language, probability) ordered from
best to worst probability.
Raises:
RuntimeError: if the model is not multilingual.
)pbdoc")

;
Expand Down
41 changes: 7 additions & 34 deletions python/tests/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,10 +352,9 @@ def test_transformers_generator_suppress_sequences(tmpdir):
@test_utils.only_on_linux
@test_utils.on_available_devices
@pytest.mark.parametrize(
"model_name,prompts,expected_transcriptions,expected_no_speech_probs",
"prompts,expected_transcriptions,expected_no_speech_probs",
[
(
"openai/whisper-tiny",
[
[
"<|startoftranscript|>",
Expand All @@ -382,7 +381,6 @@ def test_transformers_generator_suppress_sequences(tmpdir):
],
),
(
"openai/whisper-tiny",
[
["<|startoftranscript|>", "<|en|>", "<|transcribe|>"],
["<|startoftranscript|>", "<|en|>", "<|transcribe|>"],
Expand All @@ -399,7 +397,6 @@ def test_transformers_generator_suppress_sequences(tmpdir):
],
),
(
"openai/whisper-tiny",
[
[
"<|startoftranscript|>",
Expand All @@ -425,32 +422,14 @@ def test_transformers_generator_suppress_sequences(tmpdir):
pytest.approx(0.06885894387960434, abs=1e-2),
],
),
(
"openai/whisper-tiny.en",
[["<|startoftranscript|>"], ["<|startoftranscript|>"]],
[
" Mr. Quilter is the apostle of the middle classes, and we are glad"
" to welcome his gospel.",
" And so, my fellow Americans ask not what your country can do for you"
" ask what you can do for your country.",
],
[
pytest.approx(0.02644546702504158, abs=1e-4),
pytest.approx(0.062380101531744, abs=1e-3),
],
),
],
)
def test_transformers_whisper(
tmpdir,
device,
model_name,
prompts,
expected_transcriptions,
expected_no_speech_probs,
tmpdir, device, prompts, expected_transcriptions, expected_no_speech_probs
):
import transformers

model_name = "openai/whisper-tiny"
converter = ctranslate2.converters.TransformersConverter(model_name)
output_dir = str(tmpdir.join("ctranslate2_model"))
output_dir = converter.convert(output_dir)
Expand All @@ -475,16 +454,10 @@ def _get_features(audio):

model = ctranslate2.models.Whisper(output_dir, device=device)

assert model.is_multilingual == (not model_name.endswith(".en"))

if model.is_multilingual:
for result in model.detect_language(features):
best_lang, best_prob = result[0]
assert best_lang == "<|en|>"
assert best_prob > 0.9
else:
with pytest.raises(RuntimeError, match="multilingual"):
model.detect_language(features)
for result in model.detect_language(features):
best_lang, best_prob = result[0]
assert best_lang == "<|en|>"
assert best_prob > 0.9

results = model.generate(
features,
Expand Down
3 changes: 1 addition & 2 deletions src/layers/transformer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -365,8 +365,7 @@ namespace ctranslate2 {
const auto it = state.find("memory_lengths");
const StorageView* memory_lengths = it != state.end() ? &it->second : nullptr;

const auto cached_memory_proj_it = state.find("memory_keys_0");
if (cached_memory_proj_it == state.end() || cached_memory_proj_it->second.empty()) {
if (step <= 0) {
memory = &state.at("memory");

if (memory_lengths && allow_padding_removal) {
Expand Down
73 changes: 20 additions & 53 deletions src/models/whisper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,6 @@ namespace ctranslate2 {
{
}

bool WhisperReplica::is_multilingual() const {
const auto& vocabulary = _model->get_vocabulary();
return vocabulary.size() == 51865;
}

StorageView WhisperReplica::encode(const StorageView& features) {
const Device device = _model->device();
const DataType dtype = _encoder->output_type();
Expand Down Expand Up @@ -183,8 +178,9 @@ namespace ctranslate2 {
"simply adapt the number of previous text tokens in each "
"batch.");

if (prompts[0].empty())
throw std::invalid_argument("The prompt cannot be empty");
if (prompts[0].size() < 3)
throw std::invalid_argument("The prompt should have at least 3 tokens: "
"START OF TRANSCRIPT, LANGUAGE TAG, and TRANSCRIBE/TRANSLATE");

const auto& vocabulary = _model->get_vocabulary();
const auto scoped_device_setter = _model->get_scoped_device_setter();
Expand All @@ -199,51 +195,30 @@ namespace ctranslate2 {
prefix_tokens.reserve(prompts.size());
start_tokens.reserve(prompts.size());
for (const auto& prompt : prompts) {
if (prompt.size() > 1)
prefix_tokens.emplace_back(prompt.begin(), prompt.end() - 1);
else if (options.return_no_speech_prob)
prefix_tokens.emplace_back(prompt);
prefix_tokens.emplace_back(prompt.begin(), prompt.end() - 1);
start_tokens.emplace_back(prompt.end() - 1, prompt.end());
}

std::vector<float> no_speech_probs;
dim_t start_step = 0;

if (!prefix_tokens.empty()) {
const Device device = _decoder->device();
const DataType dtype = _decoder->output_type();
StorageView inputs = layers::make_sequence_inputs(prefix_tokens, device);
StorageView outputs(dtype, device);

// Forward the prefix.
_decoder->forward_prompt(inputs, state, options.return_no_speech_prob ? &outputs : nullptr);

if (options.return_no_speech_prob) {
// Get the probability of the no speech token at the start of transcript step.
StorageView sot_index = get_sot_index(prefix_tokens, vocabulary.bos_id(), device);
size_t no_speech_id = vocabulary.to_id("<|nospeech|>");
if (no_speech_id == vocabulary.unk_id())
no_speech_id = vocabulary.to_id("<|nocaptions|>");
no_speech_probs = get_no_speech_probs(*_decoder, outputs, sot_index, no_speech_id);
}
const Device device = _decoder->device();
const DataType dtype = _decoder->output_type();
StorageView inputs = layers::make_sequence_inputs(prefix_tokens, device);
StorageView outputs(dtype, device);

if (prompts[0].size() > 1)
start_step = inputs.dim(1);
else {
// If the prompt only contains the start token, it means we only got here to retrieve
// the no speech probability. The decoding will start from this token again so we need
// to reset the decoder state.
for (auto& pair : state) {
const auto& name = pair.first;
auto& tensor = pair.second;
if (!starts_with(name, "memory"))
tensor.clear();
}
}
// Initialize the decoder state with the prompt.
_decoder->forward_prompt(inputs, state, options.return_no_speech_prob ? &outputs : nullptr);

std::vector<float> no_speech_probs;
if (options.return_no_speech_prob) {
// Get the probability of the no speech token at the start of transcript step.
StorageView sot_index = get_sot_index(prefix_tokens, vocabulary.bos_id(), device);
size_t no_speech_id = vocabulary.to_id("<|nospeech|>");
if (no_speech_id == vocabulary.unk_id())
no_speech_id = vocabulary.to_id("<|nocaptions|>");
no_speech_probs = get_no_speech_probs(*_decoder, outputs, sot_index, no_speech_id);
}

DecodingOptions decoding_options;
decoding_options.start_step = start_step;
decoding_options.start_step = inputs.dim(1);
decoding_options.beam_size = options.beam_size;
decoding_options.patience = options.patience;
decoding_options.length_penalty = options.length_penalty;
Expand Down Expand Up @@ -291,9 +266,6 @@ namespace ctranslate2 {

std::vector<std::vector<std::pair<std::string, float>>>
WhisperReplica::detect_language(const StorageView& features) {
if (!is_multilingual())
throw std::runtime_error("detect_language can only be called on multilingual models");

PROFILE("WhisperReplica::detect_language");

const auto scoped_device_setter = _model->get_scoped_device_setter();
Expand Down Expand Up @@ -357,11 +329,6 @@ namespace ctranslate2 {
}


bool Whisper::is_multilingual() const {
const auto& replica = get_first_replica();
return replica.is_multilingual();
}

std::vector<std::future<WhisperGenerationResult>>
Whisper::generate(StorageView features,
std::vector<std::vector<std::string>> prompts,
Expand Down

0 comments on commit 7ee6a34

Please sign in to comment.