From df3d3371da105b76f6856e0cb8111058cb56b490 Mon Sep 17 00:00:00 2001 From: chingis Date: Fri, 18 Mar 2022 17:26:44 +0900 Subject: [PATCH 1/3] feat: dynamic arcface --- .../losses/__init__.py | 1 + .../losses/dynamic_arcface_loss.py | 37 +++++++++++++++++++ 2 files changed, 38 insertions(+) create mode 100644 src/pytorch_metric_learning/losses/dynamic_arcface_loss.py diff --git a/src/pytorch_metric_learning/losses/__init__.py b/src/pytorch_metric_learning/losses/__init__.py index 1838b483..c97bc918 100644 --- a/src/pytorch_metric_learning/losses/__init__.py +++ b/src/pytorch_metric_learning/losses/__init__.py @@ -1,5 +1,6 @@ from .angular_loss import AngularLoss from .arcface_loss import ArcFaceLoss +from .dynamic_arcface_loss import DynamicArcFaceLoss from .base_metric_loss_function import BaseMetricLossFunction, MultipleLosses from .centroid_triplet_loss import CentroidTripletLoss from .circle_loss import CircleLoss diff --git a/src/pytorch_metric_learning/losses/dynamic_arcface_loss.py b/src/pytorch_metric_learning/losses/dynamic_arcface_loss.py new file mode 100644 index 00000000..705995fe --- /dev/null +++ b/src/pytorch_metric_learning/losses/dynamic_arcface_loss.py @@ -0,0 +1,37 @@ +import numpy as np +import torch + +from .subcenter_arcface_loss import SubCenterArcFaceLoss +from .arcface_loss import ArcFaceLoss + + +class DynamicArcFaceLoss(torch.nn.Module): + """ + Implementation of https://arxiv.org/pdf/2010.05350.pdf + """ + + def __init__(self, loss_fn, n, lambda0=0.25, a=0.5, b=0.05): + super().__init__() + assert isinstance(loss_fn, (ArcFaceLoss, SubCenterArcFaceLoss)), 'Loss function should be Arcface-based' + self.lambda0 = lambda0 + self.a = a + self.b = b + + self.loss_fn = loss_fn + self.n = n if len(n.shape) == 2 else n[..., None] + self.init_margins() + + def init_margins(self): + self.margins = self.a * self.n ** (-self.lambda0) + self.b + + def get_batch_margins(self, labels): + return self.margins[labels] + + def set_margins(self, batch_margins): + self.loss_fn.margin = batch_margins + + def forward(self, embeddings, labels): + batch_margins = self.get_batch_margins(labels) + self.set_margins(batch_margins) + return self.loss_fn(embeddings, labels) + From 6fa7cc221abcef01a408d081e1522eb863da9125 Mon Sep 17 00:00:00 2001 From: chingis Date: Sat, 19 Mar 2022 16:49:01 +0900 Subject: [PATCH 2/3] test: dynamic arcface --- .../losses/dynamic_arcface_loss.py | 22 +++--- tests/losses/test_dynamicarcface_loss.py | 67 +++++++++++++++++++ 2 files changed, 81 insertions(+), 8 deletions(-) create mode 100644 tests/losses/test_dynamicarcface_loss.py diff --git a/src/pytorch_metric_learning/losses/dynamic_arcface_loss.py b/src/pytorch_metric_learning/losses/dynamic_arcface_loss.py index 705995fe..1e7322ad 100644 --- a/src/pytorch_metric_learning/losses/dynamic_arcface_loss.py +++ b/src/pytorch_metric_learning/losses/dynamic_arcface_loss.py @@ -1,5 +1,6 @@ import numpy as np import torch +from ..utils import common_functions as c_f from .subcenter_arcface_loss import SubCenterArcFaceLoss from .arcface_loss import ArcFaceLoss @@ -10,15 +11,15 @@ class DynamicArcFaceLoss(torch.nn.Module): Implementation of https://arxiv.org/pdf/2010.05350.pdf """ - def __init__(self, loss_fn, n, lambda0=0.25, a=0.5, b=0.05): + def __init__(self, n, loss_func=SubCenterArcFaceLoss, lambda0=0.25, a=0.5, b=0.05, **kwargs): super().__init__() - assert isinstance(loss_fn, (ArcFaceLoss, SubCenterArcFaceLoss)), 'Loss function should be Arcface-based' + self.lambda0 = lambda0 self.a = a self.b = b - - self.loss_fn = loss_fn - self.n = n if len(n.shape) == 2 else n[..., None] + + self.loss_func = loss_func(**kwargs) + self.n = n.flatten() self.init_margins() def init_margins(self): @@ -28,10 +29,15 @@ def get_batch_margins(self, labels): return self.margins[labels] def set_margins(self, batch_margins): - self.loss_fn.margin = batch_margins + self.loss_func.margin = batch_margins + + def cast_types(self, tensor, dtype, device): + return c_f.to_device(tensor, device=device, dtype=dtype) - def forward(self, embeddings, labels): + def forward(self, embeddings, labels): batch_margins = self.get_batch_margins(labels) + dtype, device = embeddings.dtype, embeddings.device + batch_margins = self.cast_types(batch_margins, dtype, device) self.set_margins(batch_margins) - return self.loss_fn(embeddings, labels) + return self.loss_func(embeddings, labels) diff --git a/tests/losses/test_dynamicarcface_loss.py b/tests/losses/test_dynamicarcface_loss.py new file mode 100644 index 00000000..c85a8f52 --- /dev/null +++ b/tests/losses/test_dynamicarcface_loss.py @@ -0,0 +1,67 @@ +import unittest + +import numpy as np +import torch +import torch.nn.functional as F + +from pytorch_metric_learning.losses import DynamicArcFaceLoss, ArcFaceLoss, SubCenterArcFaceLoss + +from .. import TEST_DEVICE, TEST_DTYPES +from ..zzz_testing_utils.testing_utils import angle_to_coord + + +class TestDynamicArcFaceLoss(unittest.TestCase): + def test_dynamicarcface_loss(self): + scale = 64 + n = torch.tensor([4,5,2,3,1,2,3,4,4,5]) + sub_centers = 3 + num_classes = 10 + a = 0.5 + b = 0.05 + lambda0 = 0.25 + for loss_type in (ArcFaceLoss, SubCenterArcFaceLoss): + for dtype in TEST_DTYPES: + loss_func = DynamicArcFaceLoss( + n, + scale=scale, + num_classes=10, + embedding_size=2, + loss_func=loss_type, + a=a, + b=b, + lambda0=lambda0 + ) + embedding_angles = torch.arange(0, 180) + embeddings = torch.tensor( + [angle_to_coord(a) for a in embedding_angles], + requires_grad=True, + dtype=dtype, + ).to( + TEST_DEVICE + ) # 2D embeddings + labels = torch.randint(low=0, high=10, size=(180,)) + + loss = loss_func(embeddings, labels) + loss.backward() + + weights = F.normalize(loss_func.loss_func.W, p=2, dim=0) + logits = torch.matmul(F.normalize(embeddings), weights) + if loss_type == SubCenterArcFaceLoss: + logits = logits.view(-1, num_classes, sub_centers) + logits, _ = logits.max(axis=2) + class_margins = a * n ** (-lambda0) + b + batch_margins = class_margins[labels].to(dtype=dtype, device=TEST_DEVICE) + + + for i, c in enumerate(labels): + acos = torch.acos(torch.clamp(logits[i, c], -1, 1)) + logits[i, c] = torch.cos( + acos + batch_margins[i].to(TEST_DEVICE) + ) + + correct_loss = F.cross_entropy(logits * scale, labels.to(TEST_DEVICE)) + + rtol = 1e-2 if dtype == torch.float16 else 1e-5 + + self.assertTrue(torch.isclose(loss, correct_loss, rtol=rtol)) + From c0a95a56c351607f1f358a166a767ea7c952e937 Mon Sep 17 00:00:00 2001 From: chingis Date: Sat, 19 Mar 2022 16:52:11 +0900 Subject: [PATCH 3/3] fix: test --- tests/losses/test_dynamicarcface_loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/losses/test_dynamicarcface_loss.py b/tests/losses/test_dynamicarcface_loss.py index c85a8f52..417605d2 100644 --- a/tests/losses/test_dynamicarcface_loss.py +++ b/tests/losses/test_dynamicarcface_loss.py @@ -13,9 +13,9 @@ class TestDynamicArcFaceLoss(unittest.TestCase): def test_dynamicarcface_loss(self): scale = 64 - n = torch.tensor([4,5,2,3,1,2,3,4,4,5]) sub_centers = 3 num_classes = 10 + n = torch.randint(low=1, high=200, size=(num_classes,)) a = 0.5 b = 0.05 lambda0 = 0.25