Skip to content

Commit

Permalink
Merge pull request #12 from wetdog/main
Browse files Browse the repository at this point in the history
Fix argument in duration predictor
  • Loading branch information
lucidrains authored Jul 22, 2024
2 parents d8faa5e + 73f5c87 commit 15c4de2
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion e2_tts_pytorch/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def train(self, train_dataset, epochs, batch_size, grad_accumulation_steps=1, nu
mel_lengths = batch["mel_lengths"]

if self.duration_predictor is not None:
dur_loss = self.duration_predictor(mel_spec, target_duration=batch.get('durations'))
dur_loss = self.duration_predictor(mel_spec, lens=batch.get('durations'))
self.writer.add_scalar('duration loss', dur_loss.item(), global_step)

loss = self.model(mel_spec, text=text_inputs, lens=mel_lengths)
Expand Down

0 comments on commit 15c4de2

Please sign in to comment.