diff --git a/e2_tts_pytorch/e2_tts.py b/e2_tts_pytorch/e2_tts.py index a2ee996..99c551d 100644 --- a/e2_tts_pytorch/e2_tts.py +++ b/e2_tts_pytorch/e2_tts.py @@ -8,7 +8,7 @@ """ from __future__ import annotations -from typing import Literal +from typing import Literal, List from random import random import torch @@ -21,7 +21,7 @@ from torchdiffeq import odeint import einx -from einops import einsum, rearrange, repeat, reduce +from einops import rearrange, repeat, reduce from einops.layers.torch import Rearrange from x_transformers import ( diff --git a/e2_tts_pytorch/tensor_typing.py b/e2_tts_pytorch/tensor_typing.py index bb00ac7..5985007 100644 --- a/e2_tts_pytorch/tensor_typing.py +++ b/e2_tts_pytorch/tensor_typing.py @@ -3,8 +3,7 @@ from jaxtyping import ( Float, Int, - Bool, - jaxtyped + Bool ) # jaxtyping is a misnomer, works for pytorch diff --git a/e2_tts_pytorch/trainer.py b/e2_tts_pytorch/trainer.py index e5f8340..5d5924d 100644 --- a/e2_tts_pytorch/trainer.py +++ b/e2_tts_pytorch/trainer.py @@ -4,10 +4,7 @@ from tqdm import tqdm import torch -from torch import nn -from torch.nn import Module import torch.nn.functional as F -from torch.nn.utils.rnn import pad_sequence from torch.utils.data import DataLoader, Dataset from torch.utils.tensorboard import SummaryWriter @@ -28,12 +25,13 @@ def collate_fn(batch): mel_specs = [item['mel_spec'].squeeze(0) for item in batch] - max_mel_length = max([spec.shape[-1] for spec in mel_specs]) + mel_lengths = torch.LongTensor([spec.shape[-1] for spec in mel_specs]) + max_mel_length = mel_lengths.amax() padded_mel_specs = [] for spec in mel_specs: padding = (0, max_mel_length - spec.size(-1)) - padded_spec = torch.nn.functional.pad(spec, padding, mode='constant', value=0) + padded_spec = F.pad(spec, padding, value = 0) padded_mel_specs.append(padded_spec) mel_specs = torch.stack(padded_mel_specs) @@ -117,7 +115,7 @@ def __init__( self.duration_predictor = duration_predictor self.optimizer = optimizer self.checkpoint_path = checkpoint_path - self.mel_spectrogram = TorchMelSpectrogram(sampling_rate=self.target_sample_rate) + self.mel_spectrogram = MelSpec(sampling_rate=self.target_sample_rate) self.model, self.optimizer = self.accelerator.prepare( self.model, self.optimizer ) @@ -155,7 +153,6 @@ def train(self, train_dataset, epochs, batch_size, grad_accumulation_steps=1, nu epoch_loss = 0.0 for batch in progress_bar: text_inputs = batch['text'] - text_lengths = batch['text_lengths'] mel_spec = rearrange(batch['mel'], 'b d n -> b n d') mel_lengths = batch["mel_lengths"] diff --git a/pyproject.toml b/pyproject.toml index 6c0bb72..20bd812 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "e2-tts-pytorch" -version = "0.0.19" +version = "0.0.20" description = "E2-TTS in Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }