-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathforgery_detector.py
120 lines (111 loc) · 4.88 KB
/
forgery_detector.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os
import numpy as np
from matplotlib import pyplot as plt
import torch
from procedures import ModelLayer1
from procedures import ModelLayer2
from scipy.signal import convolve2d
import cv2
import matplotlib as mpl
class Forgery_detection:
def __init__(self, path):
self.path = path
self.imgs = []
self.img_names = []
self.score_arrs = []
self.scores = []
self.labels = []
self.heat_map = []
self.img_size = 4096
self.tile_size = 256
self.overlap = 0.75
self.score_arr_size = int(np.floor((self.img_size - self.tile_size) / ((1 - self.overlap) * self.tile_size) + 1))
self.step_size = int(np.floor(self.tile_size * (1 - self.overlap)))
self.MammoFT = ModelLayer1.Mammo_FT()
self.MammoFT.load_state_dict(torch.load("procedures\\mammo-FT.ckpt"))
self.MammoFT.eval()
self.Classifier = ModelLayer2.ScoreClassifier()
self.Classifier.load_state_dict(torch.load("procedures\\Score_Classifier.ckpt"))
self.Classifier.eval()
def load_imgs(self):
for name in os.listdir(self.path):
self.img_names.append(name)
img = np.load(os.path.join(self.path, name))
self.imgs.append(img)
def classification(self):
for i,img in enumerate(self.imgs, start=0):
score_arr = self.get_score_array(self.padd(img))
self.score_arrs.append(score_arr)
score_arr = torch.from_numpy(np.expand_dims(np.expand_dims(score_arr, axis=0), axis=0).astype(np.float32))
score = self.Classifier(score_arr)
self.scores.append(score.item())
print("name: %s, score: %1.2f, label: %d"%
(self.img_names[i], score, 1 if score > 0.5 else 0))
def padd(self, img):
rows, cols =img.shape
padd_img = np.zeros((self.img_size, self.img_size))
padd_img[self.img_size//2-rows//2:self.img_size//2-rows//2 + rows,
self.img_size//2-cols//2:self.img_size//2-cols//2 + cols] = img
return padd_img
def get_score_array(self, img):
rows, cols = img.shape
score_array = np.zeros((self.score_arr_size, self.score_arr_size))
r = 0
i = 0
while r + self.tile_size <= rows:
j = 0
c = 0
while c + self.tile_size <= cols:
im_tile = img[r:r + self.tile_size, c:c + self.tile_size]
if not self.is_mammo(im_tile):
j += 1
c += self.step_size
continue
im_tile = np.expand_dims(np.expand_dims((im_tile/1100), axis=0), axis=0).astype(np.float32)
im_tile =torch.from_numpy(im_tile)
tile_score, re = self.MammoFT(im_tile)
tile_score = tile_score.item()
score_array[i, j] = tile_score
j +=1
c += self.step_size
i += 1
r += self.step_size
return score_array
def is_mammo(self,tile):
return np.mean(tile) > 150
def generate_heat_map(self):
filter_size = int(1//(1-self.overlap))
filter = np.ones((filter_size, filter_size))/(filter_size**2)
for i,score in enumerate(self.score_arrs, start=0):
pad_size = int(self.img_size//(self.tile_size*(1-self.overlap)))
pad_score = np.zeros((pad_size,pad_size))
pad_score[pad_size//2 - self.score_arr_size//2: pad_size//2 - self.score_arr_size//2 + self.score_arr_size,
pad_size//2 - self.score_arr_size//2: pad_size//2 - self.score_arr_size//2 + self.score_arr_size]\
= score
base = convolve2d(pad_score, filter)
heatmap = cv2.resize(base,(self.img_size, self.img_size))
heatmap = heatmap/np.max(heatmap)
img = self.imgs[i]
rows, cols =img.shape
cut_heatmap = np.copy(heatmap[self.img_size//2-rows//2:self.img_size//2-rows//2 + rows,
self.img_size//2-cols//2:self.img_size//2-cols//2 + cols])
img = self.display(img)
print("name: %s, score: %1.2f, label: %d" %
(self.img_names[i], self.scores[i], 1 if self.scores[i] > 0.5 else 0))
plt.figure()
plt.imshow(img, cmap='gray')
plt.imshow(cut_heatmap, cmap='jet', alpha=0.2)
plt.show()
def display(self,img_array):
# displaying mammogram image to the canvas with color normalization
mammo = np.copy(img_array)
mammo[np.where(mammo > 800)] = 800
mammo[np.where(mammo < 300)] = 300
mammo = ((mammo - 300) / (800 - 300)) * 255
return mammo
path ="D:\\data"
Fd = Forgery_detection(path)
Fd.load_imgs()
Fd.classification()
print("generate heat map")
Fd.generate_heat_map()