Skip to content

Commit

Permalink
- input_activation for Dense_plasticity
Browse files Browse the repository at this point in the history
  • Loading branch information
dsmic committed Oct 12, 2019
1 parent 7cb4a24 commit f54e5cd
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions few_shot_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def parser():
parser.add_argument('--binary_siamese', dest='binary_siamese', action='store_true') #seems to be a bad idea
parser.add_argument('--square_siamese', dest='square_siamese', action='store_true')
parser.add_argument('--eta', dest='eta', type=float, default=0)

parser.add_argument('--input_activation', dest='input_activation', type=str, default=None)
args = parser.parse_args()

parser()
Expand Down Expand Up @@ -445,6 +445,7 @@ def __init__(self,
kernel_constraint=None,
bias_constraint=None,
eta = 0.0,
input_activation=None,
**kwargs):
if 'input_shape' not in kwargs and 'input_dim' in kwargs:
kwargs['input_shape'] = (kwargs.pop('input_dim'),)
Expand All @@ -464,6 +465,7 @@ def __init__(self,
self.supports_masking = True
self.input_spec = InputSpec(min_ndim=2)
self.eta = eta
self.input_activation = activations.get(input_activation)

def build(self, input_shape):
dtype = dtypes.as_dtype(self.dtype or K.floatx())
Expand Down Expand Up @@ -518,6 +520,8 @@ def build(self, input_shape):
self.built = True

def call(self, inputs):
if self.input_activation is not None:
inputs=self.input_activation(inputs)
rank = len(inputs.shape)
placticity = tf.multiply(self.kernel_p,self.hebb)
if rank > 2:
Expand Down Expand Up @@ -616,7 +620,7 @@ def get_config(self):

flat = TimeDistributed(Flatten())(pool5)
if args.dense_img_num > 0:
x = TimeDistributed(Dense_plasticity(args.dense_img_num, eta = args.eta, activation = 'sigmoid'))(flat)
x = TimeDistributed(Dense_plasticity(args.dense_img_num, eta = args.eta, input_activation = args.input_activation, activation = 'sigmoid'))(flat)
else:
if args.binary_siamese:
x = Activation('sigmoid')(flat)
Expand Down Expand Up @@ -650,15 +654,15 @@ def get_config(self):
L1_layer = Lambda(lambda tensors:K.binary_crossentropy(tensors[0], tensors[1]))
print(encoded_l,encoded_rb_scale)
L1_distance = L1_layer([encoded_l, encoded_rb_scale])
prediction = Dense_plasticity(1, eta = args.eta, name = 'dense_siamese')(L1_distance)
prediction = Dense_plasticity(1, eta = args.eta, input_activation = args.input_activation, name = 'dense_siamese')(L1_distance)
else:
# Add a customized layer to compute the absolute difference between the encodings
if args.square_siamese:
L1_layer = Lambda(lambda tensors:K.pow(tensors[0] - tensors[1], 2))
else:
L1_layer = Lambda(lambda tensors:K.abs(tensors[0] - tensors[1]))
L1_distance = L1_layer([encoded_l, encoded_rb_scale])
prediction = Dense_plasticity(1, eta = args.eta, name = 'dense_siamese')(L1_distance)
prediction = Dense_plasticity(1, eta = args.eta, input_activation = args.input_activation, name = 'dense_siamese')(L1_distance)

# Connect the inputs with the outputs
if args.load_subnet:
Expand Down

0 comments on commit f54e5cd

Please sign in to comment.