diff --git a/e2_tts_pytorch/trainer.py b/e2_tts_pytorch/trainer.py index 6d096fa..f68bffc 100644 --- a/e2_tts_pytorch/trainer.py +++ b/e2_tts_pytorch/trainer.py @@ -112,6 +112,7 @@ def __init__( model: E2TTS, optimizer, num_warmup_steps=20000, + grad_accumulation_steps=1, duration_predictor: DurationPredictor | None = None, checkpoint_path = None, log_file = "logs.txt", @@ -125,6 +126,7 @@ def __init__( self.accelerator = Accelerator( log_with="all", + grad_accumulation_steps=grad_accumulation_steps, **accelerate_kwargs ) @@ -146,6 +148,7 @@ def __init__( self.num_warmup_steps = num_warmup_steps self.checkpoint_path = default(checkpoint_path, 'model.pth') self.mel_spectrogram = MelSpec(sampling_rate=self.target_sample_rate) + self.model, self.optimizer = self.accelerator.prepare( self.model, self.optimizer ) @@ -183,8 +186,7 @@ def load_checkpoint(self): self.scheduler.load_state_dict(checkpoint['scheduler_state_dict']) return checkpoint['step'] - def train(self, train_dataset, epochs, batch_size, grad_accumulation_steps=1, num_workers=12, save_step=1000): - # (todo) gradient accumulation needs to be accounted for + def train(self, train_dataset, epochs, batch_size, num_workers=12, save_step=1000): train_dataloader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate_fn, shuffle=True, num_workers=num_workers, pin_memory=True) total_steps = len(train_dataloader) * epochs @@ -204,23 +206,24 @@ def train(self, train_dataset, epochs, batch_size, grad_accumulation_steps=1, nu epoch_loss = 0.0 for batch in progress_bar: - text_inputs = batch['text'] - mel_spec = rearrange(batch['mel'], 'b d n -> b n d') - mel_lengths = batch["mel_lengths"] - - if self.duration_predictor is not None: - dur_loss = self.duration_predictor(mel_spec, lens=batch.get('durations')) - self.writer.add_scalar('duration loss', dur_loss.item(), global_step) - - loss = self.model(mel_spec, text=text_inputs, lens=mel_lengths) - self.accelerator.backward(loss) - - if self.max_grad_norm > 0: - self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm) - - self.optimizer.step() - self.scheduler.step() - self.optimizer.zero_grad() + with accelerator.accumulate(self.model): + text_inputs = batch['text'] + mel_spec = rearrange(batch['mel'], 'b d n -> b n d') + mel_lengths = batch["mel_lengths"] + + if self.duration_predictor is not None: + dur_loss = self.duration_predictor(mel_spec, lens=batch.get('durations')) + self.writer.add_scalar('duration loss', dur_loss.item(), global_step) + + loss = self.model(mel_spec, text=text_inputs, lens=mel_lengths) + self.accelerator.backward(loss) + + if self.max_grad_norm > 0: + self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm) + + self.optimizer.step() + self.scheduler.step() + self.optimizer.zero_grad() if self.is_main: self.ema_model.update() diff --git a/pyproject.toml b/pyproject.toml index 9a17d33..0160b81 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "e2-tts-pytorch" -version = "0.1.2" +version = "0.1.4" description = "E2-TTS in Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } @@ -24,7 +24,7 @@ classifiers=[ ] dependencies = [ - 'accelerate>=0.32.1', + 'accelerate>=0.33.0', 'einops>=0.8.0', 'einx>=0.3.0', 'ema-pytorch>=0.5.2',