From d5c6a0742324192c6e14d260f8718eb9934b6323 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Fri, 19 May 2023 13:36:28 -0700 Subject: [PATCH] use a bias-less layernorm and remove batchnorm from conformer --- README.md | 2 +- setup.py | 2 +- soundstorm_pytorch/soundstorm.py | 13 ++++++++++++- 3 files changed, 14 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 07f7cb9..eff1359 100644 --- a/README.md +++ b/README.md @@ -105,12 +105,12 @@ generated_audio = model.generate(seconds = 30, batch_size = 2) # generate 30 se - [x] make sure grouped rvq is supported. concat embeddings rather than sum across group dimension - [x] just copy conformer over and redo shaw's relative positional embedding with rotary embedding. nobody uses shaw anymore. - [x] default flash attention to true +- [x] remove batchnorm, and just use layernorm, but after the swish (as in normformer paper) - [ ] option to return list of audio files when generating - [ ] turn it into a command line tool - [ ] add cross attention and adaptive layernorm conditioning - [ ] trainer with accelerate -- [ ] add ability to use conformer without batchnorm, substituting with groupnorm + weight standardization ## Citations diff --git a/setup.py b/setup.py index 4c4564d..84f7c24 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'soundstorm-pytorch', packages = find_packages(exclude=[]), - version = '0.0.11', + version = '0.0.12', license='MIT', description = 'SoundStorm - Efficient Parallel Audio Generation from Google Deepmind, in Pytorch', author = 'Phil Wang', diff --git a/soundstorm_pytorch/soundstorm.py b/soundstorm_pytorch/soundstorm.py index a493e50..84b93a0 100644 --- a/soundstorm_pytorch/soundstorm.py +++ b/soundstorm_pytorch/soundstorm.py @@ -159,6 +159,17 @@ def __init__(self, scale, fn): def forward(self, x, **kwargs): return self.fn(x, **kwargs) * self.scale +class ChanLayerNorm(nn.Module): + def __init__(self, dim): + super().__init__() + self.gamma = nn.Parameter(torch.ones(1, dim, 1)) + + def forward(self, x): + eps = 1e-6 if x.dtype == torch.float32 else 1e-4 + var = torch.var(x, dim = 1, unbiased = False, keepdim = True) + mean = torch.mean(x, dim = 1, keepdim = True) + return (x - mean) * var.clamp(min = eps).rsqrt() * self.gamma + class PreNorm(nn.Module): def __init__(self, dim, fn): super().__init__() @@ -255,8 +266,8 @@ def __init__( nn.Conv1d(dim, inner_dim * 2, 1), GLU(dim=1), DepthWiseConv1d(inner_dim, inner_dim, kernel_size = kernel_size, padding = padding), - nn.BatchNorm1d(inner_dim) if not causal else nn.Identity(), Swish(), + ChanLayerNorm(inner_dim), nn.Conv1d(inner_dim, dim, 1), Rearrange('b c n -> b n c'), nn.Dropout(dropout)