Skip to content

Commit

Permalink
add new funtion #82 and fixed bug #71
Browse files Browse the repository at this point in the history
  • Loading branch information
CheshireCC committed Mar 4, 2024
1 parent c088113 commit 6cf5513
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 20 deletions.
8 changes: 8 additions & 0 deletions __init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
import whisperx
import faster_whisper_GUI

# fixed bugs after complited with Nuitka
from av.audio.codeccontext import AudioCodecContext
from av.video.codeccontext import VideoCodecContext
from av.data import stream
from av.packet import Packet
from av.subtitles import stream
from av.subtitles import subtitle
from av.subtitles import codeccontext
15 changes: 8 additions & 7 deletions fasterWhisperGUIConfig.json
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
"preciese": 0,
"thread_num": "4",
"num_worker": "1",
"download_root": "",
"download_root": "G:/Program Files (x86)/FasterWhisperGUI/cache",
"local_files_only": false
},
"vad_param": {
Expand All @@ -38,10 +38,11 @@
"themeColor": "#009faa"
},
"Transcription_param": {
"language": 0,
"aggregate_contents": true,
"language": 1,
"task": false,
"beam_size": "5",
"best_of": "6",
"best_of": "5",
"patience": "1.0",
"length_penalty": "1.0",
"temperature": "0.0,0.2,0.4,0.6,0.8,1.0",
Expand All @@ -66,11 +67,11 @@
"tabMovable": false,
"tabScrollable": false,
"tabShadowEnabled": false,
"tabMaxWidth": 220,
"tabMaxWidth": 367,
"closeDisplayMode": 0,
"whisperXMinSpeaker": 2,
"whisperXMaxSpeaker": 2,
"outputFormat": 0,
"whisperXMinSpeaker": 3,
"whisperXMaxSpeaker": 3,
"outputFormat": 2,
"outputEncoding": 1
}
}
12 changes: 7 additions & 5 deletions faster_whisper_GUI/mainWindows.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def getParam_model(self) -> dict:
compute_type: str = self.page_model.preciese_combox.currentText()
cpu_threads: int = int(self.page_model.LineEdit_cpu_threads.text().replace(" ", ""))
num_workers: int = int(self.page_model.LineEdit_num_workers.text().replace(" ", ""))
download_root: str = self.page_model.LineEdit_download_root.text().replace(" ", "")
download_root: str = self.page_model.LineEdit_download_root.text().strip()
local_files_only: bool = self.page_model.switchButton_local_files_only.isChecked()
use_v3_model: bool = self.page_model.switchButton_use_v3.isChecked()

Expand Down Expand Up @@ -359,9 +359,7 @@ def getParamWhisperX(self) -> dict:
if dict_WhisperXParams["min_speaker"] == 0 and dict_WhisperXParams["max_speaker"] == 0:
dict_WhisperXParams["min_speaker"] = None
dict_WhisperXParams["max_speaker"] = None
else:
dict_WhisperXParams["min_speaker"] = None
dict_WhisperXParams["max_speaker"] = None


return dict_WhisperXParams

Expand Down Expand Up @@ -845,9 +843,10 @@ def outputSubtitleFile(self):
output_dir = self.page_output.outputGroupWidget.LineEdit_output_dir.text()
code_ = self.page_output.combox_output_code.currentText()

aggregate_contents_according_to_the_speaker = self.page_transcribes.switchButton_aggregate_contents_according_to_the_speaker.isChecked()
result_to_write = self.current_result # self.result_faster_whisper if (self.result_whisperx_aligment is None and self.result_whisperx_speaker_diarize is None) else (self.result_whisperx_aligment if self.result_whisperx_speaker_diarize is None else self.result_whisperx_speaker_diarize)

self.outputWorker = OutputWorker(result_to_write, output_dir, format, code_,self)
self.outputWorker = OutputWorker(result_to_write, output_dir, format, code_, aggregate_contents_according_to_the_speaker , self)
self.outputWorker.signal_write_over.connect(self.outputOver)
self.outputWorker.start()
self.setStateTool(self.tr("保存文件"), self.tr("输出字幕文件"), False)
Expand Down Expand Up @@ -947,6 +946,9 @@ def whisperXDiarizeSpeakers(self):

if self.whisperXWorker is None:

print(f"min_speaker: {whisperParams['min_speaker']}")
print(f"max_speaker: {whisperParams['max_speaker']}")

self.whisperXWorker = WhisperXWorker(result_needed
, alignment=False
, speaker_diarize=True
Expand Down
9 changes: 8 additions & 1 deletion faster_whisper_GUI/modelLoad.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,20 @@ def run(self) -> None:
def stop(self):
self.isRunning = False

def loadModel(self,model_size_or_path:str=None):
def loadModel(self, model_size_or_path:str=None):

model = None
try:
if model_size_or_path is None:
model_size_or_path = self.model_size_or_path

# 尝试替换空格,以处理带有空格的路径
# model_size_or_path = model_size_or_path.replace("\\", "/")
# model_size_or_path = model_size_or_path.replace(" ", "\ ")

# self.download_root = self.download_root.replace("\\", "/")
# self.download_root = self.download_root.replace(" ", "\ ")

model = WhisperModel(
model_size_or_path,
device=self.device,
Expand Down
13 changes: 13 additions & 0 deletions faster_whisper_GUI/tranccribePageNavigationInterface.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,15 @@ def setupUI(self):

widget_list.append(self.word_level_timestampels_param_widget)

# --------------------------------------------------------------------------------------------
self.switchButton_aggregate_contents_according_to_the_speaker = SwitchButton()
self.switchButton_aggregate_contents_according_to_the_speaker.setChecked(False)
self.aggregate_contents_param_widget = ParamWidget(self.__tr("根据说话人聚合内容"),
self.__tr("按顺序讲相同说话人的内容聚合到一起,仅支持 txt 格式输出"),
self.switchButton_aggregate_contents_according_to_the_speaker
)
widget_list.append(self.aggregate_contents_param_widget)

# =======================================================================================================
self.titleLabel_auditory_hallucination = TitleLabel(self.__tr("幻听参数"))
widget_list.append(self.titleLabel_auditory_hallucination)
Expand Down Expand Up @@ -472,6 +481,8 @@ def loadParamsFromFile(self):

def setParam(self, Transcribe_params:dict) -> None:

self.switchButton_aggregate_contents_according_to_the_speaker.setChecked(Transcribe_params["aggregate_contents"])

self.combox_language.setCurrentIndex(Transcribe_params["language"])
# Transcribe_params["language"] = language_index

Expand Down Expand Up @@ -546,6 +557,8 @@ def getParam(self) -> dict:
Transcribe_params = {}

# 从数据模型获取文件列表

Transcribe_params["aggregate_contents"] = self.switchButton_aggregate_contents_according_to_the_speaker.isChecked()

language = self.combox_language.currentIndex()
Transcribe_params["language"] = language
Expand Down
28 changes: 21 additions & 7 deletions faster_whisper_GUI/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def __init__(self,
output_dir:str,
format:str,
output_code = "UTF-8",
aggregate_contents = False,
parent=None
) -> None:

Expand All @@ -136,6 +137,7 @@ def __init__(self,
self.format = format
self.output_dir = output_dir
self.output_code = output_code
self.aggregate_contents = aggregate_contents

def stop(self):
self.is_running = False
Expand Down Expand Up @@ -178,6 +180,7 @@ def run(self):
, language=info.language
, fileName=path
, file_code = self.output_code
, aggregate_contents = self.aggregate_contents
)

print("\n【Over】")
Expand Down Expand Up @@ -208,7 +211,7 @@ def __init__(self
self.segments_path_info = []


def transcribe_file(self, file) -> (TranscriptionInfo, List):
def transcribe_file(self, file) -> (TranscriptionInfo, List): # type: ignore
# try:
# is_av_file = self.try_decode_avFile(file)
# if not is_av_file:
Expand Down Expand Up @@ -388,15 +391,16 @@ def writeSubtitles(outputFileName:str,
format:str,
language:str="",
fileName = "",
file_code = "UTF-8"
file_code = "UTF-8",
aggregate_contents = False
):

file_code = ENCODING_DICT[file_code]

if format == "SRT":
writeSRT(outputFileName, segments, file_code = file_code)
elif format == "TXT":
writeTXT(outputFileName, segments, file_code=file_code)
writeTXT(outputFileName, segments, file_code=file_code, aggregate_contents = aggregate_contents)
elif format == "VTT":
writeVTT(outputFileName, segments,language=language, file_code=file_code)
elif format == "LRC":
Expand Down Expand Up @@ -608,21 +612,31 @@ def writeVTT(fileName:str, segments:List[segment_Transcribe],language:str,file_c

vtt.save(fileName, file_code)

def writeTXT(fileName:str, segments,file_code="utf8"):
def writeTXT(fileName:str, segments, file_code="utf8", aggregate_contents=False):
with codecs.open(fileName, "w", encoding=file_code) as f:
speaker_temp = ""

for segment in segments:

text:str = segment.text
try:
speaker = segment.speaker + ": "
except:
speaker = ""

text = speaker + text

if speaker_temp != speaker and aggregate_contents:
f.write(f"\n{speaker.encode('utf8').decode('utf8')} \n")
speaker_temp = speaker
elif not aggregate_contents:
text = speaker + text

# 重编码为 utf-8
text:str = text.encode("utf8").decode("utf8")
f.write(f"{text} \n")

f.write(f"{text} \n\n")
# text = speaker + text



def writeSRT(fileName:str, segments, file_code="UTF-8"):
index = 1
Expand Down

0 comments on commit 6cf5513

Please sign in to comment.