Skip to content

Commit 8f288f0

Browse files
author
Stanislav
authored
Update (#10)
* Gold sources * Refactored. Results confirmed * Fixed dataloading * AUC 0.998037 * AUC 0.998164 * works * Multi-gpu * Refactored * Cleanup * AUC 0.998037 * Cleanup * Cleanup * Cleanup * F1: 0.9389 on MNIST at 50% outliers.
1 parent 4439822 commit 8f288f0

20 files changed

+1017
-1260
lines changed

configs/mnist.yaml

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
DATASET:
2+
FOLDS_COUNT: 5
3+
MEAN: 0.1307
4+
PATH: mnist
5+
STD: 0.3081
6+
TOTAL_CLASS_COUNT: 10
7+
PERCENTAGES: [10, 20, 30, 40, 50]
8+
MODEL:
9+
LATENT_SIZE: 16
10+
Z_DISCRIMINATOR_CROSS_BATCH: False
11+
INPUT_IMAGE_SIZE: 32
12+
INPUT_IMAGE_CHANNELS: 1
13+
OUTPUT_DIR: results
14+
TRAIN:
15+
BASE_LEARNING_RATE: 0.002
16+
BATCH_SIZE: 128
17+
EPOCH_COUNT: 80
18+

dataloading.py

+138
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
# Copyright 2018-2020 Stanislav Pidhorskyi
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
import torch.utils.data
17+
from net import *
18+
import pickle
19+
import numpy as np
20+
from os import path
21+
import dlutils
22+
import warnings
23+
24+
25+
class Dataset:
26+
@staticmethod
27+
def list_of_pairs_to_numpy(l):
28+
return np.asarray([x[1] for x in l], np.float32), np.asarray([x[0] for x in l], np.int)
29+
30+
def __init__(self, data):
31+
self.x, self.y = Dataset.list_of_pairs_to_numpy(data)
32+
33+
def __getitem__(self, index):
34+
if isinstance(index, slice):
35+
return self.y[index.start:index.stop], self.x[index.start:index.stop]
36+
return self.y[index], self.x[index]
37+
38+
def __len__(self):
39+
return len(self.y)
40+
41+
def shuffle(self):
42+
permutation = np.random.permutation(self.y.shape[0])
43+
for x in [self.y, self.x]:
44+
np.take(x, permutation, axis=0, out=x)
45+
46+
47+
def make_datasets(cfg, folding_id, inliner_classes):
48+
data_train = []
49+
data_valid = []
50+
51+
for i in range(cfg.DATASET.FOLDS_COUNT):
52+
if i != folding_id:
53+
with open(path.join(cfg.DATASET.PATH, 'data_fold_%d.pkl' % i), 'rb') as pkl:
54+
fold = pickle.load(pkl)
55+
if len(data_valid) == 0:
56+
data_valid = fold
57+
else:
58+
data_train += fold
59+
60+
outlier_classes = []
61+
for i in range(cfg.DATASET.TOTAL_CLASS_COUNT):
62+
if i not in inliner_classes:
63+
outlier_classes.append(i)
64+
65+
data_train = [x for x in data_train if x[0] in inliner_classes]
66+
67+
with open(path.join(cfg.DATASET.PATH, 'data_fold_%d.pkl') % folding_id, 'rb') as pkl:
68+
data_test = pickle.load(pkl)
69+
70+
train_set = Dataset(data_train)
71+
valid_set = Dataset(data_valid)
72+
test_set = Dataset(data_test)
73+
74+
return train_set, valid_set, test_set
75+
76+
77+
def make_dataloader(dataset, batch_size, device):
78+
class BatchCollator(object):
79+
def __init__(self, device):
80+
self.device = device
81+
82+
def __call__(self, batch):
83+
with torch.no_grad():
84+
y, x = batch
85+
x = torch.tensor(x / 255.0, requires_grad=True, dtype=torch.float32, device=self.device)
86+
y = torch.tensor(y, dtype=torch.int32, device=self.device)
87+
return y, x
88+
89+
data_loader = dlutils.batch_provider(dataset, batch_size, BatchCollator(device))
90+
return data_loader
91+
92+
93+
def create_set_with_outlier_percentage(dataset, inliner_classes, target_percentage, concervative=True):
94+
np.random.seed(0)
95+
dataset.shuffle()
96+
dataset_outlier = [x for x in dataset if x[0] not in inliner_classes]
97+
dataset_inliner = [x for x in dataset if x[0] in inliner_classes]
98+
99+
def increase_length(data_list, target_length):
100+
repeat = (target_length + len(data_list) - 1) // len(data_list)
101+
data_list = data_list * repeat
102+
data_list = data_list[:target_length]
103+
return data_list
104+
105+
if not concervative:
106+
inliner_count = len(dataset_inliner)
107+
outlier_count = inliner_count * target_percentage // (100 - target_percentage)
108+
109+
if len(dataset_outlier) > outlier_count:
110+
dataset_outlier = dataset_outlier[:outlier_count]
111+
else:
112+
outlier_count = len(dataset_outlier)
113+
inliner_count = outlier_count * (100 - target_percentage) // target_percentage
114+
dataset_inliner = dataset_inliner[:inliner_count]
115+
else:
116+
inliner_count = len(dataset_inliner)
117+
outlier_count = len(dataset_outlier)
118+
119+
current_percentage = outlier_count * 100 / (outlier_count + inliner_count)
120+
121+
if current_percentage < target_percentage: # we don't have enought outliers
122+
outlier_count = int(inliner_count * target_percentage / (100.0 - target_percentage))
123+
dataset_outlier = increase_length(dataset_outlier, outlier_count)
124+
else: # we don't have enought inliers
125+
inlier_count = int(outlier_count * (100.0 - target_percentage) / target_percentage)
126+
dataset_inliner = increase_length(dataset_inliner, inlier_count)
127+
128+
dataset = Dataset(dataset_outlier + dataset_inliner)
129+
130+
dataset.shuffle()
131+
132+
# Post checks
133+
outlier_count = len([1 for x in dataset if x[0] not in inliner_classes])
134+
inliner_count = len([1 for x in dataset if x[0] in inliner_classes])
135+
real_percetage = outlier_count * 100.0 / (outlier_count + inliner_count)
136+
assert abs(real_percetage - target_percentage) < 0.01, "Didn't create dataset with requested percentage of outliers"
137+
138+
return dataset

defaults.py

+44
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
from yacs.config import CfgNode as CN
2+
3+
4+
_C = CN()
5+
6+
_C.OUTPUT_DIR = "results"
7+
8+
_C.DATASET = CN()
9+
10+
_C.DATASET.PERCENTAGES = [10, 20, 30, 40, 50]
11+
12+
# Values for MNIST
13+
_C.DATASET.MEAN = 0.1307
14+
_C.DATASET.STD = 0.3081
15+
16+
_C.DATASET.PATH = "mnist"
17+
_C.DATASET.TOTAL_CLASS_COUNT = 10
18+
_C.DATASET.FOLDS_COUNT = 5
19+
20+
_C.MODEL = CN()
21+
_C.MODEL.LATENT_SIZE = 32
22+
_C.MODEL.INPUT_IMAGE_SIZE = 32
23+
_C.MODEL.INPUT_IMAGE_CHANNELS = 1
24+
# If zd_merge true, will use zd discriminator that looks at entire batch.
25+
_C.MODEL.Z_DISCRIMINATOR_CROSS_BATCH = False
26+
27+
28+
_C.TRAIN = CN()
29+
30+
_C.TRAIN.BATCH_SIZE = 256
31+
_C.TRAIN.EPOCH_COUNT = 80
32+
_C.TRAIN.BASE_LEARNING_RATE = 0.002
33+
34+
_C.TEST = CN()
35+
_C.TEST.BATCH_SIZE = 1024
36+
37+
_C.MAKE_PLOTS = True
38+
39+
40+
def get_cfg_defaults():
41+
"""Get a yacs CfgNode object with default values for my_project."""
42+
# Return a clone so that the defaults will not be altered
43+
# This is for the "local variable" use pattern
44+
return _C.clone()

evaluation.py

+128
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
import numpy as np
2+
from sklearn.metrics import roc_auc_score
3+
import pickle
4+
import os
5+
6+
7+
def get_f1(true_positive, false_positive, false_negative):
8+
if true_positive == 0:
9+
return 0.0
10+
precision = true_positive / (true_positive + false_positive)
11+
recall = true_positive / (true_positive + false_negative)
12+
return 2.0 * precision * recall / (precision + recall)
13+
14+
15+
def evaluate(logger, percentage_of_outliers, inliner_classes, prediction, threshold, gt_inlier):
16+
y = np.greater(prediction, threshold)
17+
18+
gt_outlier = np.logical_not(gt_inlier)
19+
20+
true_positive = np.sum(np.logical_and(y, gt_inlier))
21+
true_negative = np.sum(np.logical_and(np.logical_not(y), gt_outlier))
22+
false_positive = np.sum(np.logical_and(y, gt_outlier))
23+
false_negative = np.sum(np.logical_and(np.logical_not(y), gt_inlier))
24+
total_count = true_positive + true_negative + false_positive + false_negative
25+
26+
accuracy = 100 * (true_positive + true_negative) / total_count
27+
28+
y_true = gt_inlier
29+
y_scores = prediction
30+
31+
try:
32+
auc = roc_auc_score(y_true, y_scores)
33+
except:
34+
auc = 0
35+
36+
logger.info("Percentage %f" % percentage_of_outliers)
37+
logger.info("Accuracy %f" % accuracy)
38+
f1 = get_f1(true_positive, false_positive, false_negative)
39+
logger.info("F1 %f" % get_f1(true_positive, false_positive, false_negative))
40+
logger.info("AUC %f" % auc)
41+
42+
# return dict(auc=auc, f1=f1)
43+
44+
# inliers
45+
X1 = [x[1] for x in zip(gt_inlier, prediction) if x[0]]
46+
47+
# outliers
48+
Y1 = [x[1] for x in zip(gt_inlier, prediction) if not x[0]]
49+
50+
minP = min(prediction) - 1
51+
maxP = max(prediction) + 1
52+
53+
##################################################################
54+
# FPR at TPR 95
55+
##################################################################
56+
fpr95 = 0.0
57+
clothest_tpr = 1.0
58+
dist_tpr = 1.0
59+
for threshold in np.arange(minP, maxP, 0.2):
60+
tpr = np.sum(np.greater_equal(X1, threshold)) / np.float(len(X1))
61+
fpr = np.sum(np.greater_equal(Y1, threshold)) / np.float(len(Y1))
62+
if abs(tpr - 0.95) < dist_tpr:
63+
dist_tpr = abs(tpr - 0.95)
64+
clothest_tpr = tpr
65+
fpr95 = fpr
66+
67+
logger.info("tpr: %f" % clothest_tpr)
68+
logger.info("fpr95: %f" % fpr95)
69+
70+
##################################################################
71+
# Detection error
72+
##################################################################
73+
error = 1.0
74+
for threshold in np.arange(minP, maxP, 0.2):
75+
tpr = np.sum(np.less(X1, threshold)) / np.float(len(X1))
76+
fpr = np.sum(np.greater_equal(Y1, threshold)) / np.float(len(Y1))
77+
error = np.minimum(error, (tpr + fpr) / 2.0)
78+
79+
logger.info("Detection error: %f" % error)
80+
81+
##################################################################
82+
# AUPR IN
83+
##################################################################
84+
auprin = 0.0
85+
recallTemp = 1.0
86+
for threshold in np.arange(minP, maxP, 0.2):
87+
tp = np.sum(np.greater_equal(X1, threshold))
88+
fp = np.sum(np.greater_equal(Y1, threshold))
89+
if tp + fp == 0:
90+
continue
91+
precision = tp / (tp + fp)
92+
recall = tp / np.float(len(X1))
93+
auprin += (recallTemp - recall) * precision
94+
recallTemp = recall
95+
auprin += recall * precision
96+
97+
logger.info("auprin: %f" % auprin)
98+
99+
##################################################################
100+
# AUPR OUT
101+
##################################################################
102+
minP, maxP = -maxP, -minP
103+
X1 = [-x for x in X1]
104+
Y1 = [-x for x in Y1]
105+
auprout = 0.0
106+
recallTemp = 1.0
107+
for threshold in np.arange(minP, maxP, 0.2):
108+
tp = np.sum(np.greater_equal(Y1, threshold))
109+
fp = np.sum(np.greater_equal(X1, threshold))
110+
if tp + fp == 0:
111+
continue
112+
precision = tp / (tp + fp)
113+
recall = tp / np.float(len(Y1))
114+
auprout += (recallTemp - recall) * precision
115+
recallTemp = recall
116+
auprout += recall * precision
117+
118+
logger.info("auprout: %f" % auprout)
119+
120+
with open(os.path.join("results.txt"), "a") as file:
121+
file.write(
122+
"Class: %s\n Percentage: %d\n"
123+
"Error: %f\n F1: %f\n AUC: %f\nfpr95: %f"
124+
"\nDetection: %f\nauprin: %f\nauprout: %f\n\n" %
125+
("_".join([str(x) for x in inliner_classes]), percentage_of_outliers, error, f1, auc, fpr95, error, auprin, auprout))
126+
127+
return dict(auc=auc, f1=f1, fpr95=fpr95, error=error, auprin=auprin, auprout=auprout)
128+
# return auc, f1, fpr95, error, auprin, auprout

0 commit comments

Comments
 (0)