Skip to content

Commit

Permalink
- fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
dsmic committed Aug 20, 2019
1 parent bfab011 commit 7d602b0
Showing 1 changed file with 57 additions and 1 deletion.
58 changes: 57 additions & 1 deletion mini_imagenet_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,22 @@ def one_hot(self, inp):
out[idx, inp[idx]] = 1
return out

def idx_to_big(self, phase, idx):
if phase=='train':
all_filenames = self.train_filenames
# labels = self.train_labels
elif phase=='val':
all_filenames = self.val_filenames
# labels = self.val_labels
elif phase=='test':
all_filenames = self.test_filenames
# labels = self.test_labels
else:
print('Please select vaild phase')

one_episode_sample_num = self.num_samples_per_class*self.shot_num
return ((idx+1)*one_episode_sample_num >= len(all_filenames))

def get_batch(self, phase='train', idx=0):
if phase=='train':
all_filenames = self.train_filenames
Expand Down Expand Up @@ -286,6 +302,12 @@ def generate_add_samples(self, phase = 'train'):
if phase == 'train':
if args.enable_idx_increase:
self.idx += 1 # only train phase allowed to change
if dataloader.idx_to_big(args.dataset, self.idx):
self.idx=0
print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
print("all data used, starting from beginning")
print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")

#print(episode_train_img.shape[0])
#assert(episode_train_img.shape[0] == 25)
for i in range(episode_train_img.shape[0]):
Expand Down Expand Up @@ -417,6 +439,39 @@ def call(x):
lambda_model.compile(loss='categorical_crossentropy', optimizer=op.SGD(args.lr), metrics=['categorical_accuracy'])
print(lambda_model.summary(line_length=180, positions = [.33, .55, .67, 1.]))


# models in models forget the layer name, therefore one must use the automatically given layer name and iterate throught the models by hand
# here we can try setting the layer not trainable
def all_layers(model):
layers = []
for l in model.layers:
#print(l.name, l.trainable, isinstance(l,Model))
if isinstance(l, Model):
a = all_layers(l)
#print(a)
layers.extend(a)
else:
layers.append(l)
return layers

for l in all_layers(siamese_net):
l2=l
if isinstance(l,TimeDistributed):
l2=l.layer
print(l2.name,l2.trainable, len(l2.get_weights()))

for l in all_layers(lambda_model):
l2=l
p='normal'
if isinstance(l,TimeDistributed):
l2=l.layer
p='timedi'
# if (l2.name == 'dense_1'):
# l2.trainable = False
print(p,l2.name,l2.trainable, len(l2.get_weights()))

#lambda_model.get_layer("dense_1").trainable = False

# testing with additional batch axis ?!
i=1
test_lambda = lambda_model([K.expand_dims(K.variable(base_train_img[0:0+1]),axis=0),K.expand_dims(K.variable(base_train_img), axis=0), K.expand_dims(K.variable(base_train_label), axis=0)])
Expand All @@ -428,7 +483,8 @@ def call(x):
checkpointer = ModelCheckpoint(filepath='checkpoints/model-{epoch:02d}.hdf5', verbose=1)
tensorboard = TensorBoard()
lambda_model.fit_generator(keras_gen_train.generate_add_samples(), train_epoch_size, args.epochs,
validation_data=keras_gen_train.generate_add_samples('test'), validation_steps=test_epoch_size, callbacks = [tensorboard])
validation_data=keras_gen_train.generate_add_samples('test'), validation_steps=test_epoch_size, callbacks = [tensorboard], workers = 0)
#workers = 0 is a work around to correct the number of calls to the validation_data generator
lambda_model.save(args.final_name+'.hdf5')


Expand Down

0 comments on commit 7d602b0

Please sign in to comment.