-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
[MRG] Multi-class roc_auc_score #10481
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
a666180
118a700
3371b1d
d74ce16
805d804
2bd693e
fc54dde
133a09a
bc40110
0d035e3
d08f084
4c7a656
4723b00
aa6dd49
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,9 +31,10 @@ | |
from ..utils.extmath import stable_cumsum | ||
from ..utils.sparsefuncs import count_nonzero | ||
from ..exceptions import UndefinedMetricWarning | ||
from ..preprocessing import label_binarize | ||
from ..preprocessing import LabelBinarizer, label_binarize | ||
|
||
from .base import _average_binary_score | ||
from .base import _average_binary_score, _average_multiclass_ovo_score, \ | ||
_average_multiclass_ovr_score | ||
|
||
|
||
def auc(x, y, reorder='deprecated'): | ||
|
@@ -157,7 +158,8 @@ def average_precision_score(y_true, y_score, average="macro", | |
class, confidence values, or non-thresholded measure of decisions | ||
(as returned by "decision_function" on some classifiers). | ||
|
||
average : string, [None, 'micro', 'macro' (default), 'samples', 'weighted'] | ||
average : string, {None, 'micro', 'macro', 'samples', 'weighted'}, | ||
default 'macro' | ||
If ``None``, the scores for each class are returned. Otherwise, | ||
this determines the type of averaging performed on the data: | ||
|
||
|
@@ -222,29 +224,39 @@ def _binary_uninterpolated_average_precision( | |
sample_weight=sample_weight) | ||
|
||
|
||
def roc_auc_score(y_true, y_score, average="macro", sample_weight=None, | ||
max_fpr=None): | ||
"""Compute Area Under the Receiver Operating Characteristic Curve (ROC AUC) | ||
from prediction scores. | ||
|
||
Note: this implementation is restricted to the binary classification task | ||
or multilabel classification task in label indicator format. | ||
def roc_auc_score(y_true, y_score, multiclass="ovr", average="macro", | ||
sample_weight=None, max_fpr=None): | ||
"""Compute Area Under the Curve (AUC) from prediction scores. | ||
|
||
Read more in the :ref:`User Guide <roc_metrics>`. | ||
|
||
Parameters | ||
---------- | ||
y_true : array, shape = [n_samples] or [n_samples, n_classes] | ||
True binary labels or binary label indicators. | ||
True binary labels in binary label indicators. | ||
The multiclass case expects shape = [n_samples] and labels | ||
with values from 0 to (n_classes-1), inclusive. | ||
|
||
y_score : array, shape = [n_samples] or [n_samples, n_classes] | ||
Target scores, can either be probability estimates of the positive | ||
class, confidence values, or non-thresholded measure of decisions | ||
(as returned by "decision_function" on some classifiers). For binary | ||
y_true, y_score is supposed to be the score of the class with greater | ||
label. | ||
|
||
average : string, [None, 'micro', 'macro' (default), 'samples', 'weighted'] | ||
(as returned by "decision_function" on some classifiers). | ||
The multiclass case expects shape = [n_samples, n_classes] | ||
where the scores correspond to probability estimates. | ||
|
||
multiclass : string, 'ovr' or 'ovo', default 'ovr' | ||
Note: multiclass ROC AUC currently only handles the 'macro' and | ||
'weighted' averages. | ||
|
||
``'ovr'``: | ||
Calculate metrics for the multiclass case using the one-vs-rest | ||
approach. | ||
``'ovo'``: | ||
Calculate metrics for the multiclass case using the one-vs-one | ||
approach. | ||
|
||
average : string, {None, 'micro', 'macro', 'samples', 'weighted'}, | ||
default 'macro' | ||
If ``None``, the scores for each class are returned. Otherwise, | ||
this determines the type of averaging performed on the data: | ||
|
||
|
@@ -265,7 +277,9 @@ def roc_auc_score(y_true, y_score, average="macro", sample_weight=None, | |
|
||
max_fpr : float > 0 and <= 1, optional | ||
If not ``None``, the standardized partial AUC [3]_ over the range | ||
[0, max_fpr] is returned. | ||
[0, max_fpr] is returned. If multiclass task, should be either | ||
equal to ``None`` or ``1.0`` as AUC ROC partial computation currently | ||
not supported in this case. | ||
|
||
Returns | ||
------- | ||
|
@@ -326,13 +340,65 @@ def _binary_roc_auc_score(y_true, y_score, sample_weight=None): | |
return 0.5 * (1 + (partial_auc - min_area) / (max_area - min_area)) | ||
|
||
y_type = type_of_target(y_true) | ||
if y_type == "binary": | ||
y_true = check_array(y_true, ensure_2d=False, dtype=None) | ||
y_score = check_array(y_score, ensure_2d=False) | ||
|
||
if y_type == "multiclass" or (y_type == "binary" and | ||
y_score.ndim == 2 and | ||
y_score.shape[1] > 2): | ||
# validation of the input y_score | ||
if not np.allclose(1, y_score.sum(axis=1)): | ||
raise ValueError("Target scores should sum up to 1.0 for all" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. space missing between "all" and "samples" There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We only need this for OvO, not for OvR, right? |
||
"samples.") | ||
|
||
# do not support partial ROC computation for multiclass | ||
if max_fpr is not None and max_fpr != 1.: | ||
raise ValueError("Partial AUC computation not available in " | ||
"multiclass setting. Parameter 'max_fpr' must be" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please be consistent within a string about whether white space appears at the end or start of a line |
||
" set to `None`. Received `max_fpr={0}` " | ||
"instead.".format(max_fpr)) | ||
|
||
# validation for multiclass parameter specifications | ||
average_options = ("macro", "weighted") | ||
if average not in average_options: | ||
raise ValueError("Parameter 'average' must be one of {0} for" | ||
" multiclass problems.".format(average_options)) | ||
multiclass_options = ("ovo", "ovr") | ||
if multiclass not in multiclass_options: | ||
raise ValueError("Parameter multiclass='{0}' is not supported" | ||
" for multiclass ROC AUC. 'multiclass' must be" | ||
" one of {1}.".format( | ||
multiclass, multiclass_options)) | ||
if sample_weight is not None: | ||
# TODO: check if only in ovo case, if yes, do not raise when ovr | ||
raise ValueError("Parameter 'sample_weight' is not supported" | ||
" for multiclass one-vs-one ROC AUC." | ||
" 'sample_weight' must be None in this case.") | ||
|
||
if multiclass == "ovo": | ||
# Hand & Till (2001) implementation | ||
return _average_multiclass_ovo_score( | ||
_binary_roc_auc_score, y_true, y_score, average) | ||
elif multiclass == "ovr" and average == "weighted": | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it the best way to use the P&D definition? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What happens if someone sets multiclass='ovr' and average='macro' right now? |
||
# Provost & Domingos (2001) implementation | ||
return _average_multiclass_ovr_score( | ||
_binary_roc_auc_score, y_true, y_score, average) | ||
else: | ||
y_true = y_true.reshape((-1, 1)) | ||
y_true_multilabel = LabelBinarizer().fit_transform(y_true) | ||
return _average_binary_score( | ||
_binary_roc_auc_score, y_true_multilabel, y_score, average, | ||
sample_weight=sample_weight) | ||
elif y_type == "binary": | ||
labels = np.unique(y_true) | ||
y_true = label_binarize(y_true, labels)[:, 0] | ||
|
||
return _average_binary_score( | ||
_binary_roc_auc_score, y_true, y_score, average, | ||
sample_weight=sample_weight) | ||
return _average_binary_score( | ||
_binary_roc_auc_score, y_true, y_score, average, | ||
sample_weight=sample_weight) | ||
else: | ||
return _average_binary_score( | ||
_binary_roc_auc_score, y_true, y_score, average, | ||
sample_weight=sample_weight) | ||
|
||
|
||
def _binary_clf_curve(y_true, y_score, pos_label=None, sample_weight=None): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this not the same as
_average_binary_score
?