Skip to content

Commit

Permalink
- independend base support
Browse files Browse the repository at this point in the history
  • Loading branch information
dsmic committed Aug 20, 2019
1 parent 7d602b0 commit 2058caa
Showing 1 changed file with 18 additions and 9 deletions.
27 changes: 18 additions & 9 deletions mini_imagenet_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
parser.add_argument('--final_name', dest='final_name', type=str, default='final_model')
parser.add_argument('--shuffle_images', dest='shuffle_images', action='store_true')
parser.add_argument('--enable_idx_increase', dest='enable_idx_increase', action='store_true')
parser.add_argument('--use_independent_base', dest='use_independent_base', action='store_true')
args = parser.parse_args()

class MiniImageNetDataLoader(object):
Expand Down Expand Up @@ -246,7 +247,7 @@ def get_batch(self, phase='train', idx=0):
return this_inputa, this_labela, this_inputb, this_labelb

cathegories = 5
dataloader = MiniImageNetDataLoader(shot_num=5, way_num=cathegories, episode_test_sample_num=15)
dataloader = MiniImageNetDataLoader(shot_num=5 * 2, way_num=cathegories, episode_test_sample_num=15) #twice shot_num is because one might be uses as the base for the samples

dataloader.generate_data_list(phase='train')
dataloader.generate_data_list(phase='val')
Expand All @@ -263,7 +264,7 @@ def get_batch(self, phase='train', idx=0):
base_train_img, base_train_label, base_test_img, base_test_label = \
dataloader.get_batch(phase='train', idx=0)

train_epoch_size = base_train_img.shape[0]
train_epoch_size = int(base_train_img.shape[0] / 2) # as double is generated for the base and train
test_epoch_size = base_test_img.shape[0]

print("epoch training size:", train_epoch_size, base_train_label.shape[0], "epoch testing size", test_epoch_size)
Expand All @@ -280,11 +281,11 @@ def generate(self, phase='train'):
# dataloader.get_batch(phase='train', idx=idx)
if phase == 'train':
#print(episode_train_img.shape[0])
for i in range(base_train_img.shape[0]):
for i in range(train_epoch_size):
yield base_train_img[i:i+1], base_train_label[i:i+1]
else:
#print(episode_test_img.shape[0])
for i in range(base_test_img.shape[0]):
for i in range(test_epoch_size):
yield base_test_img[i:i+1], base_test_label[i:i+1]

def generate_add_samples(self, phase = 'train'):
Expand All @@ -296,8 +297,16 @@ def generate_add_samples(self, phase = 'train'):
# this depends on what we are trying to train.
# care must be taken, that with a different dataset the labels have a different meaning. Thus if we use a new dataset, we must
# use network_base which fits to the database. Therefore there must be taken images with label from the same dataset.
network_base_img = episode_train_img
network_base_label = episode_train_label
network_base_img = episode_train_img[:train_epoch_size]
network_base_img = episode_train_label[:train_epoch_size]

#only half is used now, as the rest is reserved for independend base
episode_train_img = episode_train_img[train_epoch_size:]
episode_train_label = episode_train_label[train_epoch_size:]

if not args.use_independent_base:
network_base_img = episode_train_img
network_base_label = episode_train_label

if phase == 'train':
if args.enable_idx_increase:
Expand All @@ -310,14 +319,14 @@ def generate_add_samples(self, phase = 'train'):

#print(episode_train_img.shape[0])
#assert(episode_train_img.shape[0] == 25)
for i in range(episode_train_img.shape[0]):
for i in range(train_epoch_size):
yield [[episode_train_img[i:i+1]], [network_base_img], [network_base_label]], episode_train_label[i:i+1]
else:
#print(episode_test_img.shape[0])
#assert(0)
#assert(episode_test_img.shape[0] == 75)
#assert(self.idx < 50)
for i in range(episode_test_img.shape[0]):
for i in range(test_epoch_size):
#print('i',i)
yield [[episode_test_img[i:i+1]], [network_base_img], [network_base_label]], episode_test_label[i:i+1]

Expand Down Expand Up @@ -371,7 +380,7 @@ def generate_add_samples(self, phase = 'train'):
print('the shape', inputs.shape)
conv1 = TimeDistributed(Conv2D(50, 7, 1 , activation = 'relu'))(inputs)
conv2 = TimeDistributed(MaxPooling2D(pool_size = (3,3)))(conv1)
conv3 = TimeDistributed(Conv2D(200, 7, 1 , activation = 'relu'))(conv2)
conv3 = TimeDistributed(Conv2D(100, 7, 1 , activation = 'relu'))(conv2)
conv4 = TimeDistributed(MaxPooling2D(pool_size = (3,3)))(conv3)
#conv3 = TimeDistributed(Conv2D(5, 5, (3,3) , padding='same', activation = 'relu'))(conv2)
flat = TimeDistributed(Flatten())(conv4)
Expand Down

0 comments on commit 2058caa

Please sign in to comment.