-
Notifications
You must be signed in to change notification settings - Fork 1
/
build_lisa_records.py
130 lines (108 loc) · 4.73 KB
/
build_lisa_records.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
# import the packages
from config import lisa_config as config
from pipeline.utils import TFAnnotation
from sklearn.model_selection import train_test_split
from PIL import Image
import tensorflow as tf
import os
def main(_):
# open the classes output file
f = open(config.CLASSES_FILE, "w")
# loop over classes
for (k, v) in config.CLASSES.items():
# construct the class information and write to file
item = ("item {\n"
"\tid: " + str(v) + "\n"
"\tname: '" + k + "'\n"
"}\n"
)
f.write(item)
# close the output classes file
f.close()
# initialize a data dictionary used to map each image filename to all
# bounding boxes associated with the image
# then load the contents of the annotations file
D = {}
rows = open(config.ANNOT_PATH).read().strip().split("\n")
# loop over the individual rows, skipping the header
for row in rows[1:]:
# break the row into components
row = row.split(",")[0].split(";")
(imagePath, label, startX, startY, endX, endY, _) = row
(startX, startY) = (float(startX), float(startY))
(endX, endY) = (float(endX), float(endY))
# if we are not interested other labels, simply ignore it
if label not in config.CLASSES:
continue
# build the path to the input image
# then grab any other bounding boxes + labels associated image path
# labels and bounding box lists, respectively
p = os.path.sep.join([config.BASE_PATH, imagePath])
b = D.get(p, [])
# build a tuple consisting of the label and bounding box
# then update the list and store it in the dictionary
b.append((label, (startX, startY, endX, endY)))
D[p] = b
# create training and testing splits from our data dictionary
(trainKeys, testKeys) = train_test_split(list(D.keys()),
test_size = config.TEST_SIZE, random_state = 42)
# initialize the data split files
datasets = [
("train", trainKeys, config.TRAIN_RECORD),
("test", testKeys, config.TEST_RECORD)
]
# loop over the datasets
for (dType, keys, outputPath) in datasets:
# initialize the TensorFlow writer
# and initialize the total number of examples written to file
print("[INFO] processing '{}'...".format(dType))
writer = tf.python_io.TFRecordWriter(outputPath)
total = 0
# loop over all the keys in the current set
for k in keys:
# load the input image from disk as a TensorFlow object
encoded = tf.gfile.GFile(k, "rb").read()
encoded = bytes(encoded)
# load the image from disk again, this time as a PIL object
pilImage = Image.open(k)
(w, h) = pilImage.size[:2]
# parse the filename and encoding from the input path
filename = k.split(os.path.sep)[-1]
encoding = filename[filename.rfind(".") + 1:]
# initialize the annotation object used to store
# information regarding the bounding box + labels
tfAnnot = TFAnnotation()
tfAnnot.image = encoded
tfAnnot.encoding = encoding
tfAnnot.filename = filename
tfAnnot.width = w
tfAnnot.height = h
# loop over the bounding boxes + labels associated with the image
for (label, (startX, startY, endX, endY)) in D[k]:
# TensorFlow asssume all bounding boxes are in the range [0, 1]
# so that we need to scale it
xMin = startX / w
xMax = endX / w
yMin = startY / h
yMax = endY / h
# update the bounding boxes + labels lists
tfAnnot.xMins.append(xMin)
tfAnnot.xMaxs.append(xMax)
tfAnnot.yMins.append(yMin)
tfAnnot.yMaxs.append(yMax)
tfAnnot.textLabels.append(label.encode("utf8"))
tfAnnot.classes.append(config.CLASSES[label])
tfAnnot.difficult.append(0)
# increment the total number of examples
total += 1
# encode the data point attributes using the TensorFlow helper functions
features = tf.train.Features(feature = tfAnnot.build())
example = tf.train.Example(features = features)
# add the example to the writer
writer.write(example.SerializeToString())
# close the writer and print diagnostic information to the user
writer.close()
print("[INFO] {} examples saved for '{}'".format(total, dType))
# check to see if the main thread should be started
if __name__ == "__main__":
tf.app.run()