Skip to content

Commit

Permalink
allow passing in as a dictionary of kwargs for conformer
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed May 17, 2023
1 parent 4777eeb commit 4e453df
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions soundstorm_pytorch/soundstorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from einops.layers.torch import Rearrange

from beartype import beartype
from beartype.typing import Union, Dict

from conformer import Conformer
from soundstorm_pytorch.attend import Attend
Expand All @@ -27,17 +28,22 @@ class ConformerWrapper(nn.Module):
def __init__(
self,
*,
conformer: Conformer,
conformer: Union[Conformer, Dict[str, any]],
num_tokens_reduce,
num_tokens_per_head = None,
):
super().__init__()
self.conformer = conformer

if isinstance(conformer, dict):
self.conformer = Conformer(**conformer)
else:
self.conformer = conformer

self.num_tokens_reduce = num_tokens_reduce
self.num_tokens_per_head = default(num_tokens_per_head, num_tokens_reduce)

dim = conformer.dim
dim = self.conformer.dim

self.heads = nn.Sequential(
nn.Linear(dim, dim * self.num_tokens_per_head),
Expand Down

0 comments on commit 4e453df

Please sign in to comment.