From 264a1f2be50eb3d95e574e58066d82a7437f1038 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Sat, 3 Feb 2024 19:24:42 -0800 Subject: [PATCH] potential fix for an issue in generate, identified by @Jiang-Stan https://github.com/lucidrains/soundstorm-pytorch/issues/28 --- README.md | 2 ++ setup.py | 2 +- soundstorm_pytorch/soundstorm.py | 2 +- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index c37cc92..44d064e 100644 --- a/README.md +++ b/README.md @@ -22,6 +22,8 @@ They basically applied MaskGiT to - Lucas Newman for basically training a small working Soundstorm with models across multiple repositories, showing it all works end-to-end. Models include SoundStream, Text-to-Semantic T5, and finally the SoundStorm transformer here. +- @Jiang-Stan for identifying a critical bug in the iterative demasking! + ## Install ```bash diff --git a/setup.py b/setup.py index f8a4c3c..188183b 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'soundstorm-pytorch', packages = find_packages(exclude=[]), - version = '0.4.0', + version = '0.4.1', license='MIT', description = 'SoundStorm - Efficient Parallel Audio Generation from Google Deepmind, in Pytorch', author = 'Phil Wang', diff --git a/soundstorm_pytorch/soundstorm.py b/soundstorm_pytorch/soundstorm.py index 1a62246..4d157ab 100644 --- a/soundstorm_pytorch/soundstorm.py +++ b/soundstorm_pytorch/soundstorm.py @@ -953,7 +953,7 @@ def generate( if not self.can_mask_prev_unmasked: scores = scores.masked_fill(~mask, mask_value) - scores_sorted = scores.argsort(dim = -1, descending = True) + scores_sorted = scores.argsort(dim = -1, descending = True).argsort(dim = -1) mask_num_tokens = rearrange(mask_num_tokens, 'b -> b 1')