Skip to content

Commit

Permalink
use a bias-less layernorm and remove batchnorm from conformer
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed May 19, 2023
1 parent 0b43955 commit d5c6a07
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 3 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,12 @@ generated_audio = model.generate(seconds = 30, batch_size = 2) # generate 30 se
- [x] make sure grouped rvq is supported. concat embeddings rather than sum across group dimension
- [x] just copy conformer over and redo shaw's relative positional embedding with rotary embedding. nobody uses shaw anymore.
- [x] default flash attention to true
- [x] remove batchnorm, and just use layernorm, but after the swish (as in normformer paper)

- [ ] option to return list of audio files when generating
- [ ] turn it into a command line tool
- [ ] add cross attention and adaptive layernorm conditioning
- [ ] trainer with accelerate
- [ ] add ability to use conformer without batchnorm, substituting with groupnorm + weight standardization

## Citations

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'soundstorm-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.11',
version = '0.0.12',
license='MIT',
description = 'SoundStorm - Efficient Parallel Audio Generation from Google Deepmind, in Pytorch',
author = 'Phil Wang',
Expand Down
13 changes: 12 additions & 1 deletion soundstorm_pytorch/soundstorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,17 @@ def __init__(self, scale, fn):
def forward(self, x, **kwargs):
return self.fn(x, **kwargs) * self.scale

class ChanLayerNorm(nn.Module):
def __init__(self, dim):
super().__init__()
self.gamma = nn.Parameter(torch.ones(1, dim, 1))

def forward(self, x):
eps = 1e-6 if x.dtype == torch.float32 else 1e-4
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) * var.clamp(min = eps).rsqrt() * self.gamma

class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
Expand Down Expand Up @@ -255,8 +266,8 @@ def __init__(
nn.Conv1d(dim, inner_dim * 2, 1),
GLU(dim=1),
DepthWiseConv1d(inner_dim, inner_dim, kernel_size = kernel_size, padding = padding),
nn.BatchNorm1d(inner_dim) if not causal else nn.Identity(),
Swish(),
ChanLayerNorm(inner_dim),
nn.Conv1d(inner_dim, dim, 1),
Rearrange('b c n -> b n c'),
nn.Dropout(dropout)
Expand Down

0 comments on commit d5c6a07

Please sign in to comment.