Skip to content

Commit 2335a8e

Browse files
DOC Improve narrative of plot_roc_crossval example (#24710)
Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
1 parent 303d712 commit 2335a8e

File tree

1 file changed

+47
-34
lines changed

1 file changed

+47
-34
lines changed

examples/model_selection/plot_roc_crossval.py

+47-34
Original file line numberDiff line numberDiff line change
@@ -3,77 +3,88 @@
33
Receiver Operating Characteristic (ROC) with cross validation
44
=============================================================
55
6-
Example of Receiver Operating Characteristic (ROC) metric to evaluate
7-
classifier output quality using cross-validation.
6+
This example presents how to estimate and visualize the variance of the Receiver
7+
Operating Characteristic (ROC) metric using cross-validation.
88
9-
ROC curves typically feature true positive rate on the Y axis, and false
10-
positive rate on the X axis. This means that the top left corner of the plot is
11-
the "ideal" point - a false positive rate of zero, and a true positive rate of
12-
one. This is not very realistic, but it does mean that a larger area under the
13-
curve (AUC) is usually better.
14-
15-
The "steepness" of ROC curves is also important, since it is ideal to maximize
16-
the true positive rate while minimizing the false positive rate.
9+
ROC curves typically feature true positive rate (TPR) on the Y axis, and false
10+
positive rate (FPR) on the X axis. This means that the top left corner of the
11+
plot is the "ideal" point - a FPR of zero, and a TPR of one. This is not very
12+
realistic, but it does mean that a larger Area Under the Curve (AUC) is usually
13+
better. The "steepness" of ROC curves is also important, since it is ideal to
14+
maximize the TPR while minimizing the FPR.
1715
1816
This example shows the ROC response of different datasets, created from K-fold
1917
cross-validation. Taking all of these curves, it is possible to calculate the
20-
mean area under curve, and see the variance of the curve when the
18+
mean AUC, and see the variance of the curve when the
2119
training set is split into different subsets. This roughly shows how the
22-
classifier output is affected by changes in the training data, and how
23-
different the splits generated by K-fold cross-validation are from one another.
20+
classifier output is affected by changes in the training data, and how different
21+
the splits generated by K-fold cross-validation are from one another.
2422
2523
.. note::
2624
27-
See also :func:`sklearn.metrics.roc_auc_score`,
28-
:func:`sklearn.model_selection.cross_val_score`,
29-
:ref:`sphx_glr_auto_examples_model_selection_plot_roc.py`,
30-
25+
See :ref:`sphx_glr_auto_examples_model_selection_plot_roc.py` for a
26+
complement of the present example explaining the averaging strategies to
27+
generalize the metrics for multiclass classifiers.
3128
"""
3229

3330
# %%
34-
# Data IO and generation
35-
# ----------------------
36-
import numpy as np
31+
# Load and prepare data
32+
# =====================
33+
#
34+
# We import the :ref:`iris_dataset` which contains 3 classes, each one
35+
# corresponding to a type of iris plant. One class is linearly separable from
36+
# the other 2; the latter are **not** linearly separable from each other.
37+
#
38+
# In the following we binarize the dataset by dropping the "virginica" class
39+
# (`class_id=2`). This means that the "versicolor" class (`class_id=1`) is
40+
# regarded as the positive class and "setosa" as the negative class
41+
# (`class_id=0`).
3742

38-
from sklearn import datasets
43+
import numpy as np
44+
from sklearn.datasets import load_iris
3945

40-
# Import some data to play with
41-
iris = datasets.load_iris()
42-
X = iris.data
43-
y = iris.target
46+
iris = load_iris()
47+
target_names = iris.target_names
48+
X, y = iris.data, iris.target
4449
X, y = X[y != 2], y[y != 2]
4550
n_samples, n_features = X.shape
4651

47-
# Add noisy features
52+
# %%
53+
# We also add noisy features to make the problem harder.
4854
random_state = np.random.RandomState(0)
49-
X = np.c_[X, random_state.randn(n_samples, 200 * n_features)]
55+
X = np.concatenate([X, random_state.randn(n_samples, 200 * n_features)], axis=1)
5056

5157
# %%
5258
# Classification and ROC analysis
5359
# -------------------------------
60+
#
61+
# Here we run a :class:`~sklearn.svm.SVC` classifier with cross-validation and
62+
# plot the ROC curves fold-wise. Notice that the baseline to define the chance
63+
# level (dashed ROC curve) is a classifier that would always predict the most
64+
# frequent class.
65+
5466
import matplotlib.pyplot as plt
5567

5668
from sklearn import svm
5769
from sklearn.metrics import auc
5870
from sklearn.metrics import RocCurveDisplay
5971
from sklearn.model_selection import StratifiedKFold
6072

61-
# Run classifier with cross-validation and plot ROC curves
6273
cv = StratifiedKFold(n_splits=6)
6374
classifier = svm.SVC(kernel="linear", probability=True, random_state=random_state)
6475

6576
tprs = []
6677
aucs = []
6778
mean_fpr = np.linspace(0, 1, 100)
6879

69-
fig, ax = plt.subplots()
70-
for i, (train, test) in enumerate(cv.split(X, y)):
80+
fig, ax = plt.subplots(figsize=(6, 6))
81+
for fold, (train, test) in enumerate(cv.split(X, y)):
7182
classifier.fit(X[train], y[train])
7283
viz = RocCurveDisplay.from_estimator(
7384
classifier,
7485
X[test],
7586
y[test],
76-
name="ROC fold {}".format(i),
87+
name=f"ROC fold {fold}",
7788
alpha=0.3,
7889
lw=1,
7990
ax=ax,
@@ -82,8 +93,7 @@
8293
interp_tpr[0] = 0.0
8394
tprs.append(interp_tpr)
8495
aucs.append(viz.roc_auc)
85-
86-
ax.plot([0, 1], [0, 1], linestyle="--", lw=2, color="r", label="Chance", alpha=0.8)
96+
ax.plot([0, 1], [0, 1], "k--", label="chance level (AUC = 0.5)")
8797

8898
mean_tpr = np.mean(tprs, axis=0)
8999
mean_tpr[-1] = 1.0
@@ -113,7 +123,10 @@
113123
ax.set(
114124
xlim=[-0.05, 1.05],
115125
ylim=[-0.05, 1.05],
116-
title="Receiver operating characteristic example",
126+
xlabel="False Positive Rate",
127+
ylabel="True Positive Rate",
128+
title=f"Mean ROC curve with variability\n(Positive label '{target_names[1]}')",
117129
)
130+
ax.axis("square")
118131
ax.legend(loc="lower right")
119132
plt.show()

0 commit comments

Comments
 (0)