-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathsynth_multi_lingual.py
100 lines (88 loc) · 4.19 KB
/
synth_multi_lingual.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
# TODO: This code has not been tested
import matplotlib.pyplot as plt
import IPython.display as ipd
import sys
sys.path.append('waveglow/')
import numpy as np
from scipy.io.wavfile import write
import librosa
import torch
from configs.hparams import create_hparams
from train import initiate_model
from waveglow.denoiser import Denoiser
from layers import TacotronSTFT
from data_utils import TextMelLoader, TextMelCollate
from text import cmudict, text_to_sequence
def panner(signal, angle):
angle = np.radians(angle)
left = np.sqrt(2)/2.0 * (np.cos(angle) - np.sin(angle)) * signal
right = np.sqrt(2)/2.0 * (np.cos(angle) + np.sin(angle)) * signal
return np.dstack((left, right))[0]
def plot_mel_f0_alignment(mel_source, mel_outputs_postnet, f0s, alignments, figsize=(16, 16)):
fig, axes = plt.subplots(4, 1, figsize=figsize)
axes = axes.flatten()
axes[0].imshow(mel_source, aspect='auto', origin='bottom', interpolation='none')
axes[1].imshow(mel_outputs_postnet, aspect='auto', origin='bottom', interpolation='none')
axes[2].scatter(range(len(f0s)), f0s, alpha=0.5, color='red', marker='.', s=1)
axes[2].set_xlim(0, len(f0s))
axes[3].imshow(alignments, aspect='auto', origin='bottom', interpolation='none')
axes[0].set_title("Source Mel")
axes[1].set_title("Predicted Mel")
axes[2].set_title("Source pitch contour")
axes[3].set_title("Source rhythm")
plt.tight_layout()
def load_mel(path):
audio, sampling_rate = librosa.core.load(path, sr=hparams.sampling_rate)
audio = torch.from_numpy(audio)
if sampling_rate != hparams.sampling_rate:
raise ValueError("{} SR doesn't match target {} SR".format(
sampling_rate, stft.sampling_rate))
audio_norm = audio / hparams.max_wav_value
audio_norm = audio_norm.unsqueeze(0)
audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False)
melspec = stft.mel_spectrogram(audio_norm)
melspec = melspec.cuda()
return melspec
hparams = create_hparams()
stft = TacotronSTFT(hparams.filter_length, hparams.hop_length, hparams.win_length,
hparams.n_mel_channels, hparams.sampling_rate, hparams.mel_fmin,
hparams.mel_fmax)
checkpoint_path = "/home/administrator/projects/mellotron/out/checkpoint_249500"
mellotron = initiate_model(hparams).cuda().eval()
mellotron.load_state_dict(torch.load(checkpoint_path)['state_dict'])
waveglow_path = 'models/waveglow_256channels_v4.pt'
waveglow = torch.load(waveglow_path)['model'].cuda().eval()
denoiser = Denoiser(waveglow).cuda().eval()
arpabet_dict = cmudict.CMUDict('data/cmu_dictionary')
audio_paths = 'data/examples_filelist.txt'
dataloader = TextMelLoader(audio_paths, hparams)
datacollate = TextMelCollate(1)
file_idx = 0
audio_path, text, sid, lang_code = dataloader.audiopaths_and_text[file_idx]
lang_code = int(lang_code)
lang_code = torch.LongTensor([lang_code]).cuda()
# get audio path, encoded text, pitch contour and mel for gst
text_encoded = torch.LongTensor(text_to_sequence(text, hparams.text_cleaners, lang_code, arpabet_dict))[None, :].cuda()
pitch_contour = dataloader[file_idx][3][None].cuda()
mel = load_mel(audio_path)
print(audio_path, text)
# load source data to obtain rhythm using tacotron 2 as a forced aligner
x, y = mellotron.parse_batch(datacollate([dataloader[file_idx]]))
ipd.Audio(audio_path, rate=hparams.sampling_rate)
with torch.no_grad():
# get rhythm (alignment map) using tacotron 2
_, _, _, rhythm = mellotron.forward(x)
rhythm = rhythm.permute(1, 0, 2)
speaker_id = torch.LongTensor([1]).cuda()
with torch.no_grad():
mel_outputs, mel_outputs_postnet, gate_outputs, _ = mellotron.inference_noattention(
(text_encoded, mel, speaker_id, lang_code, pitch_contour, rhythm))
plot_mel_f0_alignment(x[2].data.cpu().numpy()[0],
mel_outputs_postnet.data.cpu().numpy()[0],
pitch_contour.data.cpu().numpy()[0, 0],
rhythm.data.cpu().numpy()[:, 0].T)
with torch.no_grad():
audio = denoiser(waveglow.infer(mel_outputs_postnet, sigma=0.8), 0.01)[:, 0]
audio = audio.cpu().numpy()
audio = audio / np.max(np.abs(audio))
write("{} {}.wav".format(str(file_idx), speaker_id.item()), hparams.sampling_rate, audio)