Skip to content

Commit

Permalink
- biaslayer first test for old functionality ok
Browse files Browse the repository at this point in the history
  • Loading branch information
dsmic committed Aug 24, 2019
1 parent 4a8a444 commit e8e8220
Showing 1 changed file with 20 additions and 6 deletions.
26 changes: 20 additions & 6 deletions few_shot_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,9 +171,11 @@ def generate_add_samples(self, phase = 'train'):

class BiasLayer(Layer):

def __init__(self, proto_num, mult_bias = 1, **kwargs):
def __init__(self, proto_num, do_bias, bias_num, **kwargs):
self.proto_num = proto_num
self.mult_bias = mult_bias
self.do_bias = do_bias
self.bias_num = bias_num
print('mult bias',do_bias, proto_num)
super(BiasLayer, self).__init__(**kwargs)

def build(self, input_shape):
Expand All @@ -186,13 +188,16 @@ def build(self, input_shape):

def call(self, x):
#return tf.expand_dims(self.bias, axis = 0)# let
return self.bias * self.mult_bias + (1-self.mult_bias) * x
if self.do_bias:
return self.bias + x * 0
else:
return self.bias * 0 + x

def compute_output_shape(self, input_shape):
return input_shape

def get_config(self):
return {'proto_num': self.proto_num, 'mult_bias' : self.mult_bias}
return {'proto_num': self.proto_num, 'do_bias' : self.do_bias,'bias_num' : self.bias_num}

inputs = Input(shape=(None,84,84,3))
print('the shape', inputs.shape)
Expand Down Expand Up @@ -222,11 +227,11 @@ def get_config(self):
input1 = Input(shape=(None,84,84,3))
input2 = Input(shape=(None,84,84,3)) #, tensor = K.variable(episode_train_img[0:0]))

input2b = BiasLayer(shots * cathegories, mult_bias = float(args.biaslayer1))(input2)
input2b = BiasLayer(shots * cathegories, args.biaslayer1, 1)(input2)
encoded_l = model_img(input1)
encoded_r = model_img(input2b)

encoded_rb = BiasLayer(shots * cathegories, mult_bias = float(args.biaslayer1))(encoded_r)
encoded_rb = BiasLayer(shots * cathegories, args.biaslayer2, 2)(encoded_r)
# Add a customized layer to compute the absolute difference between the encodings
L1_layer = Lambda(lambda tensors:K.abs(tensors[0] - tensors[1]))
L1_distance = L1_layer([encoded_l, encoded_rb])
Expand Down Expand Up @@ -291,6 +296,13 @@ def all_layers(model):
p='timedi'
if args.enable_only_layers_of_list is not None:
l2.trainable = False
if isinstance(l2,BiasLayer):
print(l2.bias_num, l2.do_bias,args.biaslayer1,args.biaslayer2)
if (l2.bias_num == 1):
l2.do_bias = l2.trainable = args.biaslayer1
if (l2.bias_num == 2):
l2.do_bias = l2.trainable = args.biaslayer2

print('{:10} {:10} {:20} {:10} {:10}'.format(l, p,l2.name, ("fixed", "trainable")[l2.trainable], l2.count_params()))

if args.enable_only_layers_of_list is not None:
Expand Down Expand Up @@ -325,6 +337,8 @@ def all_layers(model):
lambda_model.fit_generator(keras_gen_train.generate_add_samples(), train_epoch_size, args.epochs,
validation_data=keras_gen_train.generate_add_samples('test'), validation_steps=test_epoch_size, callbacks = [tensorboard], workers = 0)
#workers = 0 is a work around to correct the number of calls to the validation_data generator
for l in range(len(lambda_model_layers)):
lambda_model_layers[l].trainable = True
lambda_model.save(args.final_name+'.hdf5')

# tools for debugging
Expand Down

0 comments on commit e8e8220

Please sign in to comment.