diff --git a/few_shot_tests.py b/few_shot_tests.py index ca89b09..b992a48 100755 --- a/few_shot_tests.py +++ b/few_shot_tests.py @@ -362,6 +362,191 @@ def compute_output_shape(self, input_shape): def get_config(self): return {'proto_num': self.proto_num, 'do_bias' : self.do_bias,'bias_num' : self.bias_num} +from tensorflow.python.keras import initializers +from tensorflow.python.keras import regularizers +from tensorflow.python.keras import activations +from tensorflow.python.keras import constraints +from tensorflow.python.keras.engine.input_spec import InputSpec +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import standard_ops +from tensorflow.python.eager import context +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import sparse_ops +from tensorflow.python.ops import gen_math_ops +from tensorflow.python.ops import nn + +class Dense_plasticity(Layer): + """Just your regular densely-connected NN layer. + + `Dense` implements the operation: + `output = activation(dot(input, kernel) + bias)` + where `activation` is the element-wise activation function + passed as the `activation` argument, `kernel` is a weights matrix + created by the layer, and `bias` is a bias vector created by the layer + (only applicable if `use_bias` is `True`). + + Note: If the input to the layer has a rank greater than 2, then + it is flattened prior to the initial dot product with `kernel`. + + Example: + + ```python + # as first layer in a sequential model: + model = Sequential() + model.add(Dense(32, input_shape=(16,))) + # now the model will take as input arrays of shape (*, 16) + # and output arrays of shape (*, 32) + + # after the first layer, you don't need to specify + # the size of the input anymore: + model.add(Dense(32)) + ``` + + Arguments: + units: Positive integer, dimensionality of the output space. + activation: Activation function to use. + If you don't specify anything, no activation is applied + (ie. "linear" activation: `a(x) = x`). + use_bias: Boolean, whether the layer uses a bias vector. + kernel_initializer: Initializer for the `kernel` weights matrix. + bias_initializer: Initializer for the bias vector. + kernel_regularizer: Regularizer function applied to + the `kernel` weights matrix. + bias_regularizer: Regularizer function applied to the bias vector. + activity_regularizer: Regularizer function applied to + the output of the layer (its "activation").. + kernel_constraint: Constraint function applied to + the `kernel` weights matrix. + bias_constraint: Constraint function applied to the bias vector. + + Input shape: + N-D tensor with shape: `(batch_size, ..., input_dim)`. + The most common situation would be + a 2D input with shape `(batch_size, input_dim)`. + + Output shape: + N-D tensor with shape: `(batch_size, ..., units)`. + For instance, for a 2D input with shape `(batch_size, input_dim)`, + the output would have shape `(batch_size, units)`. + """ + + def __init__(self, + units, + activation=None, + use_bias=True, + kernel_initializer='glorot_uniform', + bias_initializer='zeros', + kernel_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + kernel_constraint=None, + bias_constraint=None, + **kwargs): + if 'input_shape' not in kwargs and 'input_dim' in kwargs: + kwargs['input_shape'] = (kwargs.pop('input_dim'),) + + super(Dense_plasticity, self).__init__( + activity_regularizer=regularizers.get(activity_regularizer), **kwargs) + self.units = int(units) + self.activation = activations.get(activation) + self.use_bias = use_bias + self.kernel_initializer = initializers.get(kernel_initializer) + self.bias_initializer = initializers.get(bias_initializer) + self.kernel_regularizer = regularizers.get(kernel_regularizer) + self.bias_regularizer = regularizers.get(bias_regularizer) + self.kernel_constraint = constraints.get(kernel_constraint) + self.bias_constraint = constraints.get(bias_constraint) + + self.supports_masking = True + self.input_spec = InputSpec(min_ndim=2) + + def build(self, input_shape): + dtype = dtypes.as_dtype(self.dtype or K.floatx()) + if not (dtype.is_floating or dtype.is_complex): + raise TypeError('Unable to build `Dense` layer with non-floating point ' + 'dtype %s' % (dtype,)) + input_shape = tensor_shape.TensorShape(input_shape) + if tensor_shape.dimension_value(input_shape[-1]) is None: + raise ValueError('The last dimension of the inputs to `Dense` ' + 'should be defined. Found `None`.') + last_dim = tensor_shape.dimension_value(input_shape[-1]) + self.input_spec = InputSpec(min_ndim=2, + axes={-1: last_dim}) + self.kernel = self.add_weight( + 'kernel', + shape=[last_dim, self.units], + initializer=self.kernel_initializer, + regularizer=self.kernel_regularizer, + constraint=self.kernel_constraint, + dtype=self.dtype, + trainable=True) + if self.use_bias: + self.bias = self.add_weight( + 'bias', + shape=[self.units,], + initializer=self.bias_initializer, + regularizer=self.bias_regularizer, + constraint=self.bias_constraint, + dtype=self.dtype, + trainable=True) + else: + self.bias = None + self.built = True + + def call(self, inputs): + rank = len(inputs.shape) + if rank > 2: + # Broadcasting is required for the inputs. + outputs = standard_ops.tensordot(inputs, self.kernel, [[rank - 1], [0]]) + # Reshape the output back to the original ndim of the input. + if not context.executing_eagerly(): + shape = inputs.shape.as_list() + output_shape = shape[:-1] + [self.units] + outputs.set_shape(output_shape) + else: + # Cast the inputs to self.dtype, which is the variable dtype. We do not + # cast if `should_cast_variables` is True, as in that case the variable + # will be automatically casted to inputs.dtype. + if not self._mixed_precision_policy.should_cast_variables: + inputs = math_ops.cast(inputs, self.dtype) + if K.is_sparse(inputs): + outputs = sparse_ops.sparse_tensor_dense_matmul(inputs, self.kernel) + else: + outputs = gen_math_ops.mat_mul(inputs, self.kernel) + if self.use_bias: + outputs = nn.bias_add(outputs, self.bias) + if self.activation is not None: + return self.activation(outputs) # pylint: disable=not-callable + return outputs + + def compute_output_shape(self, input_shape): + input_shape = tensor_shape.TensorShape(input_shape) + input_shape = input_shape.with_rank_at_least(2) + if tensor_shape.dimension_value(input_shape[-1]) is None: + raise ValueError( + 'The innermost dimension of input_shape must be defined, but saw: %s' + % input_shape) + return input_shape[:-1].concatenate(self.units) + + def get_config(self): + config = { + 'units': self.units, + 'activation': activations.serialize(self.activation), + 'use_bias': self.use_bias, + 'kernel_initializer': initializers.serialize(self.kernel_initializer), + 'bias_initializer': initializers.serialize(self.bias_initializer), + 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer), + 'bias_regularizer': regularizers.serialize(self.bias_regularizer), + 'activity_regularizer': + regularizers.serialize(self.activity_regularizer), + 'kernel_constraint': constraints.serialize(self.kernel_constraint), + 'bias_constraint': constraints.serialize(self.bias_constraint) + } + base_config = super(Dense, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + # Network definition starts here inputs = Input(shape=(None,84,84,3)) @@ -413,7 +598,7 @@ def get_config(self): L1_layer = Lambda(lambda tensors:K.binary_crossentropy(tensors[0], tensors[1])) print(encoded_l,encoded_rb_scale) L1_distance = L1_layer([encoded_l, encoded_rb_scale]) - prediction = Dense(1, name = 'dense_siamese')(L1_distance) + prediction = Dense_plasticity(1, name = 'dense_siamese')(L1_distance) else: # Add a customized layer to compute the absolute difference between the encodings if args.square_siamese: @@ -421,7 +606,7 @@ def get_config(self): else: L1_layer = Lambda(lambda tensors:K.abs(tensors[0] - tensors[1])) L1_distance = L1_layer([encoded_l, encoded_rb_scale]) - prediction = Dense(1, name = 'dense_siamese')(L1_distance) + prediction = Dense_plasticity(1, name = 'dense_siamese')(L1_distance) # Connect the inputs with the outputs if args.load_subnet: