Skip to content

Commit

Permalink
- biaslayers seem to work, can turn them on an off, but they are trai…
Browse files Browse the repository at this point in the history
…ned slow. Gradient may be increased?!
  • Loading branch information
dsmic committed Aug 24, 2019
1 parent e8e8220 commit d4c8dac
Showing 1 changed file with 43 additions and 3 deletions.
46 changes: 43 additions & 3 deletions few_shot_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from tensorflow.keras import optimizers as op
import tensorflow as tf
import argparse
import numpy as np

parser = argparse.ArgumentParser(description='train recurrent net.')
parser.add_argument('--pretrained_name', dest='pretrained_name', type=str, default=None)
Expand Down Expand Up @@ -184,10 +185,29 @@ def build(self, input_shape):
shape=(self.proto_num) + input_shape[2:],
initializer='zeros',
trainable=True)
if self.do_bias:
preset = 'ones'
else:
preset = 'zeros'
self.bias_enable = self.add_weight(name='bias_enable',
shape=(1),
initializer=preset,
trainable=False)
super(BiasLayer, self).build(input_shape) # Be sure to call this at the end

def set_bias(self, do_bias):
was_weights = self.get_weights()
if do_bias:
self.set_weights([was_weights[0],np.array([1])])
self.trainable = True
else:
self.set_weights([was_weights[0],np.array([0])])
self.trainable = False


def call(self, x):
#return tf.expand_dims(self.bias, axis = 0)# let
return self.bias * self.bias_enable + x * (1-self.bias_enable)
if self.do_bias:
return self.bias + x * 0
else:
Expand Down Expand Up @@ -297,11 +317,13 @@ def all_layers(model):
if args.enable_only_layers_of_list is not None:
l2.trainable = False
if isinstance(l2,BiasLayer):
print(l2.bias_num, l2.do_bias,args.biaslayer1,args.biaslayer2)
print('pre ',l2.bias_num, l2.do_bias,args.biaslayer1,args.biaslayer2)
if (l2.bias_num == 1):
l2.do_bias = l2.trainable = args.biaslayer1
l2.set_bias(args.biaslayer1)
if (l2.bias_num == 2):
l2.do_bias = l2.trainable = args.biaslayer2
l2.set_bias(args.biaslayer2)
#print('get_weights = ', l2.get_weights())
print('past',l2.bias_num, l2.do_bias,args.biaslayer1,args.biaslayer2) #, l2.bias)

print('{:10} {:10} {:20} {:10} {:10}'.format(l, p,l2.name, ("fixed", "trainable")[l2.trainable], l2.count_params()))

Expand Down Expand Up @@ -337,6 +359,24 @@ def all_layers(model):
lambda_model.fit_generator(keras_gen_train.generate_add_samples(), train_epoch_size, args.epochs,
validation_data=keras_gen_train.generate_add_samples('test'), validation_steps=test_epoch_size, callbacks = [tensorboard], workers = 0)
#workers = 0 is a work around to correct the number of calls to the validation_data generator
for l in range(len(lambda_model_layers)):
l2=lambda_model_layers[l]
p='normal'
if isinstance(l2,TimeDistributed):
l2=l2.layer
p='timedi'
if args.enable_only_layers_of_list is not None:
l2.trainable = False
if isinstance(l2,BiasLayer):
print('pre ',l2.bias_num, l2.do_bias,args.biaslayer1,args.biaslayer2)
if (l2.bias_num == 1):
l2.do_bias = l2.trainable = args.biaslayer1
if (l2.bias_num == 2):
l2.do_bias = l2.trainable = args.biaslayer2
print('past',l2.bias_num, l2.do_bias,args.biaslayer1,args.biaslayer2, l2.bias)

print('{:10} {:10} {:20} {:10} {:10}'.format(l, p,l2.name, ("fixed", "trainable")[l2.trainable], l2.count_params()))

for l in range(len(lambda_model_layers)):
lambda_model_layers[l].trainable = True
lambda_model.save(args.final_name+'.hdf5')
Expand Down

0 comments on commit d4c8dac

Please sign in to comment.