From 602b61694d2460f1b162ef442c6fe856a525fdb5 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Sat, 14 Sep 2024 18:35:28 -0700 Subject: [PATCH] fix torch autocast warning and cleanup --- setup.py | 2 +- soundstorm_pytorch/soundstorm.py | 57 +++++++++++++++++--------------- soundstorm_pytorch/trainer.py | 8 +++-- 3 files changed, 36 insertions(+), 31 deletions(-) diff --git a/setup.py b/setup.py index 92a4bb7..0093338 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'soundstorm-pytorch', packages = find_packages(exclude=[]), - version = '0.4.10', + version = '0.4.11', license='MIT', description = 'SoundStorm - Efficient Parallel Audio Generation from Google Deepmind, in Pytorch', author = 'Phil Wang', diff --git a/soundstorm_pytorch/soundstorm.py b/soundstorm_pytorch/soundstorm.py index b7aaed3..68064b2 100644 --- a/soundstorm_pytorch/soundstorm.py +++ b/soundstorm_pytorch/soundstorm.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import math from random import random, randrange from functools import wraps @@ -6,8 +8,9 @@ from pathlib import Path import torch -from torch.cuda.amp import autocast +from torch.amp import autocast from torch import Tensor, nn, einsum +from torch.nn import Module, ModuleList import torch.nn.functional as F from einops import rearrange, reduce, repeat, unpack, pack @@ -15,7 +18,7 @@ from beartype import beartype from beartype.door import is_bearable -from beartype.typing import Union, Dict, Optional, List, Optional, Any +from beartype.typing import Any from soundstorm_pytorch.attend import Attend @@ -85,7 +88,7 @@ def coin_flip(): @beartype def get_mask_subset_prob( mask: Tensor, - prob: Union[float, Tensor], + prob: float | Tensor, min_mask: int = 0, min_keep_mask: int = 0 ): @@ -124,7 +127,7 @@ def cosine_schedule(t): # rotary embedding -class RotaryEmbedding(nn.Module): +class RotaryEmbedding(Module): def __init__(self, dim, theta = 10000): super().__init__() inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) @@ -134,7 +137,7 @@ def __init__(self, dim, theta = 10000): def device(self): return next(self.buffers()).device - @autocast(enabled = False) + @autocast('cuda', enabled = False) def forward(self, seq_len): t = torch.arange(seq_len, device = self.device).type_as(self.inv_freq) freqs = torch.einsum('i , j -> i j', t, self.inv_freq) @@ -145,13 +148,13 @@ def rotate_half(x): x1, x2 = x.chunk(2, dim=-1) return torch.cat((-x2, x1), dim=-1) -@autocast(enabled = False) +@autocast('cuda', enabled = False) def apply_rotary_pos_emb(pos, t): return (t * pos.cos()) + (rotate_half(t) * pos.sin()) # t5 relative positional bias -class T5RelativePositionBias(nn.Module): +class T5RelativePositionBias(Module): def __init__( self, scale = 1., @@ -209,11 +212,11 @@ def forward(self, n): # conformer -class Swish(nn.Module): +class Swish(Module): def forward(self, x): return x * x.sigmoid() -class GLU(nn.Module): +class GLU(Module): def __init__(self, dim): super().__init__() self.dim = dim @@ -222,7 +225,7 @@ def forward(self, x): out, gate = x.chunk(2, dim=self.dim) return out * gate.sigmoid() -class DepthWiseConv1d(nn.Module): +class DepthWiseConv1d(Module): def __init__(self, chan_in, chan_out, kernel_size, padding): super().__init__() self.padding = padding @@ -243,7 +246,7 @@ def forward(self, x, mask = None): # attention, feedforward, and conv module -class Scale(nn.Module): +class Scale(Module): def __init__(self, scale, fn): super().__init__() self.fn = fn @@ -252,7 +255,7 @@ def __init__(self, scale, fn): def forward(self, x, **kwargs): return self.fn(x, **kwargs) * self.scale -class ChanLayerNorm(nn.Module): +class ChanLayerNorm(Module): def __init__(self, dim): super().__init__() self.gamma = nn.Parameter(torch.ones(1, dim, 1)) @@ -263,7 +266,7 @@ def forward(self, x): mean = torch.mean(x, dim = 1, keepdim = True) return (x - mean) * var.clamp(min = eps).rsqrt() * self.gamma -class PreNorm(nn.Module): +class PreNorm(Module): def __init__(self, dim, fn): super().__init__() self.fn = fn @@ -273,7 +276,7 @@ def forward(self, x, **kwargs): x = self.norm(x) return self.fn(x, **kwargs) -class Attention(nn.Module): +class Attention(Module): def __init__( self, dim, @@ -321,7 +324,7 @@ def forward( out = rearrange(out, 'b h n d -> b n (h d)') return self.to_out(out) -class FeedForward(nn.Module): +class FeedForward(Module): def __init__( self, dim, @@ -340,7 +343,7 @@ def __init__( def forward(self, x): return self.net(x) -class ConformerConvModule(nn.Module): +class ConformerConvModule(Module): def __init__( self, dim, @@ -378,7 +381,7 @@ def forward(self, x, mask = None): # Conformer Block -class ConformerBlock(nn.Module): +class ConformerBlock(Module): def __init__( self, *, @@ -430,7 +433,7 @@ def forward( # Conformer -class Conformer(nn.Module): +class Conformer(Module): def __init__( self, dim, @@ -454,7 +457,7 @@ def __init__( assert not (t5_rel_pos_bias and attn_flash), 'flash attention is not compatible with learned bias' self.dim = dim - self.layers = nn.ModuleList([]) + self.layers = ModuleList([]) self.rotary_emb = RotaryEmbedding(dim_head) if not t5_rel_pos_bias else None self.rel_pos_bias = T5RelativePositionBias(dim_head ** 0.5, heads = heads) if t5_rel_pos_bias else None @@ -493,7 +496,7 @@ def forward(self, x, mask = None): # conformer with sum reduction across quantized tokens at the beginning, along with heads -class ConformerWrapper(nn.Module): +class ConformerWrapper(Module): @beartype def __init__( @@ -501,7 +504,7 @@ def __init__( *, codebook_size, num_quantizers, - conformer: Union[Conformer, Dict[str, Any]], + conformer: Conformer | dict[str, Any], grouped_quantizers = 1 ): super().__init__() @@ -614,7 +617,7 @@ def forward( # for main logits as well as self token critic -class LogitHead(nn.Module): +class LogitHead(Module): def __init__( self, net: ConformerWrapper, @@ -633,16 +636,16 @@ def forward(self, x): LossBreakdown = namedtuple('LossBreakdown', ['generator_loss', 'critic_loss']) -class SoundStorm(nn.Module): +class SoundStorm(Module): @beartype def __init__( self, net: ConformerWrapper, *, - soundstream: Optional[SoundStream] = None, - spear_tts_text_to_semantic: Optional[TextToSemantic] = None, - wav2vec: Optional[Union[HubertWithKmeans, FairseqVQWav2Vec]] = None, + soundstream: SoundStream | None = None, + spear_tts_text_to_semantic: TextToSemantic | None = None, + wav2vec: HubertWithKmeans | FairseqVQWav2Vec | None = None, steps = 18, self_cond = False, self_cond_train_prob = 0.75, @@ -794,7 +797,7 @@ def generate( num_latents = None, *, mask = None, - texts: Optional[Union[List[str], Tensor]] = None, + texts: list[str] | Tensor | None = None, cond_semantic_token_ids = None, prompt_acoustic_token_ids = None, seconds = None, diff --git a/soundstorm_pytorch/trainer.py b/soundstorm_pytorch/trainer.py index 7809fc2..63bd955 100644 --- a/soundstorm_pytorch/trainer.py +++ b/soundstorm_pytorch/trainer.py @@ -1,12 +1,14 @@ +from __future__ import annotations + from pathlib import Path import re from shutil import rmtree from beartype import beartype -from beartype.typing import Optional import torch from torch import nn +from torch.nn import Module from torch.optim.lr_scheduler import CosineAnnealingLR from torch.utils.data import Dataset, random_split @@ -58,7 +60,7 @@ def checkpoint_num_steps(checkpoint_path): return int(results[-1]) -class SoundStormTrainer(nn.Module): +class SoundStormTrainer(Module): @beartype def __init__( self, @@ -67,7 +69,7 @@ def __init__( num_train_steps, num_warmup_steps, batch_size, - dataset: Optional[Dataset] = None, + dataset: Dataset | None = None, only_train_generator = False, only_train_critic = False, lr = 3e-4,