diff --git a/torchelie/nn/vq.py b/torchelie/nn/vq.py index 658311c..f6b9aec 100644 --- a/torchelie/nn/vq.py +++ b/torchelie/nn/vq.py @@ -31,6 +31,7 @@ def __init__(self, dim: int = 1, commitment: float = 0.25, init_mode: str = 'normal', + space="l2", return_indices: bool = True, max_age: int = 1000): super(VQ, self).__init__() @@ -45,6 +46,8 @@ def __init__(self, self.init_mode = init_mode self.register_buffer('age', torch.empty(num_tokens).fill_(max_age)) self.max_age = max_age + self.space = space + assert space in ["l2", "angular"] def update_usage(self, indices): with torch.no_grad(): @@ -107,7 +110,11 @@ def forward( if self.training: self.resample_dead(x) - codes, indices = quantize(x, codebook, self.commitment, self.dim) + if self.space == "angular": + codebook = F.normalize(codebook, dim=1) + x = F.normalize(x, dim=-1) + + codes, indices = quantize(x, codebook, self.commitment, -1) if self.training: self.update_usage(indices)