diff --git a/mini_imagenet_dataloader.py b/mini_imagenet_dataloader.py index 06fb251..9bb9d7f 100644 --- a/mini_imagenet_dataloader.py +++ b/mini_imagenet_dataloader.py @@ -302,16 +302,18 @@ def generate_add_samples(self, phase = 'train'): # care must be taken, that with a different dataset the labels have a different meaning. Thus if we use a new dataset, we must # use network_base which fits to the database. Therefore there must be taken images with label from the same dataset. network_base_img = batch_train_img[:train_epoch_size] - network_base_img = batch_train_label[:train_epoch_size] - - #only half is used now, as the rest is reserved for independend base + network_base_label = batch_train_label[:train_epoch_size] + + #only half is used now, as the rest is reserved for independend base episode_train_img = batch_train_img[train_epoch_size:] episode_train_label = batch_train_label[train_epoch_size:] if not args.use_independent_base: network_base_img = episode_train_img network_base_label = episode_train_label - if args.train_indep_and_dependent: + if args.train_indep_and_dependent: #train_epoch_size wrong, before should be old .... + network_base_img = batch_train_img[:int(train_epoch_size/2)] + network_base_label = batch_train_label[:int(train_epoch_size/2)] episode_train_img = batch_train_img episode_train_label = batch_train_label if phase == 'train':