-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdisplay_bbox
132 lines (97 loc) · 4.51 KB
/
display_bbox
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
from torch import nn as nn
import torch
import utils
from models import SegDecNet
import cv2
import numpy as np
import matplotlib.pyplot as plt
import time
INPUT_WIDTH = 256
INPUT_HEIGHT = 512
INPUT_CHANNELS = 1
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('using device:', device)
GT_POS = []
GT_NEG = []
PRED_POS = []
PRED_NEG = []
rf_path = '/home/yjkim/HANSUNG/datasets/hs_data/SpecularRF_new/crop_total/'
path = '/home/yjkim/HANSUNG/datasets/hs_data/0220_zoom2.5/acc95/test_real/'
dirs = os.listdir(path)
file_list = [file for file in dirs if file.endswith('Normal.png')]
run_path = '/home/yjkim/HANSUNG/mixed-segdec-net-comind2021/results/HANS/runs_rf_new/' # runs_zoom2.5/'
# run_path = '/home/yjkim/HANSUNG/mixed-segdec-net-comind2021/results/HANS/runs_2.5_2/' # runs_zoom2.5/'
model_path = run_path + 'models/best_state_dict.pth'
outputs_path = os.path.join(run_path, "test_outputs/display_bbox") # "remove1/display_bbox")
print('HS dataset inference start...!')
if __name__ == "__main__":
if not os.path.exists(outputs_path):
os.makedirs(outputs_path)
model = SegDecNet(device, INPUT_WIDTH, INPUT_HEIGHT, INPUT_CHANNELS)
model.set_gradient_multipliers(0)
model.load_state_dict(torch.load(model_path, map_location=device))
model.to(device)
for filename in file_list: # paths:
img_path = os.path.join(rf_path, filename[:-10] + 'RF.png') # filename)
img = cv2.imread(img_path) if INPUT_CHANNELS == 3 else cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
img = cv2.resize(img, (INPUT_WIDTH, INPUT_HEIGHT))
org_img = img.copy()
img = np.transpose(img, (2, 0, 1)) if INPUT_CHANNELS == 3 else img[np.newaxis]
img_t = torch.from_numpy(img)[np.newaxis].float() / 255.0 # must be [BATCH_SIZE x CHANNELS x HEIGHT x WIDTH]
img_t = img_t.to(device)
start = time.time()
prediction, pred_seg = model(img_t)
end = time.time()
pred_seg = nn.Sigmoid()(pred_seg)
prediction = nn.Sigmoid()(prediction)
pred_seg = torch.squeeze(pred_seg)
img_t = torch.squeeze(img_t)
prediction = prediction.item()
image = img_t.detach().cpu().numpy()
pred_seg = pred_seg.detach().cpu().numpy()
dsize = (INPUT_WIDTH, INPUT_HEIGHT)
pred_seg = cv2.resize(pred_seg, dsize)
## pred_seg = cv2.resize(pred_seg[0, 0, :, :], dsize) if len(pred_seg.shape) == 4 else cv2.resize(pred_seg[0, :, :], dsize)
# print('original image size is ', org_img.shape)
# print('segmentation size is ', pred_seg.shape)
new_image = np.where(image != image[0][0], 255, 0)
new_seg = np.where(pred_seg != pred_seg[0][0], 255, 0)
cv2.imwrite(f"{outputs_path}/seg_{prediction}_{filename}.png", new_seg) # it isn't save to images due to warning
cv2.imwrite(f"{outputs_path}/img_{prediction}_{filename}.png", new_image)
vmax_value = max(1, np.max(pred_seg))
print('segmentation: ', vmax_value)
# concat_img = np.hstack([org_img, pred_seg])
jet_seg = cv2.applyColorMap((pred_seg * 255).astype(np.uint8), cv2.COLORMAP_JET)
# cv2.imwrite(f"{outputs_path}/seg_{prediction}_{filename}.png", jet_seg)
scaled = (new_seg / new_seg.max() * 255).astype(np.uint8)
cv2.imwrite(f"{outputs_path}/scaled_{prediction}_{filename}.png", scaled) # jet_seg)
# Calculate to Accuracy
if "line" in filename:
GT_NEG.append(filename)
else:
GT_POS.append(filename)
if prediction < 0.5:
PRED_NEG.append(filename)
else:
PRED_POS.append(filename)
print(f'time is {end - start}s')
'''
dec_out, seg_out = model(img_t)
img_score = torch.sigmoid(dec_out)
cv2.imwrite('result.png', seg_out)
print(img_score)
'''
# cv2.imwrite('/home/yjkim/HANSUNG/mixed-segdec-net-comind2021/result.jpg', pred_seg)
# plot_sample(img_path.split('/')[-1], org_img, seg_out, outputs_path, plot_seg=False)
### plot_sample(img_path.split('/')[-1], org_img, pred_seg, outputs_path, prediction, plot_seg=False)
TP = list(set(PRED_POS).intersection(set(GT_POS)))
TN = list(set(PRED_NEG).intersection(set(GT_NEG)))
# print('TP: ', TP)
# print('TN: ', TN)
print('number of TP: ', len(TP))
print('number of TN: ', len(TN))
ACC = (len(TP) + len(TN))
print('accuracy: ', ACC)