Skip to content

Commit

Permalink
gradient accumulation
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jul 23, 2024
1 parent 54d76d6 commit e92797a
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 21 deletions.
41 changes: 22 additions & 19 deletions e2_tts_pytorch/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def __init__(
model: E2TTS,
optimizer,
num_warmup_steps=20000,
grad_accumulation_steps=1,
duration_predictor: DurationPredictor | None = None,
checkpoint_path = None,
log_file = "logs.txt",
Expand All @@ -125,6 +126,7 @@ def __init__(

self.accelerator = Accelerator(
log_with="all",
grad_accumulation_steps=grad_accumulation_steps,
**accelerate_kwargs
)

Expand All @@ -146,6 +148,7 @@ def __init__(
self.num_warmup_steps = num_warmup_steps
self.checkpoint_path = default(checkpoint_path, 'model.pth')
self.mel_spectrogram = MelSpec(sampling_rate=self.target_sample_rate)

self.model, self.optimizer = self.accelerator.prepare(
self.model, self.optimizer
)
Expand Down Expand Up @@ -183,8 +186,7 @@ def load_checkpoint(self):
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
return checkpoint['step']

def train(self, train_dataset, epochs, batch_size, grad_accumulation_steps=1, num_workers=12, save_step=1000):
# (todo) gradient accumulation needs to be accounted for
def train(self, train_dataset, epochs, batch_size, num_workers=12, save_step=1000):

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate_fn, shuffle=True, num_workers=num_workers, pin_memory=True)
total_steps = len(train_dataloader) * epochs
Expand All @@ -204,23 +206,24 @@ def train(self, train_dataset, epochs, batch_size, grad_accumulation_steps=1, nu
epoch_loss = 0.0

for batch in progress_bar:
text_inputs = batch['text']
mel_spec = rearrange(batch['mel'], 'b d n -> b n d')
mel_lengths = batch["mel_lengths"]

if self.duration_predictor is not None:
dur_loss = self.duration_predictor(mel_spec, lens=batch.get('durations'))
self.writer.add_scalar('duration loss', dur_loss.item(), global_step)

loss = self.model(mel_spec, text=text_inputs, lens=mel_lengths)
self.accelerator.backward(loss)

if self.max_grad_norm > 0:
self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)

self.optimizer.step()
self.scheduler.step()
self.optimizer.zero_grad()
with accelerator.accumulate(self.model):
text_inputs = batch['text']
mel_spec = rearrange(batch['mel'], 'b d n -> b n d')
mel_lengths = batch["mel_lengths"]

if self.duration_predictor is not None:
dur_loss = self.duration_predictor(mel_spec, lens=batch.get('durations'))
self.writer.add_scalar('duration loss', dur_loss.item(), global_step)

loss = self.model(mel_spec, text=text_inputs, lens=mel_lengths)
self.accelerator.backward(loss)

if self.max_grad_norm > 0:
self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)

self.optimizer.step()
self.scheduler.step()
self.optimizer.zero_grad()

if self.is_main:
self.ema_model.update()
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "e2-tts-pytorch"
version = "0.1.2"
version = "0.1.4"
description = "E2-TTS in Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand All @@ -24,7 +24,7 @@ classifiers=[
]

dependencies = [
'accelerate>=0.32.1',
'accelerate>=0.33.0',
'einops>=0.8.0',
'einx>=0.3.0',
'ema-pytorch>=0.5.2',
Expand Down

0 comments on commit e92797a

Please sign in to comment.