|
3 | 3 | Receiver Operating Characteristic (ROC) with cross validation
|
4 | 4 | =============================================================
|
5 | 5 |
|
6 |
| -Example of Receiver Operating Characteristic (ROC) metric to evaluate |
7 |
| -classifier output quality using cross-validation. |
| 6 | +This example presents how to estimate and visualize the variance of the Receiver |
| 7 | +Operating Characteristic (ROC) metric using cross-validation. |
8 | 8 |
|
9 |
| -ROC curves typically feature true positive rate on the Y axis, and false |
10 |
| -positive rate on the X axis. This means that the top left corner of the plot is |
11 |
| -the "ideal" point - a false positive rate of zero, and a true positive rate of |
12 |
| -one. This is not very realistic, but it does mean that a larger area under the |
13 |
| -curve (AUC) is usually better. |
14 |
| -
|
15 |
| -The "steepness" of ROC curves is also important, since it is ideal to maximize |
16 |
| -the true positive rate while minimizing the false positive rate. |
| 9 | +ROC curves typically feature true positive rate (TPR) on the Y axis, and false |
| 10 | +positive rate (FPR) on the X axis. This means that the top left corner of the |
| 11 | +plot is the "ideal" point - a FPR of zero, and a TPR of one. This is not very |
| 12 | +realistic, but it does mean that a larger Area Under the Curve (AUC) is usually |
| 13 | +better. The "steepness" of ROC curves is also important, since it is ideal to |
| 14 | +maximize the TPR while minimizing the FPR. |
17 | 15 |
|
18 | 16 | This example shows the ROC response of different datasets, created from K-fold
|
19 | 17 | cross-validation. Taking all of these curves, it is possible to calculate the
|
20 |
| -mean area under curve, and see the variance of the curve when the |
| 18 | +mean AUC, and see the variance of the curve when the |
21 | 19 | training set is split into different subsets. This roughly shows how the
|
22 |
| -classifier output is affected by changes in the training data, and how |
23 |
| -different the splits generated by K-fold cross-validation are from one another. |
| 20 | +classifier output is affected by changes in the training data, and how different |
| 21 | +the splits generated by K-fold cross-validation are from one another. |
24 | 22 |
|
25 | 23 | .. note::
|
26 | 24 |
|
27 |
| - See also :func:`sklearn.metrics.roc_auc_score`, |
28 |
| - :func:`sklearn.model_selection.cross_val_score`, |
29 |
| - :ref:`sphx_glr_auto_examples_model_selection_plot_roc.py`, |
30 |
| -
|
| 25 | + See :ref:`sphx_glr_auto_examples_model_selection_plot_roc.py` for a |
| 26 | + complement of the present example explaining the averaging strategies to |
| 27 | + generalize the metrics for multiclass classifiers. |
31 | 28 | """
|
32 | 29 |
|
33 | 30 | # %%
|
34 |
| -# Data IO and generation |
35 |
| -# ---------------------- |
36 |
| -import numpy as np |
| 31 | +# Load and prepare data |
| 32 | +# ===================== |
| 33 | +# |
| 34 | +# We import the :ref:`iris_dataset` which contains 3 classes, each one |
| 35 | +# corresponding to a type of iris plant. One class is linearly separable from |
| 36 | +# the other 2; the latter are **not** linearly separable from each other. |
| 37 | +# |
| 38 | +# In the following we binarize the dataset by dropping the "virginica" class |
| 39 | +# (`class_id=2`). This means that the "versicolor" class (`class_id=1`) is |
| 40 | +# regarded as the positive class and "setosa" as the negative class |
| 41 | +# (`class_id=0`). |
37 | 42 |
|
38 |
| -from sklearn import datasets |
| 43 | +import numpy as np |
| 44 | +from sklearn.datasets import load_iris |
39 | 45 |
|
40 |
| -# Import some data to play with |
41 |
| -iris = datasets.load_iris() |
42 |
| -X = iris.data |
43 |
| -y = iris.target |
| 46 | +iris = load_iris() |
| 47 | +target_names = iris.target_names |
| 48 | +X, y = iris.data, iris.target |
44 | 49 | X, y = X[y != 2], y[y != 2]
|
45 | 50 | n_samples, n_features = X.shape
|
46 | 51 |
|
47 |
| -# Add noisy features |
| 52 | +# %% |
| 53 | +# We also add noisy features to make the problem harder. |
48 | 54 | random_state = np.random.RandomState(0)
|
49 |
| -X = np.c_[X, random_state.randn(n_samples, 200 * n_features)] |
| 55 | +X = np.concatenate([X, random_state.randn(n_samples, 200 * n_features)], axis=1) |
50 | 56 |
|
51 | 57 | # %%
|
52 | 58 | # Classification and ROC analysis
|
53 | 59 | # -------------------------------
|
| 60 | +# |
| 61 | +# Here we run a :class:`~sklearn.svm.SVC` classifier with cross-validation and |
| 62 | +# plot the ROC curves fold-wise. Notice that the baseline to define the chance |
| 63 | +# level (dashed ROC curve) is a classifier that would always predict the most |
| 64 | +# frequent class. |
| 65 | + |
54 | 66 | import matplotlib.pyplot as plt
|
55 | 67 |
|
56 | 68 | from sklearn import svm
|
57 | 69 | from sklearn.metrics import auc
|
58 | 70 | from sklearn.metrics import RocCurveDisplay
|
59 | 71 | from sklearn.model_selection import StratifiedKFold
|
60 | 72 |
|
61 |
| -# Run classifier with cross-validation and plot ROC curves |
62 | 73 | cv = StratifiedKFold(n_splits=6)
|
63 | 74 | classifier = svm.SVC(kernel="linear", probability=True, random_state=random_state)
|
64 | 75 |
|
65 | 76 | tprs = []
|
66 | 77 | aucs = []
|
67 | 78 | mean_fpr = np.linspace(0, 1, 100)
|
68 | 79 |
|
69 |
| -fig, ax = plt.subplots() |
70 |
| -for i, (train, test) in enumerate(cv.split(X, y)): |
| 80 | +fig, ax = plt.subplots(figsize=(6, 6)) |
| 81 | +for fold, (train, test) in enumerate(cv.split(X, y)): |
71 | 82 | classifier.fit(X[train], y[train])
|
72 | 83 | viz = RocCurveDisplay.from_estimator(
|
73 | 84 | classifier,
|
74 | 85 | X[test],
|
75 | 86 | y[test],
|
76 |
| - name="ROC fold {}".format(i), |
| 87 | + name=f"ROC fold {fold}", |
77 | 88 | alpha=0.3,
|
78 | 89 | lw=1,
|
79 | 90 | ax=ax,
|
|
82 | 93 | interp_tpr[0] = 0.0
|
83 | 94 | tprs.append(interp_tpr)
|
84 | 95 | aucs.append(viz.roc_auc)
|
85 |
| - |
86 |
| -ax.plot([0, 1], [0, 1], linestyle="--", lw=2, color="r", label="Chance", alpha=0.8) |
| 96 | +ax.plot([0, 1], [0, 1], "k--", label="chance level (AUC = 0.5)") |
87 | 97 |
|
88 | 98 | mean_tpr = np.mean(tprs, axis=0)
|
89 | 99 | mean_tpr[-1] = 1.0
|
|
113 | 123 | ax.set(
|
114 | 124 | xlim=[-0.05, 1.05],
|
115 | 125 | ylim=[-0.05, 1.05],
|
116 |
| - title="Receiver operating characteristic example", |
| 126 | + xlabel="False Positive Rate", |
| 127 | + ylabel="True Positive Rate", |
| 128 | + title=f"Mean ROC curve with variability\n(Positive label '{target_names[1]}')", |
117 | 129 | )
|
| 130 | +ax.axis("square") |
118 | 131 | ax.legend(loc="lower right")
|
119 | 132 | plt.show()
|
0 commit comments