Skip to content

Commit

Permalink
- was necessary to add changes to mini_imagenet_dataloader.py as we n…
Browse files Browse the repository at this point in the history
…eeded unsampled first training set to save model_img output to BiasLayer 2

- output saving seems to be OK
  • Loading branch information
dsmic committed Aug 25, 2019
1 parent c469b10 commit 86a6347
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 11 deletions.
23 changes: 17 additions & 6 deletions few_shot_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
parser.add_argument('--biaslayer2', dest='biaslayer2', action='store_true')
parser.add_argument('--shots', dest='shots', type=int, default=5)
parser.add_argument('--debug', dest='debug', action='store_true')
parser.add_argument('--set_model_img_to_weights', dest='set_model_img_to_weights', action='store_true')

args = parser.parse_args()

Expand Down Expand Up @@ -110,7 +111,7 @@ def generate_add_samples(self, phase = 'train'):
self.idx = 0
while True:
batch_train_img, batch_train_label, episode_test_img, episode_test_label = \
dataloader.get_batch(phase=args.dataset, idx=self.idx)
dataloader.get_batch(phase=args.dataset, idx=self.idx, dont_shuffle_batch = (self.idx==0))

# this depends on what we are trying to train.
# care must be taken, that with a different dataset the labels have a different meaning. Thus if we use a new dataset, we must
Expand Down Expand Up @@ -227,7 +228,7 @@ def build(self, input_shape):
shape=(1),
initializer=preset,
trainable=False)
#print('bias_enable',self.bias_enable, K.eval(self.bias_enable[0]),'bias',self.bias,'weights')
print('bias_enable',self.bias_enable, K.eval(self.bias_enable[0]),'bias',self.bias,'weights')
super(BiasLayer, self).build(input_shape) # Be sure to call this at the end

def set_bias(self, do_bias):
Expand Down Expand Up @@ -320,7 +321,7 @@ def call(x):

if args.pretrained_name is not None:
from tensorflow.keras.models import load_model
lambda_model = load_model(args.pretrained_name, custom_objects = { "keras": tensorflow.keras , "args":args, "BiasLayer": BiasLayer})
lambda_model = load_model(args.pretrained_name, custom_objects = { "keras": tensorflow.keras , "args":args, "BiasLayer": BiasLayer, "FindModel": FindModel})
print("loaded model",lambda_model)

# models in models forget the layer name, therefore one must use the automatically given layer name and iterate throught the models by hand
Expand Down Expand Up @@ -383,6 +384,7 @@ def all_layers(model):
#print('test lambda', K.eval(test_lambda))


print('vor fitting', lambda_model_layers[17].get_weights()[0])

checkpointer = ModelCheckpoint(filepath='checkpoints/model-{epoch:02d}.hdf5', verbose=1)
tensorboard = TensorBoard(log_dir = args.tensorboard_logdir)
Expand Down Expand Up @@ -426,9 +428,18 @@ def print_FindModels(model):
out_test = find_conv_model.output
functor = K.function([in_test], [out_test])

calc_out = functor([K.expand_dims(K.expand_dims(K.variable(keras_gen_train.n_b_i),axis=0),axis=0)])
calc_out = functor([K.expand_dims(K.variable(keras_gen_train.n_b_i),axis=0)])

print('calc_out',calc_out[0])

print('vor', lambda_model_layers[17].get_weights()[0])

if args.set_model_img_to_weights:
print('\n!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
lambda_model_layers[17].set_weights([calc_out[0][0],np.array([0])])
print('nach', lambda_model_layers[17].get_weights()[0])
print('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n')

print(calc_out)



Expand All @@ -447,7 +458,7 @@ def print_FindModels(model):
l2.do_bias = l2.trainable = args.biaslayer1
if (l2.bias_num == 2):
l2.do_bias = l2.trainable = args.biaslayer2
print('past',l2.bias_num, l2.do_bias,args.biaslayer1,args.biaslayer2, l2.bias)
print('past',l2.bias_num, l2.do_bias,args.biaslayer1,args.biaslayer2, debug(l2.bias))

print('{:10} {:10} {:20} {:10} {:10}'.format(l, p,l2.name, ("fixed", "trainable")[l2.trainable], l2.count_params()), debug(l2.get_weights()))

Expand Down
11 changes: 6 additions & 5 deletions mini_imagenet_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,12 +141,13 @@ def load_list(self, phase='train'):
else:
print('Please select vaild phase')

def process_batch(self, input_filename_list, input_label_list, batch_sample_num, reshape_with_one=True):
def process_batch(self, input_filename_list, input_label_list, batch_sample_num, reshape_with_one=True, dont_shuffle_batch = False):
new_path_list = []
new_label_list = []
for k in range(batch_sample_num):
class_idxs = list(range(0, self.way_num))
random.shuffle(class_idxs)
if not dont_shuffle_batch:
random.shuffle(class_idxs)
for class_idx in class_idxs:
true_idx = class_idx*batch_sample_num + k
new_path_list.append(input_filename_list[true_idx])
Expand Down Expand Up @@ -174,7 +175,7 @@ def one_hot(self, inp):
out[idx, inp[idx]] = 1
return out

def get_batch(self, phase='train', idx=0):
def get_batch(self, phase='train', idx=0, dont_shuffle_batch = False):
if phase=='train':
all_filenames = self.train_filenames
labels = self.train_labels
Expand Down Expand Up @@ -204,7 +205,7 @@ def get_batch(self, phase='train', idx=0):
this_task_te_filenames += this_class_filenames[epitr_sample_num:]
this_task_te_labels += this_class_label[epitr_sample_num:]

this_inputa, this_labela = self.process_batch(this_task_tr_filenames, this_task_tr_labels, epitr_sample_num, reshape_with_one=False)
this_inputb, this_labelb = self.process_batch(this_task_te_filenames, this_task_te_labels, epite_sample_num, reshape_with_one=False)
this_inputa, this_labela = self.process_batch(this_task_tr_filenames, this_task_tr_labels, epitr_sample_num, reshape_with_one=False, dont_shuffle_batch = dont_shuffle_batch)
this_inputb, this_labelb = self.process_batch(this_task_te_filenames, this_task_te_labels, epite_sample_num, reshape_with_one=False, dont_shuffle_batch = dont_shuffle_batch)

return this_inputa, this_labela, this_inputb, this_labelb

0 comments on commit 86a6347

Please sign in to comment.