Skip to content

Commit

Permalink
cleanup again
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jul 14, 2024
1 parent 0ce88ca commit c750289
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 12 deletions.
4 changes: 2 additions & 2 deletions e2_tts_pytorch/e2_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"""

from __future__ import annotations
from typing import Literal
from typing import Literal, List
from random import random

import torch
Expand All @@ -21,7 +21,7 @@
from torchdiffeq import odeint

import einx
from einops import einsum, rearrange, repeat, reduce
from einops import rearrange, repeat, reduce
from einops.layers.torch import Rearrange

from x_transformers import (
Expand Down
3 changes: 1 addition & 2 deletions e2_tts_pytorch/tensor_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
from jaxtyping import (
Float,
Int,
Bool,
jaxtyped
Bool
)

# jaxtyping is a misnomer, works for pytorch
Expand Down
11 changes: 4 additions & 7 deletions e2_tts_pytorch/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,7 @@
from tqdm import tqdm

import torch
from torch import nn
from torch.nn import Module
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset
from torch.utils.tensorboard import SummaryWriter

Expand All @@ -28,12 +25,13 @@

def collate_fn(batch):
mel_specs = [item['mel_spec'].squeeze(0) for item in batch]
max_mel_length = max([spec.shape[-1] for spec in mel_specs])
mel_lengths = torch.LongTensor([spec.shape[-1] for spec in mel_specs])
max_mel_length = mel_lengths.amax()

padded_mel_specs = []
for spec in mel_specs:
padding = (0, max_mel_length - spec.size(-1))
padded_spec = torch.nn.functional.pad(spec, padding, mode='constant', value=0)
padded_spec = F.pad(spec, padding, value = 0)
padded_mel_specs.append(padded_spec)

mel_specs = torch.stack(padded_mel_specs)
Expand Down Expand Up @@ -117,7 +115,7 @@ def __init__(
self.duration_predictor = duration_predictor
self.optimizer = optimizer
self.checkpoint_path = checkpoint_path
self.mel_spectrogram = TorchMelSpectrogram(sampling_rate=self.target_sample_rate)
self.mel_spectrogram = MelSpec(sampling_rate=self.target_sample_rate)
self.model, self.optimizer = self.accelerator.prepare(
self.model, self.optimizer
)
Expand Down Expand Up @@ -155,7 +153,6 @@ def train(self, train_dataset, epochs, batch_size, grad_accumulation_steps=1, nu
epoch_loss = 0.0
for batch in progress_bar:
text_inputs = batch['text']
text_lengths = batch['text_lengths']
mel_spec = rearrange(batch['mel'], 'b d n -> b n d')
mel_lengths = batch["mel_lengths"]

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "e2-tts-pytorch"
version = "0.0.19"
version = "0.0.20"
description = "E2-TTS in Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down

0 comments on commit c750289

Please sign in to comment.