From aa271aad722a3df896abbf8b455eab9d72897060 Mon Sep 17 00:00:00 2001 From: Jordi Mas Date: Tue, 6 Jun 2023 07:25:31 +0200 Subject: [PATCH] Re-organize code to load model only once to work arround Windows issue with CTranslate2 #11 --- Makefile | 2 +- src/whisper_ctranslate2/live.py | 21 +++++++++------ src/whisper_ctranslate2/transcribe.py | 27 ++++++++++--------- .../whisper_ctranslate2.py | 19 +++++++------ 4 files changed, 40 insertions(+), 29 deletions(-) diff --git a/Makefile b/Makefile index 9dbb9bc..fecffa8 100644 --- a/Makefile +++ b/Makefile @@ -5,7 +5,7 @@ run: pip3 install --force-reinstall . run-e2e-tests: - pip install --force-reinstall ctranslate2==3.13.0 + pip install --force-reinstall ctranslate2==3.14.0 pip install --force-reinstall faster-whisper==0.6.0 CT2_USE_MKL="False" CT2_FORCE_CPU_ISA='GENERIC' nose2 -s e2e-tests diff --git a/src/whisper_ctranslate2/live.py b/src/whisper_ctranslate2/live.py index dd98bc9..b5c5898 100644 --- a/src/whisper_ctranslate2/live.py +++ b/src/whisper_ctranslate2/live.py @@ -57,6 +57,7 @@ def __init__( self.speaking = False self.blocks_speaking = 0 self.buffers_to_process = [] + self.transcribe = None @staticmethod def is_available(): @@ -119,17 +120,21 @@ def process(self): if self.verbose: print("\n\033[90mTranscribing..\033[0m") - result = Transcribe().inference( + if not self.transcribe: + self.transcribe = Transcribe( + self.model_path, + self.device, + self.device_index, + self.compute_type, + self.threads, + self.cache_directory, + self.local_files_only, + ) + + result = self.transcribe.inference( audio=_buffer.flatten(), - model_path=self.model_path, - cache_directory=self.cache_directory, - local_files_only=self.local_files_only, task=self.task, language=self.language, - threads=self.threads, - device=self.device, - device_index=self.device_index, - compute_type=self.compute_type, verbose=self.verbose, live=True, options=self.options, diff --git a/src/whisper_ctranslate2/transcribe.py b/src/whisper_ctranslate2/transcribe.py index ca53fac..41d797c 100644 --- a/src/whisper_ctranslate2/transcribe.py +++ b/src/whisper_ctranslate2/transcribe.py @@ -94,23 +94,17 @@ def _get_vad_parameters_dictionary(self, options): return vad_parameters - def inference( + def __init__( self, - audio: Union[str, BinaryIO, np.ndarray], model_path: str, - cache_directory: str, - local_files_only: bool, - task: str, - language: str, - threads: int, device: str, device_index: Union[int, List[int]], compute_type: str, - verbose: bool, - live: bool, - options: TranscriptionOptions, + threads: int, + cache_directory: str, + local_files_only: bool, ): - model = WhisperModel( + self.model = WhisperModel( model_path, device=device, device_index=device_index, @@ -120,9 +114,18 @@ def inference( local_files_only=local_files_only, ) + def inference( + self, + audio: Union[str, BinaryIO, np.ndarray], + task: str, + language: str, + verbose: bool, + live: bool, + options: TranscriptionOptions, + ): vad_parameters = self._get_vad_parameters_dictionary(options) - segments, info = model.transcribe( + segments, info = self.model.transcribe( audio=audio, language=language, task=task, diff --git a/src/whisper_ctranslate2/whisper_ctranslate2.py b/src/whisper_ctranslate2/whisper_ctranslate2.py index 1ff8335..3d4c18c 100644 --- a/src/whisper_ctranslate2/whisper_ctranslate2.py +++ b/src/whisper_ctranslate2/whisper_ctranslate2.py @@ -513,18 +513,21 @@ def main(): return + transcribe = Transcribe( + model_dir, + device, + device_index, + compute_type, + threads, + cache_directory, + local_files_only, + ) + for audio_path in audio: - result = Transcribe().inference( + result = transcribe.inference( audio_path, - model_dir, - cache_directory, - local_files_only, task, language, - threads, - device, - device_index, - compute_type, verbose, False, options,