diff --git a/README.md b/README.md index 5b24773..16914ab 100644 --- a/README.md +++ b/README.md @@ -104,3 +104,12 @@ sampled = e2tts.sample(mel[:, :5], text = text) url = {https://api.semanticscholar.org/CorpusID:236171087} } ``` + +```bibtex +@inproceedings{Sadat2024EliminatingOA, + title = {Eliminating Oversaturation and Artifacts of High Guidance Scales in Diffusion Models}, + author = {Seyedmorteza Sadat and Otmar Hilliges and Romann M. Weber}, + year = {2024}, + url = {https://api.semanticscholar.org/CorpusID:273098845} +} +``` diff --git a/e2_tts_pytorch/e2_tts.py b/e2_tts_pytorch/e2_tts.py index f85adba..67a6671 100644 --- a/e2_tts_pytorch/e2_tts.py +++ b/e2_tts_pytorch/e2_tts.py @@ -74,10 +74,34 @@ def default(v, d): def divisible_by(num, den): return (num % den) == 0 +def pack_one_with_inverse(x, pattern): + packed, packed_shape = pack([x], pattern) + + def inverse(x, inverse_pattern = None): + inverse_pattern = default(inverse_pattern, pattern) + return unpack(x, packed_shape, inverse_pattern)[0] + + return packed, inverse + class Identity(Module): def forward(self, x, **kwargs): return x +# tensor helpers + +def project(x, y): + x, inverse = pack_one_with_inverse(x, 'b *') + y, _ = pack_one_with_inverse(y, 'b *') + + dtype = x.dtype + x, y = x.double(), y.double() + unit = F.normalize(y, dim = -1) + + parallel = (x * unit).sum(dim = -1, keepdim = True) * unit + orthogonal = x - parallel + + return inverse(parallel).to(dtype), inverse(orthogonal).to(dtype) + # simple utf-8 tokenizer, since paper went character based def list_str_to_tensor( @@ -944,6 +968,8 @@ def cfg_transformer_with_pred_head( self, *args, cfg_strength: float = 1., + remove_parallel_component: bool = True, + keep_parallel_frac: float = 0., **kwargs, ): @@ -954,7 +980,14 @@ def cfg_transformer_with_pred_head( null_pred = self.transformer_with_pred_head(*args, drop_text_cond = True, **kwargs) - return pred + (pred - null_pred) * cfg_strength + cfg_update = pred - null_pred + + if remove_parallel_component: + # https://arxiv.org/abs/2410.02416 + parallel, orthogonal = project(cfg_update, pred) + cfg_update = orthogonal + parallel * keep_parallel_frac + + return pred + cfg_update * cfg_strength @torch.no_grad() def sample( diff --git a/pyproject.toml b/pyproject.toml index 8bf7c4b..056f745 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "e2-tts-pytorch" -version = "1.0.6" +version = "1.1.0" description = "E2-TTS in Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }