Skip to content

Commit

Permalink
take care of vocos decoding mel to audio by default
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Sep 10, 2024
1 parent 02d2086 commit aecf03d
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 5 deletions.
39 changes: 35 additions & 4 deletions e2_tts_pytorch/e2_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from torch.nn.utils.rnn import pad_sequence

import torchaudio
from torchaudio.functional import DB_to_amplitude, resample
from torchdiffeq import odeint

import einx
Expand All @@ -41,6 +42,8 @@

from x_transformers.x_transformers import RotaryEmbedding

from vocos import Vocos

pad_sequence = partial(pad_sequence, batch_first = True)

# constants
Expand Down Expand Up @@ -720,11 +723,13 @@ def __init__(
mel_spec_kwargs: dict = dict(),
frac_lengths_mask: tuple[float, float] = (0.7, 1.),
concat_cond = False,
text_num_embeds = None,
text_num_embeds: int | None = None,
tokenizer: (
Literal['char_utf8', 'phoneme_en'] |
Callable[[list[str]], Int['b nt']]
) = 'char_utf8'
) = 'char_utf8',
use_vocos = True,
pretrained_vocos_path = 'charactr/vocos-mel-24khz',
):
super().__init__()

Expand Down Expand Up @@ -793,6 +798,10 @@ def __init__(

self.embed_text = CharacterEmbed(dim_text, num_embeds = text_num_embeds, **char_embed_kwargs)

# default vocos for mel -> audio

self.vocos = Vocos.from_pretrained(pretrained_vocos_path) if use_vocos else None

@property
def device(self):
return next(self.parameters()).device
Expand Down Expand Up @@ -866,7 +875,10 @@ def sample(
steps = 32,
cfg_strength = 1., # they used a classifier free guidance strength of 1.
max_duration = 4096, # in case the duration predictor goes haywire
vocoder: Callable[Float['b d n'], Float['b nw']] | None = None
vocoder: Callable[[Float['b d n']], list[Float['_']]] | None = None
) -> (
Float['b n d'],
list[Float['nw']]
):
self.eval()

Expand Down Expand Up @@ -901,10 +913,13 @@ def sample(
duration = torch.full((batch,), duration, device = device, dtype = torch.long)

elif exists(self.duration_predictor):
duration = self.duration_predictor(cond, text = text, lens = lens).long()
duration = self.duration_predictor(cond, text = text, lens = lens, return_loss = False).long()

duration = torch.maximum(lens + 1, duration) # just add one token so something is generated
duration = duration.clamp(max = max_duration)

assert duration.shape[0] == batch

max_duration = duration.amax()

cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value = 0.)
Expand Down Expand Up @@ -941,10 +956,26 @@ def fn(t, x):

out = torch.where(cond_mask, cond, out)

# take care of transforming mel to audio if `vocoder` is passed in, or if `use_vocos` is turned on

if exists(vocoder):
assert not exists(self.vocos), '`use_vocos` should not be turned on if you are passing in a custom `vocoder` on sampling'
out = rearrange(out, 'b n d -> b d n')
out = vocoder(out)

elif exists(self.vocos):

audio = []
for mel, one_mask in zip(out, mask):
one_out = DB_to_amplitude(mel[one_mask], ref = 1., power = 0.5)

one_out = rearrange(one_out, 'n d -> 1 d n')
one_audio = self.vocos.decode(one_out)
one_audio = rearrange(one_audio, '1 nt -> nt')
audio.append(one_audio)

out = audio

return out

def forward(
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "e2-tts-pytorch"
version = "1.0.0"
version = "1.0.1"
description = "E2-TTS in Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down Expand Up @@ -37,6 +37,7 @@ dependencies = [
'torchdiffeq',
'torchaudio>=2.3.1',
'tqdm>=4.65.0',
'vocos',
'x-transformers>=1.31.14',
]

Expand Down

0 comments on commit aecf03d

Please sign in to comment.