diff --git a/examples/model_selection/plot_roc_crossval.py b/examples/model_selection/plot_roc_crossval.py index 3c5c3fc9119b7..5af8276bb36d7 100644 --- a/examples/model_selection/plot_roc_crossval.py +++ b/examples/model_selection/plot_roc_crossval.py @@ -13,12 +13,11 @@ better. The "steepness" of ROC curves is also important, since it is ideal to maximize the TPR while minimizing the FPR. -This example shows the ROC response of different datasets, created from K-fold -cross-validation. Taking all of these curves, it is possible to calculate the -mean AUC, and see the variance of the curve when the -training set is split into different subsets. This roughly shows how the -classifier output is affected by changes in the training data, and how different -the splits generated by K-fold cross-validation are from one another. +This example demonstrates how the classifier's ROC response is influenced by +variations in the training data as obtained through ShuffleSplit cross-validation. +By analyzing all these curves, we can calculate the mean AUC and visualize the +variability of the estimated curves across CV splits via a quantile-based +region. .. note:: @@ -34,49 +33,47 @@ # Load and prepare data # ===================== # -# We import the :ref:`iris_dataset` which contains 3 classes, each one -# corresponding to a type of iris plant. One class is linearly separable from -# the other 2; the latter are **not** linearly separable from each other. -# -# In the following we binarize the dataset by dropping the "virginica" class -# (`class_id=2`). This means that the "versicolor" class (`class_id=1`) is -# regarded as the positive class and "setosa" as the negative class -# (`class_id=0`). - -import numpy as np - -from sklearn.datasets import load_iris - -iris = load_iris() -target_names = iris.target_names -X, y = iris.data, iris.target -X, y = X[y != 2], y[y != 2] -n_samples, n_features = X.shape - -# %% -# We also add noisy features to make the problem harder. -random_state = np.random.RandomState(0) -X = np.concatenate([X, random_state.randn(n_samples, 200 * n_features)], axis=1) +# We use :class:`~sklearn.datasets.make_classification` to generate a synthetic +# dataset with 1,000 samples. The generated dataset has two classes by default. +# In this case, we set a class separation factor of 0.5, making the classes +# partially overlapping and not perfectly linearly separable. + +from sklearn.datasets import make_classification + +X, y = make_classification( + n_samples=1_000, + n_features=2, + n_redundant=0, + n_informative=2, + class_sep=0.5, + random_state=0, + n_clusters_per_class=1, +) # %% # Classification and ROC analysis # ------------------------------- # # Here we run :func:`~sklearn.model_selection.cross_validate` on a -# :class:`~sklearn.svm.SVC` classifier, then use the computed cross-validation results -# to plot the ROC curves fold-wise. Notice that the baseline to define the chance -# level (dashed ROC curve) is a classifier that would always predict the most -# frequent class. +# :class:`~sklearn.ensemble.HistGradientBoostingClassifier`, then use the +# computed cross-validation results to plot the ROC curves fold-wise. Notice +# that the baseline to define the chance level (dashed ROC curve) is a +# classifier that would always predict the most frequent class. +# +# In the following plot, quantile coverage is represented in grey, though the +# AUC value is reported in terms of the mean and standar deviation. import matplotlib.pyplot as plt +import numpy as np -from sklearn import svm +from sklearn.ensemble import HistGradientBoostingClassifier from sklearn.metrics import RocCurveDisplay, auc -from sklearn.model_selection import StratifiedKFold, cross_validate +from sklearn.model_selection import StratifiedShuffleSplit, cross_validate + +n_splits = 30 +cv = StratifiedShuffleSplit(n_splits=n_splits, random_state=0) +classifier = HistGradientBoostingClassifier(random_state=42) -n_splits = 6 -cv = StratifiedKFold(n_splits=n_splits) -classifier = svm.SVC(kernel="linear", probability=True, random_state=random_state) cv_results = cross_validate( classifier, X, y, cv=cv, return_estimator=True, return_indices=True ) @@ -84,7 +81,13 @@ prop_cycle = plt.rcParams["axes.prop_cycle"] colors = prop_cycle.by_key()["color"] curve_kwargs_list = [ - dict(alpha=0.3, lw=1, color=colors[fold % len(colors)]) for fold in range(n_splits) + dict( + alpha=0.3, + lw=1, + label=None, + color=colors[fold % len(colors)], + ) + for fold in range(n_splits) ] names = [f"ROC fold {idx}" for idx in range(n_splits)] @@ -116,27 +119,27 @@ mean_fpr, mean_tpr, color="b", - label=r"Mean ROC (AUC = %0.2f $\pm$ %0.2f)" % (mean_auc, std_auc), + label=rf"Mean ROC (AUC = {mean_auc:.2f} $\pm$ {std_auc:.2f})", lw=2, alpha=0.8, ) -std_tpr = np.std(interp_tprs, axis=0) -tprs_upper = np.minimum(mean_tpr + std_tpr, 1) -tprs_lower = np.maximum(mean_tpr - std_tpr, 0) + +upper_quantile = np.quantile(interp_tprs, 0.95, axis=0) +lower_quantile = np.quantile(interp_tprs, 0.05, axis=0) ax.fill_between( mean_fpr, - tprs_lower, - tprs_upper, + lower_quantile, + upper_quantile, color="grey", - alpha=0.2, - label=r"$\pm$ 1 std. dev.", + alpha=0.4, + label="5% to 95% percentile region", ) ax.set( xlabel="False Positive Rate", ylabel="True Positive Rate", - title=f"Mean ROC curve with variability\n(Positive label '{target_names[1]}')", + title="Mean ROC curve with variability", ) ax.legend(loc="lower right") plt.show()