-
-
Notifications
You must be signed in to change notification settings - Fork 26.2k
[MRG] Add Detection Error Tradeoff (DET) curve classification metrics #10591
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
1743375
4f65f96
ec2973d
90d681d
b594b90
c588aa1
dc41e08
6a2fc60
d4d3c4c
2935126
3b198b2
d2c4a7e
82b488e
4dbaaa3
3e4f7c2
aede955
3207a7f
6cdc535
68ebbd9
394bd47
d10be26
b0c267e
f438128
e733ff5
6ebff5b
d0a2f5c
3ff5792
a3776d8
d29d474
0d31c77
bdc2608
51a08fe
a662f44
8d492e8
596819e
a54aca2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
""" | ||
======================================= | ||
Detection error tradeoff (DET) curve | ||
======================================= | ||
|
||
In this example, we compare receiver operating characteristic (ROC) and | ||
detection error tradeoff (DET) curves for different classification algorithms | ||
for the same classification task. | ||
|
||
DET curves are commonly plotted in normal deviate scale. | ||
To achieve this we transform the errors rates as returned by the | ||
``detection_error_tradeoff_curve`` function and the axis scale using | ||
``scipy.stats.norm``. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please add a few words on what this example demonstrates? what is the take home message? Are you saying that DET makes it easier to highlight that RF is better than ROC? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This example merely tries to demonstrate to the unfamiliar user how an example of a DET curve can look like. I tried to not repeat myself too much, and the properties, pros and cons are already discussed in the usage guide (which uses the sample plot). |
||
|
||
The point of this example is to demonstrate two properties of DET curves, | ||
namely: | ||
|
||
1. It might be easier to visually assess the overall performance of different | ||
classification algorithms using DET curves over ROC curves. | ||
Due to the linear scale used for plotting ROC curves, different classifiers | ||
usually only differ in the top left corner of the graph and appear similar | ||
for a large part of the plot. On the other hand, because DET curves | ||
represent straight lines in normal deviate scale. As such, they tend to be | ||
distinguishable as a whole and the area of interest spans a large part of | ||
the plot. | ||
2. DET curves give the user direct feedback of the detection error tradeoff to | ||
aid in operating point analysis. | ||
The user can deduct directly from the DET-curve plot at which rate | ||
false-negative error rate will improve when willing to accept an increase in | ||
false-positive error rate (or vice-versa). | ||
|
||
The plots in this example compare ROC curves on the left side to corresponding | ||
DET curves on the right. | ||
There is no particular reason why these classifiers have been chosen for the | ||
example plot over other classifiers available in scikit-learn. | ||
|
||
.. note:: | ||
|
||
- See :func:`sklearn.metrics.roc_curve` for further information about ROC | ||
curves. | ||
|
||
- See :func:`sklearn.metrics.detection_error_tradeoff_curve` for further | ||
information about DET curves. | ||
|
||
- This example is loosely based on | ||
:ref:`sphx_glr_auto_examples_classification_plot_classifier_comparison.py` | ||
. | ||
|
||
""" | ||
import matplotlib.pyplot as plt | ||
|
||
from sklearn.model_selection import train_test_split | ||
from sklearn.preprocessing import StandardScaler | ||
from sklearn.datasets import make_classification | ||
from sklearn.svm import SVC | ||
from sklearn.ensemble import RandomForestClassifier | ||
from sklearn.metrics import detection_error_tradeoff_curve | ||
from sklearn.metrics import roc_curve | ||
|
||
from scipy.stats import norm | ||
from matplotlib.ticker import FuncFormatter | ||
|
||
N_SAMPLES = 1000 | ||
|
||
names = [ | ||
"Linear SVM", | ||
"Random Forest", | ||
] | ||
|
||
classifiers = [ | ||
SVC(kernel="linear", C=0.025), | ||
RandomForestClassifier(max_depth=5, n_estimators=10, max_features=1), | ||
] | ||
|
||
X, y = make_classification( | ||
n_samples=N_SAMPLES, n_features=2, n_redundant=0, n_informative=2, | ||
random_state=1, n_clusters_per_class=1) | ||
|
||
# preprocess dataset, split into training and test part | ||
X = StandardScaler().fit_transform(X) | ||
|
||
X_train, X_test, y_train, y_test = train_test_split( | ||
X, y, test_size=.4, random_state=0) | ||
|
||
# prepare plots | ||
fig, [ax_roc, ax_det] = plt.subplots(1, 2, figsize=(10, 5)) | ||
|
||
# first prepare the ROC curve | ||
ax_roc.set_title('Receiver Operating Characteristic (ROC) curves') | ||
ax_roc.set_xlabel('False Positive Rate') | ||
ax_roc.set_ylabel('True Positive Rate') | ||
ax_roc.set_xlim(0, 1) | ||
ax_roc.set_ylim(0, 1) | ||
ax_roc.grid(linestyle='--') | ||
ax_roc.yaxis.set_major_formatter( | ||
FuncFormatter(lambda y, _: '{:.0%}'.format(y))) | ||
ax_roc.xaxis.set_major_formatter( | ||
FuncFormatter(lambda y, _: '{:.0%}'.format(y))) | ||
|
||
# second prepare the DET curve | ||
ax_det.set_title('Detection Error Tradeoff (DET) curves') | ||
ax_det.set_xlabel('False Positive Rate') | ||
ax_det.set_ylabel('False Negative Rate') | ||
ax_det.set_xlim(-3, 3) | ||
ax_det.set_ylim(-3, 3) | ||
ax_det.grid(linestyle='--') | ||
|
||
# customized ticks for DET curve plot to represent normal deviate scale | ||
ticks = [0.001, 0.01, 0.05, 0.20, 0.5, 0.80, 0.95, 0.99, 0.999] | ||
tick_locs = norm.ppf(ticks) | ||
tick_lbls = [ | ||
'{:.0%}'.format(s) if (100*s).is_integer() else '{:.1%}'.format(s) | ||
for s in ticks | ||
] | ||
plt.sca(ax_det) | ||
plt.xticks(tick_locs, tick_lbls) | ||
plt.yticks(tick_locs, tick_lbls) | ||
|
||
# iterate over classifiers | ||
for name, clf in zip(names, classifiers): | ||
clf.fit(X_train, y_train) | ||
|
||
if hasattr(clf, "decision_function"): | ||
y_score = clf.decision_function(X_test) | ||
else: | ||
y_score = clf.predict_proba(X_test)[:, 1] | ||
|
||
roc_fpr, roc_tpr, _ = roc_curve(y_test, y_score) | ||
det_fpr, det_fnr, _ = detection_error_tradeoff_curve(y_test, y_score) | ||
|
||
ax_roc.plot(roc_fpr, roc_tpr) | ||
|
||
# transform errors into normal deviate scale | ||
ax_det.plot( | ||
norm.ppf(det_fpr), | ||
norm.ppf(det_fnr) | ||
) | ||
|
||
# add a single legend | ||
plt.sca(ax_det) | ||
plt.legend(names, loc="upper right") | ||
|
||
# plot | ||
plt.tight_layout() | ||
plt.show() |
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -218,6 +218,94 @@ def _binary_uninterpolated_average_precision( | |||||||
average, sample_weight=sample_weight) | ||||||||
|
||||||||
|
||||||||
def detection_error_tradeoff_curve(y_true, y_score, pos_label=None, | ||||||||
sample_weight=None): | ||||||||
"""Compute error rates for different probability thresholds. | ||||||||
|
||||||||
Note: This metrics is used for ranking evaluation of a binary | ||||||||
classification task. | ||||||||
|
||||||||
Read more in the :ref:`User Guide <det_curve>`. | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
|
||||||||
Parameters | ||||||||
---------- | ||||||||
y_true : array, shape = [n_samples] | ||||||||
True targets of binary classification in range {-1, 1} or {0, 1}. | ||||||||
|
||||||||
y_score : array, shape = [n_samples] | ||||||||
Estimated probabilities or decision function. | ||||||||
|
||||||||
pos_label : int, optional (default=None) | ||||||||
The label of the positive class | ||||||||
|
||||||||
sample_weight : array-like of shape = [n_samples], optional | ||||||||
Sample weights. | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
|
||||||||
Returns | ||||||||
------- | ||||||||
fpr : array, shape = [n_thresholds] | ||||||||
False positive rate (FPR) such that element i is the false positive | ||||||||
rate of predictions with score >= thresholds[i]. This is occasionally | ||||||||
referred to as false acceptance propability or fall-out. | ||||||||
|
||||||||
fnr : array, shape = [n_thresholds] | ||||||||
False negative rate (FNR) such that element i is the false negative | ||||||||
rate of predictions with score >= thresholds[i]. This is occasionally | ||||||||
referred to as false rejection or miss rate. | ||||||||
|
||||||||
thresholds : array, shape = [n_thresholds] | ||||||||
Decreasing score values. | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
|
||||||||
See also | ||||||||
-------- | ||||||||
roc_curve : Compute Receiver operating characteristic (ROC) curve | ||||||||
precision_recall_curve : Compute precision-recall curve | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
|
||||||||
Examples | ||||||||
-------- | ||||||||
>>> import numpy as np | ||||||||
>>> from sklearn.metrics import detection_error_tradeoff_curve | ||||||||
>>> y_true = np.array([0, 0, 1, 1]) | ||||||||
>>> y_scores = np.array([0.1, 0.4, 0.35, 0.8]) | ||||||||
>>> fpr, fnr, thresholds = detection_error_tradeoff_curve(y_true, y_scores) | ||||||||
>>> fpr | ||||||||
array([0.5, 0.5, 0. ]) | ||||||||
>>> fnr | ||||||||
array([0. , 0.5, 0.5]) | ||||||||
>>> thresholds | ||||||||
array([0.35, 0.4 , 0.8 ]) | ||||||||
|
||||||||
""" | ||||||||
if len(np.unique(y_true)) != 2: | ||||||||
raise ValueError("Only one class present in y_true. Detection error " | ||||||||
"tradeoff curve is not defined in that case.") | ||||||||
|
||||||||
fps, tps, thresholds = _binary_clf_curve(y_true, y_score, | ||||||||
pos_label=pos_label, | ||||||||
sample_weight=sample_weight) | ||||||||
|
||||||||
fns = tps[-1] - tps | ||||||||
p_count = tps[-1] | ||||||||
n_count = fps[-1] | ||||||||
|
||||||||
# start with false positives zero | ||||||||
first_ind = ( | ||||||||
fps.searchsorted(fps[0], side='right') - 1 | ||||||||
if fps.searchsorted(fps[0], side='right') > 0 | ||||||||
else None | ||||||||
) | ||||||||
# stop with false negatives zero | ||||||||
last_ind = tps.searchsorted(tps[-1]) + 1 | ||||||||
sl = slice(first_ind, last_ind) | ||||||||
|
||||||||
# reverse the output such that list of false positives is decreasing | ||||||||
return ( | ||||||||
fps[sl][::-1] / n_count, | ||||||||
fns[sl][::-1] / p_count, | ||||||||
thresholds[sl][::-1] | ||||||||
) | ||||||||
|
||||||||
|
||||||||
def _binary_roc_auc_score(y_true, y_score, sample_weight=None, max_fpr=None): | ||||||||
"""Binary roc auc score""" | ||||||||
if len(np.unique(y_true)) != 2: | ||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Martin 1997 is not cited in the text: to avoid the sphinx warning you can either cite it where appropriate either remove the square bracket content.