Skip to content

Commit

Permalink
- support different subdatasets to test some meta learning
Browse files Browse the repository at this point in the history
  • Loading branch information
dsmic committed Aug 19, 2019
1 parent dc76144 commit 9040ae4
Showing 1 changed file with 30 additions and 9 deletions.
39 changes: 30 additions & 9 deletions mini_imagenet_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,16 @@
from tqdm import trange
import imageio


import argparse
parser = argparse.ArgumentParser(description='train recurrent net.')
parser.add_argument('--pretrained_name', dest='pretrained_name', type=str, default=None)
parser.add_argument('--dataset', dest='dataset', type=str, default='train')
parser.add_argument('--lr', dest='lr', type=float, default=1e-3)
parser.add_argument('--epochs', dest='epochs', type=int, default=10)

args = parser.parse_args()

class MiniImageNetDataLoader(object):
def __init__(self, shot_num, way_num, episode_test_sample_num):
self.shot_num = shot_num
Expand Down Expand Up @@ -217,14 +227,14 @@ def get_batch(self, phase='train', idx=0):
cathegories = 5
dataloader = MiniImageNetDataLoader(shot_num=5, way_num=cathegories, episode_test_sample_num=15)

dataloader.generate_data_list(phase='train', episode_num = 20000)
#dataloader.generate_data_list(phase='val')
#dataloader.generate_data_list(phase='test')
dataloader.generate_data_list(phase='train', episode_num = 50000)
dataloader.generate_data_list(phase='val')
dataloader.generate_data_list(phase='test')

dataloader.load_list(phase='train')
dataloader.load_list(phase=args.dataset)

episode_train_img, episode_train_label, episode_test_img, episode_test_label = \
dataloader.get_batch(phase='train', idx=0)
dataloader.get_batch(phase=args.dataset, idx=0)

train_epoch_size = episode_train_img.shape[0]
test_epoch_size = episode_test_img.shape[0]
Expand Down Expand Up @@ -254,7 +264,7 @@ def generate_add_samples(self, phase = 'train'):
self.idx = 0
while True:
episode_train_img, episode_train_label, episode_test_img, episode_test_label = \
dataloader.get_batch(phase='train', idx=self.idx)
dataloader.get_batch(phase=args.dataset, idx=self.idx)
if phase == 'train':
self.idx += 1 # only train phase allowed to change
#print(episode_train_img.shape[0])
Expand Down Expand Up @@ -307,8 +317,9 @@ def generate_add_samples(self, phase = 'train'):

from tensorflow.keras.models import Model
from tensorflow.keras.layers import Activation, Dense, Input, Flatten, Conv2D, Lambda, TimeDistributed, MaxPooling2D
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard
import tensorflow.keras.backend as K
import tensorflow.keras

inputs = Input(shape=(None,84,84,3))
print('the shape', inputs.shape)
Expand Down Expand Up @@ -371,7 +382,13 @@ def call(x):

from tensorflow.keras import optimizers as op

lambda_model.compile(loss='categorical_crossentropy', optimizer=op.SGD(0.001), metrics=['categorical_accuracy'])
if args.pretrained_name is not None:
from tensorflow.keras.models import load_model
lambda_model = load_model(args.pretrained_name, custom_objects = { "keras": tensorflow.keras , "args":args})
print("loaded model",lambda_model)

#after loading to set learning rate
lambda_model.compile(loss='categorical_crossentropy', optimizer=op.SGD(args.lr), metrics=['categorical_accuracy'])
print(lambda_model.summary(line_length=180, positions = [.33, .55, .67, 1.]))

# testing with additional batch axis ?!
Expand All @@ -380,8 +397,12 @@ def call(x):
#
print('test lambda', K.eval(test_lambda))



checkpointer = ModelCheckpoint(filepath='checkpoints/model-{epoch:02d}.hdf5', verbose=1)
lambda_model.fit_generator(keras_gen_train.generate_add_samples(), train_epoch_size, 500, validation_data=keras_gen_train.generate_add_samples('test'), validation_steps=test_epoch_size, callbacks = [checkpointer])
tensorboard = TensorBoard()
lambda_model.fit_generator(keras_gen_train.generate_add_samples(), train_epoch_size, args.epochs,
validation_data=keras_gen_train.generate_add_samples('test'), validation_steps=test_epoch_size, callbacks = [checkpointer, tensorboard])


def get_weight_grad(model, inputs, outputs):
Expand Down

0 comments on commit 9040ae4

Please sign in to comment.