From 662c5c3050e07e4e625cf344468cb5f44bfba5e4 Mon Sep 17 00:00:00 2001 From: NanYANG2015 Date: Tue, 12 Nov 2024 09:34:44 +0800 Subject: [PATCH] fix: tts_mel in streaming inference --- flow_inference.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/flow_inference.py b/flow_inference.py index ebee05a..e9abeab 100644 --- a/flow_inference.py +++ b/flow_inference.py @@ -60,20 +60,23 @@ def token2wav(self, token, uuid, prompt_token=torch.zeros(1, 0, dtype=torch.int3 # mel overlap fade in out if self.mel_overlap_dict[uuid] is not None: tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window) + if finalize is False: + self.mel_overlap_dict[uuid] = tts_mel[:, :, -self.mel_overlap_len:] + tts_mel = tts_mel[:, :, :-self.mel_overlap_len] + # append hift cache if self.hift_cache_dict[uuid] is not None: hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source'] - tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2) - + hift_tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2) else: + hift_tts_mel = tts_mel hift_cache_source = torch.zeros(1, 1, 0) + # _tts_mel=tts_mel.contiguous() # keep overlap mel and hift cache - if finalize is False: - self.mel_overlap_dict[uuid] = tts_mel[:, :, -self.mel_overlap_len:] - tts_mel = tts_mel[:, :, :-self.mel_overlap_len] - tts_speech, tts_source = self.hift.inference(mel=tts_mel, cache_source=hift_cache_source) + tts_speech, tts_source = self.hift.inference(mel=hift_tts_mel, cache_source=hift_cache_source) + if finalize is False: self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:], 'source': tts_source[:, :, -self.source_cache_len:], 'speech': tts_speech[:, -self.source_cache_len:]} @@ -82,7 +85,6 @@ def token2wav(self, token, uuid, prompt_token=torch.zeros(1, 0, dtype=torch.int3 tts_speech = tts_speech[:, :-self.source_cache_len] else: - tts_speech, tts_source = self.hift.inference(mel=tts_mel, cache_source=hift_cache_source) del self.hift_cache_dict[uuid] del self.mel_overlap_dict[uuid] # if uuid in self.hift_cache_dict.keys() and self.hift_cache_dict[uuid] is not None: @@ -137,6 +139,7 @@ def stream_inference(self, token): # Convert Mel spectrogram to audio using HiFi-GAN tts_speech = torch.cat(tts_speechs, dim=-1).cpu() - + tts_mel = torch.cat(tts_mels, dim=-1).cpu() + print(token.size(1), tts_mel.size(-1), tts_speech.size(-1)) + print(int(token.size(1) / self.flow.input_frame_rate * 22050 / 256), tts_mel.size(-1)*256) return tts_speech.cpu() -