diff --git a/few_shot_tests.py b/few_shot_tests.py index e8d894b..f2a6113 100755 --- a/few_shot_tests.py +++ b/few_shot_tests.py @@ -61,7 +61,7 @@ def parser(): parser.add_argument('--binary_siamese', dest='binary_siamese', action='store_true') #seems to be a bad idea parser.add_argument('--square_siamese', dest='square_siamese', action='store_true') parser.add_argument('--eta', dest='eta', type=float, default=0) - + parser.add_argument('--input_activation', dest='input_activation', type=str, default=None) args = parser.parse_args() parser() @@ -445,6 +445,7 @@ def __init__(self, kernel_constraint=None, bias_constraint=None, eta = 0.0, + input_activation=None, **kwargs): if 'input_shape' not in kwargs and 'input_dim' in kwargs: kwargs['input_shape'] = (kwargs.pop('input_dim'),) @@ -464,6 +465,7 @@ def __init__(self, self.supports_masking = True self.input_spec = InputSpec(min_ndim=2) self.eta = eta + self.input_activation = activations.get(input_activation) def build(self, input_shape): dtype = dtypes.as_dtype(self.dtype or K.floatx()) @@ -518,6 +520,8 @@ def build(self, input_shape): self.built = True def call(self, inputs): + if self.input_activation is not None: + inputs=self.input_activation(inputs) rank = len(inputs.shape) placticity = tf.multiply(self.kernel_p,self.hebb) if rank > 2: @@ -616,7 +620,7 @@ def get_config(self): flat = TimeDistributed(Flatten())(pool5) if args.dense_img_num > 0: - x = TimeDistributed(Dense_plasticity(args.dense_img_num, eta = args.eta, activation = 'sigmoid'))(flat) + x = TimeDistributed(Dense_plasticity(args.dense_img_num, eta = args.eta, input_activation = args.input_activation, activation = 'sigmoid'))(flat) else: if args.binary_siamese: x = Activation('sigmoid')(flat) @@ -650,7 +654,7 @@ def get_config(self): L1_layer = Lambda(lambda tensors:K.binary_crossentropy(tensors[0], tensors[1])) print(encoded_l,encoded_rb_scale) L1_distance = L1_layer([encoded_l, encoded_rb_scale]) - prediction = Dense_plasticity(1, eta = args.eta, name = 'dense_siamese')(L1_distance) + prediction = Dense_plasticity(1, eta = args.eta, input_activation = args.input_activation, name = 'dense_siamese')(L1_distance) else: # Add a customized layer to compute the absolute difference between the encodings if args.square_siamese: @@ -658,7 +662,7 @@ def get_config(self): else: L1_layer = Lambda(lambda tensors:K.abs(tensors[0] - tensors[1])) L1_distance = L1_layer([encoded_l, encoded_rb_scale]) - prediction = Dense_plasticity(1, eta = args.eta, name = 'dense_siamese')(L1_distance) + prediction = Dense_plasticity(1, eta = args.eta, input_activation = args.input_activation, name = 'dense_siamese')(L1_distance) # Connect the inputs with the outputs if args.load_subnet: