diff --git a/few_shot_tests.py b/few_shot_tests.py index 7f15ba3..203baa1 100755 --- a/few_shot_tests.py +++ b/few_shot_tests.py @@ -57,6 +57,7 @@ def parser(): parser.add_argument('--EarlyStop', dest='EarlyStop', type=str, default='EarlyStop') parser.add_argument('--max_idx', dest='max_idx', type=int, default=-1) parser.add_argument('--dense_img_num', dest='dense_img_num', type=int, default=-1) + parser.add_argument('--binary_siamese', dest='binary_siamese', action='store_true') args = parser.parse_args() @@ -377,9 +378,13 @@ def get_config(self): flat = TimeDistributed(Flatten())(pool5) if args.dense_img_num > 0: - x = TimeDistributed(Dense(args.dense_img_num, activation = 'tanh'))(flat) + x = TimeDistributed(Dense(args.dense_img_num, activation = 'sigmoid'))(flat) else: - x = flat + if args.binary_siamese: + x = Activation('sigmoid')(flat) + else: + x = flat + #predictions = Activation('softmax')(x) model_img = FindModel(inputs=inputs, outputs=x) @@ -400,13 +405,20 @@ def get_config(self): encoded_rb_scale = ScaleGradientLayer()(encoded_rb) else: encoded_rb_scale = encoded_rb -# Add a customized layer to compute the absolute difference between the encodings -L1_layer = Lambda(lambda tensors:K.abs(tensors[0] - tensors[1])) -L1_distance = L1_layer([encoded_l, encoded_rb_scale]) - + # Add a dense layer with a sigmoid unit to generate the similarity score -prediction = Dense(1, name = 'dense_siamese')(L1_distance) - +if args.binary_siamese: + # Add a customized layer to compute the absolute difference between the encodings + L1_layer = Lambda(lambda tensors:K.binary_crossentropy(tensors[0], tensors[1])) + print(encoded_l,encoded_rb_scale) + L1_distance = L1_layer([encoded_l, encoded_rb_scale]) + prediction = Dense(1, name = 'dense_siamese')(L1_distance) +else: + # Add a customized layer to compute the absolute difference between the encodings + L1_layer = Lambda(lambda tensors:K.abs(tensors[0] - tensors[1])) + L1_distance = L1_layer([encoded_l, encoded_rb_scale]) + prediction = Dense(1, name = 'dense_siamese')(L1_distance) + # Connect the inputs with the outputs if args.load_subnet: submodel_name = "model_changed"