-
Notifications
You must be signed in to change notification settings - Fork 0
/
inference.py
98 lines (80 loc) · 2.7 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import cv2
import numpy as np
import pandas as pd
import time
from glob import glob
from tqdm import tqdm
from densenet_bc_moel import dense_net_bc_model
def load_for_one_folder(folder_path):
img_paths = glob(f"{folder_path}/*.JPG")
def sort_func(x):
x = x.replace("\\", "/")
file_name = x.split("/")[1].split(".")[0]
if len(x) == 1:
return int(file_name)
elif len(x) == 2:
return int(file_name)
else:
return int(file_name)
img_paths = sorted(img_paths, key=sort_func)
img = cv2.imread(img_paths[0])
img = cv2.resize(img, (64, 64))
test_data = img
for i, img_path in enumerate(img_paths):
if i > 0:
img = cv2.imread(img_path)
img = cv2.resize(img, (64, 64))
test_data = np.concatenate((test_data, img), axis=0)
test_data = test_data.reshape(len(img_paths), 64, 64, 3)
test_data = test_data.astype("float32")
test_data /= 255
return test_data
def load_for_one_file(file_path):
test_data = cv2.imread(file_path)
test_data = cv2.resize(test_data, (64, 64))
test_data = test_data.reshape(1, 64, 64, 3)
test_data = test_data.astype("float32")
test_data /= 255
return test_data
if __name__ == "__main__":
start = time.time()
map_dict = {
0: "A",
1: "B",
2: "C"
}
test_data = load_for_one_folder("Dataset")
original_df = pd.read_csv("Label.csv")
result = []
if test_data.shape == 1:
model = dense_net_bc_model()
model.load_weights("cifar10_densenet_model.46_100%.h5")
ans = model.predict(test_data)
ans = ans.tolist()
ans = map_dict[ans[0].index(max(ans[0]))]
print(ans)
else:
model = dense_net_bc_model()
model.load_weights("cifar10_densenet_model.46_100%.h5")
time.sleep(0.1)
for i in tqdm(range(test_data.shape[0]), ascii=True, desc="判讀進度", ncols=100):
temp_test_data = test_data[i].reshape(
1,
test_data[i].shape[0],
test_data[i].shape[1],
test_data[i].shape[2]
)
ans = model.predict(temp_test_data)
ans = ans.tolist()
ans = map_dict[ans[0].index(max(ans[0]))]
result += [ans]
original_df["result"] = result
original_df.to_csv("result.csv", index=False)
time.sleep(0.1)
# 秀出與ground truth不一樣的判讀結果
print("判讀錯誤的圖片".center(20, "*"))
df = pd.read_csv("result.csv")
for i in range(len(df)):
if df["label"][i] != df["result"][i]:
print(df["img_path"][i])
print(f"總執行時間{time.time() - start}秒")