Skip to content

Commit

Permalink
- support of disabling layers for training
Browse files Browse the repository at this point in the history
  • Loading branch information
dsmic committed Aug 21, 2019
1 parent f96c9d0 commit c2f67d2
Showing 1 changed file with 31 additions and 13 deletions.
44 changes: 31 additions & 13 deletions mini_imagenet_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import numpy as np
from tqdm import trange
import imageio

import ast

import argparse
parser = argparse.ArgumentParser(description='train recurrent net.')
Expand All @@ -33,6 +33,9 @@
parser.add_argument('--use_independent_base', dest='use_independent_base', action='store_true')
parser.add_argument('--train_indep_and_dependent', dest='train_indep_and_dependent', action='store_true')
parser.add_argument('--tensorboard_log_dir', dest='tensorboard_log_dir', type=str, default='./logs')
parser.add_argument('--enable_only_layers_of_list', dest='enable_only_layers_of_list', type=str, default=None)
parser.add_argument('--episode_test_sample_num', dest='episode_test_sample_num', type=int, default=15)

args = parser.parse_args()

class MiniImageNetDataLoader(object):
Expand Down Expand Up @@ -249,7 +252,7 @@ def get_batch(self, phase='train', idx=0):
return this_inputa, this_labela, this_inputb, this_labelb

cathegories = 5
dataloader = MiniImageNetDataLoader(shot_num=5 * 2, way_num=cathegories, episode_test_sample_num=15) #twice shot_num is because one might be uses as the base for the samples
dataloader = MiniImageNetDataLoader(shot_num=5 * 2, way_num=cathegories, episode_test_sample_num=args.episode_test_sample_num) #twice shot_num is because one might be uses as the base for the samples

dataloader.generate_data_list(phase='train')
dataloader.generate_data_list(phase='val')
Expand Down Expand Up @@ -452,9 +455,6 @@ def call(x):
lambda_model = load_model(args.pretrained_name, custom_objects = { "keras": tensorflow.keras , "args":args})
print("loaded model",lambda_model)

#after loading to set learning rate
lambda_model.compile(loss='categorical_crossentropy', optimizer=op.SGD(args.lr), metrics=['categorical_accuracy'])
print(lambda_model.summary(line_length=180, positions = [.33, .55, .67, 1.]))


# models in models forget the layer name, therefore one must use the automatically given layer name and iterate throught the models by hand
Expand All @@ -477,16 +477,34 @@ def all_layers(model):
l2=l.layer
print(l2.name,l2.trainable, len(l2.get_weights()))

for l in all_layers(lambda_model):
l2=l
lambda_model_layers = all_layers(lambda_model)
for l in range(len(lambda_model_layers)):
l2=lambda_model_layers[l]
p='normal'
if isinstance(l,TimeDistributed):
l2=l.layer
if isinstance(l2,TimeDistributed):
l2=l2.layer
p='timedi'
# if (l2.name == 'dense_1'):
# l2.trainable = False
print(p,l2.name,l2.trainable, len(l2.get_weights()))

if args.enable_only_layers_of_list is not None:
l2.trainable = False
print(l, p,l2.name,l2.trainable, len(l2.get_weights()))

if args.enable_only_layers_of_list is not None:
print('\nenable some layers for training')

for i in ast.literal_eval(args.enable_only_layers_of_list):
lambda_model_layers[i].trainable = True

for l in range(len(lambda_model_layers)):
l2=lambda_model_layers[l]
p='normal'
if isinstance(l2,TimeDistributed):
l2=l2.layer
p='timedi'
print(l, p,l2.name,l2.trainable, len(l2.get_weights()))

#after loading to set learning rate
lambda_model.compile(loss='categorical_crossentropy', optimizer=op.SGD(args.lr), metrics=['categorical_accuracy'])
print(lambda_model.summary(line_length=180, positions = [.33, .55, .67, 1.]))
#lambda_model.get_layer("dense_1").trainable = False

# testing with additional batch axis ?!
Expand Down

0 comments on commit c2f67d2

Please sign in to comment.