Skip to content

Commit

Permalink
- prepared for BiasLayer
Browse files Browse the repository at this point in the history
  • Loading branch information
dsmic committed Aug 23, 2019
1 parent 2850bf8 commit e8b54ae
Showing 1 changed file with 41 additions and 12 deletions.
53 changes: 41 additions & 12 deletions few_shot_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import ast
from mini_imagenet_dataloader import MiniImageNetDataLoader
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Activation, Dense, Input, Flatten, Conv2D, Lambda, TimeDistributed, MaxPooling2D
from tensorflow.keras.layers import Activation, Dense, Input, Flatten, Conv2D, Lambda, TimeDistributed, MaxPooling2D, Layer
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard
import tensorflow.keras.backend as K
import tensorflow.keras
Expand All @@ -23,6 +23,7 @@
parser.add_argument('--dataset', dest='dataset', type=str, default='train')
parser.add_argument('--lr', dest='lr', type=float, default=1e-3)
parser.add_argument('--epochs', dest='epochs', type=int, default=10)
parser.add_argument('--hidden_size', dest='hidden_size', type=int, default=64)
parser.add_argument('--final_name', dest='final_name', type=str, default='final_model')
parser.add_argument('--shuffle_images', dest='shuffle_images', action='store_true')
parser.add_argument('--enable_idx_increase', dest='enable_idx_increase', action='store_true')
Expand Down Expand Up @@ -59,7 +60,8 @@ def idx_to_big(self, phase, idx):


cathegories = 5
dataloader = OurMiniImageNetDataLoader(shot_num=5 * 2, way_num=cathegories, episode_test_sample_num=args.episode_test_sample_num, shuffle_images = args.shuffle_images) #twice shot_num is because one might be uses as the base for the samples
shots = 5
dataloader = OurMiniImageNetDataLoader(shot_num=shots * 2, way_num=cathegories, episode_test_sample_num=args.episode_test_sample_num, shuffle_images = args.shuffle_images) #twice shot_num is because one might be uses as the base for the samples

dataloader.generate_data_list(phase=args.dataset)

Expand Down Expand Up @@ -164,17 +166,43 @@ def generate_add_samples(self, phase = 'train'):
# Memory growth must be set before GPUs have been initialized
print(e)


class BiasLayer(Layer):

def __init__(self, proto_num, mult_bias = 1, **kwargs):
self.proto_num = proto_num
self.mult_bias = mult_bias
super(BiasLayer, self).__init__(**kwargs)

def build(self, input_shape):
# Create a trainable weight variable for this layer.
self.bias = self.add_weight(name='bias',
shape=(self.proto_num,input_shape[2], input_shape[3], input_shape[4]),
initializer='zeros',
trainable=True)
super(BiasLayer, self).build(input_shape) # Be sure to call this at the end

def call(self, x):
#return tf.expand_dims(self.bias, axis = 0)# let
return self.bias * self.mult_bias + (1-self.mult_bias) * x

def compute_output_shape(self, input_shape):
return input_shape

def get_config(self):
return {'proto_num': self.proto_num, 'mult_bias' : self.mult_bias}

inputs = Input(shape=(None,84,84,3))
print('the shape', inputs.shape)
conv1 = TimeDistributed(Conv2D(64, 3, padding='same', activation = 'relu'))(inputs)
conv1 = TimeDistributed(Conv2D(args.hidden_size, 3, padding='same', activation = 'relu'))(inputs)
pool1 = TimeDistributed(MaxPooling2D(pool_size = 2))(conv1)
conv2 = TimeDistributed(Conv2D(64, 3, padding='same', activation = 'relu'))(pool1)
conv2 = TimeDistributed(Conv2D(args.hidden_size, 3, padding='same', activation = 'relu'))(pool1)
pool2 = TimeDistributed(MaxPooling2D(pool_size = 2))(conv2)
conv3 = TimeDistributed(Conv2D(64, 3, padding='same', activation = 'relu'))(pool2)
conv3 = TimeDistributed(Conv2D(args.hidden_size, 3, padding='same', activation = 'relu'))(pool2)
pool3 = TimeDistributed(MaxPooling2D(pool_size = 2))(conv3)
conv4 = TimeDistributed(Conv2D(64, 3, padding='same', activation = 'relu'))(pool3)
conv4 = TimeDistributed(Conv2D(args.hidden_size, 3, padding='same', activation = 'relu'))(pool3)
pool4 = TimeDistributed(MaxPooling2D(pool_size = 2))(conv4)
conv5 = TimeDistributed(Conv2D(64, 3, padding='same', activation = 'relu'))(pool4)
conv5 = TimeDistributed(Conv2D(args.hidden_size, 3, padding='same', activation = 'relu'))(pool4)
pool5 = TimeDistributed(MaxPooling2D(pool_size = 2))(conv5)

flat = TimeDistributed(Flatten())(pool5)
Expand All @@ -192,8 +220,9 @@ def generate_add_samples(self, phase = 'train'):
input1 = Input(shape=(None,84,84,3))
input2 = Input(shape=(None,84,84,3)) #, tensor = K.variable(episode_train_img[0:0]))

input2b = BiasLayer(shots * cathegories, mult_bias = 0)(input2)
encoded_l = model_img(input1)
encoded_r = model_img(input2)
encoded_r = model_img(input2b)

# Add a customized layer to compute the absolute difference between the encodings
L1_layer = Lambda(lambda tensors:K.abs(tensors[0] - tensors[1]))
Expand Down Expand Up @@ -233,7 +262,7 @@ def call(x):

if args.pretrained_name is not None:
from tensorflow.keras.models import load_model
lambda_model = load_model(args.pretrained_name, custom_objects = { "keras": tensorflow.keras , "args":args})
lambda_model = load_model(args.pretrained_name, custom_objects = { "keras": tensorflow.keras , "args":args, "BiasLayer": BiasLayer})
print("loaded model",lambda_model)

# models in models forget the layer name, therefore one must use the automatically given layer name and iterate throught the models by hand
Expand Down Expand Up @@ -281,10 +310,10 @@ def all_layers(model):
#lambda_model.get_layer("dense_1").trainable = False

# testing with additional batch axis ?!
i=1
test_lambda = lambda_model([K.expand_dims(K.variable(base_train_img[0:0+1]),axis=0),K.expand_dims(K.variable(base_train_img), axis=0), K.expand_dims(K.variable(base_train_label), axis=0)])
#i=1
#test_lambda = lambda_model([K.expand_dims(K.variable(base_train_img[0:0+1]),axis=0),K.expand_dims(K.variable(base_train_img), axis=0), K.expand_dims(K.variable(base_train_label), axis=0)])
#
print('test lambda', K.eval(test_lambda))
#print('test lambda', K.eval(test_lambda))



Expand Down

0 comments on commit e8b54ae

Please sign in to comment.