Skip to content

Commit

Permalink
support distil-whisper (#557)
Browse files Browse the repository at this point in the history
  • Loading branch information
metame-none authored Jan 24, 2024
1 parent 72ff979 commit ad3c830
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 3 deletions.
43 changes: 43 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ This implementation is up to 4 times faster than [openai/whisper](https://github

## Benchmark

### Whisper

For reference, here's the time and memory usage that are required to transcribe [**13 minutes**](https://www.youtube.com/watch?v=0u7tTptBo9I) of audio using different implementations:

* [openai/whisper](https://github.com/openai/whisper)@[6dea21fd](https://github.com/openai/whisper/commit/6dea21fd7f7253bfe450f1e2512a0fe47ee2d258)
Expand Down Expand Up @@ -36,6 +38,33 @@ For reference, here's the time and memory usage that are required to transcribe

*Executed with 8 threads on a Intel(R) Xeon(R) Gold 6226R.*


### Distil-whisper

| Implementation | Precision | Beam size | Time | Gigaspeech WER |
| --- | --- | --- | --- | --- |
| distil-whisper/distil-large-v2 | fp16 | 4 |- | 10.36 |
| [faster-distil-large-v2](https://huggingface.co/Systran/faster-distil-whisper-large-v2) | fp16 | 5 | - | 10.28 |
| distil-whisper/distil-medium.en | fp16 | 4 | - | 11.21 |
| [faster-distil-medium.en](https://huggingface.co/Systran/faster-distil-whisper-medium.en) | fp16 | 5 | - | 11.21 |

*Executed with CUDA 11.4 on a NVIDIA 3090.*

<details>
<summary>testing details (click to expand)</summary>

For `distil-whisper/distil-large-v2`, the WER is tested with code sample from [link](https://huggingface.co/distil-whisper/distil-large-v2#evaluation). for `faster-distil-whisper`, the WER is tested with setting:
```python
from faster_whisper import WhisperModel

model_size = "distil-large-v2"
# model_size = "distil-medium.en"
# Run on GPU with FP16
model = WhisperModel(model_size, device="cuda", compute_type="float16")
segments, info = model.transcribe("audio.mp3", beam_size=5, language="en")
```
</details>

## Requirements

* Python 3.8 or greater
Expand Down Expand Up @@ -101,6 +130,8 @@ pip install --force-reinstall "faster-whisper @ https://github.com/guillaumekln/

## Usage

### Faster-whisper

```python
from faster_whisper import WhisperModel

Expand Down Expand Up @@ -128,6 +159,18 @@ for segment in segments:
segments, _ = model.transcribe("audio.mp3")
segments = list(segments) # The transcription will actually run here.
```
### Faster-distil-whisper
For usage of `faster-ditil-whisper`, please refer to: https://github.com/guillaumekln/faster-whisper/issues/533

```python
model_size = "distil-large-v2"
# model_size = "distil-medium.en"
model = WhisperModel(model_size, device="cuda", compute_type="float16")
segments, info = model.transcribe("audio.mp3", beam_size=5,
language="en", max_new_tokens=128, condition_on_previous_text=False)

```
NOTE: emprically, `condition_on_previous_text=True` will degrade the performance of `faster-distil-whisper` for long audio.

### Word-level timestamps

Expand Down
6 changes: 5 additions & 1 deletion faster_whisper/feature_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,15 @@ def stft(self, frames, window):
data[f] = np.fft.fft(fft_signal, axis=0)[:num_fft_bins]
return data.T

def __call__(self, waveform, padding=True):
def __call__(self, waveform, padding=True, chunk_length=None):
"""
Compute the log-Mel spectrogram of the provided audio, gives similar results
whisper's original torch implementation with 1e-5 tolerance.
"""
if chunk_length is not None:
self.n_samples = chunk_length * self.sampling_rate
self.nb_max_frames = self.n_samples // self.hop_length

if padding:
waveform = np.pad(waveform, [(0, self.n_samples)])

Expand Down
27 changes: 25 additions & 2 deletions faster_whisper/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ class TranscriptionOptions(NamedTuple):
word_timestamps: bool
prepend_punctuations: str
append_punctuations: str
max_new_tokens: Optional[int]


class TranscriptionInfo(NamedTuple):
Expand Down Expand Up @@ -213,6 +214,8 @@ def transcribe(
append_punctuations: str = "\"'.。,,!!??::”)]}、",
vad_filter: bool = False,
vad_parameters: Optional[Union[dict, VadOptions]] = None,
max_new_tokens: Optional[int] = None,
chunk_length: Optional[int] = None,
) -> Tuple[Iterable[Segment], TranscriptionInfo]:
"""Transcribes an input file.
Expand Down Expand Up @@ -264,6 +267,10 @@ def transcribe(
https://github.com/snakers4/silero-vad.
vad_parameters: Dictionary of Silero VAD parameters or VadOptions class (see available
parameters and default values in the class `VadOptions`).
max_new_tokens: Maximum number of new tokens to generate. If not set, the maximum will be
set by the default max_length.
chunk_length: The length of audio segments. If it is not None, it will overwrite the
default chunk_length of the FeatureExtractor.
Returns:
A tuple with:
Expand Down Expand Up @@ -313,7 +320,7 @@ def transcribe(
else:
speech_chunks = None

features = self.feature_extractor(audio)
features = self.feature_extractor(audio, chunk_length=chunk_length)

encoder_output = None
all_language_probs = None
Expand Down Expand Up @@ -379,6 +386,7 @@ def transcribe(
word_timestamps=word_timestamps,
prepend_punctuations=prepend_punctuations,
append_punctuations=append_punctuations,
max_new_tokens=max_new_tokens,
)

segments = self.generate_segments(features, tokenizer, options, encoder_output)
Expand Down Expand Up @@ -642,6 +650,21 @@ def generate_with_fallback(
max_initial_timestamp_index = int(
round(options.max_initial_timestamp / self.time_precision)
)
if options.max_new_tokens is not None:
max_length = len(prompt) + options.max_new_tokens
else:
max_length = self.max_length

if max_length > self.max_length:
raise ValueError(
f"The length of the prompt is {len(prompt)}, and the `max_new_tokens` "
f"{max_length - len(prompt)}. Thus, the combined length of the prompt "
f"and `max_new_tokens` is: {max_length}. This exceeds the "
f"`max_length` of the Whisper model: {self.max_length}. "
"You should either reduce the length of your prompt, or "
"reduce the value of `max_new_tokens`, "
f"so that their combined length is less that {self.max_length}."
)

for temperature in options.temperatures:
if temperature > 0:
Expand All @@ -663,7 +686,7 @@ def generate_with_fallback(
length_penalty=options.length_penalty,
repetition_penalty=options.repetition_penalty,
no_repeat_ngram_size=options.no_repeat_ngram_size,
max_length=self.max_length,
max_length=max_length,
return_scores=True,
return_no_speech_prob=True,
suppress_blank=options.suppress_blank,
Expand Down
3 changes: 3 additions & 0 deletions faster_whisper/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
"large-v2": "Systran/faster-whisper-large-v2",
"large-v3": "Systran/faster-whisper-large-v3",
"large": "Systran/faster-whisper-large-v3",
"distil-large-v2": "Systran/faster-distil-whisper-large-v2",
"distil-medium.en": "Systran/faster-distil-whisper-medium.en",
"distil-small.en": "Systran/faster-distil-whisper-small.en",
}


Expand Down

0 comments on commit ad3c830

Please sign in to comment.