Skip to content

Commit

Permalink
feat: solve some type hints issues
Browse files Browse the repository at this point in the history
  • Loading branch information
Equipo45 committed Dec 4, 2024
1 parent 8327d8c commit 5f9bf53
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 31 deletions.
66 changes: 37 additions & 29 deletions faster_whisper/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from dataclasses import asdict, dataclass
from inspect import signature
from math import ceil
from typing import BinaryIO, Iterable, List, Optional, Tuple, Union
from typing import Any, BinaryIO, Iterable, List, Optional, Tuple, Union
from warnings import warn

import ctranslate2
Expand Down Expand Up @@ -81,11 +81,11 @@ class TranscriptionOptions:
compression_ratio_threshold: Optional[float]
condition_on_previous_text: bool
prompt_reset_on_temperature: float
temperatures: List[float]
temperatures: Union[List[float], Tuple[float, ...]]
initial_prompt: Optional[Union[str, Iterable[int]]]
prefix: Optional[str]
suppress_blank: bool
suppress_tokens: Optional[List[int]]
suppress_tokens: Union[List[int], Tuple[int, ...]]
without_timestamps: bool
max_initial_timestamp: float
word_timestamps: bool
Expand All @@ -106,7 +106,7 @@ class TranscriptionInfo:
duration_after_vad: float
all_language_probs: Optional[List[Tuple[str, float]]]
transcription_options: TranscriptionOptions
vad_options: VadOptions
vad_options: Optional[VadOptions]


class BatchedInferencePipeline:
Expand All @@ -121,7 +121,6 @@ def forward(self, features, tokenizer, chunks_metadata, options):
encoder_output, outputs = self.generate_segment_batched(
features, tokenizer, options
)

segmented_outputs = []
segment_sizes = []
for chunk_metadata, output in zip(chunks_metadata, outputs):
Expand All @@ -130,8 +129,8 @@ def forward(self, features, tokenizer, chunks_metadata, options):
segment_sizes.append(segment_size)
(
subsegments,
seek,
single_timestamp_ending,
_,
_,
) = self.model._split_segments_by_timestamps(
tokenizer=tokenizer,
tokens=output["tokens"],
Expand Down Expand Up @@ -295,7 +294,7 @@ def transcribe(
hallucination_silence_threshold: Optional[float] = None,
batch_size: int = 8,
hotwords: Optional[str] = None,
language_detection_threshold: Optional[float] = 0.5,
language_detection_threshold: float = 0.5,
language_detection_segments: int = 1,
) -> Tuple[Iterable[Segment], TranscriptionInfo]:
"""transcribe audio in chunks in batched fashion and return with language info.
Expand Down Expand Up @@ -582,7 +581,7 @@ def __init__(
num_workers: int = 1,
download_root: Optional[str] = None,
local_files_only: bool = False,
files: dict = None,
files: Optional[dict] = None,
**model_kwargs,
):
"""Initializes the Whisper model.
Expand Down Expand Up @@ -731,7 +730,7 @@ def transcribe(
clip_timestamps: Union[str, List[float]] = "0",
hallucination_silence_threshold: Optional[float] = None,
hotwords: Optional[str] = None,
language_detection_threshold: Optional[float] = 0.5,
language_detection_threshold: float = 0.5,
language_detection_segments: int = 1,
) -> Tuple[Iterable[Segment], TranscriptionInfo]:
"""Transcribes an input file.
Expand Down Expand Up @@ -833,7 +832,7 @@ def transcribe(
elif isinstance(vad_parameters, dict):
vad_parameters = VadOptions(**vad_parameters)
speech_chunks = get_speech_timestamps(audio, vad_parameters)
audio_chunks, chunks_metadata = collect_chunks(audio, speech_chunks)
audio_chunks, _ = collect_chunks(audio, speech_chunks)
audio = np.concatenate(audio_chunks, axis=0)
duration_after_vad = audio.shape[0] / sampling_rate

Expand Down Expand Up @@ -925,7 +924,7 @@ def transcribe(
condition_on_previous_text=condition_on_previous_text,
prompt_reset_on_temperature=prompt_reset_on_temperature,
temperatures=(
temperature if isinstance(temperature, (list, tuple)) else [temperature]
temperature if isinstance(temperature, (List, Tuple)) else [temperature]
),
initial_prompt=initial_prompt,
prefix=prefix,
Expand Down Expand Up @@ -953,7 +952,8 @@ def transcribe(

if speech_chunks:
segments = restore_speech_timestamps(segments, speech_chunks, sampling_rate)

if isinstance(vad_parameters, dict):
vad_parameters = VadOptions(**vad_parameters)
info = TranscriptionInfo(
language=language,
language_probability=language_probability,
Expand All @@ -974,7 +974,7 @@ def _split_segments_by_timestamps(
segment_size: int,
segment_duration: float,
seek: int,
) -> List[List[int]]:
) -> Tuple[List[Any], int, bool]:
current_segments = []
single_timestamp_ending = (
len(tokens) >= 2 and tokens[-2] < tokenizer.timestamp_begin <= tokens[-1]
Expand Down Expand Up @@ -1517,8 +1517,8 @@ def add_word_timestamps(
num_frames: int,
prepend_punctuations: str,
append_punctuations: str,
last_speech_timestamp: float,
) -> float:
last_speech_timestamp: Union[float, None],
) -> Optional[float]:
if len(segments) == 0:
return

Expand Down Expand Up @@ -1665,9 +1665,11 @@ def find_alignment(
text_indices = np.array([pair[0] for pair in alignments])
time_indices = np.array([pair[1] for pair in alignments])

words, word_tokens = tokenizer.split_to_word_tokens(
text_token + [tokenizer.eot]
)
if isinstance(text_token, int):
tokens = [text_token] + [tokenizer.eot]
else:
tokens = text_token + [tokenizer.eot]
words, word_tokens = tokenizer.split_to_word_tokens(tokens)
if len(word_tokens) <= 1:
# return on eot only
# >>> np.pad([], (1, 0))
Expand Down Expand Up @@ -1715,7 +1717,7 @@ def detect_language(
audio: Optional[np.ndarray] = None,
features: Optional[np.ndarray] = None,
vad_filter: bool = False,
vad_parameters: Union[dict, VadOptions] = None,
vad_parameters: Optional[Union[dict, VadOptions]] = None,
language_detection_segments: int = 1,
language_detection_threshold: float = 0.5,
) -> Tuple[str, float, List[Tuple[str, float]]]:
Expand Down Expand Up @@ -1747,18 +1749,24 @@ def detect_language(
if audio is not None:
if vad_filter:
speech_chunks = get_speech_timestamps(audio, vad_parameters)
audio_chunks, chunks_metadata = collect_chunks(audio, speech_chunks)
audio_chunks, _ = collect_chunks(audio, speech_chunks)
audio = np.concatenate(audio_chunks, axis=0)

assert (
audio is not None
), "Audio have a problem while concatanating the audio_chunks; return None"
audio = audio[
: language_detection_segments * self.feature_extractor.n_samples
]
features = self.feature_extractor(audio)

assert (
features is not None
), "No features extracted from audio file; return None"
features = features[
..., : language_detection_segments * self.feature_extractor.nb_max_frames
]

assert (
features is not None
), "No features extracted when detectting language in audio segments; return None"
detected_language_info = {}
for i in range(0, features.shape[-1], self.feature_extractor.nb_max_frames):
encoder_output = self.encode(
Expand Down Expand Up @@ -1828,13 +1836,13 @@ def get_compression_ratio(text: str) -> float:

def get_suppressed_tokens(
tokenizer: Tokenizer,
suppress_tokens: Tuple[int],
) -> Optional[List[int]]:
if -1 in suppress_tokens:
suppress_tokens: Optional[List[int]],
) -> Tuple[int, ...]:
if suppress_tokens is None or len(suppress_tokens) == 0:
suppress_tokens = [] # interpret empty string as an empty list
elif -1 in suppress_tokens:
suppress_tokens = [t for t in suppress_tokens if t >= 0]
suppress_tokens.extend(tokenizer.non_speech_tokens)
elif suppress_tokens is None or len(suppress_tokens) == 0:
suppress_tokens = [] # interpret empty string as an empty list
else:
assert isinstance(suppress_tokens, list), "suppress_tokens must be a list"

Expand Down
7 changes: 5 additions & 2 deletions faster_whisper/vad.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os

from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Tuple, Union

import numpy as np

Expand Down Expand Up @@ -44,7 +44,7 @@ class VadOptions:

def get_speech_timestamps(
audio: np.ndarray,
vad_options: Optional[VadOptions] = None,
vad_options: Optional[Union[dict, VadOptions]] = None,
sampling_rate: int = 16000,
**kwargs,
) -> List[dict]:
Expand All @@ -62,6 +62,9 @@ def get_speech_timestamps(
if vad_options is None:
vad_options = VadOptions(**kwargs)

if isinstance(vad_options, dict):
vad_options = VadOptions(**vad_options)

threshold = vad_options.threshold
min_speech_duration_ms = vad_options.min_speech_duration_ms
max_speech_duration_s = vad_options.max_speech_duration_s
Expand Down

0 comments on commit 5f9bf53

Please sign in to comment.