diff --git a/few_shot_tests.py b/few_shot_tests.py index 87c4c3a..76ef3d9 100755 --- a/few_shot_tests.py +++ b/few_shot_tests.py @@ -17,6 +17,7 @@ from tensorflow.keras import optimizers as op import tensorflow as tf import argparse +import numpy as np parser = argparse.ArgumentParser(description='train recurrent net.') parser.add_argument('--pretrained_name', dest='pretrained_name', type=str, default=None) @@ -184,10 +185,29 @@ def build(self, input_shape): shape=(self.proto_num) + input_shape[2:], initializer='zeros', trainable=True) + if self.do_bias: + preset = 'ones' + else: + preset = 'zeros' + self.bias_enable = self.add_weight(name='bias_enable', + shape=(1), + initializer=preset, + trainable=False) super(BiasLayer, self).build(input_shape) # Be sure to call this at the end + def set_bias(self, do_bias): + was_weights = self.get_weights() + if do_bias: + self.set_weights([was_weights[0],np.array([1])]) + self.trainable = True + else: + self.set_weights([was_weights[0],np.array([0])]) + self.trainable = False + + def call(self, x): #return tf.expand_dims(self.bias, axis = 0)# let + return self.bias * self.bias_enable + x * (1-self.bias_enable) if self.do_bias: return self.bias + x * 0 else: @@ -297,11 +317,13 @@ def all_layers(model): if args.enable_only_layers_of_list is not None: l2.trainable = False if isinstance(l2,BiasLayer): - print(l2.bias_num, l2.do_bias,args.biaslayer1,args.biaslayer2) + print('pre ',l2.bias_num, l2.do_bias,args.biaslayer1,args.biaslayer2) if (l2.bias_num == 1): - l2.do_bias = l2.trainable = args.biaslayer1 + l2.set_bias(args.biaslayer1) if (l2.bias_num == 2): - l2.do_bias = l2.trainable = args.biaslayer2 + l2.set_bias(args.biaslayer2) + #print('get_weights = ', l2.get_weights()) + print('past',l2.bias_num, l2.do_bias,args.biaslayer1,args.biaslayer2) #, l2.bias) print('{:10} {:10} {:20} {:10} {:10}'.format(l, p,l2.name, ("fixed", "trainable")[l2.trainable], l2.count_params())) @@ -337,6 +359,24 @@ def all_layers(model): lambda_model.fit_generator(keras_gen_train.generate_add_samples(), train_epoch_size, args.epochs, validation_data=keras_gen_train.generate_add_samples('test'), validation_steps=test_epoch_size, callbacks = [tensorboard], workers = 0) #workers = 0 is a work around to correct the number of calls to the validation_data generator +for l in range(len(lambda_model_layers)): + l2=lambda_model_layers[l] + p='normal' + if isinstance(l2,TimeDistributed): + l2=l2.layer + p='timedi' + if args.enable_only_layers_of_list is not None: + l2.trainable = False + if isinstance(l2,BiasLayer): + print('pre ',l2.bias_num, l2.do_bias,args.biaslayer1,args.biaslayer2) + if (l2.bias_num == 1): + l2.do_bias = l2.trainable = args.biaslayer1 + if (l2.bias_num == 2): + l2.do_bias = l2.trainable = args.biaslayer2 + print('past',l2.bias_num, l2.do_bias,args.biaslayer1,args.biaslayer2, l2.bias) + + print('{:10} {:10} {:20} {:10} {:10}'.format(l, p,l2.name, ("fixed", "trainable")[l2.trainable], l2.count_params())) + for l in range(len(lambda_model_layers)): lambda_model_layers[l].trainable = True lambda_model.save(args.final_name+'.hdf5')