-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfocal_loss.py
123 lines (101 loc) · 4.24 KB
/
focal_loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
## Modified FROM https://github.com/AdeelH/pytorch-multi-class-focal-loss/blob/master/focal_loss.py ##
from typing import Optional, Sequence
import torch
from torch import Tensor, nn
from torch.nn import functional as F
class FocalLoss(nn.Module):
"""Focal Loss, as described in https://arxiv.org/abs/1708.02002.
It is essentially an enhancement to cross entropy loss and is
useful for classification tasks when there is a large class imbalance.
x is expected to contain raw, unnormalized scores for each class.
y is expected to contain class labels.
Shape:
- x: (batch_size, C) or (batch_size, C, d1, d2, ..., dK), K > 0.
- y: (batch_size,) or (batch_size, d1, d2, ..., dK), K > 0.
"""
def __init__(
self,
alpha: Optional[Tensor] = None,
gamma: float = 2,
reduction: str = "mean",
ignore_index: int = -100,
):
"""Constructor.
Args:
alpha (Tensor, optional): Weights for each class. Defaults to None.
gamma (float, optional): A constant, as described in the paper.
Defaults to 0.
reduction (str, optional): 'mean', 'sum' or 'none'.
Defaults to 'mean'.
ignore_index (int, optional): class label to ignore.
Defaults to -100.
"""
if reduction not in ("mean", "sum", "none"):
raise ValueError('Reduction must be one of: "mean", "sum", "none".')
super().__init__()
self.alpha = alpha
self.gamma = gamma
self.ignore_index = ignore_index
self.reduction = reduction
self.nll_loss = nn.NLLLoss(
weight=alpha, reduction="none", ignore_index=ignore_index
)
def __repr__(self):
arg_keys = ["alpha", "gamma", "ignore_index", "reduction"]
arg_vals = [self.__dict__[k] for k in arg_keys]
arg_strs = [f"{k}={v!r}" for k, v in zip(arg_keys, arg_vals)]
arg_str = ", ".join(arg_strs)
return f"{type(self).__name__}({arg_str})"
def forward(self, x: Tensor, y: Tensor) -> Tensor:
# Changed it according to
# https://www.kaggle.com/code/thedrcat/focal-multilabel-loss-in-pytorch-explained
# x are logits of shape (batch X labels)
# y are 0,1 of shape (batch X labels)
# Transform logits to probability
p = torch.sigmoid(x)
# Transform it so that p is close to 1 for good predictions, close to 0 for bad predictions
p = torch.where(y >= 0.5, p, 1 - p)
# Take the cross-entropy loss (no reduction so shape is batch X labels)
# Also weight the losses with the alpha per class label
ce = torch.nn.BCEWithLogitsLoss(weight=self.alpha, reduction="none")(x, y)
# Focal term of focal loss (same shape)
focal_term = (1 - p) ** self.gamma
# the full loss: * ((1 - pt)^gamma) * log(pt)
loss = focal_term * ce
if self.reduction == "mean":
loss = loss.mean()
elif self.reduction == "sum":
loss = loss.sum()
return loss
def focal_loss(
alpha: Optional[Sequence] = None,
gamma: float = 0.0,
reduction: str = "mean",
ignore_index: int = -100,
device="cpu",
dtype=torch.float32,
) -> FocalLoss:
"""Factory function for FocalLoss.
Args:
alpha (Sequence, optional): Weights for each class. Will be converted
to a Tensor if not None. Defaults to None.
gamma (float, optional): A constant, as described in the paper.
Defaults to 0.
reduction (str, optional): 'mean', 'sum' or 'none'.
Defaults to 'mean'.
ignore_index (int, optional): class label to ignore.
Defaults to -100.
device (str, optional): Device to move alpha to. Defaults to 'cpu'.
dtype (torch.dtype, optional): dtype to cast alpha to.
Defaults to torch.float32.
Returns:
A FocalLoss object
"""
if alpha is not None:
if not isinstance(alpha, Tensor):
alpha = torch.tensor(alpha) # type: ignore
alpha = alpha.to(device=device, dtype=dtype) # type: ignore
fl = FocalLoss(
alpha=alpha, gamma=gamma, reduction=reduction, ignore_index=ignore_index # type: ignore
)
return fl