Skip to content

Commit

Permalink
address potential issue with ds conv, thanks to @Jiang-Stan #26
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 30, 2023
1 parent 63c7d09 commit 907a320
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 9 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'soundstorm-pytorch',
packages = find_packages(exclude=[]),
version = '0.2.0',
version = '0.3.0',
license='MIT',
description = 'SoundStorm - Efficient Parallel Audio Generation from Google Deepmind, in Pytorch',
author = 'Phil Wang',
Expand Down
34 changes: 26 additions & 8 deletions soundstorm_pytorch/soundstorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pathlib import Path

import torch
from torch.cuda.amp import autocast
from torch import Tensor, nn, einsum
import torch.nn.functional as F

Expand Down Expand Up @@ -124,6 +125,7 @@ def __init__(self, dim, theta = 10000):
def device(self):
return next(self.buffers()).device

@autocast(enabled = False)
def forward(self, seq_len):
t = torch.arange(seq_len, device = self.device).type_as(self.inv_freq)
freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
Expand All @@ -134,6 +136,7 @@ def rotate_half(x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)

@autocast(enabled = False)
def apply_rotary_pos_emb(pos, t):
return (t * pos.cos()) + (rotate_half(t) * pos.sin())

Expand Down Expand Up @@ -216,9 +219,18 @@ def __init__(self, chan_in, chan_out, kernel_size, padding):
self.padding = padding
self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, groups = chan_in)

def forward(self, x):
def forward(self, x, mask = None):
if exists(mask):
mask = mask[..., None]

This comment has been minimized.

Copy link
@chenht2021

chenht2021 Dec 1, 2023

Contributor

A mistake here.
mask = rearrange(mask, "b n -> b 1 n")

This comment has been minimized.

Copy link
@lucidrains

lucidrains Dec 1, 2023

Author Owner

oops, thanks!

x = x.masked_fill(~mask, 0.)

x = F.pad(x, self.padding)
return self.conv(x)
out = self.conv(x)

if exists(mask):
out = out.masked_fill(~mask, 0.)

return out

# attention, feedforward, and conv module

Expand Down Expand Up @@ -333,21 +345,27 @@ def __init__(
inner_dim = dim * expansion_factor
padding = calc_same_padding(kernel_size) if not causal else (kernel_size - 1, 0)

self.net = nn.Sequential(
self.net1 = nn.Sequential(
nn.LayerNorm(dim),
Rearrange('b n c -> b c n'),
nn.Conv1d(dim, inner_dim * 2, 1),
GLU(dim=1),
DepthWiseConv1d(inner_dim, inner_dim, kernel_size = kernel_size, padding = padding),
GLU(dim=1)
)

self.ds_conv = DepthWiseConv1d(inner_dim, inner_dim, kernel_size = kernel_size, padding = padding)

self.net2 = nn.Sequential(
Swish(),
ChanLayerNorm(inner_dim),
nn.Conv1d(inner_dim, dim, 1),
Rearrange('b c n -> b n c'),
nn.Dropout(dropout)
)

def forward(self, x):
return self.net(x)
def forward(self, x, mask = None):
x = self.net1(x)
x = self.ds_conv(x, mask = mask)
return self.net2(x)

# Conformer Block

Expand Down Expand Up @@ -388,7 +406,7 @@ def forward(
):
x = self.ff1(x) + x
x = self.attn(x, mask = mask, rotary_emb = rotary_emb, attn_bias = attn_bias) + x
x = self.conv(x) + x
x = self.conv(x, mask = mask) + x
x = self.ff2(x) + x
x = self.post_norm(x)
return x
Expand Down

0 comments on commit 907a320

Please sign in to comment.