-
Notifications
You must be signed in to change notification settings - Fork 19
/
Copy pathloss_functions.py
134 lines (107 loc) · 4.44 KB
/
loss_functions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
class AngularPenaltySMLoss(nn.Module):
def __init__(
self, in_features, out_features, loss_type='arcface', eps=1e-7, s=None, m=None,
):
'''
Angular Penalty Softmax Loss
Three 'loss_types' available: ['arcface', 'sphereface', 'cosface']
These losses are described in the following papers:
ArcFace: https://arxiv.org/abs/1801.07698
SphereFace: https://arxiv.org/abs/1704.08063
CosFace/Ad Margin: https://arxiv.org/abs/1801.05599
'''
super(AngularPenaltySMLoss, self).__init__()
loss_type = loss_type.lower()
assert loss_type in ['arcface', 'sphereface', 'cosface']
if loss_type == 'arcface':
self.s = 64.0 if not s else s
self.m = 0.5 if not m else m
if loss_type == 'sphereface':
self.s = 64.0 if not s else s
self.m = 1.35 if not m else m
if loss_type == 'cosface':
self.s = 30.0 if not s else s
self.m = 0.4 if not m else m
self.loss_type = loss_type
self.in_features = in_features
self.out_features = out_features
self.fc = nn.Linear(in_features, out_features, bias=False)
self.eps = eps
def forward(self, x, labels):
'''
input shape (N, in_features)
'''
assert len(x) == len(labels)
assert torch.min(labels) >= 0
assert torch.max(labels) < self.out_features
for W in self.fc.parameters():
W = F.normalize(W, p=2, dim=1)
x = F.normalize(x, p=2, dim=1)
wf = self.fc(x)
if self.loss_type == 'cosface':
numerator = self.s * (torch.diagonal(wf.transpose(0, 1)[labels]) - self.m)
if self.loss_type == 'arcface':
numerator = self.s * torch.cos(
torch.acos(
torch.clamp(torch.diagonal(wf.transpose(0, 1)[labels]), -1.0 + self.eps, 1 - self.eps,)
)
+ self.m
)
if self.loss_type == 'sphereface':
numerator = self.s * torch.cos(
self.m
* torch.acos(
torch.clamp(torch.diagonal(wf.transpose(0, 1)[labels]), -1.0 + self.eps, 1 - self.eps,)
)
)
excl = torch.cat(
[torch.cat((wf[i, :y], wf[i, y + 1 :])).unsqueeze(0) for i, y in enumerate(labels)], dim=0,
)
denominator = torch.exp(numerator) + torch.sum(torch.exp(self.s * excl), dim=1)
L = numerator - torch.log(denominator)
return -torch.mean(L)
class CompLoss(nn.Module):
def __init__(self, weights=None):
super().__init__()
self.weights = weights
def forward(self, y_pred, y_true):
if self.weights != None:
print(1)
else:
tp = torch.mean(y_true * y_pred, dim=0)
tn = torch.mean((1 - y_true) * (1 - y_pred), dim=0)
fp = torch.mean((1 - y_true) * y_pred, dim=0)
fn = torch.mean(y_true * (1 - y_pred), dim=0)
accuracy = (tp+tn)/(fp+fn+tp+tn)
accuracy = torch.mean(accuracy)
return 1- accuracy
class FocalLoss(nn.Module):
def __init__(self, gamma=0, alpha=None, size_average=True):
super(FocalLoss, self).__init__()
self.gamma = gamma
self.alpha = alpha
if isinstance(alpha,(float,int,long)): self.alpha = torch.Tensor([alpha,1-alpha])
if isinstance(alpha,list): self.alpha = torch.Tensor(alpha)
self.size_average = size_average
def forward(self, input, target):
if input.dim()>2:
input = input.view(input.size(0),input.size(1),-1) # N,C,H,W => N,C,H*W
input = input.transpose(1,2) # N,C,H*W => N,H*W,C
input = input.contiguous().view(-1,input.size(2)) # N,H*W,C => N*H*W,C
target = target.view(-1,1)
logpt = F.log_softmax(input)
logpt = logpt.gather(1,target)
logpt = logpt.view(-1)
pt = Variable(logpt.data.exp())
if self.alpha is not None:
if self.alpha.type()!=input.data.type():
self.alpha = self.alpha.type_as(input.data)
at = self.alpha.gather(0,target.data.view(-1))
logpt = logpt * Variable(at)
loss = -1 * (1-pt)**self.gamma * logpt
if self.size_average: return loss.mean()
else: return loss.sum()