Skip to content

Commit

Permalink
single file is fine
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Sep 9, 2024
1 parent 4dc0c4c commit be1b780
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 34 deletions.
22 changes: 15 additions & 7 deletions e2_tts_pytorch/e2_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,15 @@
from functools import partial
from itertools import zip_longest
from collections import namedtuple

from typing import Literal, Callable

import jaxtyping
from beartype import beartype

import torch
from torch import nn, tensor, from_numpy
import torch.nn.functional as F
from torch import nn, tensor, Tensor, from_numpy
from torch.nn import Module, ModuleList, Sequential, Linear
from torch.nn.utils.rnn import pad_sequence

Expand All @@ -38,16 +41,21 @@

from x_transformers.x_transformers import RotaryEmbedding

from e2_tts_pytorch.tensor_typing import (
Float,
Int,
Bool
)

pad_sequence = partial(pad_sequence, batch_first = True)

# constants

class TorchTyping:
def __init__(self, abstract_dtype):
self.abstract_dtype = abstract_dtype

def __getitem__(self, shapes: str):
return self.abstract_dtype[Tensor, shapes]

Float = TorchTyping(jaxtyping.Float)
Int = TorchTyping(jaxtyping.Int)
Bool = TorchTyping(jaxtyping.Bool)

E2TTSReturn = namedtuple('E2TTS', ['loss', 'cond', 'pred'])

# helpers
Expand Down
26 changes: 0 additions & 26 deletions e2_tts_pytorch/tensor_typing.py

This file was deleted.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "e2-tts-pytorch"
version = "0.9.7"
version = "0.9.8"
description = "E2-TTS in Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down

0 comments on commit be1b780

Please sign in to comment.