Skip to content

Commit

Permalink
restrict direction loss to random span masked region
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 5, 2024
1 parent b3c6014 commit a255741
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 3 deletions.
15 changes: 13 additions & 2 deletions e2_tts_pytorch/e2_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

import einx
from einops.layers.torch import Rearrange
from einops import rearrange, repeat, reduce, pack, unpack
from einops import rearrange, repeat, reduce, einsum, pack, unpack

from x_transformers import (
Attention,
Expand Down Expand Up @@ -75,6 +75,9 @@ def exists(v):
def default(v, d):
return v if exists(v) else d

def l2norm(t):
return F.normalize(t, dim = -1)

def divisible_by(num, den):
return (num % den) == 0

Expand Down Expand Up @@ -106,6 +109,12 @@ def project(x, y):

return inverse(parallel).to(dtype), inverse(orthogonal).to(dtype)

# losses

def calc_direction_loss(pred, target):
# make direction loss at most 1.
return 0.5 * (1. - einsum(l2norm(pred), l2norm(target), '... d, ... d -> ...'))

# simple utf-8 tokenizer, since paper went character based

def list_str_to_tensor(
Expand Down Expand Up @@ -1329,7 +1338,9 @@ def forward(
direction_loss = self.zero

if self.add_direction_loss:
direction_loss = ((1. - F.cosine_similarity(pred, flow, dim = 1)) / 2).mean() # make direction loss at most 1.
direction_loss = calc_direction_loss(pred, flow)

direction_loss = direction_loss[rand_span_mask].mean()

# total loss and get breakdown

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "e2-tts-pytorch"
version = "1.5.0"
version = "1.5.1"
description = "E2-TTS in Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down

0 comments on commit a255741

Please sign in to comment.