Skip to content

Commit

Permalink
add angular mode to VQ
Browse files Browse the repository at this point in the history
  • Loading branch information
Vermeille authored Dec 11, 2024
1 parent c53de03 commit 3afab40
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion torchelie/nn/vq.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def __init__(self,
dim: int = 1,
commitment: float = 0.25,
init_mode: str = 'normal',
space="l2",
return_indices: bool = True,
max_age: int = 1000):
super(VQ, self).__init__()
Expand All @@ -45,6 +46,8 @@ def __init__(self,
self.init_mode = init_mode
self.register_buffer('age', torch.empty(num_tokens).fill_(max_age))
self.max_age = max_age
self.space = space
assert space in ["l2", "angular"]

def update_usage(self, indices):
with torch.no_grad():
Expand Down Expand Up @@ -107,7 +110,11 @@ def forward(
if self.training:
self.resample_dead(x)

codes, indices = quantize(x, codebook, self.commitment, self.dim)
if self.space == "angular":
codebook = F.normalize(codebook, dim=1)
x = F.normalize(x, dim=-1)

codes, indices = quantize(x, codebook, self.commitment, -1)

if self.training:
self.update_usage(indices)
Expand Down

0 comments on commit 3afab40

Please sign in to comment.