Skip to content

Commit

Permalink
throw in ema
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jul 22, 2024
1 parent 04b9c1c commit d8faa5e
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 5 deletions.
27 changes: 23 additions & 4 deletions e2_tts_pytorch/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from einops import rearrange
from accelerate import Accelerator

from ema_pytorch import EMA

from loguru import logger

from e2_tts_pytorch.e2_tts import (
Expand Down Expand Up @@ -116,7 +118,8 @@ def __init__(
max_grad_norm = 1.0,
sample_rate = 22050,
tensorboard_log_dir = 'runs/e2_tts_experiment',
accelerate_kwargs: dict = dict()
accelerate_kwargs: dict = dict(),
ema_kwargs: dict = dict()
):
logger.add(log_file)

Expand All @@ -126,7 +129,12 @@ def __init__(
)

self.target_sample_rate = sample_rate

self.model = model

if self.is_main:
self.ema_model = EMA(model, **ema_kwargs)

self.duration_predictor = duration_predictor
self.optimizer = optimizer
self.num_warmup_steps = num_warmup_steps
Expand All @@ -139,11 +147,16 @@ def __init__(

self.writer = SummaryWriter(log_dir=tensorboard_log_dir)

@property
def is_main(self):
return self.accelerator.is_main_process

def save_checkpoint(self, step, finetune=False):

checkpoint = dict(
model_state_dict = self.accelerator.unwrap_model(self.model).state_dict(),
optimizer_state_dict = self.optimizer.state_dict(),
ema_model_state_dict = self.ema_model.state_dict(),
optimizer_state_dict = self.accelerator.unwrap_model(self.optimizer).state_dict(),
scheduler_state_dict = self.scheduler.state_dict(),
step = step
)
Expand All @@ -156,7 +169,10 @@ def load_checkpoint(self):

checkpoint = torch.load(self.checkpoint_path)
self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
self.accelerator.unwrap_model(self.optimizer).load_state_dict(checkpoint['optimizer_state_dict'])

self.ema_model.load_state_dict(checkpoint['ema_model_state_dict'])

if self.scheduler:
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
return checkpoint['step']
Expand Down Expand Up @@ -199,7 +215,10 @@ def train(self, train_dataset, epochs, batch_size, grad_accumulation_steps=1, nu
self.optimizer.step()
self.scheduler.step()
self.optimizer.zero_grad()


if self.is_main:
self.ema_model.update()

if self.accelerator.is_local_main_process:
logger.info(f"step {global_step+1}: loss = {loss.item():.4f}")
self.writer.add_scalar('loss', loss.item(), global_step)
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "e2-tts-pytorch"
version = "0.0.39"
version = "0.0.40"
description = "E2-TTS in Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand All @@ -27,6 +27,7 @@ dependencies = [
'accelerate>=0.32.1',
'einops>=0.8.0',
'einx>=0.3.0',
'ema-pytorch>=0.5.2',
'gateloop-transformer>=0.2.2',
'jaxtyping',
'loguru',
Expand Down

0 comments on commit d8faa5e

Please sign in to comment.