-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
FEA confusion matrix derived metrics #17265
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
haochunchang
wants to merge
19
commits into
scikit-learn:main
from
haochunchang:confusion-matrix-derived-metrics
Closed
Changes from all commits
Commits
Show all changes
19 commits
Select commit
Hold shift + click to select a range
9dbfbc8
added a function with confusion matrix derived metrics (fpr, tpr, tnr…
64a5a7b
changed the true postive sum in the function
523eaa0
add print
b977216
remove one print
5a061ef
remove print statements
6493977
add coauthors.
141fa4a
fix doc string outputs
9615ae8
pep8 test
79e1562
trivial
8f21052
remove imported but unused flake8
3ffd830
to trigger test
fb73c6e
Take over PR #15522
haochunchang c780053
Modify doc and zero-division in the weighted average.
haochunchang 408c2db
Add tests for binary, multiclass and empty prediction.
haochunchang 4adfe2e
Add tpr_fpr_tnr_fnr_scores to test_common.py.
haochunchang 53d6fd2
Remove pred_sum variable
haochunchang 88c41af
Merge branch 'master' of https://github.com/scikit-learn/scikit-learn…
haochunchang a5b5262
Fix linting
haochunchang f74fc10
Fix parameter documentation
haochunchang File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1538,6 +1538,209 @@ def precision_recall_fscore_support(y_true, y_pred, *, beta=1.0, labels=None, | |
return precision, recall, f_score, true_sum | ||
|
||
|
||
@_deprecate_positional_args | ||
def tpr_fpr_tnr_fnr_scores(y_true, y_pred, *, labels=None, pos_label=1, | ||
average=None, warn_for=('tpr', 'fpr', | ||
'tnr', 'fnr'), | ||
sample_weight=None, zero_division="warn"): | ||
"""Compute True Positive Rate (TPR), False Positive Rate (FPR),\ | ||
True Negative Rate (TNR), False Negative Rate (FNR) for each class | ||
|
||
The TPR is the ratio ``tp / (tp + fn)`` where ``tp`` is the number of | ||
true positives and ``fn`` the number of false negatives. | ||
|
||
The FPR is the ratio ``fp / (tn + fp)`` where ``tn`` is the number of | ||
true negatives and ``fp`` the number of false positives. | ||
|
||
The TNR is the ratio ``tn / (tn + fp)`` where ``tn`` is the number of | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Might be more informative |
||
true negatives and ``fp`` the number of false positives. | ||
|
||
The FNR is the ratio ``fn / (tp + fn)`` where ``tp`` is the number of | ||
true positives and ``fn`` the number of false negatives. | ||
|
||
If ``pos_label is None`` and in binary classification, this function | ||
returns the true positive rate, false positive rate, true negative rate | ||
and false negative rate if ``average`` is one of ``'micro'``, ``'macro'``, | ||
``'weighted'`` or ``'samples'``. | ||
|
||
Parameters | ||
---------- | ||
y_true : {array-like, label indicator array, sparse matrix} \ | ||
of shape (n_samples,) | ||
Ground truth (correct) target values. | ||
|
||
y_pred : {array-like, label indicator array, sparse matrix} \ | ||
of shape (n_samples,) | ||
Estimated targets as returned by a classifier. | ||
|
||
labels : list, default=None | ||
The set of labels to include when ``average != 'binary'``, and their | ||
order if ``average is None``. Labels present in the data can be | ||
excluded, for example to calculate a multiclass average ignoring a | ||
majority negative class, while labels not present in the data will | ||
result in 0 components in a macro average. For multilabel targets, | ||
labels are column indices. By default, all labels in ``y_true`` and | ||
``y_pred`` are used in sorted order. | ||
|
||
pos_label : str or int, default=1 | ||
The class to report if ``average='binary'`` and the data is binary. | ||
If the data are multiclass or multilabel, this will be ignored; | ||
setting ``labels=[pos_label]`` and ``average != 'binary'`` will report | ||
scores for that label only. | ||
|
||
average : str, {None, 'binary', 'micro', 'macro', 'samples', 'weighted'}, \ | ||
default=None | ||
If ``None``, the scores for each class are returned. Otherwise, this | ||
determines the type of averaging performed on the data: | ||
|
||
``'binary'``: | ||
Only report results for the class specified by ``pos_label``. | ||
This is applicable only if targets (``y_{true,pred}``) are binary. | ||
``'micro'``: | ||
Calculate metrics globally by counting the total true positives, | ||
false negatives and false positives. | ||
``'macro'``: | ||
Calculate metrics for each label, and find their unweighted | ||
mean. This does not take label imbalance into account. | ||
``'weighted'``: | ||
Calculate metrics for each label, and find their average weighted | ||
by support (the number of true instances for each label). This | ||
alters 'macro' to account for label imbalance. | ||
``'samples'``: | ||
Calculate metrics for each instance, and find their average (only | ||
meaningful for multilabel classification where this differs from | ||
:func:`accuracy_score`). | ||
|
||
warn_for : tuple or set, for internal use | ||
This determines which warnings will be made in the case that this | ||
function is being used to return only one of its metrics. | ||
|
||
sample_weight : array-like of shape (n_samples,), default=None | ||
Sample weights. | ||
|
||
zero_division : str or int, {'warn', 0, 1}, default="warn" | ||
Sets the value to return when there is a zero division: | ||
- tpr, fnr: when there are no positive labels | ||
- fpr, tnr: when there are no negative labels | ||
|
||
If set to "warn", this acts as 0, but warnings are also raised. | ||
|
||
Returns | ||
------- | ||
tpr : float (if average is not None), \ | ||
or ndarray of shape (n_unique_labels,) | ||
|
||
fpr : float (if average is not None), \ | ||
or ndarray of shape (n_unique_labels,) | ||
|
||
tnr : float (if average is not None), \ | ||
or ndarray of shape (n_unique_labels,) | ||
|
||
fnr : float (if average is not None), \ | ||
or ndarray of shape (n_unique_labels,) | ||
The number of occurrences of each label in ``y_true``. | ||
|
||
References | ||
---------- | ||
.. [1] `Wikipedia entry for confusion matrix | ||
<https://en.wikipedia.org/wiki/Confusion_matrix>`_ | ||
|
||
.. [2] `Discriminative Methods for Multi-labeled Classification Advances | ||
in Knowledge Discovery and Data Mining (2004), pp. 22-30 by Shantanu | ||
Godbole, Sunita Sarawagi | ||
<http://www.godbole.net/shantanu/pubs/multilabelsvm-pakdd04.pdf>`_ | ||
|
||
Examples | ||
-------- | ||
>>> import numpy as np | ||
>>> y_true = np.array(['cat', 'dog', 'pig', 'cat', 'dog', 'pig']) | ||
>>> y_pred = np.array(['cat', 'pig', 'dog', 'cat', 'cat', 'dog']) | ||
>>> tpr_fpr_tnr_fnr_scores(y_true, y_pred, average='macro') | ||
(0.3333333333333333, 0.3333333333333333, 0.6666666666666666, | ||
0.6666666666666666) | ||
>>> tpr_fpr_tnr_fnr_scores(y_true, y_pred, average='micro') | ||
(0.3333333333333333, 0.3333333333333333, 0.6666666666666666, | ||
0.6666666666666666) | ||
>>> tpr_fpr_tnr_fnr_scores(y_true, y_pred, average='weighted') | ||
(0.3333333333333333, 0.3333333333333333, 0.6666666666666666, | ||
0.6666666666666666) | ||
|
||
It is possible to compute per-label fpr, fnr, tnr, tpr and | ||
supports instead of averaging: | ||
|
||
>>> tpr_fpr_tnr_fnr_scores(y_true, y_pred, average=None, | ||
... labels=['pig', 'dog', 'cat']) | ||
(array([0., 0., 1.]), array([0.25, 0.5 , 0.25]), | ||
array([0.75, 0.5 , 0.75]), array([1., 1., 0.])) | ||
|
||
Notes | ||
----- | ||
When ``true positive + false negative == 0``, TPR, FNR are undefined; | ||
When ``true negative + false positive == 0``, FPR, TNR are undefined. | ||
In such cases, by default the metric will be set to 0, as will f-score, | ||
and ``UndefinedMetricWarning`` will be raised. This behavior can be | ||
modified with ``zero_division``. | ||
""" | ||
_check_zero_division(zero_division) | ||
|
||
labels = _check_set_wise_labels(y_true, y_pred, average, labels, pos_label) | ||
|
||
# Calculate tp_sum, fp_sum, tn_sum, fn_sum, pos_sum, neg_sum | ||
samplewise = average == 'samples' | ||
MCM = multilabel_confusion_matrix(y_true, y_pred, | ||
sample_weight=sample_weight, | ||
labels=labels, samplewise=samplewise) | ||
tn_sum = MCM[:, 0, 0] | ||
fp_sum = MCM[:, 0, 1] | ||
fn_sum = MCM[:, 1, 0] | ||
tp_sum = MCM[:, 1, 1] | ||
neg_sum = tn_sum + fp_sum | ||
pos_sum = fn_sum + tp_sum | ||
|
||
if average == 'micro': | ||
tp_sum = np.array([tp_sum.sum()]) | ||
fp_sum = np.array([fp_sum.sum()]) | ||
tn_sum = np.array([tn_sum.sum()]) | ||
fn_sum = np.array([fn_sum.sum()]) | ||
neg_sum = np.array([neg_sum.sum()]) | ||
pos_sum = np.array([pos_sum.sum()]) | ||
|
||
# Divide, and on zero-division, set scores and/or warn according to | ||
# zero_division: | ||
tpr = _prf_divide(tp_sum, pos_sum, 'tpr', 'positives', | ||
average, warn_for, zero_division) | ||
fpr = _prf_divide(fp_sum, neg_sum, 'fpr', 'negatives', | ||
average, warn_for, zero_division) | ||
tnr = _prf_divide(tn_sum, neg_sum, 'tnr', 'negatives', | ||
average, warn_for, zero_division) | ||
fnr = _prf_divide(fn_sum, pos_sum, 'fnr', 'positives', | ||
average, warn_for, zero_division) | ||
# Average the results | ||
if average == 'weighted': | ||
weights = pos_sum | ||
if weights.sum() == 0: | ||
zero_division_value = 0.0 if zero_division in ["warn", 0] else 1.0 | ||
# TPR and FNR is zero_division if there are no positive labels | ||
# FPR and TNR is zero_division if there are no negative labels | ||
return (zero_division_value if pos_sum.sum() == 0 else 0, | ||
zero_division_value if neg_sum.sum() == 0 else 0, | ||
zero_division_value if neg_sum.sum() == 0 else 0, | ||
zero_division_value if pos_sum.sum() == 0 else 0) | ||
|
||
elif average == 'samples': | ||
weights = sample_weight | ||
else: | ||
weights = None | ||
|
||
if average is not None: | ||
assert average != 'binary' or len(fpr) == 1 | ||
fpr = np.average(fpr, weights=weights) | ||
tnr = np.average(tnr, weights=weights) | ||
fnr = np.average(fnr, weights=weights) | ||
tpr = np.average(tpr, weights=weights) | ||
return tpr, fpr, tnr, fnr | ||
|
||
|
||
@_deprecate_positional_args | ||
def precision_score(y_true, y_pred, *, labels=None, pos_label=1, | ||
average='binary', sample_weight=None, | ||
|
@@ -2174,7 +2377,8 @@ def log_loss(y_true, y_pred, *, eps=1e-15, normalize=True, sample_weight=None, | |
y_true : array-like or label indicator matrix | ||
Ground truth (correct) labels for n_samples samples. | ||
|
||
y_pred : array-like of float, shape = (n_samples, n_classes) or (n_samples,) | ||
y_pred : array-like of float, shape = (n_samples, n_classes) \ | ||
or (n_samples,) | ||
Predicted probabilities, as returned by a classifier's | ||
predict_proba method. If ``y_pred.shape = (n_samples,)`` | ||
the probabilities provided are assumed to be that of the | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might be more informative