Skip to content

Commit

Permalink
need to unwrap
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Oct 13, 2024
1 parent 9599691 commit 8cea21c
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions e2_tts_pytorch/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def train(self, train_dataset, epochs, batch_size, num_workers=12, save_step=100

velocity_consistency_model = None
if self.need_velocity_consistent_loss and self.ema_model.initted:
velocity_consistency_model = self.ema_model.ema_model
velocity_consistency_model = self.accelerator.unwrap_model(self.ema_model).ema_model

loss, cond, pred, pred_data = self.model(
mel_spec,
Expand All @@ -262,7 +262,7 @@ def train(self, train_dataset, epochs, batch_size, num_workers=12, save_step=100
self.scheduler.step()
self.optimizer.zero_grad()

self.ema_model.update()
self.accelerator.unwrap_model(self.ema_model).update()

if self.accelerator.is_local_main_process:
logger.info(f"step {global_step+1}: loss = {loss.item():.4f}")
Expand Down

0 comments on commit 8cea21c

Please sign in to comment.