Skip to content

Commit

Permalink
add the direction loss for flow matching, claimed to accelerate train…
Browse files Browse the repository at this point in the history
…ing from a research group out of Wuhan China
  • Loading branch information
lucidrains committed Nov 5, 2024
1 parent d3d5e38 commit b3c6014
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 4 deletions.
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,12 @@ sampled = e2tts.sample(mel[:, :5], text = text)
url = {https://api.semanticscholar.org/CorpusID:273532030}
}
```

```bibtex
@inproceedings{Yao2024FasterDiTTF,
title = {FasterDiT: Towards Faster Diffusion Transformers Training without Architecture Modification},
author = {Jingfeng Yao and Wang Cheng and Wenyu Liu and Xinggang Wang},
year = {2024},
url = {https://api.semanticscholar.org/CorpusID:273346237}
}
```
25 changes: 22 additions & 3 deletions e2_tts_pytorch/e2_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __getitem__(self, shapes: str):

# named tuples

LossBreakdown = namedtuple('LossBreakdown', ['flow', 'velocity_consistency'])
LossBreakdown = namedtuple('LossBreakdown', ['flow', 'velocity_consistency', 'direction'])

E2TTSReturn = namedtuple('E2TTS', ['loss', 'cond', 'pred_flow', 'pred_data', 'loss_breakdown'])

Expand Down Expand Up @@ -909,7 +909,9 @@ def __init__(
use_vocos = True,
pretrained_vocos_path = 'charactr/vocos-mel-24khz',
sampling_rate: int | None = None,
add_direction_loss = False,
velocity_consistency_weight = 0.,
direction_loss_weight = 1.
):
super().__init__()

Expand Down Expand Up @@ -986,6 +988,11 @@ def __init__(
self.register_buffer('zero', torch.tensor(0.), persistent = False)
self.velocity_consistency_weight = velocity_consistency_weight

# direction loss for flow matching

self.add_direction_loss = add_direction_loss
self.direction_loss_weight = direction_loss_weight

# default vocos for mel -> audio

self.vocos = Vocos.from_pretrained(pretrained_vocos_path) if use_vocos else None
Expand Down Expand Up @@ -1317,10 +1324,22 @@ def forward(

loss = loss[rand_span_mask].mean()

# maybe direction loss

direction_loss = self.zero

if self.add_direction_loss:
direction_loss = ((1. - F.cosine_similarity(pred, flow, dim = 1)) / 2).mean() # make direction loss at most 1.

# total loss and get breakdown

total_loss = loss + velocity_loss * self.velocity_consistency_weight
breakdown = LossBreakdown(loss, velocity_loss)
total_loss = (
loss +
direction_loss * self.direction_loss_weight +
velocity_loss * self.velocity_consistency_weight
)

breakdown = LossBreakdown(loss, velocity_loss, direction_loss)

# return total loss and bunch of intermediates

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "e2-tts-pytorch"
version = "1.4.1"
version = "1.5.0"
description = "E2-TTS in Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down

0 comments on commit b3c6014

Please sign in to comment.