Skip to content

DOC Rework ROC example with cross-validation #29611

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 51 additions & 48 deletions examples/model_selection/plot_roc_crossval.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,11 @@
better. The "steepness" of ROC curves is also important, since it is ideal to
maximize the TPR while minimizing the FPR.

This example shows the ROC response of different datasets, created from K-fold
cross-validation. Taking all of these curves, it is possible to calculate the
mean AUC, and see the variance of the curve when the
training set is split into different subsets. This roughly shows how the
classifier output is affected by changes in the training data, and how different
the splits generated by K-fold cross-validation are from one another.
This example demonstrates how the classifier's ROC response is influenced by
variations in the training data as obtained through ShuffleSplit cross-validation.
By analyzing all these curves, we can calculate the mean AUC and visualize the
variability of the estimated curves across CV splits via a quantile-based
region.

.. note::

Expand All @@ -34,57 +33,61 @@
# Load and prepare data
# =====================
#
# We import the :ref:`iris_dataset` which contains 3 classes, each one
# corresponding to a type of iris plant. One class is linearly separable from
# the other 2; the latter are **not** linearly separable from each other.
#
# In the following we binarize the dataset by dropping the "virginica" class
# (`class_id=2`). This means that the "versicolor" class (`class_id=1`) is
# regarded as the positive class and "setosa" as the negative class
# (`class_id=0`).

import numpy as np

from sklearn.datasets import load_iris

iris = load_iris()
target_names = iris.target_names
X, y = iris.data, iris.target
X, y = X[y != 2], y[y != 2]
n_samples, n_features = X.shape

# %%
# We also add noisy features to make the problem harder.
random_state = np.random.RandomState(0)
X = np.concatenate([X, random_state.randn(n_samples, 200 * n_features)], axis=1)
# We use :class:`~sklearn.datasets.make_classification` to generate a synthetic
# dataset with 1,000 samples. The generated dataset has two classes by default.
# In this case, we set a class separation factor of 0.5, making the classes
# partially overlapping and not perfectly linearly separable.

from sklearn.datasets import make_classification

X, y = make_classification(
n_samples=1_000,
n_features=2,
n_redundant=0,
n_informative=2,
class_sep=0.5,
random_state=0,
n_clusters_per_class=1,
)

# %%
# Classification and ROC analysis
# -------------------------------
#
# Here we run :func:`~sklearn.model_selection.cross_validate` on a
# :class:`~sklearn.svm.SVC` classifier, then use the computed cross-validation results
# to plot the ROC curves fold-wise. Notice that the baseline to define the chance
# level (dashed ROC curve) is a classifier that would always predict the most
# frequent class.
# :class:`~sklearn.ensemble.HistGradientBoostingClassifier`, then use the
# computed cross-validation results to plot the ROC curves fold-wise. Notice
# that the baseline to define the chance level (dashed ROC curve) is a
# classifier that would always predict the most frequent class.
#
# In the following plot, quantile coverage is represented in grey, though the
# AUC value is reported in terms of the mean and standar deviation.

import matplotlib.pyplot as plt
import numpy as np

from sklearn import svm
from sklearn.ensemble import HistGradientBoostingClassifier
from sklearn.metrics import RocCurveDisplay, auc
from sklearn.model_selection import StratifiedKFold, cross_validate
from sklearn.model_selection import StratifiedShuffleSplit, cross_validate

n_splits = 30
cv = StratifiedShuffleSplit(n_splits=n_splits, random_state=0)
classifier = HistGradientBoostingClassifier(random_state=42)

n_splits = 6
cv = StratifiedKFold(n_splits=n_splits)
classifier = svm.SVC(kernel="linear", probability=True, random_state=random_state)
cv_results = cross_validate(
classifier, X, y, cv=cv, return_estimator=True, return_indices=True
)

prop_cycle = plt.rcParams["axes.prop_cycle"]
colors = prop_cycle.by_key()["color"]
curve_kwargs_list = [
dict(alpha=0.3, lw=1, color=colors[fold % len(colors)]) for fold in range(n_splits)
dict(
alpha=0.3,
lw=1,
label=None,
color=colors[fold % len(colors)],
)
for fold in range(n_splits)
]
names = [f"ROC fold {idx}" for idx in range(n_splits)]

Expand Down Expand Up @@ -116,27 +119,27 @@
mean_fpr,
mean_tpr,
color="b",
label=r"Mean ROC (AUC = %0.2f $\pm$ %0.2f)" % (mean_auc, std_auc),
label=rf"Mean ROC (AUC = {mean_auc:.2f} $\pm$ {std_auc:.2f})",
lw=2,
alpha=0.8,
)

std_tpr = np.std(interp_tprs, axis=0)
tprs_upper = np.minimum(mean_tpr + std_tpr, 1)
tprs_lower = np.maximum(mean_tpr - std_tpr, 0)

upper_quantile = np.quantile(interp_tprs, 0.95, axis=0)
lower_quantile = np.quantile(interp_tprs, 0.05, axis=0)
ax.fill_between(
mean_fpr,
tprs_lower,
tprs_upper,
lower_quantile,
upper_quantile,
color="grey",
alpha=0.2,
label=r"$\pm$ 1 std. dev.",
alpha=0.4,
label="5% to 95% percentile region",
)

ax.set(
xlabel="False Positive Rate",
ylabel="True Positive Rate",
title=f"Mean ROC curve with variability\n(Positive label '{target_names[1]}')",
title="Mean ROC curve with variability",
)
ax.legend(loc="lower right")
plt.show()