From 082c1846b5bda39807b7384c504c1aee4e367793 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Thu, 24 Aug 2023 10:14:26 -0700 Subject: [PATCH] if one uses -1 for padding --- setup.py | 2 +- soundstorm_pytorch/soundstorm.py | 12 +++++++++--- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/setup.py b/setup.py index 244a58c..ea972da 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'soundstorm-pytorch', packages = find_packages(exclude=[]), - version = '0.1.1', + version = '0.1.2', license='MIT', description = 'SoundStorm - Efficient Parallel Audio Generation from Google Deepmind, in Pytorch', author = 'Phil Wang', diff --git a/soundstorm_pytorch/soundstorm.py b/soundstorm_pytorch/soundstorm.py index 2898fbe..f844105 100644 --- a/soundstorm_pytorch/soundstorm.py +++ b/soundstorm_pytorch/soundstorm.py @@ -1052,11 +1052,17 @@ def forward( seq_mask = mask - if not exists(seq_mask) and exists(self.pad_id): - seq_mask = (x != self.pad_id).any(dim = -1) - elif not exists(seq_mask): + if not exists(seq_mask): seq_mask = torch.ones((b, n), device = device, dtype = torch.bool) + if exists(self.pad_id): + pad_mask = (x == self.pad_id).any(dim = -1) + seq_mask = seq_mask & ~pad_mask + + if self.pad_id < 0: + # if using say -1 for padding + x = torch.where(rearrange(pad_mask, 'b n -> b n 1'), 0, x) + # maybe condition cond_tokens = self.maybe_get_condition(cond_semantic_token_ids, length = x.shape[-2])