Skip to content

Commit

Permalink
- minor changes
Browse files Browse the repository at this point in the history
  • Loading branch information
dsmic committed Aug 20, 2019
1 parent 2058caa commit 5a583b5
Showing 1 changed file with 14 additions and 8 deletions.
22 changes: 14 additions & 8 deletions mini_imagenet_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
parser.add_argument('--shuffle_images', dest='shuffle_images', action='store_true')
parser.add_argument('--enable_idx_increase', dest='enable_idx_increase', action='store_true')
parser.add_argument('--use_independent_base', dest='use_independent_base', action='store_true')
parser.add_argument('--train_indep_and_dependent', dest='train_indep_and_dependent', action='store_true')
parser.add_argument('--tensorboard_log_dir', dest='tensorboard_log_dir', type=str, default='./logs')
args = parser.parse_args()

class MiniImageNetDataLoader(object):
Expand Down Expand Up @@ -264,7 +266,9 @@ def get_batch(self, phase='train', idx=0):
base_train_img, base_train_label, base_test_img, base_test_label = \
dataloader.get_batch(phase='train', idx=0)

train_epoch_size = int(base_train_img.shape[0] / 2) # as double is generated for the base and train
train_epoch_size = base_train_img.shape[0]
if not args.train_indep_and_dependent:
train_epoch_size = int(train_epoch_size / 2) # as double is generated for the base and train
test_epoch_size = base_test_img.shape[0]

print("epoch training size:", train_epoch_size, base_train_label.shape[0], "epoch testing size", test_epoch_size)
Expand All @@ -291,23 +295,25 @@ def generate(self, phase='train'):
def generate_add_samples(self, phase = 'train'):
self.idx = 0
while True:
episode_train_img, episode_train_label, episode_test_img, episode_test_label = \
batch_train_img, batch_train_label, episode_test_img, episode_test_label = \
dataloader.get_batch(phase=args.dataset, idx=self.idx)

# this depends on what we are trying to train.
# care must be taken, that with a different dataset the labels have a different meaning. Thus if we use a new dataset, we must
# use network_base which fits to the database. Therefore there must be taken images with label from the same dataset.
network_base_img = episode_train_img[:train_epoch_size]
network_base_img = episode_train_label[:train_epoch_size]
network_base_img = batch_train_img[:train_epoch_size]
network_base_img = batch_train_label[:train_epoch_size]

#only half is used now, as the rest is reserved for independend base
episode_train_img = episode_train_img[train_epoch_size:]
episode_train_label = episode_train_label[train_epoch_size:]
episode_train_img = batch_train_img[train_epoch_size:]
episode_train_label = batch_train_label[train_epoch_size:]

if not args.use_independent_base:
network_base_img = episode_train_img
network_base_label = episode_train_label

if args.train_indep_and_dependent:
episode_train_img = batch_train_img
episode_train_label = batch_train_label
if phase == 'train':
if args.enable_idx_increase:
self.idx += 1 # only train phase allowed to change
Expand Down Expand Up @@ -490,7 +496,7 @@ def all_layers(model):


checkpointer = ModelCheckpoint(filepath='checkpoints/model-{epoch:02d}.hdf5', verbose=1)
tensorboard = TensorBoard()
tensorboard = TensorBoard(log_dir = args.tensorboard_log_dir)
lambda_model.fit_generator(keras_gen_train.generate_add_samples(), train_epoch_size, args.epochs,
validation_data=keras_gen_train.generate_add_samples('test'), validation_steps=test_epoch_size, callbacks = [tensorboard], workers = 0)
#workers = 0 is a work around to correct the number of calls to the validation_data generator
Expand Down

0 comments on commit 5a583b5

Please sign in to comment.