Skip to content

Dynamic arcface #452

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/pytorch_metric_learning/losses/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .angular_loss import AngularLoss
from .arcface_loss import ArcFaceLoss
from .dynamic_arcface_loss import DynamicArcFaceLoss
from .base_metric_loss_function import BaseMetricLossFunction, MultipleLosses
from .centroid_triplet_loss import CentroidTripletLoss
from .circle_loss import CircleLoss
Expand Down
43 changes: 43 additions & 0 deletions src/pytorch_metric_learning/losses/dynamic_arcface_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import numpy as np
import torch
from ..utils import common_functions as c_f

from .subcenter_arcface_loss import SubCenterArcFaceLoss
from .arcface_loss import ArcFaceLoss


class DynamicArcFaceLoss(torch.nn.Module):
"""
Implementation of https://arxiv.org/pdf/2010.05350.pdf
"""

def __init__(self, n, loss_func=SubCenterArcFaceLoss, lambda0=0.25, a=0.5, b=0.05, **kwargs):
super().__init__()

self.lambda0 = lambda0
self.a = a
self.b = b

self.loss_func = loss_func(**kwargs)
self.n = n.flatten()
self.init_margins()

def init_margins(self):
self.margins = self.a * self.n ** (-self.lambda0) + self.b

def get_batch_margins(self, labels):
return self.margins[labels]

def set_margins(self, batch_margins):
self.loss_func.margin = batch_margins

def cast_types(self, tensor, dtype, device):
return c_f.to_device(tensor, device=device, dtype=dtype)

def forward(self, embeddings, labels):
batch_margins = self.get_batch_margins(labels)
dtype, device = embeddings.dtype, embeddings.device
batch_margins = self.cast_types(batch_margins, dtype, device)
self.set_margins(batch_margins)
return self.loss_func(embeddings, labels)

67 changes: 67 additions & 0 deletions tests/losses/test_dynamicarcface_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import unittest

import numpy as np
import torch
import torch.nn.functional as F

from pytorch_metric_learning.losses import DynamicArcFaceLoss, ArcFaceLoss, SubCenterArcFaceLoss

from .. import TEST_DEVICE, TEST_DTYPES
from ..zzz_testing_utils.testing_utils import angle_to_coord


class TestDynamicArcFaceLoss(unittest.TestCase):
def test_dynamicarcface_loss(self):
scale = 64
sub_centers = 3
num_classes = 10
n = torch.randint(low=1, high=200, size=(num_classes,))
a = 0.5
b = 0.05
lambda0 = 0.25
for loss_type in (ArcFaceLoss, SubCenterArcFaceLoss):
for dtype in TEST_DTYPES:
loss_func = DynamicArcFaceLoss(
n,
scale=scale,
num_classes=10,
embedding_size=2,
loss_func=loss_type,
a=a,
b=b,
lambda0=lambda0
)
embedding_angles = torch.arange(0, 180)
embeddings = torch.tensor(
[angle_to_coord(a) for a in embedding_angles],
requires_grad=True,
dtype=dtype,
).to(
TEST_DEVICE
) # 2D embeddings
labels = torch.randint(low=0, high=10, size=(180,))

loss = loss_func(embeddings, labels)
loss.backward()

weights = F.normalize(loss_func.loss_func.W, p=2, dim=0)
logits = torch.matmul(F.normalize(embeddings), weights)
if loss_type == SubCenterArcFaceLoss:
logits = logits.view(-1, num_classes, sub_centers)
logits, _ = logits.max(axis=2)
class_margins = a * n ** (-lambda0) + b
batch_margins = class_margins[labels].to(dtype=dtype, device=TEST_DEVICE)


for i, c in enumerate(labels):
acos = torch.acos(torch.clamp(logits[i, c], -1, 1))
logits[i, c] = torch.cos(
acos + batch_margins[i].to(TEST_DEVICE)
)

correct_loss = F.cross_entropy(logits * scale, labels.to(TEST_DEVICE))

rtol = 1e-2 if dtype == torch.float16 else 1e-5

self.assertTrue(torch.isclose(loss, correct_loss, rtol=rtol))