diff --git a/e2_tts_pytorch/e2_tts.py b/e2_tts_pytorch/e2_tts.py index 59d57cb..8ef3c64 100644 --- a/e2_tts_pytorch/e2_tts.py +++ b/e2_tts_pytorch/e2_tts.py @@ -342,6 +342,7 @@ def __init__( self, transformer: dict | Transformer, text_num_embeds = 256, + num_channels = None, mel_spec_kwargs: dict = dict() ): super().__init__() @@ -355,7 +356,7 @@ def __init__( # mel spec self.mel_spec = MelSpec(**mel_spec_kwargs) - self.num_channels = self.mel_spec.n_mel_channels + self.num_channels = default(num_channels, self.mel_spec.n_mel_channels) self.transformer = transformer dim = transformer.dim @@ -445,6 +446,7 @@ def __init__( ), text_num_embeds = 256, cond_drop_prob = 0.25, + num_channels = None, mel_spec_module: Module | None = None, mel_spec_kwargs: dict = dict(), immiscible = False @@ -480,7 +482,7 @@ def __init__( # mel spec self.mel_spec = default(mel_spec_module, MelSpec(**mel_spec_kwargs)) - num_channels = self.mel_spec.n_mel_channels + num_channels = default(num_channels, self.mel_spec.n_mel_channels) self.num_channels = num_channels diff --git a/pyproject.toml b/pyproject.toml index f894ba6..9a17d33 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "e2-tts-pytorch" -version = "0.1.1" +version = "0.1.2" description = "E2-TTS in Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }