Skip to content

Commit

Permalink
Adds batched command line argument to enabler faster batch inference
Browse files Browse the repository at this point in the history
  • Loading branch information
jordimas committed Nov 21, 2024
1 parent f34b3b4 commit 282953e
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 4 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
numpy
faster-whisper>=1.0.2
faster-whisper>=1.1.0
ctranslate2
tqdm
sounddevice
Expand Down
7 changes: 7 additions & 0 deletions src/whisper_ctranslate2/commandline.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,13 @@ def read_command_line():
help="Hotwords/hint phrases to the model. Useful for names you want the model to priotize",
)

algorithm_args.add_argument(
"--batched",
type=CommandLine()._str2bool,
default="False",
help="Uses Batched transcription which can provide an additional 2x-3x speed increase",
)

vad_args = parser.add_argument_group("VAD filter arguments")

vad_args.add_argument(
Expand Down
1 change: 1 addition & 0 deletions src/whisper_ctranslate2/live.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def process(self):
self.threads,
self.cache_directory,
self.local_files_only,
False,
)

result = self.transcribe.inference(
Expand Down
18 changes: 15 additions & 3 deletions src/whisper_ctranslate2/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import NamedTuple, Optional, List, Union
import tqdm
import sys
from faster_whisper import WhisperModel
from faster_whisper import WhisperModel, BatchedInferencePipeline
from .languages import LANGUAGES
from typing import BinaryIO
import numpy as np
Expand Down Expand Up @@ -107,6 +107,7 @@ def __init__(
threads: int,
cache_directory: str,
local_files_only: bool,
batched: bool,
):
self.model = WhisperModel(
model_path,
Expand All @@ -117,6 +118,10 @@ def __init__(
download_root=cache_directory,
local_files_only=local_files_only,
)
if batched:
self.batched_model = BatchedInferencePipeline(model=self.model)
else:
self.batched_model = None

def inference(
self,
Expand All @@ -129,7 +134,14 @@ def inference(
):
vad_parameters = self._get_vad_parameters_dictionary(options)

segments, info = self.model.transcribe(
if self.batched_model:
model = self.batched_model
vad = True
else:
model = self.model
vad = options.vad_filter

segments, info = model.transcribe(
audio=audio,
language=language,
task=task,
Expand All @@ -154,7 +166,7 @@ def inference(
prepend_punctuations=options.prepend_punctuations,
append_punctuations=options.append_punctuations,
hallucination_silence_threshold=options.hallucination_silence_threshold,
vad_filter=options.vad_filter,
vad_filter=vad,
vad_parameters=vad_parameters,
)

Expand Down
2 changes: 2 additions & 0 deletions src/whisper_ctranslate2/whisper_ctranslate2.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def main():
live_input_device: int = args.pop("live_input_device")
hf_token = args.pop("hf_token")
speaker_name = args.pop("speaker_name")
batched = args.pop("batched")

if model == "large-v3-turbo":
model = "deepdml/faster-whisper-large-v3-turbo-ct2"
Expand Down Expand Up @@ -214,6 +215,7 @@ def main():
threads,
cache_directory,
local_files_only,
batched,
)

diarization = len(hf_token) > 0
Expand Down

0 comments on commit 282953e

Please sign in to comment.