Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions sklearn/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from .classification import accuracy_score
from .classification import classification_report
from .classification import binarized_multilabel_confusion_matrix
from .classification import confusion_matrix
from .classification import f1_score
from .classification import fbeta_score
Expand Down Expand Up @@ -64,6 +65,7 @@
'classification_report',
'cluster',
'completeness_score',
'binarized_multilabel_confusion_matrix',
'confusion_matrix',
'consensus_score',
'euclidean_distances',
Expand Down
64 changes: 64 additions & 0 deletions sklearn/metrics/classification.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,70 @@ def accuracy_score(y_true, y_pred, normalize=True, sample_weight=None):
return _weighted_sum(score, sample_weight, normalize)


def binarized_multilabel_confusion_matrix(y_true, y_pred):
"""Compute True positive, False positive, True negative, False negative
for a multilabel classification problem

Parameters
----------
y_true : array, shape = [n_samples]
Ground truth (correct) target values.

y_pred : array, shape = [n_samples]
Estimated targets as returned by a classifier.

Returns
-------
C : array, shape = [n_classes, ]
where you can access the value by
using keys 'tp', 'fp', 'fn' ,'tn'
for example:
C['tp'] returns an array a, where
a[i] contains the true positives
for the class a
Multi-label Confusion matrix

References
----------
.. [1] `Wikipedia entry for the Confusion matrix
<http://en.wikipedia.org/wiki/Sensitivity_and_specificity>`_
[2] http://www.cnts.ua.ac.be/~vincent/pdf/microaverage.pdf

Examples
--------
>>> from sklearn.metrics import binarized_multilabel_confusion_matrix
>>> y_true = np.array([[1, 0], [0, 1]])
>>> y_pred = np.array([[1, 1], [1, 0]])
>>> binarized_multilabel_confusion_matrix(y_true, y_pred)['tp']
array([1, 0])
"""
y_type, y_true, y_pred = _check_targets(y_true, y_pred)
if not (y_type == 'multilabel-indicator'):
raise ValueError("%s is not supported" % y_type)

n_labels = y_true.get_shape()[1]
data = np.array([])
for label_idx in range(0, n_labels):
y_pred_col = y_pred.getcol(label_idx)
y_true_col = y_true.getcol(label_idx)
# tp can be get by dot product
t_pos = y_true_col.transpose().dot(y_pred_col).toarray()[0][0]
# fp are the ones in y_pred that
# match zeros in y_true
f_pos = y_pred_col.getnnz() - t_pos
f_neg = y_true_col.getnnz() - t_pos
zeros = y_true_col.get_shape()[0] - y_true_col.getnnz()
t_neg = zeros - f_pos
data = np.hstack([data, [t_pos, f_pos, f_neg, t_neg]])
rows = np.tile([0, 1, 2, 3], n_labels)
columns = np.repeat(range(0, n_labels), 4)
mcm = coo_matrix((data, (rows, columns)), shape=(4, n_labels)).\
toarray()
return (np.array(list(map(tuple, np.transpose(mcm))),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure about the use of struct arrays for return values, but in any case, list(map(tuple, ...)) is much more than needed. You should just be able to use view on the existing array and provide your struct dtype.

dtype=[('tp', int), ('fp', int),
('fn', int), ('tn', int)]))


def confusion_matrix(y_true, y_pred, labels=None):
"""Compute confusion matrix to evaluate the accuracy of a classification

Expand Down
1 change: 1 addition & 0 deletions sklearn/metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from .classification import accuracy_score
from .classification import classification_report
from .classification import binarized_multilabel_confusion_matrix
from .classification import confusion_matrix
from .classification import f1_score
from .classification import fbeta_score
Expand Down