From 2058caa72e9e4a5e777fe67eb2c2602195a8363b Mon Sep 17 00:00:00 2001 From: detlef Date: Tue, 20 Aug 2019 15:24:37 +0200 Subject: [PATCH] - independend base support --- mini_imagenet_dataloader.py | 27 ++++++++++++++++++--------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/mini_imagenet_dataloader.py b/mini_imagenet_dataloader.py index 86af73b..1abb9e5 100644 --- a/mini_imagenet_dataloader.py +++ b/mini_imagenet_dataloader.py @@ -30,6 +30,7 @@ parser.add_argument('--final_name', dest='final_name', type=str, default='final_model') parser.add_argument('--shuffle_images', dest='shuffle_images', action='store_true') parser.add_argument('--enable_idx_increase', dest='enable_idx_increase', action='store_true') +parser.add_argument('--use_independent_base', dest='use_independent_base', action='store_true') args = parser.parse_args() class MiniImageNetDataLoader(object): @@ -246,7 +247,7 @@ def get_batch(self, phase='train', idx=0): return this_inputa, this_labela, this_inputb, this_labelb cathegories = 5 -dataloader = MiniImageNetDataLoader(shot_num=5, way_num=cathegories, episode_test_sample_num=15) +dataloader = MiniImageNetDataLoader(shot_num=5 * 2, way_num=cathegories, episode_test_sample_num=15) #twice shot_num is because one might be uses as the base for the samples dataloader.generate_data_list(phase='train') dataloader.generate_data_list(phase='val') @@ -263,7 +264,7 @@ def get_batch(self, phase='train', idx=0): base_train_img, base_train_label, base_test_img, base_test_label = \ dataloader.get_batch(phase='train', idx=0) -train_epoch_size = base_train_img.shape[0] +train_epoch_size = int(base_train_img.shape[0] / 2) # as double is generated for the base and train test_epoch_size = base_test_img.shape[0] print("epoch training size:", train_epoch_size, base_train_label.shape[0], "epoch testing size", test_epoch_size) @@ -280,11 +281,11 @@ def generate(self, phase='train'): # dataloader.get_batch(phase='train', idx=idx) if phase == 'train': #print(episode_train_img.shape[0]) - for i in range(base_train_img.shape[0]): + for i in range(train_epoch_size): yield base_train_img[i:i+1], base_train_label[i:i+1] else: #print(episode_test_img.shape[0]) - for i in range(base_test_img.shape[0]): + for i in range(test_epoch_size): yield base_test_img[i:i+1], base_test_label[i:i+1] def generate_add_samples(self, phase = 'train'): @@ -296,8 +297,16 @@ def generate_add_samples(self, phase = 'train'): # this depends on what we are trying to train. # care must be taken, that with a different dataset the labels have a different meaning. Thus if we use a new dataset, we must # use network_base which fits to the database. Therefore there must be taken images with label from the same dataset. - network_base_img = episode_train_img - network_base_label = episode_train_label + network_base_img = episode_train_img[:train_epoch_size] + network_base_img = episode_train_label[:train_epoch_size] + + #only half is used now, as the rest is reserved for independend base + episode_train_img = episode_train_img[train_epoch_size:] + episode_train_label = episode_train_label[train_epoch_size:] + + if not args.use_independent_base: + network_base_img = episode_train_img + network_base_label = episode_train_label if phase == 'train': if args.enable_idx_increase: @@ -310,14 +319,14 @@ def generate_add_samples(self, phase = 'train'): #print(episode_train_img.shape[0]) #assert(episode_train_img.shape[0] == 25) - for i in range(episode_train_img.shape[0]): + for i in range(train_epoch_size): yield [[episode_train_img[i:i+1]], [network_base_img], [network_base_label]], episode_train_label[i:i+1] else: #print(episode_test_img.shape[0]) #assert(0) #assert(episode_test_img.shape[0] == 75) #assert(self.idx < 50) - for i in range(episode_test_img.shape[0]): + for i in range(test_epoch_size): #print('i',i) yield [[episode_test_img[i:i+1]], [network_base_img], [network_base_label]], episode_test_label[i:i+1] @@ -371,7 +380,7 @@ def generate_add_samples(self, phase = 'train'): print('the shape', inputs.shape) conv1 = TimeDistributed(Conv2D(50, 7, 1 , activation = 'relu'))(inputs) conv2 = TimeDistributed(MaxPooling2D(pool_size = (3,3)))(conv1) -conv3 = TimeDistributed(Conv2D(200, 7, 1 , activation = 'relu'))(conv2) +conv3 = TimeDistributed(Conv2D(100, 7, 1 , activation = 'relu'))(conv2) conv4 = TimeDistributed(MaxPooling2D(pool_size = (3,3)))(conv3) #conv3 = TimeDistributed(Conv2D(5, 5, (3,3) , padding='same', activation = 'relu'))(conv2) flat = TimeDistributed(Flatten())(conv4)