diff --git a/e2_tts_pytorch/trainer.py b/e2_tts_pytorch/trainer.py index f73f034..342daab 100644 --- a/e2_tts_pytorch/trainer.py +++ b/e2_tts_pytorch/trainer.py @@ -14,6 +14,8 @@ from einops import rearrange from accelerate import Accelerator +from ema_pytorch import EMA + from loguru import logger from e2_tts_pytorch.e2_tts import ( @@ -116,7 +118,8 @@ def __init__( max_grad_norm = 1.0, sample_rate = 22050, tensorboard_log_dir = 'runs/e2_tts_experiment', - accelerate_kwargs: dict = dict() + accelerate_kwargs: dict = dict(), + ema_kwargs: dict = dict() ): logger.add(log_file) @@ -126,7 +129,12 @@ def __init__( ) self.target_sample_rate = sample_rate + self.model = model + + if self.is_main: + self.ema_model = EMA(model, **ema_kwargs) + self.duration_predictor = duration_predictor self.optimizer = optimizer self.num_warmup_steps = num_warmup_steps @@ -139,11 +147,16 @@ def __init__( self.writer = SummaryWriter(log_dir=tensorboard_log_dir) + @property + def is_main(self): + return self.accelerator.is_main_process + def save_checkpoint(self, step, finetune=False): checkpoint = dict( model_state_dict = self.accelerator.unwrap_model(self.model).state_dict(), - optimizer_state_dict = self.optimizer.state_dict(), + ema_model_state_dict = self.ema_model.state_dict(), + optimizer_state_dict = self.accelerator.unwrap_model(self.optimizer).state_dict(), scheduler_state_dict = self.scheduler.state_dict(), step = step ) @@ -156,7 +169,10 @@ def load_checkpoint(self): checkpoint = torch.load(self.checkpoint_path) self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint['model_state_dict']) - self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + self.accelerator.unwrap_model(self.optimizer).load_state_dict(checkpoint['optimizer_state_dict']) + + self.ema_model.load_state_dict(checkpoint['ema_model_state_dict']) + if self.scheduler: self.scheduler.load_state_dict(checkpoint['scheduler_state_dict']) return checkpoint['step'] @@ -199,7 +215,10 @@ def train(self, train_dataset, epochs, batch_size, grad_accumulation_steps=1, nu self.optimizer.step() self.scheduler.step() self.optimizer.zero_grad() - + + if self.is_main: + self.ema_model.update() + if self.accelerator.is_local_main_process: logger.info(f"step {global_step+1}: loss = {loss.item():.4f}") self.writer.add_scalar('loss', loss.item(), global_step) diff --git a/pyproject.toml b/pyproject.toml index 1d3e976..1a05baa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "e2-tts-pytorch" -version = "0.0.39" +version = "0.0.40" description = "E2-TTS in Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } @@ -27,6 +27,7 @@ dependencies = [ 'accelerate>=0.32.1', 'einops>=0.8.0', 'einx>=0.3.0', + 'ema-pytorch>=0.5.2', 'gateloop-transformer>=0.2.2', 'jaxtyping', 'loguru',