Skip to content

[MRG] Support for multi-class roc_auc scores #7663

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 31 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
00a0626
better implementation of the multiclass logic (in terms of design). d…
kathyxchen Oct 12, 2016
8a84578
ovr and associated testing
kathyxchen Oct 13, 2016
485fd59
some testing implemented for the value errors, but not yet comprehensive
kathyxchen Oct 13, 2016
2ac42c2
implemented ovr with the multilabelbinarizer
kathyxchen Oct 19, 2016
4e6141f
removed the ovr implementation that was in the base.py function
kathyxchen Oct 19, 2016
7bd899e
lots more code cleanup
kathyxchen Oct 19, 2016
f4fb56f
pending, need more test cases
kathyxchen Oct 19, 2016
dd5c06a
making changes in response to PR: remove unused variable and added in…
Oct 26, 2016
91b1428
making a change to one of the rst files for documenting the multiclas…
Oct 26, 2016
3d4d065
making a change to one of the rst files for documenting the multiclas…
kathyxchen Oct 26, 2016
e037993
added a valueerror test case after checking code coverage for new fun…
kathyxchen Oct 26, 2016
acb977e
sample_weight can only be None, documentation update
Nov 20, 2016
8dd9665
model_evaluation documentation update
Nov 21, 2016
7f652aa
docstring update in _average_multiclass_ovo_score
Nov 30, 2016
4016c0c
update documentation for multiclass base function and test
Nov 30, 2016
86327d9
updated the documentation with equations and citations
Dec 1, 2016
271b882
improve the test cases for one-vs-one multiclass roc auc
Dec 6, 2016
d70ae6c
ovo uses bincount and ovr uses labelbinarizer
Dec 7, 2016
bf8c5fe
fixed a coefficient bug in the weighted HT2001 algorithm and refactor…
Dec 7, 2016
ed7e840
update the docs with the correct equation
Dec 7, 2016
b2214c8
updating the plot_roc example with plots for one vs one
Dec 7, 2016
d2aa2a0
updating plot_roc with roc_auc_score functions
Dec 10, 2016
fde6387
updating with some style changes and including the invariant under pe…
Mar 14, 2017
12592f4
flake8 on plot_roc
Mar 14, 2017
b4e498e
over-indent flake8 fix
Mar 14, 2017
5688ade
fixed the normalization equation for ovo
Mar 26, 2017
a784dbc
beginning the update to examples, needs to be tested
Mar 26, 2017
0138a75
updating the documentation for model_evaluation with new citations
kathyxchen Apr 27, 2017
ad5e93b
fix flake8 error in plot_roc
kathyxchen Apr 27, 2017
165513a
update with sample weights in ovr case
kathyxchen Apr 27, 2017
9530511
modifications to plot_roc example to improve readability, fixed one bug
Jun 7, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 55 additions & 5 deletions doc/modules/model_evaluation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -252,15 +252,16 @@ Some also work in the multilabel case:
precision_recall_fscore_support
precision_score
recall_score
roc_auc_score
zero_one_loss

And some work with binary and multilabel (but not multiclass) problems:

Some work with binary and multilabel (but not multiclass) problems:

.. autosummary::
:template: function.rst

average_precision_score
roc_auc_score


In the following sub-sections, we will describe each of those functions,
Expand Down Expand Up @@ -976,10 +977,41 @@ In multi-label classification, the :func:`roc_auc_score` function is
extended by averaging over the labels as :ref:`above <average>`.

Compared to metrics such as the subset accuracy, the Hamming loss, or the
F1 score, ROC doesn't require optimizing a threshold for each label. The
:func:`roc_auc_score` function can also be used in multi-class classification,
if the predicted outputs have been binarized.
F1 score, ROC doesn't require optimizing a threshold for each label.

The :func:`roc_auc_score` function can also be used in multi-class
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would instead briefly describe the averaging strategies.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

which of the averaging strategies was described by the authors with weights and which wasn?

classification. [F2009]_ Two averaging strategies are currently supported: the
one-vs-one algorithm computes the average of the pairwise
ROC AUC scores, and the one-vs-rest algorithm
computes the average of the ROC AUC scores for each class against
all other classes. In both cases, the predicted class labels are provided in
an array with values from 0 to ``n_classes``, and the scores correspond to the
probability estimates that a sample belongs to a particular class.

**One-vs-one Algorithm**
The AUC of each class against each other, computing
the AUC of all possible pairwise combinations :math:`c(c-1)` for a
:math:`c`-dimensional classifier.

[HT2001]_ Using the uniform class distribution:

.. math:: \frac{1}{c(c-1)}\sum_{j=1}^c\sum_{k \neq j}^c \textnormal{AUC}(j, k)

[F2009]_ Weighted by the prevalence of classes `j` and `k`:

.. math:: \frac{1}{c-1}\sum_{j=1}^c\sum_{k \neq j}^c p(j \cup k)\textnormal{AUC}(j, k)

**One-vs-rest Algorithm**
AUC of each class against the rest. This treats
a :math:`c`-dimensional classifier as :math:`c` two-dimensional classifiers.

[F2006]_ Using the uniform class distribution:

.. math:: \frac{\sum_{j=1}^c \textnormal{AUC}(j, \textnormal{rest}_j)}{c}

[F2001]_ Weighted by the a priori class distribution:

.. math:: \frac{\sum_{j=1}^c p(j)\textnormal{AUC}(j, \textnormal{rest}_j)}{c}

.. image:: ../auto_examples/model_selection/images/sphx_glr_plot_roc_002.png
:target: ../auto_examples/model_selection/plot_roc.html
Expand All @@ -1000,6 +1032,24 @@ if the predicted outputs have been binarized.
for an example of using ROC to
model species distribution.

.. topic:: References:

.. [F2001] Fawcett, T., 2001. `Using rule sets to maximize
ROC performance <http://ieeexplore.ieee.org/document/989510/>`_
In Data Mining, 2001.
Proceedings IEEE International Conference, pp. 131-138.
.. [F2006] Fawcett, T., 2006. `An introduction to ROC analysis.
<http://www.sciencedirect.com/science/article/pii/S016786550500303X>`_
Pattern Recognition Letters, 27(8), pp. 861-874.
.. [F2009] Ferri, C., Hernandez-Orallo, J., and Modroiu, R., 2009.
`An experimental comparison of performance measures for classification.
<http://www.sciencedirect.com/science/article/pii/S0167865508002687>`_
Pattern Recognition Letters, 30(1), pp. 27-38.
.. [HT2001] Hand, D.J. and Till, R.J., 2001. `A simple generalisation
of the area under the ROC curve for multiple class classification problems.
<http://link.springer.com/article/10.1023/A:1010920819831>`_
Machine learning, 45(2), pp.171-186.

.. _zero_one_loss:

Zero one loss
Expand Down
116 changes: 100 additions & 16 deletions examples/model_selection/plot_roc.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,39 @@
-------------------

ROC curves are typically used in binary classification to study the output of
a classifier. In order to extend ROC curve and ROC area to multi-class
or multi-label classification, it is necessary to binarize the output. One ROC
curve can be drawn per label, but one can also draw a ROC curve by considering
a classifier. Extensions of ROC curve and ROC area to multi-class
or multi-label classification can use the One-vs-Rest or One-vs-One scheme.

One-vs-Rest
-----------

The output is binarized and one ROC curve is drawn per label,
where label is set to be the positive class and all other labels (the "rest")
are considered the negative class.

The ROC area can be approximated by taking the average--unweighted or weighted
by the a priori class distribution--of the one-vs-rest ROC areas.

One can also draw a ROC curve by considering
each element of the label indicator matrix as a binary prediction
(micro-averaging).

Another evaluation measure for multi-class classification is
Another evaluation measure for one-vs-rest multi-class classification is
macro-averaging, which gives equal weight to the classification of each
label.

One-vs-One
----------

Two ROC curves can be drawn per pair of labels because either of the two
labels can be considered the positive class (and the other the negative
class). The ROC area of a label pair is approximated taking the average of
these two ROC AUC scores.

The One-vs-One approximation of a multi-class ROC AUC score is the average--
unweighted or weighted by class prevalence--across all of the pairwise
approximate ROC AUC scores.

.. note::

See also :func:`sklearn.metrics.roc_auc_score`,
Expand All @@ -39,10 +62,10 @@

import numpy as np
import matplotlib.pyplot as plt
from itertools import cycle
from itertools import combinations, cycle

from sklearn import svm, datasets
from sklearn.metrics import roc_curve, auc
from sklearn.metrics import roc_curve, auc, roc_auc_score
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import label_binarize
from sklearn.multiclass import OneVsRestClassifier
Expand All @@ -53,9 +76,8 @@
X = iris.data
y = iris.target

# Binarize the output
y = label_binarize(y, classes=[0, 1, 2])
n_classes = y.shape[1]
classes = np.unique(y)
n_classes = len(classes)

# Add noisy features to make the problem harder
random_state = np.random.RandomState(0)
Expand All @@ -72,17 +94,17 @@
y_score = classifier.fit(X_train, y_train).decision_function(X_test)

# Compute ROC curve and ROC area for each class

# Binarize y_test to compute the ROC curve
y_test_binarized = label_binarize(y_test, classes=classes)

fpr = dict()
tpr = dict()
roc_auc = dict()
for i in range(n_classes):
fpr[i], tpr[i], _ = roc_curve(y_test[:, i], y_score[:, i])
fpr[i], tpr[i], _ = roc_curve(y_test_binarized[:, i], y_score[:, i])
roc_auc[i] = auc(fpr[i], tpr[i])

# Compute micro-average ROC curve and ROC area
fpr["micro"], tpr["micro"], _ = roc_curve(y_test.ravel(), y_score.ravel())
roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])


##############################################################################
# Plot of a ROC curve for a specific class
Expand All @@ -101,7 +123,12 @@


##############################################################################
# Plot ROC curves for the multiclass problem
# Plot ROC curves for the multiclass problem using One vs. Rest classification.

# Compute micro-average ROC curve and ROC area
fpr["micro"], tpr["micro"], _ = roc_curve(
y_test_binarized.ravel(), y_score.ravel())
roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])

# Compute macro-average ROC curve and ROC area

Expand Down Expand Up @@ -143,6 +170,63 @@
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Some extension of Receiver operating characteristic to multi-class')
plt.title('An extension of ROC to multi-class '
'using One-vs-Rest')
plt.legend(loc="lower right")
plt.show()

# Compute the One-vs-Rest ROC AUC score, weighted and unweighted
unweighted_roc_auc_ovr = roc_auc_score(y_test, y_score, multiclass="ovr")
weighted_roc_auc_ovr = roc_auc_score(
y_test, y_score, multiclass="ovr", average="weighted")
print("One-vs-Rest ROC AUC scores: {0} (unweighted), {1} (weighted)".format(
unweighted_roc_auc_ovr, weighted_roc_auc_ovr))

##############################################################################
# Plot ROC curves for the multiclass problem using One vs. One classification.

for a, b in combinations(range(n_classes), 2):
# Filter `y_test` and `y_score` to only consider the current
# `a` and `b` class pair.
ab_mask = np.logical_or(y_test == a, y_test == b)
y_true_filtered = y_test[ab_mask]
y_score_filtered = y_score[ab_mask]

# Compute ROC curve and ROC area with `a` as the positive class
class_a = y_true_filtered == a
fpr[(a, b)], tpr[(a, b)], _ = roc_curve(
class_a, y_score_filtered[:, a])
roc_auc[(a, b)] = auc(fpr[(a, b)], tpr[(a, b)])

# Compute ROC curve and ROC area with `b` as the positive class
class_b = y_true_filtered == b
fpr[(b, a)], tpr[(b, a)], _ = roc_curve(
class_b, y_score_filtered[:, b])
roc_auc[(b, a)] = auc(fpr[(b, a)], tpr[(b, a)])

plt.figure()
for a, b in combinations(range(n_classes), 2):
plt.plot(fpr[(a, b)], tpr[(a, b)], lw=lw,
label='ROC curve: class {0} vs. {1} '
'(area = {2:0.2f})'.format(
a, b, roc_auc[(a, b)]))
plt.plot(fpr[(b, a)], tpr[(b, a)], lw=lw,
label='ROC curve: class {0} vs. {1} '
'(area = {2:0.2f})'.format(
b, a, roc_auc[(b, a)]))
plt.plot([0, 1], [0, 1], 'k--', lw=lw)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('An extension of ROC to multi-class '
'using One-vs-One')
plt.legend(bbox_to_anchor=(1.1, 0.30))
plt.show()

# Compute the One-vs-One ROC AUC score, weighted and unweighted
unweighted_roc_auc_ovo = roc_auc_score(y_test, y_score, multiclass="ovo")
weighted_roc_auc_ovo = roc_auc_score(
y_test, y_score, multiclass="ovo", average="weighted")
print("One-vs-One ROC AUC scores: {0} (unweighted), {1} (weighted)".format(
unweighted_roc_auc_ovo, weighted_roc_auc_ovo))
67 changes: 67 additions & 0 deletions sklearn/metrics/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# License: BSD 3 clause

from __future__ import division
import itertools

import numpy as np

Expand Down Expand Up @@ -131,3 +132,69 @@ def _average_binary_score(binary_metric, y_true, y_score, average,
return np.average(score, weights=average_weight)
else:
return score


def _average_multiclass_ovo_score(binary_metric, y_true, y_score, average):
"""Uses the binary metric for one-vs-one multiclass classification,
where the score is computed according to the Hand & Till (2001) algorithm.

Parameters
----------
y_true : array, shape = [n_samples]
True multiclass labels.
Assumes labels have been recoded to 0 to n_classes.

y_score : array, shape = [n_samples, n_classes]
Target scores corresponding to probability estimates of a sample
belonging to a particular class

average : 'macro' or 'weighted', default='macro'
``'macro'``:
Calculate metrics for each label, and find their unweighted
mean. This does not take label imbalance into account. Classes
are assumed to be uniformly distributed.
``'weighted'``:
Calculate metrics for each label, taking into account the a priori
distribution of the classes.

binary_metric : callable, the binary metric function to use.
Accepts the following as input
y_true_target : array, shape = [n_samples_target]
Some sub-array of y_true for a pair of classes designated
positive and negative in the one-vs-one scheme.
y_score_target : array, shape = [n_samples_target]
Scores corresponding to the probability estimates
of a sample belonging to the designated positive class label

Returns
-------
score : float
Average the sum of pairwise binary metric scores
"""
n_classes = len(np.unique(y_true))
n_pairs = n_classes * (n_classes - 1) // 2
prevalence = np.empty(n_pairs)
pair_scores = np.empty(n_pairs)

ix = 0
for a, b in itertools.combinations(range(n_classes), 2):
a_mask = y_true == a
ab_mask = np.logical_or(a_mask, y_true == b)

prevalence[ix] = np.sum(ab_mask) / len(y_true)

y_score_filtered = y_score[ab_mask]

a_true = a_mask[ab_mask]
b_true = np.logical_not(a_true)

a_true_score = binary_metric(
a_true, y_score_filtered[:, a])
b_true_score = binary_metric(
b_true, y_score_filtered[:, b])
binary_avg_score = (a_true_score + b_true_score) / 2
pair_scores[ix] = binary_avg_score

ix += 1
return (np.average(pair_scores, weights=prevalence)
if average == "weighted" else np.average(pair_scores))
Loading