Skip to content

Commit

Permalink
Added:
Browse files Browse the repository at this point in the history
- pytorch training pipeline
- pytorch metrics (losses, optimisers, etc.)
- pytorch convnet
- batch iteration fro detection (to be tested)

Modified:
- initial path preparation
- toy ConvNet example
  • Loading branch information
nshepeleva committed Mar 18, 2019
1 parent 9849cf4 commit 29d1a6f
Show file tree
Hide file tree
Showing 12 changed files with 1,134 additions and 272 deletions.
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,9 @@ venv.bak/

# mypy
.mypy_cache/

# idea
.idea/

# custom folders
experiments/
6 changes: 3 additions & 3 deletions configs/config_ConvNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def load_config():
config = ConfigFlags().return_flags()

config.net = 'ConvNet'
config.training_mode = False
config.training_mode = True
config.data_set = 'MNIST'
config.image_size = [28, 28, 1]

Expand All @@ -31,7 +31,7 @@ def load_config():
config.ref_patience = 1
config.batch_size = 32
config.num_epochs = 1000
config.loss = 'mse'
config.loss = 'softmax'
config.optimizer = 'adam'
config.gradcam_record = True
config.gradcam_layers = 6
Expand All @@ -44,7 +44,7 @@ def load_config():
config.upconv = 'upconv'
config.nonlin = 'relu'
config.task_type = 'classification'
config.accuracy = 'mse'
config.accuracy = 'percent'
config.augmentation = {'flip_hor': False,
'flip_vert': False}
config.data_split = 0.7
Expand Down
2 changes: 1 addition & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from network.TrainRunner import TrainRunner
from network.InferenceRunner import InferenceRunner

EXPERIMENT_ID = 'ToyNet'
EXPERIMENT_ID = 'ConvNet'


def run(experiment_id):
Expand Down
52 changes: 36 additions & 16 deletions network/NetRunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import tensorflow as tf

from utils.DataParser import DataParser
from network.wrappers import ConvNet, UNet, VGG16
from network.wrappers import ConvNet, UNet, FakeDetection


class NetRunner:
Expand Down Expand Up @@ -152,8 +152,10 @@ def build_tensorflow_pipeline(self):
in_shape = NotImplementedError
gt_shape = NotImplementedError
elif self.task_type is 'detection':
in_shape = NotImplementedError
gt_shape = NotImplementedError
in_shape = [None, self.img_size[0], self.img_size[1], self.img_size[2]]
gt_shape = [None, 4+self.num_classes, None]
self._in_data = tf.placeholder(tf.float32, shape=in_shape, name='Input_train')
self._gt_data = tf.placeholder(tf.float32, shape=gt_shape, name='GT_train')
else:
raise ValueError('Task not supported')

Expand All @@ -177,12 +179,12 @@ def build_tensorflow_pipeline(self):
self.update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

self.network = self._pick_model()
self.pred_output = self.network.build_net_tf(self.in_data)
self.pred_output = self.network.build_net(self.in_data)
self.global_step = tf.train.get_or_create_global_step()

with tf.control_dependencies(self.update_ops):
self.loss = self.network.return_loss(y_pred=self.pred_output, y_true=self.gt_data)
self.train_op = self.network.return_optimizer(self.global_step)
self.train_op = self.network.return_optimizer(global_step=self.global_step)
self.accuracy = self.network.return_accuracy(y_pred=self.pred_output, y_true=self.gt_data,
b_s=self.batch_size, net_type=self.task_type,
loss_type=self.loss_type)
Expand Down Expand Up @@ -240,6 +242,11 @@ def build_tensorflow_pipeline(self):
self.graph_op = tf.global_variables_initializer()

def build_pytorch_pipeline(self):
self.learning_rate = self.lr
self.network = self._pick_model()

self.train_op = self.network.return_optimizer(net_param=self.network.parameters())

return NotImplementedError


Expand All @@ -248,17 +255,30 @@ def _pick_model(self):
Pick a deep model specified by self.network_type string
:return:
"""
if self.network_type == 'ConvNet':
return ConvNet.ConvNet(self.network_type, self.loss_type, self.accuracy_type, self.learning_rate,
training=self.training_mode, framework=self.framework, num_classes=self.num_classes, trainable_layers=self.trainable_layers)
elif self.network_type == 'UNet':
return UNet.UNet(self.network_type, self.loss_type, self.accuracy_type, self.learning_rate, training=self.training_mode, framework=self.framework,
trainable_layers=self.trainable_layers, num_classes=self.num_classes)
elif self.network_type == 'VGG16':
return VGG16.VGG16(self.network_type, self.loss_type, self.accuracy_type, self.learning_rate,
training=self.training_mode, framework=self.framework, num_classes=self.num_classes, trainable_layers=self.trainable_layers)
else:
return ValueError('Architecture does not exist')
if self.framework == "tensorflow":
if self.network_type == 'ConvNet':
return ConvNet.ConvNet(self.network_type, self.loss_type, self.accuracy_type, self.learning_rate, framework=self.framework,
training=self.is_training, num_filters=self.num_filters, nonlin=self.nonlin, num_classes=self.num_classes,
trainable_layers=self.trainable_layers)
elif self.network_type == 'UNet':
return UNet.UNet(self.network_type, self.loss_type, self.accuracy_type, self.learning_rate, framework=self.framework, training=self.training_mode,
trainable_layers=self.trainable_layers, num_classes=self.num_classes)
# elif self.network_type == 'VGG16':
# # return VGG16.VGG16(self.network_type, self.loss_type, self.accuracy_type, self.learning_rate,
# # training=self.training_mode, framework=self.framework, num_classes=self.num_classes, trainable_layers=self.trainable_layers)
elif self.network_type == 'FakeDetection':
return FakeDetection.FakeDetection(self.network_type, self.loss_type, self.accuracy_type, self.learning_rate, training=self.training_mode,
trainable_layers=self.trainable_layers, num_classes=self.num_classes)
else:
raise ValueError('Architecture does not exist')
elif self.framework == "pytorch":
if self.network_type == 'ConvNet':
return ConvNet.ConvNet_pt(self.network_type, self.loss_type, self.accuracy_type, self.learning_rate, framework=self.framework,
training=self.is_training, num_filters=self.num_filters, nonlin=self.nonlin, num_classes=self.num_classes,
trainable_layers=self.trainable_layers)
else:
raise ValueError('Architecture does not exist')


def _initialize_short_summary(self):
"""
Expand Down
Loading

0 comments on commit 29d1a6f

Please sign in to comment.