Skip to content

Commit

Permalink
- more preparation
Browse files Browse the repository at this point in the history
  • Loading branch information
dsmic committed Oct 10, 2019
1 parent ef5e262 commit a965bb5
Showing 1 changed file with 31 additions and 4 deletions.
35 changes: 31 additions & 4 deletions few_shot_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard, Callback
import tensorflow.keras.backend as K
import tensorflow.keras
from tensorflow import keras
from tensorflow.keras import optimizers as op
import tensorflow as tf
import argparse
Expand Down Expand Up @@ -464,11 +465,11 @@ def __init__(self,
def build(self, input_shape):
dtype = dtypes.as_dtype(self.dtype or K.floatx())
if not (dtype.is_floating or dtype.is_complex):
raise TypeError('Unable to build `Dense` layer with non-floating point '
raise TypeError('Unable to build `Dense_plasticity` layer with non-floating point '
'dtype %s' % (dtype,))
input_shape = tensor_shape.TensorShape(input_shape)
if tensor_shape.dimension_value(input_shape[-1]) is None:
raise ValueError('The last dimension of the inputs to `Dense` '
raise ValueError('The last dimension of the inputs to `Dense_plasticity` '
'should be defined. Found `None`.')
last_dim = tensor_shape.dimension_value(input_shape[-1])
self.input_spec = InputSpec(min_ndim=2,
Expand All @@ -481,6 +482,25 @@ def build(self, input_shape):
constraint=self.kernel_constraint,
dtype=self.dtype,
trainable=True)

#plasticity
self.kernel_p = self.add_weight(
'kernel_p',
shape=[last_dim, self.units],
initializer=keras.initializers.Constant(),
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint,
dtype=self.dtype,
trainable=True)
self.hebb = self.add_weight(
'hebb',
shape=[last_dim, self.units],
initializer=self.kernel_initializer,
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint,
dtype=self.dtype,
trainable=False)

if self.use_bias:
self.bias = self.add_weight(
'bias',
Expand All @@ -496,9 +516,12 @@ def build(self, input_shape):

def call(self, inputs):
rank = len(inputs.shape)
placticity = tf.multiply(self.kernel_p,self.hebb)
if rank > 2:
# Broadcasting is required for the inputs.
outputs = standard_ops.tensordot(inputs, self.kernel, [[rank - 1], [0]])
outputs2 = standard_ops.tensordot(inputs, placticity, [[rank - 1], [0]])
outputs = tf.add(outputs,outputs2)
# Reshape the output back to the original ndim of the input.
if not context.executing_eagerly():
shape = inputs.shape.as_list()
Expand All @@ -512,8 +535,12 @@ def call(self, inputs):
inputs = math_ops.cast(inputs, self.dtype)
if K.is_sparse(inputs):
outputs = sparse_ops.sparse_tensor_dense_matmul(inputs, self.kernel)
outputs2 = sparse_ops.sparse_tensor_dense_matmul(inputs, placticity)
outputs.set_shape(output_shape)
else:
outputs = gen_math_ops.mat_mul(inputs, self.kernel)
outputs2 = gen_math_ops.mat_mul(inputs, placticity)
outputs.set_shape(output_shape)
if self.use_bias:
outputs = nn.bias_add(outputs, self.bias)
if self.activation is not None:
Expand Down Expand Up @@ -543,7 +570,7 @@ def get_config(self):
'kernel_constraint': constraints.serialize(self.kernel_constraint),
'bias_constraint': constraints.serialize(self.bias_constraint)
}
base_config = super(Dense, self).get_config()
base_config = super(Dense_plasticity, self).get_config()
return dict(list(base_config.items()) + list(config.items()))


Expand Down Expand Up @@ -639,7 +666,7 @@ def call(x):

if args.pretrained_name is not None:
from tensorflow.keras.models import load_model
lambda_model = load_model(args.pretrained_name, custom_objects = { "keras": tensorflow.keras , "args":args, "BiasLayer": BiasLayer, "FindModel": FindModel})
lambda_model = load_model(args.pretrained_name, custom_objects = { "keras": tensorflow.keras , "args":args, "BiasLayer": BiasLayer, "FindModel": FindModel, "Dense_plasticity": Dense_plasticity})
print("loaded model",lambda_model)

if args.load_weights_name:
Expand Down

0 comments on commit a965bb5

Please sign in to comment.