Skip to content

Commit

Permalink
fix torch autocast warning and cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Sep 15, 2024
1 parent a559705 commit 602b616
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 31 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'soundstorm-pytorch',
packages = find_packages(exclude=[]),
version = '0.4.10',
version = '0.4.11',
license='MIT',
description = 'SoundStorm - Efficient Parallel Audio Generation from Google Deepmind, in Pytorch',
author = 'Phil Wang',
Expand Down
57 changes: 30 additions & 27 deletions soundstorm_pytorch/soundstorm.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import math
from random import random, randrange
from functools import wraps
Expand All @@ -6,16 +8,17 @@
from pathlib import Path

import torch
from torch.cuda.amp import autocast
from torch.amp import autocast
from torch import Tensor, nn, einsum
from torch.nn import Module, ModuleList
import torch.nn.functional as F

from einops import rearrange, reduce, repeat, unpack, pack
from einops.layers.torch import Rearrange, EinMix

from beartype import beartype
from beartype.door import is_bearable
from beartype.typing import Union, Dict, Optional, List, Optional, Any
from beartype.typing import Any

from soundstorm_pytorch.attend import Attend

Expand Down Expand Up @@ -85,7 +88,7 @@ def coin_flip():
@beartype
def get_mask_subset_prob(
mask: Tensor,
prob: Union[float, Tensor],
prob: float | Tensor,
min_mask: int = 0,
min_keep_mask: int = 0
):
Expand Down Expand Up @@ -124,7 +127,7 @@ def cosine_schedule(t):

# rotary embedding

class RotaryEmbedding(nn.Module):
class RotaryEmbedding(Module):
def __init__(self, dim, theta = 10000):
super().__init__()
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
Expand All @@ -134,7 +137,7 @@ def __init__(self, dim, theta = 10000):
def device(self):
return next(self.buffers()).device

@autocast(enabled = False)
@autocast('cuda', enabled = False)
def forward(self, seq_len):
t = torch.arange(seq_len, device = self.device).type_as(self.inv_freq)
freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
Expand All @@ -145,13 +148,13 @@ def rotate_half(x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)

@autocast(enabled = False)
@autocast('cuda', enabled = False)
def apply_rotary_pos_emb(pos, t):
return (t * pos.cos()) + (rotate_half(t) * pos.sin())

# t5 relative positional bias

class T5RelativePositionBias(nn.Module):
class T5RelativePositionBias(Module):
def __init__(
self,
scale = 1.,
Expand Down Expand Up @@ -209,11 +212,11 @@ def forward(self, n):

# conformer

class Swish(nn.Module):
class Swish(Module):
def forward(self, x):
return x * x.sigmoid()

class GLU(nn.Module):
class GLU(Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
Expand All @@ -222,7 +225,7 @@ def forward(self, x):
out, gate = x.chunk(2, dim=self.dim)
return out * gate.sigmoid()

class DepthWiseConv1d(nn.Module):
class DepthWiseConv1d(Module):
def __init__(self, chan_in, chan_out, kernel_size, padding):
super().__init__()
self.padding = padding
Expand All @@ -243,7 +246,7 @@ def forward(self, x, mask = None):

# attention, feedforward, and conv module

class Scale(nn.Module):
class Scale(Module):
def __init__(self, scale, fn):
super().__init__()
self.fn = fn
Expand All @@ -252,7 +255,7 @@ def __init__(self, scale, fn):
def forward(self, x, **kwargs):
return self.fn(x, **kwargs) * self.scale

class ChanLayerNorm(nn.Module):
class ChanLayerNorm(Module):
def __init__(self, dim):
super().__init__()
self.gamma = nn.Parameter(torch.ones(1, dim, 1))
Expand All @@ -263,7 +266,7 @@ def forward(self, x):
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) * var.clamp(min = eps).rsqrt() * self.gamma

class PreNorm(nn.Module):
class PreNorm(Module):
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
Expand All @@ -273,7 +276,7 @@ def forward(self, x, **kwargs):
x = self.norm(x)
return self.fn(x, **kwargs)

class Attention(nn.Module):
class Attention(Module):
def __init__(
self,
dim,
Expand Down Expand Up @@ -321,7 +324,7 @@ def forward(
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)

class FeedForward(nn.Module):
class FeedForward(Module):
def __init__(
self,
dim,
Expand All @@ -340,7 +343,7 @@ def __init__(
def forward(self, x):
return self.net(x)

class ConformerConvModule(nn.Module):
class ConformerConvModule(Module):
def __init__(
self,
dim,
Expand Down Expand Up @@ -378,7 +381,7 @@ def forward(self, x, mask = None):

# Conformer Block

class ConformerBlock(nn.Module):
class ConformerBlock(Module):
def __init__(
self,
*,
Expand Down Expand Up @@ -430,7 +433,7 @@ def forward(

# Conformer

class Conformer(nn.Module):
class Conformer(Module):
def __init__(
self,
dim,
Expand All @@ -454,7 +457,7 @@ def __init__(
assert not (t5_rel_pos_bias and attn_flash), 'flash attention is not compatible with learned bias'

self.dim = dim
self.layers = nn.ModuleList([])
self.layers = ModuleList([])

self.rotary_emb = RotaryEmbedding(dim_head) if not t5_rel_pos_bias else None
self.rel_pos_bias = T5RelativePositionBias(dim_head ** 0.5, heads = heads) if t5_rel_pos_bias else None
Expand Down Expand Up @@ -493,15 +496,15 @@ def forward(self, x, mask = None):

# conformer with sum reduction across quantized tokens at the beginning, along with heads

class ConformerWrapper(nn.Module):
class ConformerWrapper(Module):

@beartype
def __init__(
self,
*,
codebook_size,
num_quantizers,
conformer: Union[Conformer, Dict[str, Any]],
conformer: Conformer | dict[str, Any],
grouped_quantizers = 1
):
super().__init__()
Expand Down Expand Up @@ -614,7 +617,7 @@ def forward(

# for main logits as well as self token critic

class LogitHead(nn.Module):
class LogitHead(Module):
def __init__(
self,
net: ConformerWrapper,
Expand All @@ -633,16 +636,16 @@ def forward(self, x):

LossBreakdown = namedtuple('LossBreakdown', ['generator_loss', 'critic_loss'])

class SoundStorm(nn.Module):
class SoundStorm(Module):

@beartype
def __init__(
self,
net: ConformerWrapper,
*,
soundstream: Optional[SoundStream] = None,
spear_tts_text_to_semantic: Optional[TextToSemantic] = None,
wav2vec: Optional[Union[HubertWithKmeans, FairseqVQWav2Vec]] = None,
soundstream: SoundStream | None = None,
spear_tts_text_to_semantic: TextToSemantic | None = None,
wav2vec: HubertWithKmeans | FairseqVQWav2Vec | None = None,
steps = 18,
self_cond = False,
self_cond_train_prob = 0.75,
Expand Down Expand Up @@ -794,7 +797,7 @@ def generate(
num_latents = None,
*,
mask = None,
texts: Optional[Union[List[str], Tensor]] = None,
texts: list[str] | Tensor | None = None,
cond_semantic_token_ids = None,
prompt_acoustic_token_ids = None,
seconds = None,
Expand Down
8 changes: 5 additions & 3 deletions soundstorm_pytorch/trainer.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from __future__ import annotations

from pathlib import Path
import re
from shutil import rmtree

from beartype import beartype
from beartype.typing import Optional

import torch
from torch import nn
from torch.nn import Module
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import Dataset, random_split

Expand Down Expand Up @@ -58,7 +60,7 @@ def checkpoint_num_steps(checkpoint_path):
return int(results[-1])


class SoundStormTrainer(nn.Module):
class SoundStormTrainer(Module):
@beartype
def __init__(
self,
Expand All @@ -67,7 +69,7 @@ def __init__(
num_train_steps,
num_warmup_steps,
batch_size,
dataset: Optional[Dataset] = None,
dataset: Dataset | None = None,
only_train_generator = False,
only_train_critic = False,
lr = 3e-4,
Expand Down

0 comments on commit 602b616

Please sign in to comment.