Skip to content

Commit

Permalink
add what seems to be a new winning technique for higher classifier fr…
Browse files Browse the repository at this point in the history
…ee guidance without oversaturation
  • Loading branch information
lucidrains committed Oct 6, 2024
1 parent 7969f7a commit 1ace1ba
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 2 deletions.
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,12 @@ sampled = e2tts.sample(mel[:, :5], text = text)
url = {https://api.semanticscholar.org/CorpusID:236171087}
}
```

```bibtex
@inproceedings{Sadat2024EliminatingOA,
title = {Eliminating Oversaturation and Artifacts of High Guidance Scales in Diffusion Models},
author = {Seyedmorteza Sadat and Otmar Hilliges and Romann M. Weber},
year = {2024},
url = {https://api.semanticscholar.org/CorpusID:273098845}
}
```
35 changes: 34 additions & 1 deletion e2_tts_pytorch/e2_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,34 @@ def default(v, d):
def divisible_by(num, den):
return (num % den) == 0

def pack_one_with_inverse(x, pattern):
packed, packed_shape = pack([x], pattern)

def inverse(x, inverse_pattern = None):
inverse_pattern = default(inverse_pattern, pattern)
return unpack(x, packed_shape, inverse_pattern)[0]

return packed, inverse

class Identity(Module):
def forward(self, x, **kwargs):
return x

# tensor helpers

def project(x, y):
x, inverse = pack_one_with_inverse(x, 'b *')
y, _ = pack_one_with_inverse(y, 'b *')

dtype = x.dtype
x, y = x.double(), y.double()
unit = F.normalize(y, dim = -1)

parallel = (x * unit).sum(dim = -1, keepdim = True) * unit
orthogonal = x - parallel

return inverse(parallel).to(dtype), inverse(orthogonal).to(dtype)

# simple utf-8 tokenizer, since paper went character based

def list_str_to_tensor(
Expand Down Expand Up @@ -944,6 +968,8 @@ def cfg_transformer_with_pred_head(
self,
*args,
cfg_strength: float = 1.,
remove_parallel_component: bool = True,
keep_parallel_frac: float = 0.,
**kwargs,
):

Expand All @@ -954,7 +980,14 @@ def cfg_transformer_with_pred_head(

null_pred = self.transformer_with_pred_head(*args, drop_text_cond = True, **kwargs)

return pred + (pred - null_pred) * cfg_strength
cfg_update = pred - null_pred

if remove_parallel_component:
# https://arxiv.org/abs/2410.02416
parallel, orthogonal = project(cfg_update, pred)
cfg_update = orthogonal + parallel * keep_parallel_frac

return pred + cfg_update * cfg_strength

@torch.no_grad()
def sample(
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "e2-tts-pytorch"
version = "1.0.6"
version = "1.1.0"
description = "E2-TTS in Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down

0 comments on commit 1ace1ba

Please sign in to comment.