From be1b78014eb49c3e3e8d60939622093f6e09395f Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Mon, 9 Sep 2024 09:47:20 -0700 Subject: [PATCH] single file is fine --- e2_tts_pytorch/e2_tts.py | 22 +++++++++++++++------- e2_tts_pytorch/tensor_typing.py | 26 -------------------------- pyproject.toml | 2 +- 3 files changed, 16 insertions(+), 34 deletions(-) delete mode 100644 e2_tts_pytorch/tensor_typing.py diff --git a/e2_tts_pytorch/e2_tts.py b/e2_tts_pytorch/e2_tts.py index 2198c6a..27789b0 100644 --- a/e2_tts_pytorch/e2_tts.py +++ b/e2_tts_pytorch/e2_tts.py @@ -13,12 +13,15 @@ from functools import partial from itertools import zip_longest from collections import namedtuple + from typing import Literal, Callable + +import jaxtyping from beartype import beartype import torch -from torch import nn, tensor, from_numpy import torch.nn.functional as F +from torch import nn, tensor, Tensor, from_numpy from torch.nn import Module, ModuleList, Sequential, Linear from torch.nn.utils.rnn import pad_sequence @@ -38,16 +41,21 @@ from x_transformers.x_transformers import RotaryEmbedding -from e2_tts_pytorch.tensor_typing import ( - Float, - Int, - Bool -) - pad_sequence = partial(pad_sequence, batch_first = True) # constants +class TorchTyping: + def __init__(self, abstract_dtype): + self.abstract_dtype = abstract_dtype + + def __getitem__(self, shapes: str): + return self.abstract_dtype[Tensor, shapes] + +Float = TorchTyping(jaxtyping.Float) +Int = TorchTyping(jaxtyping.Int) +Bool = TorchTyping(jaxtyping.Bool) + E2TTSReturn = namedtuple('E2TTS', ['loss', 'cond', 'pred']) # helpers diff --git a/e2_tts_pytorch/tensor_typing.py b/e2_tts_pytorch/tensor_typing.py deleted file mode 100644 index 5985007..0000000 --- a/e2_tts_pytorch/tensor_typing.py +++ /dev/null @@ -1,26 +0,0 @@ -from torch import Tensor - -from jaxtyping import ( - Float, - Int, - Bool -) - -# jaxtyping is a misnomer, works for pytorch - -class TorchTyping: - def __init__(self, abstract_dtype): - self.abstract_dtype = abstract_dtype - - def __getitem__(self, shapes: str): - return self.abstract_dtype[Tensor, shapes] - -Float = TorchTyping(Float) -Int = TorchTyping(Int) -Bool = TorchTyping(Bool) - -__all__ = [ - Float, - Int, - Bool -] diff --git a/pyproject.toml b/pyproject.toml index 4416547..bf902fb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "e2-tts-pytorch" -version = "0.9.7" +version = "0.9.8" description = "E2-TTS in Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }