Skip to content

Commit

Permalink
make it work
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed May 18, 2023
1 parent 48e8aad commit 84d735a
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 5 deletions.
49 changes: 48 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,46 @@ loss.backward()
generated = model.generate(1024, batch_size = 2) # (2, 1024)
```

To directly train on raw audio, you need to pass in your pretrained `SoundStream` into `SoundStorm`. You can train your own `SoundStream` at <a href="https://github.com/lucidrains/audiolm-pytorch#soundstream--encodec">audiolm-pytorch</a>.

```python
import torch
from soundstorm_pytorch import SoundStorm, ConformerWrapper, Conformer, SoundStream

conformer = ConformerWrapper(
codebook_size = 1024,
num_quantizers = 4,
conformer = dict(
dim = 512,
depth = 2
),
)

soundstream = SoundStream(
codebook_size = 1024,
rq_num_quantizers = 4,
attn_window_size = 128,
attn_depth = 2
)

model = SoundStorm(
conformer,
soundstream = soundstream # pass in the soundstream
)

audio = torch.randn(2, 10080) # now you have a raw audio that you directly pass into the model

loss, _ = model(audio)
loss.backward()

generated_audio = model.generate(1024, batch_size = 2) # generated audio is also a raw wave now
```

## Todo

- [ ] integrate soundstream
- [x] integrate soundstream

- [ ] when generating, make sure it can return audio file, and length can be defined in seconds (takes into sampling freq etc)
- [ ] turn it into a command line tool
- [ ] add cross attention and adaptive layernorm conditioning (just copy paste in the entire conformer repository, if conditioning adds too much cruft to the other repo)

Expand Down Expand Up @@ -113,3 +150,13 @@ generated = model.generate(1024, batch_size = 2) # (2, 1024)
year = {2021}
}
```

```bibtex
@inproceedings{rogozhnikov2022einops,
title = {Einops: Clear and Reliable Tensor Manipulations with Einstein-like Notation},
author = {Alex Rogozhnikov},
booktitle = {International Conference on Learning Representations},
year = {2022},
url = {https://openreview.net/forum?id=oapKSVM2bcj}
}
```
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'soundstorm-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.5',
version = '0.0.6',
license='MIT',
description = 'SoundStorm - Efficient Parallel Audio Generation from Google Deepmind, in Pytorch',
author = 'Phil Wang',
Expand All @@ -19,7 +19,7 @@
],
install_requires=[
'accelerate',
'audiolm-pytorch>=1.0.0',
'audiolm-pytorch>=1.0.1',
'beartype',
'conformer>=0.3.2',
'einops>=0.6.1',
Expand Down
1 change: 1 addition & 0 deletions soundstorm_pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from soundstorm_pytorch.soundstorm import (
SoundStorm,
SoundStream,
ConformerWrapper,
Conformer
)
31 changes: 29 additions & 2 deletions soundstorm_pytorch/soundstorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,13 +219,21 @@ def __init__(
critic_loss_weight = 1.
):
super().__init__()
self.soundstream = soundstream

if exists(self.soundstream):
assert soundstream.rq_groups == 1, 'grouped residual vector quantized soundstream not supported, yet'
assert net.codebook_size == soundstream.codebook_size
assert net.num_quantizers == soundstream.num_quantizers

assert not (self_token_critic and exists(token_critic))

self.net = net

dim = net.dim
self.dim = dim
self.num_tokens = net.codebook_size
self.num_quantizers = net.num_quantizers

self.mask_id = net.codebook_size

Expand Down Expand Up @@ -348,10 +356,20 @@ def generate(

self.train(was_training)

out = seq

if exists(self.soundstream):
seq = rearrange(seq, 'b (n q) -> b n q', q = self.num_quantizers)

with torch.no_grad():
self.soundstream.eval()
out = self.soundstream.decode_from_codebook_indices(seq)
out = rearrange(out, 'b 1 ... -> b ...')

if sample_one:
seq = rearrange(seq, '1 n -> n')
out = rearrange(out, '1 ... -> ...')

return seq
return out

def forward(
self,
Expand All @@ -361,6 +379,15 @@ def forward(
generator_sample_temperature = None,
**kwargs
):
is_raw_audio = x.dtype == torch.float

if is_raw_audio:
assert exists(self.soundstream)
with torch.no_grad():
self.soundstream.eval()
_, x, _ = self.soundstream(x, return_encoded = True)
x = rearrange(x, 'b n q -> b (n q)')

b, n, device = *x.shape, x.device

orig_seq = x.clone()
Expand Down

0 comments on commit 84d735a

Please sign in to comment.