diff --git a/tensorflow_addons/metrics/matthews_correlation_coefficient.py b/tensorflow_addons/metrics/matthews_correlation_coefficient.py index 2e1fb25a30..ed1b78673e 100644 --- a/tensorflow_addons/metrics/matthews_correlation_coefficient.py +++ b/tensorflow_addons/metrics/matthews_correlation_coefficient.py @@ -69,7 +69,7 @@ def __init__( ): """Creates a Matthews Correlation Coefficient instance.""" super().__init__(name=name, dtype=dtype) - self.num_classes = num_classes + self.num_classes = max(2, num_classes) self.conf_mtx = self.add_weight( "conf_mtx", shape=(self.num_classes, self.num_classes), @@ -82,9 +82,19 @@ def update_state(self, y_true, y_pred, sample_weight=None): y_true = tf.cast(y_true, dtype=self.dtype) y_pred = tf.cast(y_pred, dtype=self.dtype) + if y_true.shape[-1] == 1: + labels = tf.squeeze(tf.round(y_true), axis=-1) + else: + labels = tf.argmax(y_true, 1) + + if y_pred.shape[-1] == 1: + predictions = tf.squeeze(tf.round(y_pred), axis=-1) + else: + predictions = tf.argmax(y_pred, 1) + new_conf_mtx = tf.math.confusion_matrix( - labels=tf.argmax(y_true, 1), - predictions=tf.argmax(y_pred, 1), + labels=labels, + predictions=predictions, num_classes=self.num_classes, weights=sample_weight, dtype=self.dtype, @@ -126,7 +136,4 @@ def reset_states(self): """Resets all of the metric state variables.""" for v in self.variables: - K.set_value( - v, - np.zeros((self.num_classes, self.num_classes), v.dtype.as_numpy_dtype), - ) + K.set_value(v, np.zeros(v.shape, v.dtype.as_numpy_dtype)) diff --git a/tensorflow_addons/metrics/tests/matthews_correlation_coefficient_test.py b/tensorflow_addons/metrics/tests/matthews_correlation_coefficient_test.py index 1d5ad0f062..27551659b5 100644 --- a/tensorflow_addons/metrics/tests/matthews_correlation_coefficient_test.py +++ b/tensorflow_addons/metrics/tests/matthews_correlation_coefficient_test.py @@ -23,12 +23,12 @@ def test_config(): # mcc object - mcc1 = MatthewsCorrelationCoefficient(num_classes=1) - assert mcc1.num_classes == 1 + mcc1 = MatthewsCorrelationCoefficient(num_classes=2) + assert mcc1.num_classes == 2 assert mcc1.dtype == tf.float32 # check configure mcc2 = MatthewsCorrelationCoefficient.from_config(mcc1.get_config()) - assert mcc2.num_classes == 1 + assert mcc2.num_classes == 2 assert mcc2.dtype == tf.float32 @@ -36,6 +36,17 @@ def check_results(obj, value): np.testing.assert_allclose(value, obj.result().numpy(), atol=1e-6) +def test_binary_classes_sparse(): + gt_label = tf.constant([[1.0], [1.0], [1.0], [0.0]], dtype=tf.float32) + preds = tf.constant([[1.0], [0.0], [1.0], [1.0]], dtype=tf.float32) + # Initialize + mcc = MatthewsCorrelationCoefficient(1) + # Update + mcc.update_state(gt_label, preds) + # Check results + check_results(mcc, [-0.33333334]) + + def test_binary_classes(): gt_label = tf.constant( [[0.0, 1.0], [0.0, 1.0], [0.0, 1.0], [1.0, 0.0]], dtype=tf.float32 @@ -91,6 +102,16 @@ def test_multiple_classes(): sklearn_result = sklearn_matthew(gt_label.argmax(axis=1), preds.argmax(axis=1)) check_results(mcc, sklearn_result) + gt_label_sparse = tf.constant( + [[0.0], [2.0], [0.0], [2.0], [1.0], [1.0], [0.0], [0.0], [2.0], [1.0]] + ) + preds_sparse = tf.constant( + [[2.0], [0.0], [2.0], [2.0], [2.0], [2.0], [2.0], [0.0], [2.0], [2.0]] + ) + mcc = MatthewsCorrelationCoefficient(3) + mcc.update_state(gt_label_sparse, preds_sparse) + check_results(mcc, sklearn_result) + # Keras model API check def test_keras_model():