-
Notifications
You must be signed in to change notification settings - Fork 26
/
Copy pathlosses.py
60 lines (47 loc) · 2.26 KB
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
# This source file is part of DiffAI
# Copyright (c) 2018 Secure, Reliable, and Intelligent Systems Lab (SRI), ETH Zurich
# This software is distributed under the MIT License: https://opensource.org/licenses/MIT
# SPDX-License-Identifier: MIT
# For more information see https://github.com/eth-sri/diffai
# THE SOFTWARE IS PROVIDED "AS-IS" WITHOUT ANY WARRANTY OF ANY KIND, EITHER
# EXPRESS, IMPLIED OR STATUTORY, INCLUDING BUT NOT LIMITED TO ANY WARRANTY
# THAT THE SOFTWARE WILL CONFORM TO SPECIFICATIONS OR BE ERROR-FREE AND ANY
# IMPLIED WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE,
# TITLE, OR NON-INFRINGEMENT. IN NO EVENT SHALL ETH ZURICH BE LIABLE FOR ANY
# DAMAGES, INCLUDING BUT NOT LIMITED TO DIRECT, INDIRECT,
# SPECIAL OR CONSEQUENTIAL DAMAGES, ARISING OUT OF, RESULTING FROM, OR IN
# ANY WAY CONNECTED WITH THIS SOFTWARE (WHETHER OR NOT BASED UPON WARRANTY,
# CONTRACT, TORT OR OTHERWISE).
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import helpers as h
import domains
from domains import *
import math
POINT_DOMAINS = [m for m in h.getMethods(domains) if h.hasMethod(m, "attack")] + [ torch.FloatTensor, torch.Tensor, torch.cuda.FloatTensor ]
SYMETRIC_DOMAINS = [domains.Box] + POINT_DOMAINS
def domRes(outDom, target, **args): # TODO: make faster again by keeping sparse tensors sparse
t = h.one_hot(target.data.long(), outDom.size()[1]).to_dense()
tmat = t.unsqueeze(2).matmul(t.unsqueeze(1))
tl = t.unsqueeze(2).expand(-1, -1, tmat.size()[1])
inv_t = h.eye(tmat.size()[1]).expand(tmat.size()[0], -1, -1)
inv_t = inv_t - tmat
tl = tl.bmm(inv_t)
fst = outDom.bmm(tl)
snd = outDom.bmm(inv_t)
diff = fst - snd
return diff.lb() + t
def isSafeDom(outDom, target, **args):
od,_ = torch.min(domRes(outDom, target, **args), 1)
return od.gt(0.0).long().item()
def isSafeBox(target, net, inp, eps, dom):
atarg = target.argmax(1)[0].unsqueeze(0)
if hasattr(dom, "attack"):
x = dom.attack(net, eps, inp, target)
pred = net(x).argmax(1)[0].unsqueeze(0) # get the index of the max log-probability
return pred.item() == atarg.item()
else:
outDom = net(dom.box(inp, eps))
return isSafeDom(outDom, atarg)