From cac149f3913e9e074df5a7ba1904295612b79c4f Mon Sep 17 00:00:00 2001 From: fra31 Date: Thu, 9 Apr 2020 13:40:05 +0000 Subject: [PATCH 1/7] added apgd and fab --- advertorch/attacks/autopgd.py | 415 +++++++++++++++++++ advertorch/attacks/fab_with_threshold.py | 392 ++++++++++++++++++ advertorch/attacks/fast_adaptive_boundary.py | 2 +- 3 files changed, 808 insertions(+), 1 deletion(-) create mode 100644 advertorch/attacks/autopgd.py create mode 100644 advertorch/attacks/fab_with_threshold.py diff --git a/advertorch/attacks/autopgd.py b/advertorch/attacks/autopgd.py new file mode 100644 index 0000000..b9e0c22 --- /dev/null +++ b/advertorch/attacks/autopgd.py @@ -0,0 +1,415 @@ +# Copyright (c) 2020-present, Francesco Croce +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree +# + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import time +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .base import Attack +from .base import LabelMixin + + +class APGDAttack(Attack, LabelMixin): + """ + AutoPGD + https://arxiv.org/abs/2003.01690 + + :param predict: forward pass function + :param norm: Lp-norm of the attack ('Linf', 'L2' supported) + :param n_restarts: number of random restarts + :param n_iter: number of iterations + :param eps: bound on the norm of perturbations + :param seed: random seed for the starting point + :param loss: loss to optimize ('ce', 'dlr' supported) + :param eot_iter: iterations for Expectation over Trasformation + :param rho: parameter for decreasing the step size + """ + + def __init__( + self, + predict, + n_iter=100, + norm='Linf', + n_restarts=1, + eps=None, + seed=0, + loss='ce', + eot_iter=1, + rho=.75, + verbose=False): + """ + AutoPGD implementation in PyTorch + """ + super(APGDAttack, self).__init__( + predict, loss_fn=None, clip_min=0., clip_max=1.) + + self.predict = predict + self.n_iter = n_iter + self.eps = eps + self.norm = norm + self.n_restarts = n_restarts + self.seed = seed + self.loss = loss + self.eot_iter = eot_iter + self.thr_decr = rho + self.verbose = verbose + + def init_hyperparam(self, x): + assert self.norm in ['Linf', 'L2'] + assert not self.eps is None + + self.device = x.device + self.orig_dim = list(x.shape[1:]) + self.ndims = len(self.orig_dim) + if self.seed is None: + self.seed = time.time() + + ### set parameters for checkpoints + self.n_iter_2 = max(int(0.22 * self.n_iter), 1) + self.n_iter_min = max(int(0.06 * self.n_iter), 1) + self.size_decr = max(int(0.03 * self.n_iter), 1) + + def check_oscillation(self, x, j, k, y5, k3=0.75): + t = torch.zeros(x.shape[1]).to(self.device) + for counter5 in range(k): + t += (x[j - counter5] > x[j - counter5 - 1]).float() + + return (t <= k * k3 * torch.ones_like(t)).float() + + def check_shape(self, x): + return x if len(x.shape) > 0 else x.unsqueeze(0) + + def normalize(self, x): + if self.norm == 'Linf': + t = x.abs().view(x.shape[0], -1).max(1)[0] + return x / (t.view(-1, *([1] * self.ndims)) + 1e-12) + + elif self.norm == 'L2': + t = (x ** 2).view(x.shape[0], -1).sum(-1).sqrt() + return x / (t.view(-1, *([1] * self.ndims)) + 1e-12) + + def lp_norm(self, x): + if self.norm == 'L2': + t = (x ** 2).view(x.shape[0], -1).sum(-1).sqrt() + return t.view(-1, *([1] * self.ndims)) + + def dlr_loss(self, x, y): + x_sorted, ind_sorted = x.sort(dim=1) + ind = (ind_sorted[:, -1] == y).float() + u = torch.arange(x.shape[0]) + + return -(x[u, y] - x_sorted[:, -2] * ind - x_sorted[:, -1] * ( + 1. - ind)) / (x_sorted[:, -1] - x_sorted[:, -3] + 1e-12) + + def attack_single_run(self, x, y): + if len(x.shape) < self.ndims: + x = x.unsqueeze(0) + y = y.unsqueeze(0) + + if self.norm == 'Linf': + t = 2 * torch.rand(x.shape).to(self.device).detach() - 1 + x_adv = x + self.eps * torch.ones_like(x + ).detach() * self.normalize(t) + elif self.norm == 'L2': + t = torch.randn(x.shape).to(self.device).detach() + x_adv = x + self.eps * torch.ones_like(x + ).detach() * self.normalize(t) + + x_adv = x_adv.clamp(0., 1.) + x_best = x_adv.clone() + x_best_adv = x_adv.clone() + loss_steps = torch.zeros([self.n_iter, x.shape[0]] + ).to(self.device) + loss_best_steps = torch.zeros([self.n_iter + 1, x.shape[0]] + ).to(self.device) + acc_steps = torch.zeros_like(loss_best_steps) + + if self.loss == 'ce': + criterion_indiv = nn.CrossEntropyLoss(reduction='none') + elif self.loss == 'dlr': + criterion_indiv = self.dlr_loss + elif self.loss == 'dlr-targeted': + criterion_indiv = self.dlr_loss_targeted + else: + raise ValueError('unknowkn loss') + + x_adv.requires_grad_() + grad = torch.zeros_like(x) + for _ in range(self.eot_iter): + with torch.enable_grad(): + logits = self.predict(x_adv) + loss_indiv = criterion_indiv(logits, y) + loss = loss_indiv.sum() + + grad += torch.autograd.grad(loss, [x_adv])[0].detach() + + grad /= float(self.eot_iter) + grad_best = grad.clone() + + acc = logits.detach().max(1)[1] == y + acc_steps[0] = acc + 0 + loss_best = loss_indiv.detach().clone() + + step_size = 2. * self.eps * torch.ones([x.shape[0], *( + [1] * self.ndims)]).to(self.device).detach() + x_adv_old = x_adv.clone() + counter = 0 + k = self.n_iter_2 + 0 + counter3 = 0 + + loss_best_last_check = loss_best.clone() + reduced_last_check = torch.ones_like(loss_best) + n_reduced = 0 + + for i in range(self.n_iter): + ### gradient step + with torch.no_grad(): + x_adv = x_adv.detach() + grad2 = x_adv - x_adv_old + x_adv_old = x_adv.clone() + + a = 0.75 if i > 0 else 1.0 + + if self.norm == 'Linf': + x_adv_1 = x_adv + step_size * torch.sign(grad) + x_adv_1 = torch.clamp(torch.min(torch.max(x_adv_1, + x - self.eps), x + self.eps), 0.0, 1.0) + x_adv_1 = torch.clamp(torch.min(torch.max( + x_adv + (x_adv_1 - x_adv) * a + grad2 * (1 - a), + x - self.eps), x + self.eps), 0.0, 1.0) + + elif self.norm == 'L2': + x_adv_1 = x_adv + step_size * self.normalize(grad) + x_adv_1 = torch.clamp(x + self.normalize(x_adv_1 - x + ) * torch.min(self.eps * torch.ones_like(x).detach(), + self.lp_norm(x_adv_1 - x)), 0.0, 1.0) + x_adv_1 = x_adv + (x_adv_1 - x_adv) * a + grad2 * (1 - a) + x_adv_1 = torch.clamp(x + self.normalize(x_adv_1 - x + ) * torch.min(self.eps * torch.ones_like(x).detach(), + self.lp_norm(x_adv_1 - x)), 0.0, 1.0) + + x_adv = x_adv_1 + 0. + + ### get gradient + x_adv.requires_grad_() + grad = torch.zeros_like(x) + for _ in range(self.eot_iter): + with torch.enable_grad(): + logits = self.predict(x_adv) + loss_indiv = criterion_indiv(logits, y) + loss = loss_indiv.sum() + + grad += torch.autograd.grad(loss, [x_adv])[0].detach() + + grad /= float(self.eot_iter) + + pred = logits.detach().max(1)[1] == y + acc = torch.min(acc, pred) + acc_steps[i + 1] = acc + 0 + ind_pred = (pred == 0).nonzero().squeeze() + x_best_adv[ind_pred] = x_adv[ind_pred] + 0. + if self.verbose: + print('iteration: {} - Best loss: {:.6f}'.format( + i, loss_best.sum())) + + ### check step size + with torch.no_grad(): + y1 = loss_indiv.detach().clone() + loss_steps[i] = y1 + 0 + ind = (y1 > loss_best).nonzero().squeeze() + x_best[ind] = x_adv[ind].clone() + grad_best[ind] = grad[ind].clone() + loss_best[ind] = y1[ind] + 0 + loss_best_steps[i + 1] = loss_best + 0 + + counter3 += 1 + + if counter3 == k: + fl_oscillation = self.check_oscillation(loss_steps, i, k, + loss_best, k3=self.thr_decr) + fl_reduce_no_impr = (1. - reduced_last_check) * ( + loss_best_last_check >= loss_best).float() + fl_oscillation = torch.max(fl_oscillation, + fl_reduce_no_impr) + reduced_last_check = fl_oscillation.clone() + loss_best_last_check = loss_best.clone() + + if fl_oscillation.sum() > 0: + ind_fl_osc = (fl_oscillation > 0).nonzero().squeeze() + step_size[ind_fl_osc] /= 2.0 + n_reduced = fl_oscillation.sum() + + x_adv[ind_fl_osc] = x_best[ind_fl_osc].clone() + grad[ind_fl_osc] = grad_best[ind_fl_osc].clone() + + counter3 = 0 + k = max(k - self.size_decr, self.n_iter_min) + + return (x_best, acc, loss_best, x_best_adv) + + def perturb(self, x, y=None, best_loss=False): + """ + :param x: clean images + :param y: clean labels, if None we use the predicted labels + :param best_loss: if True the points attaining highest loss + are returned, otherwise adversarial examples + """ + + assert self.loss in ['ce', 'dlr'] + self.init_hyperparam(x) + + x = x.detach().clone().float().to(self.device) + if y is None: + y_pred = self._get_predicted_label(x) + y = y_pred.detach().clone().long().to(self.device) + else: + y = y.detach().clone().long().to(self.device) + + adv = x.clone() + acc = self.predict(x).max(1)[1] == y + loss = -1e10 * torch.ones_like(acc).float() + if self.verbose: + print('-------------------------- ', + 'running {}-attack with epsilon {:.5f}'.format( + self.norm, self.eps), + '--------------------------') + print('initial accuracy: {:.2%}'.format(acc.float().mean())) + + startt = time.time() + if not best_loss: + torch.random.manual_seed(self.seed) + torch.cuda.random.manual_seed(self.seed) + + for counter in range(self.n_restarts): + ind_to_fool = acc.nonzero().squeeze() + if len(ind_to_fool.shape) == 0: + ind_to_fool = ind_to_fool.unsqueeze(0) + if ind_to_fool.numel() != 0: + x_to_fool = x[ind_to_fool].clone() + y_to_fool = y[ind_to_fool].clone() + res_curr = self.attack_single_run(x_to_fool, y_to_fool) + best_curr, acc_curr, loss_curr, adv_curr = res_curr + ind_curr = (acc_curr == 0).nonzero().squeeze() + + acc[ind_to_fool[ind_curr]] = 0 + adv[ind_to_fool[ind_curr]] = adv_curr[ind_curr].clone() + if self.verbose: + print('restart {} - robust accuracy: {:.2%}'.format( + counter, acc.float().mean()), + '- cum. time: {:.1f} s'.format( + time.time() - startt)) + + return adv + + else: + adv_best = x.detach().clone() + loss_best = torch.ones([x.shape[0]]).to( + self.device) * (-float('inf')) + for counter in range(self.n_restarts): + best_curr, _, loss_curr, _ = self.attack_single_run(x, y) + ind_curr = (loss_curr > loss_best).nonzero().squeeze() + adv_best[ind_curr] = best_curr[ind_curr] + 0. + loss_best[ind_curr] = loss_curr[ind_curr] + 0. + + if self.verbose: + print('restart {} - loss: {:.5f}'.format( + counter, loss_best.sum())) + + return adv_best + +class APGDTargeted(APGDAttack): + def __init__( + self, + predict, + n_iter=100, + norm='Linf', + n_restarts=1, + eps=None, + seed=0, + eot_iter=1, + rho=.75, + n_target_classes=9, + verbose=False): + """ + AutoPGD on the targeted DLR loss + """ + super(APGDTargeted, self).__init__(predict, n_iter=n_iter, norm=norm, + n_restarts=n_restarts, eps=eps, seed=seed, loss='dlr-targeted', + eot_iter=eot_iter, rho=rho, verbose=verbose) + + self.y_target = None + self.n_target_classes = n_target_classes + + def dlr_loss_targeted(self, x, y): + x_sorted, ind_sorted = x.sort(dim=1) + u = torch.arange(x.shape[0]) + + return -(x[u, y] - x[u, self.y_target]) / (x_sorted[:, -1] - .5 * ( + x_sorted[:, -3] + x_sorted[:, -4]) + 1e-12) + + def perturb(self, x, y=None): + """ + :param x: clean images + :param y: clean labels, if None we use the predicted labels + """ + + assert self.loss in ['dlr-targeted'] + self.init_hyperparam(x) + + x = x.detach().clone().float().to(self.device) + if y is None: + y_pred = self._get_predicted_label(x) + y = y_pred.detach().clone().long().to(self.device) + else: + y = y.detach().clone().long().to(self.device) + + adv = x.clone() + acc = self.predict(x).max(1)[1] == y + if self.verbose: + print('-------------------------- ', + 'running {}-attack with epsilon {:.5f}'.format( + self.norm, self.eps), + '--------------------------') + print('initial accuracy: {:.2%}'.format(acc.float().mean())) + + startt = time.time() + + torch.random.manual_seed(self.seed) + torch.cuda.random.manual_seed(self.seed) + + for target_class in range(2, self.n_target_classes + 2): + for counter in range(self.n_restarts): + ind_to_fool = acc.nonzero().squeeze() + if len(ind_to_fool.shape) == 0: + ind_to_fool = ind_to_fool.unsqueeze(0) + if ind_to_fool.numel() != 0: + x_to_fool = x[ind_to_fool].clone() + y_to_fool = y[ind_to_fool].clone() + output = self.predict(x_to_fool) + self.y_target = output.sort(dim=1)[1][:, -target_class] + + res_curr = self.attack_single_run(x_to_fool, y_to_fool) + best_curr, acc_curr, loss_curr, adv_curr = res_curr + ind_curr = (acc_curr == 0).nonzero().squeeze() + + acc[ind_to_fool[ind_curr]] = 0 + adv[ind_to_fool[ind_curr]] = adv_curr[ind_curr].clone() + if self.verbose: + print('target class {}'.format(target_class), + '- restart {} - robust accuracy: {:.2%}'.format( + counter, acc.float().mean()), + '- cum. time: {:.1f} s'.format( + time.time() - startt)) + + return adv + diff --git a/advertorch/attacks/fab_with_threshold.py b/advertorch/attacks/fab_with_threshold.py new file mode 100644 index 0000000..38190e6 --- /dev/null +++ b/advertorch/attacks/fab_with_threshold.py @@ -0,0 +1,392 @@ +# Copyright (c) 2019-present, Francesco Croce +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import torch +from torch.autograd.gradcheck import zero_gradients +import time + +from .fast_adaptive_boundary import FABAttack + + +class FABWithThreshold(FABAttack): + """ + Fast Adaptive Boundary Attack (Linf, L2, L1) + plus bound on the norm of the perturbations + https://arxiv.org/abs/1907.02044 + + :param predict: forward pass function + :param norm: Lp-norm to minimize ('Linf', 'L2', 'L1' supported) + :param n_restarts: number of random restarts + :param n_iter: number of iterations + :param eps: upper bound on the norm of the perturbations + :param alpha_max: alpha_max + :param eta: overshooting + :param beta: backward step + """ + + def __init__( + self, + predict, + norm='Linf', + n_restarts=1, + n_iter=100, + eps=None, + alpha_max=0.1, + eta=1.05, + beta=0.9, + verbose=False, + seed=0): + + super(FABWithThreshold, self).__init__( + predict=predict, norm=norm, n_restarts=n_restarts, + n_iter=n_iter, eps=eps, alpha_max=alpha_max, eta=eta, beta=beta, + verbose=verbose) + + self.seed = seed + + def init_hyperparam(self, x): + assert self.norm in ['Linf', 'L2', 'L1'] + assert not self.eps is None + + self.device = x.device + self.orig_dim = list(x.shape[1:]) + self.ndims = len(self.orig_dim) + if self.seed is None: + self.seed = time.time() + + def attack_single_run(self, x, y, use_rand_start=False): + startt = time.time() + if len(x.shape) == self.ndims: + x = x.unsqueeze(0) + y = y.unsqueeze(0) + + im2 = x.clone() + la2 = y.clone() + bs = im2.shape[0] + u1 = torch.arange(bs) + adv = im2.clone() + adv_c = x.clone() + res2 = 1e10 * torch.ones([bs]).to(self.device) + res_c = torch.zeros([x.shape[0]]).to(self.device) + x1 = im2.clone() + x0 = im2.clone().reshape([bs, -1]) + + if use_rand_start: + if self.norm == 'Linf': + t = 2 * torch.rand(x1.shape).to(self.device) - 1 + x1 = im2 + (torch.min(res2, + self.eps * torch.ones(res2.shape) + .to(self.device) + ).reshape([-1, *[1]*self.ndims]) + ) * t / (t.reshape([t.shape[0], -1]).abs() + .max(dim=1, keepdim=True)[0] + .reshape([-1, *[1]*self.ndims])) * .5 + elif self.norm == 'L2': + t = torch.randn(x1.shape).to(self.device) + x1 = im2 + (torch.min(res2, + self.eps * torch.ones(res2.shape) + .to(self.device) + ).reshape([-1, *[1]*self.ndims]) + ) * t / ((t ** 2) + .view(t.shape[0], -1) + .sum(dim=-1) + .sqrt() + .view(t.shape[0], *[1]*self.ndims)) * .5 + elif self.norm == 'L1': + t = torch.randn(x1.shape).to(self.device) + x1 = im2 + (torch.min(res2, + self.eps * torch.ones(res2.shape) + .to(self.device) + ).reshape([-1, *[1]*self.ndims]) + ) * t / (t.abs().view(t.shape[0], -1) + .sum(dim=-1) + .view(t.shape[0], *[1]*self.ndims)) / 2 + + x1 = x1.clamp(0.0, 1.0) + + counter_iter = 0 + while counter_iter < self.n_iter: + with torch.no_grad(): + df, dg = self.get_diff_logits_grads_batch(x1, la2) + if self.norm == 'Linf': + dist1 = df.abs() / (1e-12 + + dg.abs() + .view(dg.shape[0], dg.shape[1], -1) + .sum(dim=-1)) + elif self.norm == 'L2': + dist1 = df.abs() / (1e-12 + (dg ** 2) + .view(dg.shape[0], dg.shape[1], -1) + .sum(dim=-1).sqrt()) + elif self.norm == 'L1': + dist1 = df.abs() / (1e-12 + dg.abs().reshape( + [df.shape[0], df.shape[1], -1]).max(dim=2)[0]) + else: + raise ValueError('norm not supported') + ind = dist1.min(dim=1)[1] + dg2 = dg[u1, ind] + b = (- df[u1, ind] + (dg2 * x1).view(x1.shape[0], -1) + .sum(dim=-1)) + w = dg2.reshape([bs, -1]) + + if self.norm == 'Linf': + d3 = self.projection_linf( + torch.cat((x1.reshape([bs, -1]), x0), 0), + torch.cat((w, w), 0), + torch.cat((b, b), 0)) + elif self.norm == 'L2': + d3 = self.projection_l2( + torch.cat((x1.reshape([bs, -1]), x0), 0), + torch.cat((w, w), 0), + torch.cat((b, b), 0)) + elif self.norm == 'L1': + d3 = self.projection_l1( + torch.cat((x1.reshape([bs, -1]), x0), 0), + torch.cat((w, w), 0), + torch.cat((b, b), 0)) + d1 = torch.reshape(d3[:bs], x1.shape) + d2 = torch.reshape(d3[-bs:], x1.shape) + if self.norm == 'Linf': + a0 = d3.abs().max(dim=1, keepdim=True)[0]\ + .view(-1, *[1]*self.ndims) + elif self.norm == 'L2': + a0 = (d3 ** 2).sum(dim=1, keepdim=True).sqrt()\ + .view(-1, *[1]*self.ndims) + elif self.norm == 'L1': + a0 = d3.abs().sum(dim=1, keepdim=True)\ + .view(-1, *[1]*self.ndims) + a0 = torch.max(a0, 1e-8 * torch.ones( + a0.shape).to(self.device)) + a1 = a0[:bs] + a2 = a0[-bs:] + alpha = torch.min(torch.max(a1 / (a1 + a2), + torch.zeros(a1.shape) + .to(self.device))[0], + self.alpha_max * torch.ones(a1.shape) + .to(self.device)) + x1 = ((x1 + self.eta * d1) * (1 - alpha) + + (im2 + d2 * self.eta) * alpha).clamp(0.0, 1.0) + + is_adv = self._get_predicted_label(x1) != la2 + + if is_adv.sum() > 0: + ind_adv = is_adv.nonzero().squeeze() + ind_adv = self.check_shape(ind_adv) + if self.norm == 'Linf': + t = (x1[ind_adv] - im2[ind_adv]).reshape( + [ind_adv.shape[0], -1]).abs().max(dim=1)[0] + elif self.norm == 'L2': + t = ((x1[ind_adv] - im2[ind_adv]) ** 2)\ + .view(ind_adv.shape[0], -1).sum(dim=-1).sqrt() + elif self.norm == 'L1': + t = (x1[ind_adv] - im2[ind_adv])\ + .abs().view(ind_adv.shape[0], -1).sum(dim=-1) + adv[ind_adv] = x1[ind_adv] * (t < res2[ind_adv]).\ + float().reshape([-1, *[1]*self.ndims]) + adv[ind_adv]\ + * (t >= res2[ind_adv]).float().reshape( + [-1, *[1]*self.ndims]) + res2[ind_adv] = t * (t < res2[ind_adv]).float()\ + + res2[ind_adv] * (t >= res2[ind_adv]).float() + x1[ind_adv] = im2[ind_adv] + ( + x1[ind_adv] - im2[ind_adv]) * self.beta + + counter_iter += 1 + + ind_succ = res2 < 1e10 + if self.verbose: + print('success rate: {:.0f}/{:.0f}' + .format(ind_succ.float().sum(), bs) + + ' (on correctly classified points) in {:.1f} s' + .format(time.time() - startt)) + + res_c = res2 * ind_succ.float() + 1e10 * (1 - ind_succ.float()) + ind_succ = self.check_shape(ind_succ.nonzero().squeeze()) + adv_c[ind_succ] = adv[ind_succ].clone() + + return adv_c + + def perturb(self, x, y=None): + """ + :param x: clean images + :param y: clean labels, if None we use the predicted labels + """ + + self.init_hyperparam(x) + + x = x.detach().clone().float().to(self.device) + if y is None: + y_pred = self._get_predicted_label(x) + y = y_pred.detach().clone().long().to(self.device) + else: + y = y.detach().clone().long().to(self.device) + + adv = x.clone() + with torch.no_grad(): + acc = self.predict(x).max(1)[1] == y + + startt = time.time() + + torch.random.manual_seed(self.seed) + torch.cuda.random.manual_seed(self.seed) + + for counter in range(self.n_restarts): + ind_to_fool = acc.nonzero().squeeze() + if len(ind_to_fool.shape) == 0: + ind_to_fool = ind_to_fool.unsqueeze(0) + if ind_to_fool.numel() != 0: + x_to_fool = x[ind_to_fool].clone() + y_to_fool = y[ind_to_fool].clone() + + adv_curr = self.attack_single_run( + x_to_fool, y_to_fool, use_rand_start=(counter > 0)) + + output_curr = self.predict(adv_curr) + acc_curr = output_curr.max(1)[1] == y_to_fool + if self.norm == 'Linf': + res = (x_to_fool - adv_curr).abs().view( + x_to_fool.shape[0], -1).max(1)[0] + elif self.norm == 'L2': + res = ((x_to_fool - adv_curr) ** 2).view( + x_to_fool.shape[0], -1).sum(dim=-1).sqrt() + elif self.norm == 'L1': + res = (x_to_fool - adv_curr).abs().view( + x_to_fool.shape[0], -1).sum(-1) + acc_curr = torch.max(acc_curr, res > self.eps) + + ind_curr = (acc_curr == 0).nonzero().squeeze() + acc[ind_to_fool[ind_curr]] = 0 + adv[ind_to_fool[ind_curr]] = adv_curr[ + ind_curr].clone() + + if self.verbose: + print('restart {}'.format(counter), + '- target_class {}'.format(target_cl), + '- robust accuracy: {:.2%}'.format( + acc.float().mean()), + 'at eps = {:.5f}'.format(self.eps), + '- cum. time: {:.1f} s'.format( + time.time() - startt)) + + return adv + +class FABTargeted(FABWithThreshold): + def __init__( + self, + predict, + norm='Linf', + n_restarts=1, + n_iter=100, + eps=None, + alpha_max=0.1, + eta=1.05, + beta=0.9, + verbose=False, + seed=0, + n_target_classes=9): + """ + FAB with considering only one possible alternative class + """ + super(FABTargeted, self).__init__( + predict=predict, norm=norm, n_restarts=n_restarts, + n_iter=n_iter, eps=eps, alpha_max=alpha_max, eta=eta, beta=beta, + verbose=verbose, seed=seed) + + self.y_target = None + self.n_target_classes = n_target_classes + + def get_diff_logits_grads_batch(self, imgs, la): + la_target = self.y_target + u = torch.arange(imgs.shape[0]) + + im = imgs.clone().requires_grad_() + with torch.enable_grad(): + y = self.predict(im) + diffy = -(y[u, la] - y[u, la_target]) + sumdiffy = diffy.sum() + + zero_gradients(im) + sumdiffy.backward() + graddiffy = im.grad.data + df = diffy.detach().unsqueeze(1) + dg = graddiffy.unsqueeze(1) + + return df, dg + + def perturb(self, x, y=None): + """ + :param x: clean images + :param y: clean labels, if None we use the predicted labels + """ + + self.init_hyperparam(x) + + x = x.detach().clone().float().to(self.device) + if y is None: + y_pred = self._get_predicted_label(x) + y = y_pred.detach().clone().long().to(self.device) + else: + y = y.detach().clone().long().to(self.device) + + adv = x.clone() + with torch.no_grad(): + output = self.predict(x) + la_sorted = output.sort(1)[1] + acc = output.max(1)[1] == y + + startt = time.time() + + torch.random.manual_seed(self.seed) + torch.cuda.random.manual_seed(self.seed) + + for target_cl in range(2, self.n_target_classes + 2): + for counter in range(self.n_restarts): + ind_to_fool = acc.nonzero().squeeze() + if len(ind_to_fool.shape) == 0: + ind_to_fool = ind_to_fool.unsqueeze(0) + if ind_to_fool.numel() != 0: + x_to_fool = x[ind_to_fool].clone() + y_to_fool = y[ind_to_fool].clone() + + self.y_target = la_sorted[ + ind_to_fool, -target_cl].clone() + + adv_curr = self.attack_single_run(x_to_fool, + y_to_fool, use_rand_start=(counter > 0)) + + output_curr = self.predict(adv_curr) + acc_curr = output_curr.max(1)[1] == y_to_fool + if self.norm == 'Linf': + res = (x_to_fool - adv_curr).abs().view( + x_to_fool.shape[0], -1).max(1)[0] + elif self.norm == 'L2': + res = ((x_to_fool - adv_curr) ** 2).view( + x_to_fool.shape[0], -1).sum(dim=-1).sqrt() + elif self.norm == 'L1': + res = (x_to_fool - adv_curr).abs().view( + x_to_fool.shape[0], -1).sum(-1) + acc_curr = torch.max(acc_curr, res > self.eps) + + ind_curr = (acc_curr == 0).nonzero().squeeze() + acc[ind_to_fool[ind_curr]] = 0 + adv[ind_to_fool[ind_curr]] = adv_curr[ + ind_curr].clone() + + if self.verbose: + print('restart {}'.format(counter), + '- target_class {}'.format(target_cl), + '- robust accuracy: {:.2%}'.format( + acc.float().mean()), + 'at eps = {:.5f}'.format(self.eps), + '- cum. time: {:.1f} s'.format( + time.time() - startt)) + + return adv + diff --git a/advertorch/attacks/fast_adaptive_boundary.py b/advertorch/attacks/fast_adaptive_boundary.py index 54a77f2..96730f6 100644 --- a/advertorch/attacks/fast_adaptive_boundary.py +++ b/advertorch/attacks/fast_adaptive_boundary.py @@ -89,7 +89,7 @@ def get_diff_logits_grads_batch(self, imgs, la): g2[counter] = im.grad.data g2 = torch.transpose(g2, 0, 1).detach() - y2 = self.predict(imgs).detach() + y2 = y.detach() df = y2 - y2[torch.arange(imgs.shape[0]), la].unsqueeze(1) dg = g2 - g2[torch.arange(imgs.shape[0]), la].unsqueeze(1) df[torch.arange(imgs.shape[0]), la] = 1e10 From 53350ca8b82005e32e27d9948b5b08e08c45a8c3 Mon Sep 17 00:00:00 2001 From: fra31 Date: Thu, 9 Apr 2020 19:18:18 +0000 Subject: [PATCH 2/7] added square --- advertorch/attacks/autopgd.py | 1 - advertorch/attacks/square.py | 379 ++++++++++++++++++++++++++++++++++ 2 files changed, 379 insertions(+), 1 deletion(-) create mode 100644 advertorch/attacks/square.py diff --git a/advertorch/attacks/autopgd.py b/advertorch/attacks/autopgd.py index b9e0c22..91816ce 100644 --- a/advertorch/attacks/autopgd.py +++ b/advertorch/attacks/autopgd.py @@ -53,7 +53,6 @@ def __init__( super(APGDAttack, self).__init__( predict, loss_fn=None, clip_min=0., clip_max=1.) - self.predict = predict self.n_iter = n_iter self.eps = eps self.norm = norm diff --git a/advertorch/attacks/square.py b/advertorch/attacks/square.py new file mode 100644 index 0000000..c5c01ae --- /dev/null +++ b/advertorch/attacks/square.py @@ -0,0 +1,379 @@ +#Copyright (c) 2020-present, Francesco Croce +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import torch +import time +import math + +from .base import Attack +from .base import LabelMixin + + +class ModelAdapterSA(): + """ + Wrapper for Square Attack + """ + + def __init__(self, model): + self.model = model + + def predict(self, x): + return self.model(x) + + def fmargin(self, x, y): + logits = self.predict(x) + u = torch.arange(x.shape[0]) + y_corr = logits[u, y].clone() + logits[u, y] = -float('inf') + y_others = logits.max(dim=-1)[0] + + return y_corr - y_others + +class SquareAttack(Attack, LabelMixin): + """ + Square Attack + https://arxiv.org/abs/1912.00049 + + :param predict: forward pass function + :param norm: Lp-norm of the attack ('Linf', 'L2' supported) + :param n_restarts: number of random restarts + :param n_queries: max number of queries (each restart) + :param eps: bound on the norm of perturbations + :param seed: random seed for the starting point + :param p_init: parameter to control size of squares + """ + + def __init__( + self, + predict, + norm='Linf', + n_queries=5000, + eps=None, + p_init=.8, + n_restarts=1, + seed=0, + verbose=False): + """ + Square Attack implementation in PyTorch + """ + super(SquareAttack, self).__init__( + predict, loss_fn=None, clip_min=0., clip_max=1.) + + self.model = ModelAdapterSA(predict) + self.norm = norm + self.n_queries = n_queries + self.eps = eps + self.p_init = p_init + self.n_restarts = n_restarts + self.seed = seed + self.verbose = verbose + + def init_hyperparam(self, x): + assert self.norm in ['Linf', 'L2'] + assert not self.eps is None + + self.device = x.device + self.orig_dim = list(x.shape[1:]) + self.ndims = len(self.orig_dim) + if self.seed is None: + self.seed = time.time() + + def check_shape(self, x): + return x if len(x.shape) == (self.ndims + 1) else x.unsqueeze(0) + + def random_choice(self, shape): + t = 2 * torch.rand(shape).to(self.device) - 1 + return torch.sign(t) + + def random_int(self, low=0, high=1, shape=[1]): + t = low + (high - low) * torch.rand(shape).to(self.device) + return t.long() + + def normalize(self, x): + if self.norm == 'Linf': + t = x.abs().view(x.shape[0], -1).max(1)[0] + return x / (t.view(-1, *([1] * self.ndims)) + 1e-12) + + elif self.norm == 'L2': + t = (x ** 2).view(x.shape[0], -1).sum(-1).sqrt() + return x / (t.view(-1, *([1] * self.ndims)) + 1e-12) + + def lp_norm(self, x): + if self.norm == 'L2': + t = (x ** 2).view(x.shape[0], -1).sum(-1).sqrt() + return t.view(-1, *([1] * self.ndims)) + + def eta_rectangles(self, x, y): + delta = torch.zeros([x, y]).to(self.device) + x_c, y_c = x // 2 + 1, y // 2 + 1 + + counter2 = [x_c - 1, y_c - 1] + for counter in range(0, max(x_c, y_c)): + delta[max(counter2[0], 0):min(counter2[0] + (2*counter + 1), x), + max(0, counter2[1]):min(counter2[1] + (2*counter + 1), y) + ] += 1.0/(torch.Tensor([counter + 1]).view(1, 1).to( + self.device) ** 2) + counter2[0] -= 1 + counter2[1] -= 1 + + delta /= (delta ** 2).sum(dim=(0,1), keepdim=True).sqrt() + + return delta + + def eta(self, s): + delta = torch.zeros([s, s]).to(self.device) + delta[:s // 2] = self.eta_rectangles(s // 2, s) + delta[s // 2:] = -1. * self.eta_rectangles(s - s // 2, s) + delta /= (delta ** 2).sum(dim=(0, 1), keepdim=True).sqrt() + if torch.rand([1]) > 0.5: + delta = delta.permute([1, 0]) + + return delta + + def p_selection(self, it): + #it = int(it / self.n_queries * 10000) + + if 10 < it <= 50: + p = self.p_init / 2 + elif 50 < it <= 200: + p = self.p_init / 4 + elif 200 < it <= 500: + p = self.p_init / 8 + elif 500 < it <= 1000: + p = self.p_init / 16 + elif 1000 < it <= 2000: + p = self.p_init / 32 + elif 2000 < it <= 4000: + p = self.p_init / 64 + elif 4000 < it <= 6000: + p = self.p_init / 128 + elif 6000 < it <= 8000: + p = self.p_init / 256 + elif 8000 < it: + p = self.p_init / 512 + else: + p = self.p_init + + return p + + def attack_single_run(self, x, y): + with torch.no_grad(): + if self.norm == 'Linf': + adv = x.clone() + c, h, w = x.shape[1:] + n_features = c * h * w + n_ex_total = x.shape[0] + + x_best = torch.clamp(x + self.eps * self.random_choice( + [x.shape[0], c, 1, w]), 0., 1.) + margin_min = self.model.fmargin(x_best, y) + n_queries = torch.ones(x.shape[0]).to(self.device) + s_init = int(math.sqrt(self.p_init * n_features / c)) + + for i_iter in range(self.n_queries): + idx_to_fool = (margin_min > 0.0).nonzero().squeeze() + + x_curr = self.check_shape(x[idx_to_fool]) + x_best_curr = self.check_shape(x_best[idx_to_fool]) + y_curr = y[idx_to_fool] + margin_min_curr = margin_min[idx_to_fool] + + p = self.p_selection(i_iter) + s = max(int(round(math.sqrt(p * n_features / c))), 1) + vh = self.random_int(0, h - s) + vw = self.random_int(0, w - s) + new_deltas = torch.zeros([c, h, w]).to(self.device) + new_deltas[:, vh:vh + s, vw:vw + s + ] = 2. * self.eps * self.random_choice([c, 1, 1]) + + x_new = x_best_curr + new_deltas + x_new = torch.min(torch.max(x_new, x_curr - self.eps), + x_curr + self.eps) + x_new = torch.clamp(x_new, 0., 1.) + x_new = self.check_shape(x_new) + + margin = self.model.fmargin(x_new, y_curr) + + idx_improved = (margin < margin_min_curr).float() + margin_min[idx_to_fool] = idx_improved * margin + ( + 1. - idx_improved) * margin_min_curr + idx_improved = idx_improved.reshape([-1, + *[1]*len(x.shape[:-1])]) + x_best[idx_to_fool] = idx_improved * x_new + ( + 1. - idx_improved) * x_best_curr + n_queries[idx_to_fool] += 1. + + ind_succ = (margin_min <= 0.).nonzero().squeeze() + if self.verbose and ind_succ.numel() != 0: + print('{}'.format(i_iter + 1), + '- success rate={}/{} ({:.2%})'.format( + ind_succ.numel(), n_ex_total, + float(ind_succ.numel()) / n_ex_total), + '- avg # queries={:.1f}'.format( + n_queries[ind_succ].mean().item()), + '- med # queries={:.1f}'.format( + n_queries[ind_succ].median().item()), + '- loss={:.3f}'.format(margin_min.mean())) + + if ind_succ.numel() == n_ex_total: + break + + elif self.norm == 'L2': + adv = x.clone() + c, h, w = x.shape[1:] + n_features = c * h * w + n_ex_total = x.shape[0] + + delta_init = torch.zeros_like(x) + s = h // 5 + sp_init = (h - s * 5) // 2 + vh = sp_init + 0 + for _ in range(h // s): + vw = sp_init + 0 + for _ in range(w // s): + delta_init[:, :, vh:vh + s, vw:vw + s] += self.eta( + s).view(1, 1, s, s) * self.random_choice( + [x.shape[0], c, 1, 1]) + vw += s + vh += s + + x_best = torch.clamp(x + self.normalize(delta_init + ) * self.eps, 0., 1.) + margin_min = self.model.fmargin(x_best, y) + n_queries = torch.ones(x.shape[0]).to(self.device) + s_init = int(math.sqrt(self.p_init * n_features / c)) + + for i_iter in range(self.n_queries): + idx_to_fool = (margin_min > 0.0).nonzero().squeeze() + + x_curr = self.check_shape(x[idx_to_fool]) + x_best_curr = self.check_shape(x_best[idx_to_fool]) + y_curr = y[idx_to_fool] + margin_min_curr = margin_min[idx_to_fool] + + delta_curr = x_best_curr - x_curr + p = self.p_selection(i_iter) + s = max(int(round(math.sqrt(p * n_features / c))), 3) + if s % 2 == 0: + s += 1 + + + vh = self.random_int(0, h - s) + vw = self.random_int(0, w - s) + new_deltas_mask = torch.zeros_like(x_curr) + new_deltas_mask[:, :, vh:vh + s, vw:vw + s] = 1.0 + norms_window_1 = (delta_curr[:, :, vh:vh + s, vw:vw + s + ] ** 2).sum(dim=(-2, -1), keepdim=True).sqrt() + + vh2 = self.random_int(0, h - s) + vw2 = self.random_int(0, w - s) + new_deltas_mask_2 = torch.zeros_like(x_curr) + new_deltas_mask_2[:, :, vh2:vh2 + s, vw2:vw2 + s] = 1. + + norms_image = self.lp_norm(x_best_curr - x_curr) + mask_image = torch.max(new_deltas_mask, new_deltas_mask_2) + norms_windows = self.lp_norm(delta_curr * mask_image) + + new_deltas = torch.ones([x_curr.shape[0], c, s, s] + ).to(self.device) + new_deltas *= (self.eta(s).view(1, 1, s, s) * + self.random_choice([x_curr.shape[0], c, 1, 1])) + old_deltas = delta_curr[:, :, vh:vh + s, vw:vw + s] / ( + 1e-12 + norms_window_1) + new_deltas += old_deltas + new_deltas = new_deltas / (new_deltas ** 2).sum( + dim=(-2, -1), keepdim=True).sqrt() * (torch.max( + (self.eps * torch.ones_like(new_deltas)) ** 2 - + norms_image ** 2, torch.zeros_like(new_deltas)) / + c + norms_windows ** 2).sqrt() + delta_curr[:, :, vh2:vh2 + s, vw2:vw2 + s] = 0. + delta_curr[:, :, vh:vh + s, vw:vw + s] = new_deltas + 0 + + x_new = torch.clamp(x_curr + self.normalize(delta_curr + ) * self.eps, 0. ,1.) + x_new = self.check_shape(x_new) + + norms_image = self.lp_norm(x_new - x_curr) + + margin = self.model.fmargin(x_new, y_curr) + idx_improved = (margin < margin_min_curr).float() + margin_min[idx_to_fool] = idx_improved * margin + ( + 1. - idx_improved) * margin_min_curr + idx_improved = idx_improved.reshape([-1, + *[1]*len(x.shape[:-1])]) + x_best[idx_to_fool] = idx_improved * x_new + ( + 1. - idx_improved) * x_best_curr + n_queries[idx_to_fool] += 1. + + + ind_succ = (margin_min <= 0.).nonzero().squeeze() + if self.verbose and ind_succ.numel() != 0: + print('{}'.format(i_iter + 1), + '- success rate={}/{} ({:.2%})'.format( + ind_succ.numel(), n_ex_total, float( + ind_succ.numel()) / n_ex_total), + '- avg # queries={:.1f}'.format( + n_queries[ind_succ].mean().item()), + '- med # queries={:.1f}'.format( + n_queries[ind_succ].median().item()), + '- loss={:.3f}'.format(margin_min.mean())) + + if ind_succ.numel() == n_ex_total: + break + + return n_queries, x_best + + def perturb(self, x, y=None): + """ + :param x: clean images + :param y: clean labels, if None we use the predicted labels + """ + + self.init_hyperparam(x) + + adv = x.clone() + if y is None: + y_pred = self._get_predicted_label(x) + y = y_pred.detach().clone().long().to(self.device) + else: + y = y.detach().clone().long().to(self.device) + + acc = self.model.predict(x).max(1)[1] == y + + startt = time.time() + + torch.random.manual_seed(self.seed) + torch.cuda.random.manual_seed(self.seed) + + for counter in range(self.n_restarts): + ind_to_fool = acc.nonzero().squeeze() + if len(ind_to_fool.shape) == 0: + ind_to_fool = ind_to_fool.unsqueeze(0) + if ind_to_fool.numel() != 0: + x_to_fool = x[ind_to_fool].clone() + y_to_fool = y[ind_to_fool].clone() + + _, adv_curr = self.attack_single_run(x_to_fool, y_to_fool) + + output_curr = self.model.predict(adv_curr) + acc_curr = output_curr.max(1)[1] == y_to_fool + ind_curr = (acc_curr == 0).nonzero().squeeze() + + acc[ind_to_fool[ind_curr]] = 0 + adv[ind_to_fool[ind_curr]] = adv_curr[ind_curr].clone() + if self.verbose: + print('restart {} - robust accuracy: {:.2%}'.format( + counter, acc.float().mean()), + '- cum. time: {:.1f} s'.format( + time.time() - startt)) + + return adv \ No newline at end of file From 9afe0a5709413aedac04511176c9f01d195c1703 Mon Sep 17 00:00:00 2001 From: fra31 Date: Tue, 14 Apr 2020 15:57:13 +0000 Subject: [PATCH 3/7] updated square --- advertorch/attacks/square.py | 230 +++++++++++++++++++++-------------- 1 file changed, 140 insertions(+), 90 deletions(-) diff --git a/advertorch/attacks/square.py b/advertorch/attacks/square.py index c5c01ae..ce471be 100644 --- a/advertorch/attacks/square.py +++ b/advertorch/attacks/square.py @@ -1,4 +1,4 @@ -#Copyright (c) 2020-present, Francesco Croce +# Copyright (c) 2020-present, Francesco Croce # All rights reserved. # # This source code is licensed under the license found in the @@ -13,36 +13,17 @@ import torch import time import math +import torch.nn.functional as F from .base import Attack from .base import LabelMixin -class ModelAdapterSA(): - """ - Wrapper for Square Attack - """ - - def __init__(self, model): - self.model = model - - def predict(self, x): - return self.model(x) - - def fmargin(self, x, y): - logits = self.predict(x) - u = torch.arange(x.shape[0]) - y_corr = logits[u, y].clone() - logits[u, y] = -float('inf') - y_others = logits.max(dim=-1)[0] - - return y_corr - y_others - class SquareAttack(Attack, LabelMixin): """ Square Attack https://arxiv.org/abs/1912.00049 - + :param predict: forward pass function :param norm: Lp-norm of the attack ('Linf', 'L2' supported) :param n_restarts: number of random restarts @@ -50,8 +31,10 @@ class SquareAttack(Attack, LabelMixin): :param eps: bound on the norm of perturbations :param seed: random seed for the starting point :param p_init: parameter to control size of squares + :param loss: loss function optimized ('margin', 'ce' supported) + :param resc_schedule adapt schedule of p to n_queries """ - + def __init__( self, predict, @@ -61,14 +44,17 @@ def __init__( p_init=.8, n_restarts=1, seed=0, - verbose=False): + verbose=False, + targeted=False, + loss='margin', + resc_schedule=True): """ Square Attack implementation in PyTorch """ super(SquareAttack, self).__init__( predict, loss_fn=None, clip_min=0., clip_max=1.) - - self.model = ModelAdapterSA(predict) + + self.predict = predict self.norm = norm self.n_queries = n_queries self.eps = eps @@ -76,28 +62,62 @@ def __init__( self.n_restarts = n_restarts self.seed = seed self.verbose = verbose - + self.targeted = targeted + self.loss = loss + self.rescale_schedule = resc_schedule + + def margin_and_loss(self, x, y): + """ + :param y: correct labels if untargeted else target labels + """ + + logits = self.predict(x) + xent = F.cross_entropy(logits, y, reduction='none') + u = torch.arange(x.shape[0]) + y_corr = logits[u, y].clone() + logits[u, y] = -float('inf') + y_others = logits.max(dim=-1)[0] + + if not self.targeted: + if self.loss == 'ce': + return y_corr - y_others, -1. * xent + elif self.loss == 'margin': + return y_corr - y_others, y_corr - y_others + else: + return y_others - y_corr, xent + def init_hyperparam(self, x): assert self.norm in ['Linf', 'L2'] assert not self.eps is None + assert self.loss in ['ce', 'margin'] self.device = x.device self.orig_dim = list(x.shape[1:]) self.ndims = len(self.orig_dim) if self.seed is None: self.seed = time.time() - + + def random_target_classes(self, y_pred, n_classes): + y = torch.zeros_like(y_pred) + for counter in range(y_pred.shape[0]): + l = list(range(n_classes)) + l.remove(y_pred[counter]) + t = self.random_int(0, len(l)) + y[counter] = l[t] + + return y.long().to(self.device) + def check_shape(self, x): return x if len(x.shape) == (self.ndims + 1) else x.unsqueeze(0) - + def random_choice(self, shape): t = 2 * torch.rand(shape).to(self.device) - 1 return torch.sign(t) - + def random_int(self, low=0, high=1, shape=[1]): t = low + (high - low) * torch.rand(shape).to(self.device) return t.long() - + def normalize(self, x): if self.norm == 'Linf': t = x.abs().view(x.shape[0], -1).max(1)[0] @@ -111,11 +131,11 @@ def lp_norm(self, x): if self.norm == 'L2': t = (x ** 2).view(x.shape[0], -1).sum(-1).sqrt() return t.view(-1, *([1] * self.ndims)) - + def eta_rectangles(self, x, y): delta = torch.zeros([x, y]).to(self.device) x_c, y_c = x // 2 + 1, y // 2 + 1 - + counter2 = [x_c - 1, y_c - 1] for counter in range(0, max(x_c, y_c)): delta[max(counter2[0], 0):min(counter2[0] + (2*counter + 1), x), @@ -124,7 +144,7 @@ def eta_rectangles(self, x, y): self.device) ** 2) counter2[0] -= 1 counter2[1] -= 1 - + delta /= (delta ** 2).sum(dim=(0,1), keepdim=True).sqrt() return delta @@ -136,12 +156,15 @@ def eta(self, s): delta /= (delta ** 2).sum(dim=(0, 1), keepdim=True).sqrt() if torch.rand([1]) > 0.5: delta = delta.permute([1, 0]) - + return delta def p_selection(self, it): - #it = int(it / self.n_queries * 10000) - + """ schedule to decrease the parameter p """ + + if self.rescale_schedule: + it = int(it / self.n_queries * 10000) + if 10 < it <= 50: p = self.p_init / 2 elif 50 < it <= 200: @@ -162,20 +185,20 @@ def p_selection(self, it): p = self.p_init / 512 else: p = self.p_init - + return p def attack_single_run(self, x, y): with torch.no_grad(): + adv = x.clone() + c, h, w = x.shape[1:] + n_features = c * h * w + n_ex_total = x.shape[0] + if self.norm == 'Linf': - adv = x.clone() - c, h, w = x.shape[1:] - n_features = c * h * w - n_ex_total = x.shape[0] - x_best = torch.clamp(x + self.eps * self.random_choice( [x.shape[0], c, 1, w]), 0., 1.) - margin_min = self.model.fmargin(x_best, y) + margin_min, loss_min = self.margin_and_loss(x_best, y) n_queries = torch.ones(x.shape[0]).to(self.device) s_init = int(math.sqrt(self.p_init * n_features / c)) @@ -185,7 +208,10 @@ def attack_single_run(self, x, y): x_curr = self.check_shape(x[idx_to_fool]) x_best_curr = self.check_shape(x_best[idx_to_fool]) y_curr = y[idx_to_fool] + if len(y_curr.shape) == 0: + y_curr = y_curr.unsqueeze(0) margin_min_curr = margin_min[idx_to_fool] + loss_min_curr = loss_min[idx_to_fool] p = self.p_selection(i_iter) s = max(int(round(math.sqrt(p * n_features / c))), 1) @@ -201,11 +227,16 @@ def attack_single_run(self, x, y): x_new = torch.clamp(x_new, 0., 1.) x_new = self.check_shape(x_new) - margin = self.model.fmargin(x_new, y_curr) + margin, loss = self.margin_and_loss(x_new, y_curr) - idx_improved = (margin < margin_min_curr).float() + idx_improved = (loss < loss_min_curr).float() margin_min[idx_to_fool] = idx_improved * margin + ( 1. - idx_improved) * margin_min_curr + loss_min[idx_to_fool] = idx_improved * loss + ( + 1. - idx_improved) * loss_min_curr + + idx_miscl = (margin <= 0.).float() + idx_improved = torch.max(idx_improved, idx_miscl) idx_improved = idx_improved.reshape([-1, *[1]*len(x.shape[:-1])]) x_best[idx_to_fool] = idx_improved * x_new + ( @@ -222,17 +253,12 @@ def attack_single_run(self, x, y): n_queries[ind_succ].mean().item()), '- med # queries={:.1f}'.format( n_queries[ind_succ].median().item()), - '- loss={:.3f}'.format(margin_min.mean())) + '- loss={:.3f}'.format(loss_min.mean())) if ind_succ.numel() == n_ex_total: break elif self.norm == 'L2': - adv = x.clone() - c, h, w = x.shape[1:] - n_features = c * h * w - n_ex_total = x.shape[0] - delta_init = torch.zeros_like(x) s = h // 5 sp_init = (h - s * 5) // 2 @@ -245,44 +271,46 @@ def attack_single_run(self, x, y): [x.shape[0], c, 1, 1]) vw += s vh += s - + x_best = torch.clamp(x + self.normalize(delta_init ) * self.eps, 0., 1.) - margin_min = self.model.fmargin(x_best, y) + margin_min, loss_min = self.margin_and_loss(x_best, y) n_queries = torch.ones(x.shape[0]).to(self.device) s_init = int(math.sqrt(self.p_init * n_features / c)) - + for i_iter in range(self.n_queries): idx_to_fool = (margin_min > 0.0).nonzero().squeeze() - + x_curr = self.check_shape(x[idx_to_fool]) x_best_curr = self.check_shape(x_best[idx_to_fool]) y_curr = y[idx_to_fool] + if len(y_curr.shape) == 0: + y_curr = y_curr.unsqueeze(0) margin_min_curr = margin_min[idx_to_fool] - + loss_min_curr = loss_min[idx_to_fool] + delta_curr = x_best_curr - x_curr p = self.p_selection(i_iter) s = max(int(round(math.sqrt(p * n_features / c))), 3) if s % 2 == 0: s += 1 - - + vh = self.random_int(0, h - s) vw = self.random_int(0, w - s) new_deltas_mask = torch.zeros_like(x_curr) new_deltas_mask[:, :, vh:vh + s, vw:vw + s] = 1.0 norms_window_1 = (delta_curr[:, :, vh:vh + s, vw:vw + s ] ** 2).sum(dim=(-2, -1), keepdim=True).sqrt() - + vh2 = self.random_int(0, h - s) vw2 = self.random_int(0, w - s) new_deltas_mask_2 = torch.zeros_like(x_curr) new_deltas_mask_2[:, :, vh2:vh2 + s, vw2:vw2 + s] = 1. - + norms_image = self.lp_norm(x_best_curr - x_curr) mask_image = torch.max(new_deltas_mask, new_deltas_mask_2) norms_windows = self.lp_norm(delta_curr * mask_image) - + new_deltas = torch.ones([x_curr.shape[0], c, s, s] ).to(self.device) new_deltas *= (self.eta(s).view(1, 1, s, s) * @@ -297,24 +325,29 @@ def attack_single_run(self, x, y): c + norms_windows ** 2).sqrt() delta_curr[:, :, vh2:vh2 + s, vw2:vw2 + s] = 0. delta_curr[:, :, vh:vh + s, vw:vw + s] = new_deltas + 0 - + x_new = torch.clamp(x_curr + self.normalize(delta_curr ) * self.eps, 0. ,1.) x_new = self.check_shape(x_new) - norms_image = self.lp_norm(x_new - x_curr) - - margin = self.model.fmargin(x_new, y_curr) - idx_improved = (margin < margin_min_curr).float() + + margin, loss = self.margin_and_loss(x_new, y_curr) + + idx_improved = (loss < loss_min_curr).float() margin_min[idx_to_fool] = idx_improved * margin + ( 1. - idx_improved) * margin_min_curr + loss_min[idx_to_fool] = idx_improved * loss + ( + 1. - idx_improved) * loss_min_curr + + idx_miscl = (margin <= 0.).float() + idx_improved = torch.max(idx_improved, idx_miscl) + idx_improved = idx_improved.reshape([-1, *[1]*len(x.shape[:-1])]) x_best[idx_to_fool] = idx_improved * x_new + ( 1. - idx_improved) * x_best_curr n_queries[idx_to_fool] += 1. - - + ind_succ = (margin_min <= 0.).nonzero().squeeze() if self.verbose and ind_succ.numel() != 0: print('{}'.format(i_iter + 1), @@ -325,35 +358,48 @@ def attack_single_run(self, x, y): n_queries[ind_succ].mean().item()), '- med # queries={:.1f}'.format( n_queries[ind_succ].median().item()), - '- loss={:.3f}'.format(margin_min.mean())) - + '- loss={:.3f}'.format(loss_min.mean())) + if ind_succ.numel() == n_ex_total: break - + return n_queries, x_best - + def perturb(self, x, y=None): """ :param x: clean images - :param y: clean labels, if None we use the predicted labels + :param y: untargeted attack -> clean labels, + if None we use the predicted labels + targeted attack -> target labels, if None random classes, + different from the predicted ones, are sampled """ - + self.init_hyperparam(x) - + adv = x.clone() if y is None: - y_pred = self._get_predicted_label(x) - y = y_pred.detach().clone().long().to(self.device) + if not self.targeted: + y_pred = self._get_predicted_label(x) + y = y_pred.detach().clone().long().to(self.device) + else: + with torch.no_grad(): + output = self.predict(x) + n_classes = output.shape[-1] + y_pred = output.max(1)[1] + y = self.random_target_classes(y_pred, n_classes) else: y = y.detach().clone().long().to(self.device) - - acc = self.model.predict(x).max(1)[1] == y - + + if not self.targeted: + acc = self.predict(x).max(1)[1] == y + else: + acc = self.predict(x).max(1)[1] != y + startt = time.time() - + torch.random.manual_seed(self.seed) torch.cuda.random.manual_seed(self.seed) - + for counter in range(self.n_restarts): ind_to_fool = acc.nonzero().squeeze() if len(ind_to_fool.shape) == 0: @@ -361,13 +407,16 @@ def perturb(self, x, y=None): if ind_to_fool.numel() != 0: x_to_fool = x[ind_to_fool].clone() y_to_fool = y[ind_to_fool].clone() - + _, adv_curr = self.attack_single_run(x_to_fool, y_to_fool) - - output_curr = self.model.predict(adv_curr) - acc_curr = output_curr.max(1)[1] == y_to_fool + + output_curr = self.predict(adv_curr) + if not self.targeted: + acc_curr = output_curr.max(1)[1] == y_to_fool + else: + acc_curr = output_curr.max(1)[1] != y_to_fool ind_curr = (acc_curr == 0).nonzero().squeeze() - + acc[ind_to_fool[ind_curr]] = 0 adv[ind_to_fool[ind_curr]] = adv_curr[ind_curr].clone() if self.verbose: @@ -375,5 +424,6 @@ def perturb(self, x, y=None): counter, acc.float().mean()), '- cum. time: {:.1f} s'.format( time.time() - startt)) - - return adv \ No newline at end of file + + return adv + From 06e670589ff9658c12f393f066df800672ded0a9 Mon Sep 17 00:00:00 2001 From: fra31 Date: Tue, 21 Apr 2020 17:29:55 +0000 Subject: [PATCH 4/7] added tests for square --- advertorch/attacks/__init__.py | 9 +++++++++ advertorch/attacks/fab_with_threshold.py | 12 ++++++------ advertorch/attacks/square.py | 22 +++++++++++----------- advertorch/test_utils.py | 3 +++ 4 files changed, 29 insertions(+), 17 deletions(-) diff --git a/advertorch/attacks/__init__.py b/advertorch/attacks/__init__.py index 56b2bc4..dbf9d91 100644 --- a/advertorch/attacks/__init__.py +++ b/advertorch/attacks/__init__.py @@ -44,9 +44,18 @@ from .jsma import JSMA from .spsa import LinfSPSAAttack + from .fast_adaptive_boundary import FABAttack from .fast_adaptive_boundary import LinfFABAttack from .fast_adaptive_boundary import L2FABAttack from .fast_adaptive_boundary import L1FABAttack +from .fab_with_threshold import FABWithThreshold +from .fab_with_threshold import FABTargeted + +from .autopgd import APGDAttack +from .autopgd import APGDTargeted + +from .square import SquareAttack + from .utils import ChooseBestAttack diff --git a/advertorch/attacks/fab_with_threshold.py b/advertorch/attacks/fab_with_threshold.py index 38190e6..f5e06eb 100644 --- a/advertorch/attacks/fab_with_threshold.py +++ b/advertorch/attacks/fab_with_threshold.py @@ -45,14 +45,14 @@ def __init__( beta=0.9, verbose=False, seed=0): - + super(FABWithThreshold, self).__init__( predict=predict, norm=norm, n_restarts=n_restarts, n_iter=n_iter, eps=eps, alpha_max=alpha_max, eta=eta, beta=beta, verbose=verbose) - + self.seed = seed - + def init_hyperparam(self, x): assert self.norm in ['Linf', 'L2', 'L1'] assert not self.eps is None @@ -62,13 +62,13 @@ def init_hyperparam(self, x): self.ndims = len(self.orig_dim) if self.seed is None: self.seed = time.time() - + def attack_single_run(self, x, y, use_rand_start=False): startt = time.time() if len(x.shape) == self.ndims: x = x.unsqueeze(0) y = y.unsqueeze(0) - + im2 = x.clone() la2 = y.clone() bs = im2.shape[0] @@ -79,7 +79,7 @@ def attack_single_run(self, x, y, use_rand_start=False): res_c = torch.zeros([x.shape[0]]).to(self.device) x1 = im2.clone() x0 = im2.clone().reshape([bs, -1]) - + if use_rand_start: if self.norm == 'Linf': t = 2 * torch.rand(x1.shape).to(self.device) - 1 diff --git a/advertorch/attacks/square.py b/advertorch/attacks/square.py index ce471be..68e7d35 100644 --- a/advertorch/attacks/square.py +++ b/advertorch/attacks/square.py @@ -194,17 +194,17 @@ def attack_single_run(self, x, y): c, h, w = x.shape[1:] n_features = c * h * w n_ex_total = x.shape[0] - + if self.norm == 'Linf': x_best = torch.clamp(x + self.eps * self.random_choice( [x.shape[0], c, 1, w]), 0., 1.) margin_min, loss_min = self.margin_and_loss(x_best, y) n_queries = torch.ones(x.shape[0]).to(self.device) s_init = int(math.sqrt(self.p_init * n_features / c)) - + for i_iter in range(self.n_queries): idx_to_fool = (margin_min > 0.0).nonzero().squeeze() - + x_curr = self.check_shape(x[idx_to_fool]) x_best_curr = self.check_shape(x_best[idx_to_fool]) y_curr = y[idx_to_fool] @@ -212,7 +212,7 @@ def attack_single_run(self, x, y): y_curr = y_curr.unsqueeze(0) margin_min_curr = margin_min[idx_to_fool] loss_min_curr = loss_min[idx_to_fool] - + p = self.p_selection(i_iter) s = max(int(round(math.sqrt(p * n_features / c))), 1) vh = self.random_int(0, h - s) @@ -220,21 +220,21 @@ def attack_single_run(self, x, y): new_deltas = torch.zeros([c, h, w]).to(self.device) new_deltas[:, vh:vh + s, vw:vw + s ] = 2. * self.eps * self.random_choice([c, 1, 1]) - + x_new = x_best_curr + new_deltas x_new = torch.min(torch.max(x_new, x_curr - self.eps), x_curr + self.eps) x_new = torch.clamp(x_new, 0., 1.) x_new = self.check_shape(x_new) - + margin, loss = self.margin_and_loss(x_new, y_curr) - + idx_improved = (loss < loss_min_curr).float() margin_min[idx_to_fool] = idx_improved * margin + ( 1. - idx_improved) * margin_min_curr loss_min[idx_to_fool] = idx_improved * loss + ( 1. - idx_improved) * loss_min_curr - + idx_miscl = (margin <= 0.).float() idx_improved = torch.max(idx_improved, idx_miscl) idx_improved = idx_improved.reshape([-1, @@ -242,7 +242,7 @@ def attack_single_run(self, x, y): x_best[idx_to_fool] = idx_improved * x_new + ( 1. - idx_improved) * x_best_curr n_queries[idx_to_fool] += 1. - + ind_succ = (margin_min <= 0.).nonzero().squeeze() if self.verbose and ind_succ.numel() != 0: print('{}'.format(i_iter + 1), @@ -254,10 +254,10 @@ def attack_single_run(self, x, y): '- med # queries={:.1f}'.format( n_queries[ind_succ].median().item()), '- loss={:.3f}'.format(loss_min.mean())) - + if ind_succ.numel() == n_ex_total: break - + elif self.norm == 'L2': delta_init = torch.zeros_like(x) s = h // 5 diff --git a/advertorch/test_utils.py b/advertorch/test_utils.py index 665867c..8f89ced 100644 --- a/advertorch/test_utils.py +++ b/advertorch/test_utils.py @@ -31,6 +31,7 @@ from advertorch.attacks import LinfFABAttack from advertorch.attacks import L2FABAttack from advertorch.attacks import L1FABAttack +from advertorch.attacks import SquareAttack from advertorch.defenses import JPEGFilter from advertorch.defenses import BitSqueezing from advertorch.defenses import MedianSmoothing2D @@ -244,6 +245,7 @@ def generate_data_model_on_img(): image_only_attacks = [ SpatialTransformAttack, LocalSearchAttack, + SquareAttack, ] label_attacks = [ @@ -265,6 +267,7 @@ def generate_data_model_on_img(): LinfFABAttack, L2FABAttack, L1FABAttack, + SquareAttack, ] feature_attacks = [ From 6c860020ec902157fef6c62b018121e513d39a5d Mon Sep 17 00:00:00 2001 From: fra31 Date: Fri, 24 Apr 2020 16:53:58 +0200 Subject: [PATCH 5/7] added benchmarks --- .../attack_benchmarks/benchmark_autopgd.py | 175 +++++++++++++++ .../benchmark_fab_with_threshold.py | 200 ++++++++++++++++++ .../attack_benchmarks/benchmark_square.py | 117 ++++++++++ 3 files changed, 492 insertions(+) create mode 100644 advertorch_examples/attack_benchmarks/benchmark_autopgd.py create mode 100644 advertorch_examples/attack_benchmarks/benchmark_fab_with_threshold.py create mode 100644 advertorch_examples/attack_benchmarks/benchmark_square.py diff --git a/advertorch_examples/attack_benchmarks/benchmark_autopgd.py b/advertorch_examples/attack_benchmarks/benchmark_autopgd.py new file mode 100644 index 0000000..ed98ebe --- /dev/null +++ b/advertorch_examples/attack_benchmarks/benchmark_autopgd.py @@ -0,0 +1,175 @@ +# Copyright (c) 2018-present, Royal Bank of Canada and other authors. +# See the AUTHORS.txt file for a list of contributors. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# +# +# +# Automatically generated benchmark report (screen print of running this file) +# +# sysname: Linux +# release: 4.4.0-146-generic +# version: #172-Ubuntu SMP Wed Apr 3 09:00:08 UTC 2019 +# machine: x86_64 +# python: 3.6.8 +# torch: 1.4.0 +# torchvision: 0.5.0 +# advertorch: 0.2.2 + +# attack type: APGDAttack +# attack kwargs: norm=Linf +# n_restarts=1 +# n_iter=100 +# eps=0.3 +# loss=ce +# data: mnist_test, 10000 samples +# model: MNIST LeNet5 standard training +# accuracy: 98.89% +# attack success rate: 100.0% + +# attack type: APGDAttack +# attack kwargs: norm=L2 +# n_restarts=1 +# n_iter=100 +# eps=2.0 +# loss=ce +# data: mnist_test, 10000 samples +# model: MNIST LeNet5 standard training +# accuracy: 98.89% +# attack success rate: 90.76% + +# attack type: APGDTargeted +# attack kwargs: norm=Linf +# n_restarts=1 +# n_iter=20 +# eps=0.3 +# n_target_classes=9 +# data: mnist_test, 10000 samples +# model: MNIST LeNet5 standard training +# accuracy: 98.89% +# attack success rate: 100.0% + +# attack type: APGDTargeted +# attack kwargs: norm=L2 +# n_restarts=1 +# n_iter=20 +# eps=2.0 +# n_target_classes=9 +# data: mnist_test, 10000 samples +# model: MNIST LeNet5 standard training +# accuracy: 98.89% +# attack success rate: 91.61% + +# attack type: APGDAttack +# attack kwargs: norm=Linf +# n_restarts=1 +# n_iter=100 +# eps=0.3 +# loss=ce +# data: mnist_test, 10000 samples +# model: MNIST LeNet 5 PGD training according to Madry et al. 2018 +# accuracy: 98.64% +# attack success rate: 9.02% + +# attack type: APGDAttack +# attack kwargs: norm=L2 +# n_restarts=1 +# n_iter=100 +# eps=2.0 +# loss=ce +# data: mnist_test, 10000 samples +# model: MNIST LeNet 5 PGD training according to Madry et al. 2018 +# accuracy: 98.64% +# attack success rate: 15.81% + +# attack type: APGDTargeted +# attack kwargs: norm=Linf +# n_restarts=1 +# n_iter=20 +# eps=0.3 +# n_target_classes=9 +# data: mnist_test, 10000 samples +# model: MNIST LeNet 5 PGD training according to Madry et al. 2018 +# accuracy: 98.64% +# attack success rate: 7.96% + +# attack type: APGDTargeted +# attack kwargs: norm=L2 +# n_restarts=1 +# n_iter=20 +# eps=2.0 +# n_target_classes=9 +# data: mnist_test, 10000 samples +# model: MNIST LeNet 5 PGD training according to Madry et al. 2018 +# accuracy: 98.64% +# attack success rate: 12.66% + + + +from advertorch_examples.utils import get_mnist_test_loader +from advertorch_examples.utils import get_mnist_lenet5_clntrained +from advertorch_examples.utils import get_mnist_lenet5_advtrained +from advertorch_examples.benchmark_utils import get_benchmark_sys_info + +from advertorch.attacks import APGDAttack, APGDTargeted + +from advertorch_examples.benchmark_utils import benchmark_attack_success_rate + +batch_size = 1000 +device = "cuda" + +lst_attack = [ + (APGDAttack, dict( + norm='Linf', + n_restarts=1, + n_iter=100, + eps=.3, + loss='ce' + )), + (APGDAttack, dict( + norm='L2', + n_restarts=1, + n_iter=100, + eps=2., + loss='ce' + )), + (APGDTargeted, dict( + norm='Linf', + n_restarts=1, + n_iter=20, + eps=.3, + n_target_classes=9 + )), + (APGDTargeted, dict( + norm='L2', + n_restarts=1, + n_iter=20, + eps=2., + n_target_classes=9 + )), +] # each element in the list is the tuple (attack_class, attack_kwargs) + +mnist_clntrained_model = get_mnist_lenet5_clntrained().to(device) +mnist_advtrained_model = get_mnist_lenet5_advtrained().to(device) +mnist_test_loader = get_mnist_test_loader(batch_size=batch_size) + +lst_setting = [ + (mnist_clntrained_model, mnist_test_loader), + (mnist_advtrained_model, mnist_test_loader), +] + + +info = get_benchmark_sys_info() + +lst_benchmark = [] +for model, loader in lst_setting: + for attack_class, attack_kwargs in lst_attack: + lst_benchmark.append(benchmark_attack_success_rate( + model, loader, attack_class, attack_kwargs + )) + +print(info) +for item in lst_benchmark: + print(item) diff --git a/advertorch_examples/attack_benchmarks/benchmark_fab_with_threshold.py b/advertorch_examples/attack_benchmarks/benchmark_fab_with_threshold.py new file mode 100644 index 0000000..75b0d61 --- /dev/null +++ b/advertorch_examples/attack_benchmarks/benchmark_fab_with_threshold.py @@ -0,0 +1,200 @@ +# Copyright (c) 2018-present, Royal Bank of Canada and other authors. +# See the AUTHORS.txt file for a list of contributors. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# +# +# +# Automatically generated benchmark report (screen print of running this file) +# +# sysname: Linux +# release: 4.4.0-146-generic +# version: #172-Ubuntu SMP Wed Apr 3 09:00:08 UTC 2019 +# machine: x86_64 +# python: 3.6.8 +# torch: 1.4.0 +# torchvision: 0.5.0 +# advertorch: 0.2.2 + +# attack type: FABWithThreshold +# attack kwargs: norm=Linf +# n_restarts=1 +# n_iter=20 +# alpha_max=0.1 +# eta=1.05 +# beta=0.9 +# eps=0.4 +# data: mnist_test, 10000 samples +# model: MNIST LeNet5 standard training +# accuracy: 98.89% +# attack success rate: 100.0% +# Among successful attacks (Linf norm) on correctly classified examples: +# minimum distance: 0.0001393 +# median distance: 0.112 +# maximum distance: 0.2132 +# average distance: 0.1092 +# distance standard deviation: 0.03501 + +# attack type: FABWithThreshold +# attack kwargs: norm=L2 +# n_restarts=1 +# n_iter=20 +# alpha_max=0.1 +# eta=1.05 +# beta=0.9 +# eps=2.0 +# data: mnist_test, 10000 samples +# model: MNIST LeNet5 standard training +# accuracy: 98.89% +# attack success rate: 89.1% +# Among successful attacks (L2 norm) on correctly classified examples: +# minimum distance: 0.001727 +# median distance: 1.359 +# maximum distance: 2.0 +# average distance: 1.309 +# distance standard deviation: 0.4015 + +# attack type: FABWithThreshold +# attack kwargs: norm=L1 +# n_restarts=1 +# n_iter=20 +# alpha_max=0.1 +# eta=1.05 +# beta=0.9 +# eps=10.0 +# data: mnist_test, 10000 samples +# model: MNIST LeNet5 standard training +# accuracy: 98.89% +# attack success rate: 70.74% +# Among successful attacks (L1 norm) on correctly classified examples: +# minimum distance: 0.007687 +# median distance: 6.239 +# maximum distance: 9.997 +# average distance: 6.096 +# distance standard deviation: 2.302 + +# attack type: FABWithThreshold +# attack kwargs: norm=Linf +# n_restarts=1 +# n_iter=20 +# alpha_max=0.1 +# eta=1.05 +# beta=0.9 +# eps=0.4 +# data: mnist_test, 10000 samples +# model: MNIST LeNet 5 PGD training according to Madry et al. 2018 +# accuracy: 98.64% +# attack success rate: 92.5% +# Among successful attacks (Linf norm) on correctly classified examples: +# minimum distance: 0.001414 +# median distance: 0.3482 +# maximum distance: 0.3999 +# average distance: 0.3411 +# distance standard deviation: 0.04846 + +# attack type: FABWithThreshold +# attack kwargs: norm=L2 +# n_restarts=1 +# n_iter=20 +# alpha_max=0.1 +# eta=1.05 +# beta=0.9 +# eps=2.0 +# data: mnist_test, 10000 samples +# model: MNIST LeNet 5 PGD training according to Madry et al. 2018 +# accuracy: 98.64% +# attack success rate: 9.23% +# Among successful attacks (L2 norm) on correctly classified examples: +# minimum distance: 0.003937 +# median distance: 1.07 +# maximum distance: 1.999 +# average distance: 1.101 +# distance standard deviation: 0.6595 + +# attack type: FABWithThreshold +# attack kwargs: norm=L1 +# n_restarts=1 +# n_iter=20 +# alpha_max=0.1 +# eta=1.05 +# beta=0.9 +# eps=10.0 +# data: mnist_test, 10000 samples +# model: MNIST LeNet 5 PGD training according to Madry et al. 2018 +# accuracy: 98.64% +# attack success rate: 5.04% +# Among successful attacks (L1 norm) on correctly classified examples: +# minimum distance: 0.006217 +# median distance: 1.553 +# maximum distance: 9.708 +# average distance: 2.462 +# distance standard deviation: 2.407 + + + +from advertorch_examples.utils import get_mnist_test_loader +from advertorch_examples.utils import get_mnist_lenet5_clntrained +from advertorch_examples.utils import get_mnist_lenet5_advtrained +from advertorch_examples.benchmark_utils import get_benchmark_sys_info + +from advertorch.attacks import FABWithThreshold, FABTargeted + +from advertorch_examples.benchmark_utils import benchmark_margin + +batch_size = 1000 +device = "cuda" + +lst_attack = [ + (FABWithThreshold, dict( + norm='Linf', + n_restarts=1, + n_iter=20, + alpha_max=0.1, + eta=1.05, + beta=0.9, + eps=.4, + )), + (FABWithThreshold, dict( + norm='L2', + n_restarts=1, + n_iter=20, + alpha_max=0.1, + eta=1.05, + beta=0.9, + eps=2., + )), + (FABWithThreshold, dict( + norm='L1', + n_restarts=1, + n_iter=20, + alpha_max=0.1, + eta=1.05, + beta=0.9, + eps=10., + )), +] # each element in the list is the tuple (attack_class, attack_kwargs) + +mnist_clntrained_model = get_mnist_lenet5_clntrained().to(device) +mnist_advtrained_model = get_mnist_lenet5_advtrained().to(device) +mnist_test_loader = get_mnist_test_loader(batch_size=batch_size) + +lst_setting = [ + (mnist_clntrained_model, mnist_test_loader), + (mnist_advtrained_model, mnist_test_loader), +] + + +info = get_benchmark_sys_info() + +lst_benchmark = [] +for model, loader in lst_setting: + for attack_class, attack_kwargs in lst_attack: + lst_benchmark.append(benchmark_margin( + model, loader, attack_class, attack_kwargs, + norm=attack_kwargs["norm"])) + +print(info) +for item in lst_benchmark: + print(item) diff --git a/advertorch_examples/attack_benchmarks/benchmark_square.py b/advertorch_examples/attack_benchmarks/benchmark_square.py new file mode 100644 index 0000000..8a3aaa8 --- /dev/null +++ b/advertorch_examples/attack_benchmarks/benchmark_square.py @@ -0,0 +1,117 @@ +# Copyright (c) 2018-present, Royal Bank of Canada and other authors. +# See the AUTHORS.txt file for a list of contributors. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# +# +# +# Automatically generated benchmark report (screen print of running this file) +# +# sysname: Linux +# release: 4.4.0-146-generic +# version: #172-Ubuntu SMP Wed Apr 3 09:00:08 UTC 2019 +# machine: x86_64 +# python: 3.6.8 +# torch: 1.4.0 +# torchvision: 0.5.0 +# advertorch: 0.2.2 + +# attack type: SquareAttack +# attack kwargs: norm=Linf +# n_restarts=1 +# n_queries=1000 +# eps=0.3 +# p_init=0.8 +# data: mnist_test, 10000 samples +# model: MNIST LeNet5 standard training +# accuracy: 98.89% +# attack success rate: 100.0% + +# attack type: SquareAttack +# attack kwargs: norm=L2 +# n_restarts=1 +# n_queries=1000 +# eps=2.0 +# p_init=0.8 +# data: mnist_test, 10000 samples +# model: MNIST LeNet5 standard training +# accuracy: 98.89% +# attack success rate: 58.98% + +# attack type: SquareAttack +# attack kwargs: norm=Linf +# n_restarts=1 +# n_queries=1000 +# eps=0.3 +# p_init=0.8 +# data: mnist_test, 10000 samples +# model: MNIST LeNet 5 PGD training according to Madry et al. 2018 +# accuracy: 98.64% +# attack success rate: 10.71% + +# attack type: SquareAttack +# attack kwargs: norm=L2 +# n_restarts=1 +# n_queries=1000 +# eps=2.0 +# p_init=0.8 +# data: mnist_test, 10000 samples +# model: MNIST LeNet 5 PGD training according to Madry et al. 2018 +# accuracy: 98.64% +# attack success rate: 60.05% + + + +from advertorch_examples.utils import get_mnist_test_loader +from advertorch_examples.utils import get_mnist_lenet5_clntrained +from advertorch_examples.utils import get_mnist_lenet5_advtrained +from advertorch_examples.benchmark_utils import get_benchmark_sys_info + +from advertorch.attacks import SquareAttack + +from advertorch_examples.benchmark_utils import benchmark_attack_success_rate + +batch_size = 1000 +device = "cuda" + +lst_attack = [ + (SquareAttack, dict( + norm='Linf', + n_restarts=1, + n_queries=1000, + eps=.3, + p_init=.8 + )), + (SquareAttack, dict( + norm='L2', + n_restarts=1, + n_queries=1000, + eps=2., + p_init=.8 + )), +] # each element in the list is the tuple (attack_class, attack_kwargs) + +mnist_clntrained_model = get_mnist_lenet5_clntrained().to(device) +mnist_advtrained_model = get_mnist_lenet5_advtrained().to(device) +mnist_test_loader = get_mnist_test_loader(batch_size=batch_size) + +lst_setting = [ + (mnist_clntrained_model, mnist_test_loader), + (mnist_advtrained_model, mnist_test_loader), +] + + +info = get_benchmark_sys_info() + +lst_benchmark = [] +for model, loader in lst_setting: + for attack_class, attack_kwargs in lst_attack: + lst_benchmark.append(benchmark_attack_success_rate( + model, loader, attack_class, attack_kwargs + )) + +print(info) +for item in lst_benchmark: + print(item) From fb3f04c84147206c0de2a56ede5d50060e0fcbf2 Mon Sep 17 00:00:00 2001 From: fra31 Date: Tue, 28 Apr 2020 20:07:10 +0000 Subject: [PATCH 6/7] small fix square targeted --- advertorch/attacks/square.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/advertorch/attacks/square.py b/advertorch/attacks/square.py index 68e7d35..be3eac7 100644 --- a/advertorch/attacks/square.py +++ b/advertorch/attacks/square.py @@ -229,14 +229,19 @@ def attack_single_run(self, x, y): margin, loss = self.margin_and_loss(x_new, y_curr) + # update loss if new loss is better idx_improved = (loss < loss_min_curr).float() - margin_min[idx_to_fool] = idx_improved * margin + ( - 1. - idx_improved) * margin_min_curr + loss_min[idx_to_fool] = idx_improved * loss + ( 1. - idx_improved) * loss_min_curr + # update margin and x_best if new loss is better + # or misclassification idx_miscl = (margin <= 0.).float() idx_improved = torch.max(idx_improved, idx_miscl) + + margin_min[idx_to_fool] = idx_improved * margin + ( + 1. - idx_improved) * margin_min_curr idx_improved = idx_improved.reshape([-1, *[1]*len(x.shape[:-1])]) x_best[idx_to_fool] = idx_improved * x_new + ( @@ -333,15 +338,19 @@ def attack_single_run(self, x, y): margin, loss = self.margin_and_loss(x_new, y_curr) + # update loss if new loss is better idx_improved = (loss < loss_min_curr).float() - margin_min[idx_to_fool] = idx_improved * margin + ( - 1. - idx_improved) * margin_min_curr + loss_min[idx_to_fool] = idx_improved * loss + ( 1. - idx_improved) * loss_min_curr + # update margin and x_best if new loss is better + # or misclassification idx_miscl = (margin <= 0.).float() idx_improved = torch.max(idx_improved, idx_miscl) + margin_min[idx_to_fool] = idx_improved * margin + ( + 1. - idx_improved) * margin_min_curr idx_improved = idx_improved.reshape([-1, *[1]*len(x.shape[:-1])]) x_best[idx_to_fool] = idx_improved * x_new + ( From b7782e8871fb5cb15fe5410724821be9c402a3f8 Mon Sep 17 00:00:00 2001 From: fra31 Date: Fri, 26 Feb 2021 11:21:21 +0000 Subject: [PATCH 7/7] small fix fab --- advertorch/attacks/fab_with_threshold.py | 2 +- advertorch/attacks/fast_adaptive_boundary.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/advertorch/attacks/fab_with_threshold.py b/advertorch/attacks/fab_with_threshold.py index f5e06eb..3b110f0 100644 --- a/advertorch/attacks/fab_with_threshold.py +++ b/advertorch/attacks/fab_with_threshold.py @@ -169,7 +169,7 @@ def attack_single_run(self, x, y, use_rand_start=False): a2 = a0[-bs:] alpha = torch.min(torch.max(a1 / (a1 + a2), torch.zeros(a1.shape) - .to(self.device))[0], + .to(self.device)), self.alpha_max * torch.ones(a1.shape) .to(self.device)) x1 = ((x1 + self.eta * d1) * (1 - alpha) + diff --git a/advertorch/attacks/fast_adaptive_boundary.py b/advertorch/attacks/fast_adaptive_boundary.py index 96730f6..4231ebd 100644 --- a/advertorch/attacks/fast_adaptive_boundary.py +++ b/advertorch/attacks/fast_adaptive_boundary.py @@ -440,7 +440,7 @@ def perturb(self, x, y=None): a2 = a0[-bs:] alpha = torch.min(torch.max(a1 / (a1 + a2), torch.zeros(a1.shape) - .to(self.device))[0], + .to(self.device)), self.alpha_max * torch.ones(a1.shape) .to(self.device)) x1 = ((x1 + self.eta * d1) * (1 - alpha) +