-
Notifications
You must be signed in to change notification settings - Fork 17
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #4 from davidwarshaw/dev
Metrics complete.
- Loading branch information
Showing
8 changed files
with
264 additions
and
56 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
""" | ||
Metrics for evaluating hierachical multi-classification performance. | ||
""" | ||
|
||
from __future__ import print_function | ||
from __future__ import division | ||
|
||
from sklearn import tree | ||
from sklearn import metrics as skmetrics | ||
from sklearn.utils import check_consistent_length | ||
from sklearn.utils import column_or_1d | ||
from sklearn.utils.multiclass import type_of_target | ||
|
||
from itertools import chain | ||
|
||
import numpy as np | ||
import pandas as pd | ||
|
||
def _check_targets_hmc(y_true, y_pred): | ||
check_consistent_length(y_true, y_pred) | ||
y_type = set([type_of_target(y_true), type_of_target(y_pred)]) | ||
if y_type == set(["binary", "multiclass"]): | ||
y_type = set(["multiclass"]) | ||
if y_type != set(["multiclass"]): | ||
raise ValueError("{0} is not supported".format(y_type)) | ||
y_true = column_or_1d(y_true) | ||
y_pred = column_or_1d(y_pred) | ||
return y_true, y_pred | ||
|
||
## General Scores | ||
# Average accuracy | ||
def accuracy_score(class_hierarchy, y_true, y_pred): | ||
y_true, y_pred = _check_targets_hmc(y_true, y_pred) | ||
return skmetrics.accuracy_score(y_true, y_pred) | ||
|
||
## Hierarchy Precision / Recall | ||
def _aggregate_class_sets(set_function, y_true, y_pred): | ||
intersection_sum = 0 | ||
true_sum = 0 | ||
predicted_sum = 0 | ||
for true, pred in zip(y_true.tolist(), y_pred.tolist()): | ||
true_set = set([true] + set_function(true)) | ||
pred_set = set([pred] + set_function(pred)) | ||
intersection_sum += len(true_set.intersection(pred_set)) | ||
true_sum += len(true_set) | ||
predicted_sum += len(pred_set) | ||
return (true_sum, predicted_sum, intersection_sum) | ||
|
||
# Ancestors Scores (Super Class) | ||
# Precision | ||
def precision_score_ancestors(class_hierarchy, y_true, y_pred): | ||
y_true, y_pred = _check_targets_hmc(y_true, y_pred) | ||
true_sum, predicted_sum, intersection_sum = _aggregate_class_sets(class_hierarchy._get_ancestors, y_true, y_pred) | ||
return intersection_sum / predicted_sum | ||
|
||
# Recall | ||
def recall_score_ancestors(class_hierarchy, y_true, y_pred): | ||
y_true, y_pred = _check_targets_hmc(y_true, y_pred) | ||
true_sum, predicted_sum, intersection_sum = _aggregate_class_sets(class_hierarchy._get_ancestors, y_true, y_pred) | ||
return intersection_sum / true_sum | ||
|
||
# Descendants Scores (Sub Class) | ||
# Precision | ||
def precision_score_descendants(class_hierarchy, y_true, y_pred): | ||
y_true, y_pred = _check_targets_hmc(y_true, y_pred) | ||
true_sum, predicted_sum, intersection_sum = _aggregate_class_sets(class_hierarchy._get_descendants, y_true, y_pred) | ||
return intersection_sum / predicted_sum | ||
|
||
# Recall | ||
def recall_score_descendants(class_hierarchy, y_true, y_pred): | ||
y_true, y_pred = _check_targets_hmc(y_true, y_pred) | ||
true_sum, predicted_sum, intersection_sum = _aggregate_class_sets(class_hierarchy._get_descendants, y_true, y_pred) | ||
return intersection_sum / true_sum | ||
|
||
# Hierarchy Fscore | ||
def _fbeta_score_class_sets(set_function, y_true, y_pred, beta=1): | ||
y_true, y_pred = _check_targets_hmc(y_true, y_pred) | ||
true_sum, predicted_sum, intersection_sum = _aggregate_class_sets(set_function, y_true, y_pred) | ||
precision = intersection_sum / predicted_sum | ||
recall = intersection_sum / true_sum | ||
return ((beta ** 2 + 1) * precision * recall) / ((beta ** 2 * precision) + recall) | ||
|
||
def f1_score_ancestors(class_hierarchy, y_true, y_pred): | ||
y_true, y_pred = _check_targets_hmc(y_true, y_pred) | ||
return _fbeta_score_class_sets(class_hierarchy._get_ancestors, y_true, y_pred) | ||
|
||
def f1_score_descendants(class_hierarchy, y_true, y_pred): | ||
y_true, y_pred = _check_targets_hmc(y_true, y_pred) | ||
return _fbeta_score_class_sets(class_hierarchy._get_descendants, y_true, y_pred) | ||
|
||
# # Classification Report | ||
# def classification_report(class_hierarchy, y_true, y_pred): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,5 +7,5 @@ | |
description='Decision tree based hierachical multi-classifier', | ||
author='David Warshaw', | ||
author_email='[email protected]', | ||
py_modules=['hmc', 'datasets'], | ||
py_modules=['hmc', 'datasets', 'metrics'], | ||
requires=['sklearn', 'numpy', 'pandas']) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
from .test_hmc import TestClassHierarchy | ||
from .test_hmc import TestDecisionTreeHierarchicalClassifier | ||
from .test_metrics import TestMetrics |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
""" | ||
Tests for the hmc metrics module. | ||
""" | ||
|
||
import unittest | ||
|
||
import pandas as pd | ||
|
||
from sklearn import tree | ||
from sklearn.cross_validation import train_test_split | ||
from sklearn import metrics as skmetrics | ||
|
||
import hmc | ||
import hmc.metrics as metrics | ||
|
||
class TestMetrics(unittest.TestCase): | ||
|
||
def setUp(self): | ||
self.ch = hmc.load_shades_class_hierachy() | ||
self.X, self.y = hmc.load_shades_data() | ||
self.X_train, self.X_test, self.y_train, self.y_test = train_test_split(self.X, self.y, | ||
test_size=0.50, random_state=0) | ||
self.dt = hmc.DecisionTreeHierarchicalClassifier(self.ch) | ||
self.dt_nonh = tree.DecisionTreeClassifier() | ||
self.dt = self.dt.fit(self.X_train, self.y_train) | ||
self.dt_nonh = self.dt_nonh.fit(self.X_train, self.y_train) | ||
self.y_pred = self.dt.predict(self.X_test) | ||
self.y_pred_nonh = self.dt_nonh.predict(self.X_test) | ||
|
||
## General Scores | ||
# Average accuracy | ||
def test_accuracy_score(self): | ||
accuracy = metrics.accuracy_score(self.ch, self.y_test, self.y_pred) | ||
accuracy_sk = skmetrics.accuracy_score(self.y_test, self.y_pred) | ||
# Hierachical classification should be at least as accurate as traditional classification | ||
self.assertTrue(accuracy >= accuracy_sk) | ||
|
||
## Hierarchy Precision / Recall | ||
# Ancestors Scores (Super Class) | ||
# Precision | ||
def test_precision_score_ancestors(self): | ||
precision_ancestors = metrics.precision_score_ancestors(self.ch, self.y_test, self.y_pred) | ||
precision_sk = skmetrics.precision_score(self.y_test, self.y_pred, average="macro") | ||
self.assertTrue(precision_ancestors >= precision_sk) | ||
|
||
# Recall | ||
def test_recall_score_ancestors(self): | ||
recall_ancestors = metrics.recall_score_ancestors(self.ch, self.y_test, self.y_pred) | ||
recall_sk = skmetrics.recall_score(self.y_test, self.y_pred, average="macro") | ||
self.assertTrue(recall_ancestors >= recall_sk) | ||
|
||
# Descendants Scores (Sub Class) | ||
# Precision | ||
def test_precision_score_descendants(self): | ||
precision_descendants = metrics.precision_score_descendants(self.ch, self.y_test, self.y_pred) | ||
precision_sk = skmetrics.precision_score(self.y_test, self.y_pred, average="macro") | ||
self.assertTrue(precision_descendants >= precision_sk) | ||
|
||
# Recall | ||
def test_recall_score_descendants(self): | ||
recall_descendants = metrics.recall_score_descendants(self.ch, self.y_test, self.y_pred) | ||
recall_sk = skmetrics.recall_score(self.y_test, self.y_pred, average="macro") | ||
self.assertTrue(recall_descendants >= recall_sk) | ||
|
||
# F1 | ||
# Ancestors | ||
def test_f1_score_ancestors(self): | ||
f1_ancestors = metrics.f1_score_ancestors(self.ch, self.y_test, self.y_pred) | ||
f1_sk = skmetrics.f1_score(self.y_test, self.y_pred, average="macro") | ||
self.assertTrue(f1_ancestors >= f1_sk) | ||
|
||
# Descendants | ||
def test_f1_score_descendants(self): | ||
f1_descendants = metrics.f1_score_descendants(self.ch, self.y_test, self.y_pred) | ||
f1_sk = skmetrics.f1_score(self.y_test, self.y_pred, average="macro") | ||
self.assertTrue(f1_descendants >= f1_sk) | ||
|
||
if __name__ == '__main__': | ||
unittest.main() |