From 8cea21c6e1007da38f27d2dec8aeb9af8eef6671 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Sun, 13 Oct 2024 11:18:36 -0700 Subject: [PATCH] need to unwrap --- e2_tts_pytorch/trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/e2_tts_pytorch/trainer.py b/e2_tts_pytorch/trainer.py index 4b1794f..a1ce3dd 100644 --- a/e2_tts_pytorch/trainer.py +++ b/e2_tts_pytorch/trainer.py @@ -244,7 +244,7 @@ def train(self, train_dataset, epochs, batch_size, num_workers=12, save_step=100 velocity_consistency_model = None if self.need_velocity_consistent_loss and self.ema_model.initted: - velocity_consistency_model = self.ema_model.ema_model + velocity_consistency_model = self.accelerator.unwrap_model(self.ema_model).ema_model loss, cond, pred, pred_data = self.model( mel_spec, @@ -262,7 +262,7 @@ def train(self, train_dataset, epochs, batch_size, num_workers=12, save_step=100 self.scheduler.step() self.optimizer.zero_grad() - self.ema_model.update() + self.accelerator.unwrap_model(self.ema_model).update() if self.accelerator.is_local_main_process: logger.info(f"step {global_step+1}: loss = {loss.item():.4f}")