diff --git a/few_shot_tests.py b/few_shot_tests.py index b992a48..3bf9c43 100755 --- a/few_shot_tests.py +++ b/few_shot_tests.py @@ -14,6 +14,7 @@ from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard, Callback import tensorflow.keras.backend as K import tensorflow.keras +from tensorflow import keras from tensorflow.keras import optimizers as op import tensorflow as tf import argparse @@ -464,11 +465,11 @@ def __init__(self, def build(self, input_shape): dtype = dtypes.as_dtype(self.dtype or K.floatx()) if not (dtype.is_floating or dtype.is_complex): - raise TypeError('Unable to build `Dense` layer with non-floating point ' + raise TypeError('Unable to build `Dense_plasticity` layer with non-floating point ' 'dtype %s' % (dtype,)) input_shape = tensor_shape.TensorShape(input_shape) if tensor_shape.dimension_value(input_shape[-1]) is None: - raise ValueError('The last dimension of the inputs to `Dense` ' + raise ValueError('The last dimension of the inputs to `Dense_plasticity` ' 'should be defined. Found `None`.') last_dim = tensor_shape.dimension_value(input_shape[-1]) self.input_spec = InputSpec(min_ndim=2, @@ -481,6 +482,25 @@ def build(self, input_shape): constraint=self.kernel_constraint, dtype=self.dtype, trainable=True) + + #plasticity + self.kernel_p = self.add_weight( + 'kernel_p', + shape=[last_dim, self.units], + initializer=keras.initializers.Constant(), + regularizer=self.kernel_regularizer, + constraint=self.kernel_constraint, + dtype=self.dtype, + trainable=True) + self.hebb = self.add_weight( + 'hebb', + shape=[last_dim, self.units], + initializer=self.kernel_initializer, + regularizer=self.kernel_regularizer, + constraint=self.kernel_constraint, + dtype=self.dtype, + trainable=False) + if self.use_bias: self.bias = self.add_weight( 'bias', @@ -496,9 +516,12 @@ def build(self, input_shape): def call(self, inputs): rank = len(inputs.shape) + placticity = tf.multiply(self.kernel_p,self.hebb) if rank > 2: # Broadcasting is required for the inputs. outputs = standard_ops.tensordot(inputs, self.kernel, [[rank - 1], [0]]) + outputs2 = standard_ops.tensordot(inputs, placticity, [[rank - 1], [0]]) + outputs = tf.add(outputs,outputs2) # Reshape the output back to the original ndim of the input. if not context.executing_eagerly(): shape = inputs.shape.as_list() @@ -512,8 +535,12 @@ def call(self, inputs): inputs = math_ops.cast(inputs, self.dtype) if K.is_sparse(inputs): outputs = sparse_ops.sparse_tensor_dense_matmul(inputs, self.kernel) + outputs2 = sparse_ops.sparse_tensor_dense_matmul(inputs, placticity) + outputs.set_shape(output_shape) else: outputs = gen_math_ops.mat_mul(inputs, self.kernel) + outputs2 = gen_math_ops.mat_mul(inputs, placticity) + outputs.set_shape(output_shape) if self.use_bias: outputs = nn.bias_add(outputs, self.bias) if self.activation is not None: @@ -543,7 +570,7 @@ def get_config(self): 'kernel_constraint': constraints.serialize(self.kernel_constraint), 'bias_constraint': constraints.serialize(self.bias_constraint) } - base_config = super(Dense, self).get_config() + base_config = super(Dense_plasticity, self).get_config() return dict(list(base_config.items()) + list(config.items())) @@ -639,7 +666,7 @@ def call(x): if args.pretrained_name is not None: from tensorflow.keras.models import load_model - lambda_model = load_model(args.pretrained_name, custom_objects = { "keras": tensorflow.keras , "args":args, "BiasLayer": BiasLayer, "FindModel": FindModel}) + lambda_model = load_model(args.pretrained_name, custom_objects = { "keras": tensorflow.keras , "args":args, "BiasLayer": BiasLayer, "FindModel": FindModel, "Dense_plasticity": Dense_plasticity}) print("loaded model",lambda_model) if args.load_weights_name: