Skip to content

DOC Rework ROC example with cross-validation #29611

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
93 changes: 45 additions & 48 deletions examples/model_selection/plot_roc_crossval.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,11 @@
better. The "steepness" of ROC curves is also important, since it is ideal to
maximize the TPR while minimizing the FPR.

This example shows the ROC response of different datasets, created from K-fold
cross-validation. Taking all of these curves, it is possible to calculate the
mean AUC, and see the variance of the curve when the
training set is split into different subsets. This roughly shows how the
classifier output is affected by changes in the training data, and how different
the splits generated by K-fold cross-validation are from one another.
This example demonstrates how the classifier's ROC response is influenced by
variations in the training data as obtained through ShuffleSplit cross-validation.
By analyzing all these curves, we can calculate the mean AUC and visualize the
variability of the estimated curves across CV splits via a quantile-based
region.

.. note::

Expand All @@ -34,48 +33,45 @@
# Load and prepare data
# =====================
#
# We import the :ref:`iris_dataset` which contains 3 classes, each one
# corresponding to a type of iris plant. One class is linearly separable from
# the other 2; the latter are **not** linearly separable from each other.
#
# In the following we binarize the dataset by dropping the "virginica" class
# (`class_id=2`). This means that the "versicolor" class (`class_id=1`) is
# regarded as the positive class and "setosa" as the negative class
# (`class_id=0`).

import numpy as np

from sklearn.datasets import load_iris

iris = load_iris()
target_names = iris.target_names
X, y = iris.data, iris.target
X, y = X[y != 2], y[y != 2]
n_samples, n_features = X.shape

# %%
# We also add noisy features to make the problem harder.
random_state = np.random.RandomState(0)
X = np.concatenate([X, random_state.randn(n_samples, 200 * n_features)], axis=1)
# We use :class:`~sklearn.datasets.make_classification` to generate a synthetic
# dataset with 1,000 samples. The generated dataset has two classes by default.
# In this case, we set a class separation factor of 0.5, making the classes
# partially overlapping and not perfectly linearly separable.

from sklearn.datasets import make_classification

X, y = make_classification(
n_samples=1_000,
n_features=2,
n_redundant=0,
n_informative=2,
class_sep=0.5,
random_state=0,
n_clusters_per_class=1,
)

# %%
# Classification and ROC analysis
# -------------------------------
#
# Here we run a :class:`~sklearn.svm.SVC` classifier with cross-validation and
# plot the ROC curves fold-wise. Notice that the baseline to define the chance
# level (dashed ROC curve) is a classifier that would always predict the most
# frequent class.
# Here we run a :class:`~sklearn.ensemble.HistGradientBoostingClassifier`
# classifier with cross-validation and plot the ROC curves fold-wise. Notice
# that the baseline to define the chance level (dashed ROC curve) is a
# classifier that would always predict the most frequent class.
#
# In the following plot, quantile coverage is represented in grey, though the
# AUC value is reported in terms of the mean and standar deviation.

import matplotlib.pyplot as plt
import numpy as np

from sklearn import svm
from sklearn.ensemble import HistGradientBoostingClassifier
from sklearn.metrics import RocCurveDisplay, auc
from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import StratifiedShuffleSplit

n_splits = 6
cv = StratifiedKFold(n_splits=n_splits)
classifier = svm.SVC(kernel="linear", probability=True, random_state=random_state)
n_splits = 30
cv = StratifiedShuffleSplit(n_splits=n_splits, random_state=0)
classifier = HistGradientBoostingClassifier(random_state=42)

tprs = []
aucs = []
Expand All @@ -88,11 +84,12 @@
classifier,
X[test],
y[test],
name=f"ROC fold {fold}",
label=None,
alpha=0.3,
lw=1,
ax=ax,
plot_chance_level=(fold == n_splits - 1),
chance_level_kw={"label": None},
)
interp_tpr = np.interp(mean_fpr, viz.fpr, viz.tpr)
interp_tpr[0] = 0.0
Expand All @@ -107,27 +104,27 @@
mean_fpr,
mean_tpr,
color="b",
label=r"Mean ROC (AUC = %0.2f $\pm$ %0.2f)" % (mean_auc, std_auc),
label=rf"Mean ROC (AUC = {mean_auc:.2f} $\pm$ {std_auc:.2f})",
lw=2,
alpha=0.8,
)

std_tpr = np.std(tprs, axis=0)
tprs_upper = np.minimum(mean_tpr + std_tpr, 1)
tprs_lower = np.maximum(mean_tpr - std_tpr, 0)

upper_quantile = np.quantile(tprs, 0.95, axis=0)
lower_quantile = np.quantile(tprs, 0.05, axis=0)
ax.fill_between(
mean_fpr,
tprs_lower,
tprs_upper,
lower_quantile,
upper_quantile,
color="grey",
alpha=0.2,
label=r"$\pm$ 1 std. dev.",
alpha=0.4,
label="5% to 95% percentile region",
)

ax.set(
xlabel="False Positive Rate",
ylabel="True Positive Rate",
title=f"Mean ROC curve with variability\n(Positive label '{target_names[1]}')",
title="Mean ROC curve with variability",
)
ax.legend(loc="lower right")
plt.show()
Loading