From 96189d95405de2b44833299bf41996d5cfe1bb8f Mon Sep 17 00:00:00 2001 From: Magellnea Date: Sun, 31 Aug 2014 18:39:53 +0300 Subject: [PATCH] ENH Implementing multilabel confusion matrix FIX Use structured array for access label of multilabel binarized confusion matrix. DOC Minor DOC test fix. FIX Clean up and minor fixes. FIX map function from python 2 to python 3 --- sklearn/metrics/__init__.py | 2 + sklearn/metrics/classification.py | 64 +++++++++++++++++++++++++++++++ sklearn/metrics/metrics.py | 34 ++++++++++++++++ 3 files changed, 100 insertions(+) mode change 100644 => 100755 sklearn/metrics/classification.py create mode 100644 sklearn/metrics/metrics.py diff --git a/sklearn/metrics/__init__.py b/sklearn/metrics/__init__.py index 413831939fbbc..cfa10d1916fa7 100644 --- a/sklearn/metrics/__init__.py +++ b/sklearn/metrics/__init__.py @@ -16,6 +16,7 @@ from .classification import accuracy_score from .classification import classification_report from .classification import cohen_kappa_score +from .classification import binarized_multilabel_confusion_matrix from .classification import confusion_matrix from .classification import f1_score from .classification import fbeta_score @@ -70,6 +71,7 @@ 'classification_report', 'cluster', 'completeness_score', + 'binarized_multilabel_confusion_matrix', 'confusion_matrix', 'consensus_score', 'coverage_error', diff --git a/sklearn/metrics/classification.py b/sklearn/metrics/classification.py old mode 100644 new mode 100755 index ee07fa634d080..eee8f1b702537 --- a/sklearn/metrics/classification.py +++ b/sklearn/metrics/classification.py @@ -275,6 +275,70 @@ def confusion_matrix(y_true, y_pred, labels=None, sample_weight=None): return CM +def binarized_multilabel_confusion_matrix(y_true, y_pred): + """Compute True positive, False positive, True negative, False negative + for a multilabel classification problem + + Parameters + ---------- + y_true : array, shape = [n_samples] + Ground truth (correct) target values. + + y_pred : array, shape = [n_samples] + Estimated targets as returned by a classifier. + + Returns + ------- + C : array, shape = [n_classes, ] + where you can access the value by + using keys 'tp', 'fp', 'fn' ,'tn' + for example: + C['tp'] returns an array a, where + a[i] contains the true positives + for the class a + Multi-label Confusion matrix + + References + ---------- + .. [1] `Wikipedia entry for the Confusion matrix + `_ + [2] http://www.cnts.ua.ac.be/~vincent/pdf/microaverage.pdf + + Examples + -------- + >>> from sklearn.metrics import binarized_multilabel_confusion_matrix + >>> y_true = np.array([[1, 0], [0, 1]]) + >>> y_pred = np.array([[1, 1], [1, 0]]) + >>> binarized_multilabel_confusion_matrix(y_true, y_pred)['tp'] + array([1, 0]) + """ + y_type, y_true, y_pred = _check_targets(y_true, y_pred) + if not (y_type == 'multilabel-indicator'): + raise ValueError("%s is not supported" % y_type) + + n_labels = y_true.get_shape()[1] + data = np.array([]) + for label_idx in range(0, n_labels): + y_pred_col = y_pred.getcol(label_idx) + y_true_col = y_true.getcol(label_idx) + # tp can be get by dot product + t_pos = y_true_col.transpose().dot(y_pred_col).toarray()[0][0] + # fp are the ones in y_pred that + # match zeros in y_true + f_pos = y_pred_col.getnnz() - t_pos + f_neg = y_true_col.getnnz() - t_pos + zeros = y_true_col.get_shape()[0] - y_true_col.getnnz() + t_neg = zeros - f_pos + data = np.hstack([data, [t_pos, f_pos, f_neg, t_neg]]) + rows = np.tile([0, 1, 2, 3], n_labels) + columns = np.repeat(range(0, n_labels), 4) + mcm = coo_matrix((data, (rows, columns)), shape=(4, n_labels)).\ + toarray() + return (np.array(list(map(tuple, np.transpose(mcm))), + dtype=[('tp', int), ('fp', int), + ('fn', int), ('tn', int)])) + + def cohen_kappa_score(y1, y2, labels=None, weights=None): """Cohen's kappa: a statistic that measures inter-annotator agreement. diff --git a/sklearn/metrics/metrics.py b/sklearn/metrics/metrics.py new file mode 100644 index 0000000000000..cb83067d42465 --- /dev/null +++ b/sklearn/metrics/metrics.py @@ -0,0 +1,34 @@ +import warnings +warnings.warn("sklearn.metrics.metrics is deprecated and will be removed in " + "0.18. Please import from sklearn.metrics", + DeprecationWarning) + + +from .ranking import auc +from .ranking import average_precision_score +from .ranking import label_ranking_average_precision_score +from .ranking import precision_recall_curve +from .ranking import roc_auc_score +from .ranking import roc_curve + +from .classification import accuracy_score +from .classification import classification_report +from .classification import binarized_multilabel_confusion_matrix +from .classification import confusion_matrix +from .classification import f1_score +from .classification import fbeta_score +from .classification import hamming_loss +from .classification import hinge_loss +from .classification import jaccard_similarity_score +from .classification import log_loss +from .classification import matthews_corrcoef +from .classification import precision_recall_fscore_support +from .classification import precision_score +from .classification import recall_score +from .classification import zero_one_loss + +from .regression import explained_variance_score +from .regression import mean_absolute_error +from .regression import mean_squared_error +from .regression import median_absolute_error +from .regression import r2_score