Skip to content

Commit

Permalink
remove the eos token id from the semantic conditioning
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jun 29, 2023
1 parent 7ef65a9 commit cf62f10
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'soundstorm-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.18',
version = '0.0.19',
license='MIT',
description = 'SoundStorm - Efficient Parallel Audio Generation from Google Deepmind, in Pytorch',
author = 'Phil Wang',
Expand All @@ -23,7 +23,7 @@
'beartype',
'classifier-free-guidance-pytorch>=0.1.5',
'einops>=0.6.1',
'spear-tts-pytorch',
'spear-tts-pytorch>=0.0.4',
'torch>=1.6',
],
classifiers=[
Expand Down
6 changes: 6 additions & 0 deletions soundstorm_pytorch/soundstorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -844,6 +844,12 @@ def maybe_get_condition(self, token_ids = None, length = None):

with context():
mask = token_ids != self.semantic_pad_id

# also remove the eos semantic token id

if exists(self.text_to_semantic) and self.text_to_semantic.autoset_eos_id['speech']:
mask &= token_ids != self.num_semantic_token_ids

token_ids = token_ids.masked_fill(~mask, 0)

semantic_tokens = self.semantic_token_emb(token_ids)
Expand Down

0 comments on commit cf62f10

Please sign in to comment.