diff --git a/mini_imagenet_dataloader.py b/mini_imagenet_dataloader.py index 1abb9e5..06fb251 100644 --- a/mini_imagenet_dataloader.py +++ b/mini_imagenet_dataloader.py @@ -31,6 +31,8 @@ parser.add_argument('--shuffle_images', dest='shuffle_images', action='store_true') parser.add_argument('--enable_idx_increase', dest='enable_idx_increase', action='store_true') parser.add_argument('--use_independent_base', dest='use_independent_base', action='store_true') +parser.add_argument('--train_indep_and_dependent', dest='train_indep_and_dependent', action='store_true') +parser.add_argument('--tensorboard_log_dir', dest='tensorboard_log_dir', type=str, default='./logs') args = parser.parse_args() class MiniImageNetDataLoader(object): @@ -264,7 +266,9 @@ def get_batch(self, phase='train', idx=0): base_train_img, base_train_label, base_test_img, base_test_label = \ dataloader.get_batch(phase='train', idx=0) -train_epoch_size = int(base_train_img.shape[0] / 2) # as double is generated for the base and train +train_epoch_size = base_train_img.shape[0] +if not args.train_indep_and_dependent: + train_epoch_size = int(train_epoch_size / 2) # as double is generated for the base and train test_epoch_size = base_test_img.shape[0] print("epoch training size:", train_epoch_size, base_train_label.shape[0], "epoch testing size", test_epoch_size) @@ -291,23 +295,25 @@ def generate(self, phase='train'): def generate_add_samples(self, phase = 'train'): self.idx = 0 while True: - episode_train_img, episode_train_label, episode_test_img, episode_test_label = \ + batch_train_img, batch_train_label, episode_test_img, episode_test_label = \ dataloader.get_batch(phase=args.dataset, idx=self.idx) # this depends on what we are trying to train. # care must be taken, that with a different dataset the labels have a different meaning. Thus if we use a new dataset, we must # use network_base which fits to the database. Therefore there must be taken images with label from the same dataset. - network_base_img = episode_train_img[:train_epoch_size] - network_base_img = episode_train_label[:train_epoch_size] + network_base_img = batch_train_img[:train_epoch_size] + network_base_img = batch_train_label[:train_epoch_size] #only half is used now, as the rest is reserved for independend base - episode_train_img = episode_train_img[train_epoch_size:] - episode_train_label = episode_train_label[train_epoch_size:] + episode_train_img = batch_train_img[train_epoch_size:] + episode_train_label = batch_train_label[train_epoch_size:] if not args.use_independent_base: network_base_img = episode_train_img network_base_label = episode_train_label - + if args.train_indep_and_dependent: + episode_train_img = batch_train_img + episode_train_label = batch_train_label if phase == 'train': if args.enable_idx_increase: self.idx += 1 # only train phase allowed to change @@ -490,7 +496,7 @@ def all_layers(model): checkpointer = ModelCheckpoint(filepath='checkpoints/model-{epoch:02d}.hdf5', verbose=1) -tensorboard = TensorBoard() +tensorboard = TensorBoard(log_dir = args.tensorboard_log_dir) lambda_model.fit_generator(keras_gen_train.generate_add_samples(), train_epoch_size, args.epochs, validation_data=keras_gen_train.generate_add_samples('test'), validation_steps=test_epoch_size, callbacks = [tensorboard], workers = 0) #workers = 0 is a work around to correct the number of calls to the validation_data generator