Skip to content

Commit

Permalink
end-to-end text-to-speech with audiolm, spear-tts, and soundstorm
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jun 29, 2023
1 parent cf62f10 commit 3e6b1e7
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 10 deletions.
36 changes: 36 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,42 @@ loss.backward()
generated_audio = model.generate(seconds = 30, batch_size = 2) # generate 30 seconds of audio (it will calculate the length in seconds based off the sampling frequency and cumulative downsamples in the soundstream passed in above)
```

Complete text-to-speech will rely on a trained `TextToSemantic` encoder / decoder transformer. You will then load the weights and pass it into the `SoundStorm` as `spear_tts_text_to_semantic`

```python
from spear_tts_pytorch import TextToSemantic

text_to_semantic = TextToSemantic(
dim = 512,
source_depth = 12,
target_depth = 12,
num_text_token_ids = 50000,
num_semantic_token_ids = 20000,
use_openai_tokenizer = True
)

# load the trained text-to-semantic transformer

text_to_semantic.load_state_dict(torch.load('/path/to/trained/model.pt'))

# pass it into the soundstorm

model = SoundStorm(
conformer,
soundstream = soundstream,
spear_tts_text_to_semantic = text_to_semantic
).cuda()

# and now you can generate state-of-the-art speech

generated_speech = model.generate(
texts = [
'the rain in spain stays mainly in the plain',
'the quick brown fox jumps over the lazy dog'
]
) # (2, n) - raw waveform decoded from soundstream
```

## Todo

- [x] integrate soundstream
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'soundstorm-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.19',
version = '0.0.20',
license='MIT',
description = 'SoundStorm - Efficient Parallel Audio Generation from Google Deepmind, in Pytorch',
author = 'Phil Wang',
Expand All @@ -23,7 +23,7 @@
'beartype',
'classifier-free-guidance-pytorch>=0.1.5',
'einops>=0.6.1',
'spear-tts-pytorch>=0.0.4',
'spear-tts-pytorch>=0.0.6',
'torch>=1.6',
],
classifiers=[
Expand Down
39 changes: 31 additions & 8 deletions soundstorm_pytorch/soundstorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from einops.layers.torch import Rearrange, EinMix

from beartype import beartype
from beartype.typing import Union, Dict, Optional
from beartype.door import is_bearable
from beartype.typing import Union, Dict, Optional, List, Optional

from soundstorm_pytorch.attend import Attend

Expand Down Expand Up @@ -657,11 +658,13 @@ def __init__(

if exists(spear_tts_text_to_semantic):
self.semantic_token_emb = spear_tts_text_to_semantic.semantic_token_emb
self.semantic_cond_to_model_dim = nn.Linear(spear_tts_text_to_semantic, net.dim)
self.semantic_pad_id = spear_tts_text_to_semantic.semantic_pad_id
self.num_semantic_token_ids = spear_tts_text_to_semantic.num_semantic_token_ids
self.semantic_cond_to_model_dim = nn.Linear(spear_tts_text_to_semantic.dim, net.dim)
self.semantic_pad_id = spear_tts_text_to_semantic.pad_id.get('speech')
else:
assert exists(num_semantic_token_ids), 'if you are conditioning, you must pass in the number of semantic token ids'
self.semantic_token_emb = nn.Embedding(num_semantic_token_ids, dim)
self.num_semantic_token_ids = num_semantic_token_ids
self.semantic_cond_to_model_dim = nn.Identity()
self.semantic_pad_id = semantic_pad_id

Expand Down Expand Up @@ -720,21 +723,33 @@ def generate(
self,
num_latents = None,
*,
texts: Optional[Union[List[str], Tensor]] = None,
cond_semantic_token_ids = None,
seconds = None,
batch_size = None,
start_temperature = 1.,
filter_thres = 0.7,
noise_level_scale = 1.,
text_to_semantic_generate_kwargs: dict = {},
**kwargs
):
assert not (exists(cond_semantic_token_ids) ^ self.should_condition), 'you either have text-conditioning turned on and have not passed in any conditioning semantic token ids, or vice versa'

assert exists(num_latents) ^ exists(seconds)
if not exists(cond_semantic_token_ids):
assert exists(texts) and exists(self.text_to_semantic)

if is_bearable(texts, List[str]):
assert exists(self.text_to_semantic.tokenizer_encode)
texts = self.text_to_semantic.tokenizer_encode(texts)
texts = texts.to(self.device)

cond_semantic_token_ids = self.text_to_semantic.generate(
texts,
source_type = 'text',
target_type = 'speech',
**text_to_semantic_generate_kwargs
)

if not exists(num_latents):
assert exists(self.soundstream), 'soundstream must be passed in to generate in seconds'
num_latents = (seconds * self.soundstream.target_sample_hz) // self.soundstream.seq_len_multiple_of
assert not (exists(cond_semantic_token_ids) ^ self.should_condition), 'you either have text-conditioning turned on and have not passed in any conditioning semantic token ids, or vice versa'

# maybe condition

Expand All @@ -749,6 +764,14 @@ def generate(
sample_one = not exists(batch_size)
batch_size = default(batch_size, 1)

assert exists(num_latents) ^ exists(seconds)

if not exists(num_latents):
assert exists(self.soundstream), 'soundstream must be passed in to generate in seconds'
num_latents = (seconds * self.soundstream.target_sample_hz) // self.soundstream.seq_len_multiple_of

# determine sequence length

seq_len = num_latents * self.grouped_quantizers * self.num_quantizers

# device and time
Expand Down

0 comments on commit 3e6b1e7

Please sign in to comment.