Skip to content

Commit e8117f9

Browse files
committed
modified dataset
1 parent 4f3f780 commit e8117f9

File tree

128 files changed

+116
-13
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

128 files changed

+116
-13
lines changed

config/conf.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"lr":1e-3,
44
"batch_size":4,
55
"optimizer":"adam",
6-
"metrics":"mIoU",
6+
"metrics":["mIoU"],
77
"loss":"dice",
88
"monitor_metric":"val_mIoU"
99
}
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.

requirements.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
tensorflow
2-
glob
2+
glob
3+
numpy

src/datasets/dataset.py

+67-1
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,54 @@
1+
import json
12
import tensorflow as tf
23
from glob import glob
34
import os
5+
import numpy as np
6+
from PIL import Image, ImageDraw
47

5-
REQUIRED_FOLDERS = ['Defects', 'NoDefects', 'Annotations']
8+
REQUIRED_FOLDERS = ['Defects', 'NoDefects', 'annotations']
9+
10+
LABEL_DICT = {
11+
'BG':0,
12+
'HOLE':1,
13+
'VERTICAL':2,
14+
'HORIZONTAL':3,
15+
'SPATTERING':4,
16+
'INCANDESCENCE':5
17+
}
18+
19+
all_defects = []
20+
21+
def map_fn(file:str, save_defects=False):
22+
# load image
23+
img = tf.keras.utils.load_img(file)
24+
shape = img.size
25+
# build labels
26+
label = np.zeros((shape[0], shape[1]))
27+
if 'NoDefects' not in file:
28+
# load file
29+
annotation_file = file.replace('Defects', 'annotations').replace('jpg', 'json')
30+
with open(annotation_file, 'r') as f:
31+
annotations = json.load(f)
32+
# build annotation image & draw poligons
33+
img = Image.fromarray(label)
34+
for shape in annotations['shapes']:
35+
# use only one label name
36+
if shape['label'].upper() == 'VERTICAL DEFECT':
37+
shape['label'] = 'VERTICAL'
38+
if shape['label'].upper() == 'SPATTING':
39+
shape['label'] = 'SPATTERING'
40+
41+
points = [(x[0], x[1]) for x in shape['points']]
42+
label_id = LABEL_DICT[shape['label'].upper()]
43+
# append defects to all defects
44+
if save_defects:
45+
all_defects.append((label_id, points))
46+
# draw poligon on image
47+
ImageDraw.Draw(img).polygon(points, fill=label_id)
48+
label = np.array(img)
49+
label = tf.one_hot(label, len(LABEL_DICT))
50+
return (img, label)
51+
652

753
class AMDdataset():
854
'''Additive Manufactoring dataset class'''
@@ -19,3 +65,23 @@ def build(self):
1965
raise FileNotFoundError(f'Directory {self.path} does not contain correct folders. It must contains {REQUIRED_FOLDERS}')
2066

2167
# TODO split the dataset and load it into 3 tf.dataset: self.train, self.val, self.test
68+
train = []
69+
test = []
70+
val = []
71+
72+
# split folders
73+
for f in folders:
74+
files = np.array(glob(os.path.join(self.path, f,'*.jpg')))
75+
if len(files) == 0:
76+
continue
77+
n = len(files) // 3
78+
idx = np.random.permutation(np.arange(len(files)))
79+
80+
test.extend(files[idx[:n]])
81+
val.extend(files[idx[n:n*2]])
82+
train.extend(files[idx[n*2:]])
83+
84+
train = [map_fn(x, True) for x in train]
85+
test = [map_fn(x) for x in test]
86+
val = [map_fn(x, True) for x in val]
87+

src/network/gan.py

+34-6
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import tensorflow as tf
2-
from tensorflow.keras.applications.mobilenet_v3 import MobileNetV3
2+
from tensorflow.keras.applications import MobileNetV3Small
33

44
class ConvBlock(tf.keras.layers.Layer):
55
KERNEL_SIZE = 3
@@ -88,21 +88,49 @@ def call(self, inputs):
8888
x = self.base_model(inputs)
8989
x = self.result(x)
9090
return x
91+
9192
class GAN(tf.keras.Model):
9293
def __init__(
9394
self,
9495
name="gan",
95-
input_shape=(1024, 800, 1),
96+
# input_shape=(1024, 800, 1),
9697
**kwargs
9798
):
9899
super(GAN, self).__init__(name=name,**kwargs)
99100
self.generator = Generator()
100101
self.discriminator = Discriminator()
101102

102103
def train_step(self, inputs):
103-
# TODO
104-
return 0
104+
inputs_with_defects = inputs # TODO add function to insert defects
105+
# train the discriminator
106+
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
107+
y_true_pred = self.discriminator(inputs, training=True)
108+
gen_images = self.generator(inputs_with_defects, training=True)
109+
y_fake_pred = self.discriminator(gen_images, training=True)
110+
111+
y_true = tf.ones_like(y_true_pred)
112+
y_fake = tf.zeros_like(y_fake_pred)
113+
114+
# (the loss function is configured in compile())
115+
loss_d = self.loss['discriminator'](tf.concat(y_true, y_fake), tf.concat(y_true_pred, y_fake_pred), regularization_losses=self.losses)
116+
loss_g = self.loss['generator'](y_true, y_fake_pred, regularization_losses=self.losses)
117+
loss = tf.reduce_sum(loss_g, 0.5 * loss_d)
118+
119+
# compute gradients
120+
gen_gradients = gen_tape.gradient(loss_g, self.generator.trainable_variables)
121+
disc_gradients = disc_tape.gradient(loss_d, self.discriminator.trainable_variables)
122+
123+
# apply gradients
124+
# (the optimizer should be configured in compile())
125+
self.optimizer['generator'].apply_gradients(zip(gen_gradients, self.generator.trainable_variables))
126+
self.optimizer['discriminator'].apply_gradients(zip(disc_gradients, self.discriminator.trainable_variables))
127+
128+
# compute metrics
129+
out = {m.name: m.result() for m in self.metrics}
130+
out['loss'] = loss
131+
return out
105132

106133
def call(self, inputs):
107-
# TODO
108-
return {}
134+
gen_out = self.generator(inputs)
135+
disc_out = self.discriminator(gen_out)
136+
return [gen_out, disc_out]

src/train.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,34 @@
11
from telnetlib import GA
2-
import tensorflow
2+
import tensorflow as tf
33
from glob import glob
44
from utils.parser import parse_arguments
55
from datasets.dataset import AMDdataset
66
from network.unet import UNet
77
from network.gan import GAN
88
from utils.train_utils import TrainWrapper
99
from utils.conf_reader import read_conf
10+
import numpy as np
11+
1012

1113
MODELS={
1214
'unet':UNet,
1315
'gan':GAN
1416
}
1517

1618
if __name__ == '__main__':
19+
# make deterministic
20+
np.random.seed(0)
21+
tf.keras.utils.set_random_seed(1)
22+
tf.config.experimental.enable_op_determinism()
1723
# parsing the arguments
1824
args = parse_arguments()
1925
# read config file
2026
conf = read_conf(args.conf)
2127
# opening the dataset
2228
datasets = AMDdataset(args.dataset_folder)
29+
datasets.build()
2330
# define the model
24-
net = MODELS(args.model)
31+
net = MODELS[args.model]()
2532
# train
2633
train_wrapper = TrainWrapper(net, conf=conf, train_dataset=datasets.train, val_dataset=datasets.val, test_dataset=datasets.test)
2734

src/utils/conf_reader.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import json
22
import tensorflow as tf
3-
from losses import LOSSES
3+
from utils.losses import LOSSES
44

55
OPTIMIZERS = {
66
'adam':tf.keras.optimizers.Adam,

src/utils/losses.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,6 @@ def call(self, y_true, y_pred):
1717

1818
LOSSES = {
1919
'dice':DiceLoss,
20-
'cross_entropy':tf.keras.losses.CategoricalCrossEntropy
20+
'categorical_cross_entropy':tf.keras.losses.CategoricalCrossentropy,
21+
'binary_crossentropy':tf.keras.losses.BinaryCrossentropy
2122
}

0 commit comments

Comments
 (0)