diff --git a/doc/conf.py b/doc/conf.py index 61df593f3fe8f..278b588c103b5 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -447,6 +447,9 @@ def add_js_css_files(app, pagename, templatename, context, doctree): "auto_examples/model_selection/grid_search_text_feature_extraction": ( "auto_examples/model_selection/plot_grid_search_text_feature_extraction" ), + "auto_examples/model_selection/plot_validation_curve": ( + "auto_examples/model_selection/plot_train_error_vs_test_error" + ), "auto_examples/datasets/plot_digits_last_image": ( "auto_examples/exercises/plot_digits_classification_exercises" ), diff --git a/examples/model_selection/plot_train_error_vs_test_error.py b/examples/model_selection/plot_train_error_vs_test_error.py index dc370383b2ef7..a64b4ca94846e 100644 --- a/examples/model_selection/plot_train_error_vs_test_error.py +++ b/examples/model_selection/plot_train_error_vs_test_error.py @@ -1,15 +1,18 @@ """ -========================= -Train error vs Test error -========================= +========================================================= +Effect of model regularization on training and test error +========================================================= -Illustration of how the performance of an estimator on unseen data (test data) -is not the same as the performance on training data. As the regularization -increases the performance on train decreases while the performance on test -is optimal within a range of values of the regularization parameter. -The example with an Elastic-Net regression model and the performance is -measured using the explained variance a.k.a. R^2. +In this example, we evaluate the impact of the regularization parameter in a +linear model called :class:`~sklearn.linear_model.ElasticNet`. To carry out this +evaluation, we use a validation curve using +:class:`~sklearn.model_selection.ValidationCurveDisplay`. This curve shows the +training and test scores of the model for different values of the regularization +parameter. +Once we identify the optimal regularization parameter, we compare the true and +estimated coefficients of the model to determine if the model is able to recover +the coefficients from the noisy input data. """ # Authors: The scikit-learn developers @@ -18,71 +21,146 @@ # %% # Generate sample data # -------------------- -import numpy as np - -from sklearn import linear_model +# +# We generate a regression dataset that contains many features relative to the +# number of samples. However, only 10% of the features are informative. In this context, +# linear models exposing L1 penalization are commonly used to recover a sparse +# set of coefficients. from sklearn.datasets import make_regression from sklearn.model_selection import train_test_split -n_samples_train, n_samples_test, n_features = 75, 150, 500 -X, y, coef = make_regression( +n_samples_train, n_samples_test, n_features = 150, 300, 500 +X, y, true_coef = make_regression( n_samples=n_samples_train + n_samples_test, n_features=n_features, n_informative=50, shuffle=False, noise=1.0, coef=True, + random_state=42, ) X_train, X_test, y_train, y_test = train_test_split( X, y, train_size=n_samples_train, test_size=n_samples_test, shuffle=False ) + # %% -# Compute train and test errors -# ----------------------------- +# Model definition +# ---------------- +# +# Here, we do not use a model that only exposes an L1 penalty. Instead, we use +# an :class:`~sklearn.linear_model.ElasticNet` model that exposes both L1 and L2 +# penalties. +# +# We fix the `l1_ratio` parameter such that the solution found by the model is still +# sparse. Therefore, this type of model tries to find a sparse solution but at the same +# time also tries to shrink all coefficients towards zero. +# +# In addition, we force the coefficients of the model to be positive since we know that +# `make_regression` generates a response with a positive signal. So we use this +# pre-knowledge to get a better model. + +from sklearn.linear_model import ElasticNet + +enet = ElasticNet(l1_ratio=0.9, positive=True, max_iter=10_000) + + +# %% +# Evaluate the impact of the regularization parameter +# --------------------------------------------------- +# +# To evaluate the impact of the regularization parameter, we use a validation +# curve. This curve shows the training and test scores of the model for different +# values of the regularization parameter. +# +# The regularization `alpha` is a parameter applied to the coefficients of the model: +# when it tends to zero, no regularization is applied and the model tries to fit the +# training data with the least amount of error. However, it leads to overfitting when +# features are noisy. When `alpha` increases, the model coefficients are constrained, +# and thus the model cannot fit the training data as closely, avoiding overfitting. +# However, if too much regularization is applied, the model underfits the data and +# is not able to properly capture the signal. +# +# The validation curve helps in finding a good trade-off between both extremes: the +# model is not regularized and thus flexible enough to fit the signal, but not too +# flexible to overfit. The :class:`~sklearn.model_selection.ValidationCurveDisplay` +# allows us to display the training and validation scores across a range of alpha +# values. +import numpy as np + +from sklearn.model_selection import ValidationCurveDisplay + alphas = np.logspace(-5, 1, 60) -enet = linear_model.ElasticNet(l1_ratio=0.7, max_iter=10000) -train_errors = list() -test_errors = list() -for alpha in alphas: - enet.set_params(alpha=alpha) - enet.fit(X_train, y_train) - train_errors.append(enet.score(X_train, y_train)) - test_errors.append(enet.score(X_test, y_test)) - -i_alpha_optim = np.argmax(test_errors) -alpha_optim = alphas[i_alpha_optim] -print("Optimal regularization parameter : %s" % alpha_optim) - -# Estimate the coef_ on full data with optimal regularization parameter -enet.set_params(alpha=alpha_optim) -coef_ = enet.fit(X, y).coef_ +disp = ValidationCurveDisplay.from_estimator( + enet, + X_train, + y_train, + param_name="alpha", + param_range=alphas, + scoring="r2", + n_jobs=2, + score_type="both", +) +disp.ax_.set( + title=r"Validation Curve for ElasticNet (R$^2$ Score)", + xlabel=r"alpha (regularization strength)", + ylabel="R$^2$ Score", +) + +test_scores_mean = disp.test_scores.mean(axis=1) +idx_avg_max_test_score = np.argmax(test_scores_mean) +disp.ax_.vlines( + alphas[idx_avg_max_test_score], + disp.ax_.get_ylim()[0], + test_scores_mean[idx_avg_max_test_score], + color="k", + linewidth=2, + linestyle="--", + label=f"Optimum on test\n$\\alpha$ = {alphas[idx_avg_max_test_score]:.2e}", +) +_ = disp.ax_.legend(loc="lower right") # %% -# Plot results functions -# ---------------------- +# To find the optimal regularization parameter, we can select the value of `alpha` +# that maximizes the validation score. +# +# Coefficients comparison +# ----------------------- +# +# Now that we have identified the optimal regularization parameter, we can compare the +# true coefficients and the estimated coefficients. +# +# First, let's set the regularization parameter to the optimal value and fit the +# model on the training data. In addition, we'll show the test score for this model. +enet.set_params(alpha=alphas[idx_avg_max_test_score]).fit(X_train, y_train) +print( + f"Test score: {enet.score(X_test, y_test):.3f}", +) +# %% +# Now, we plot the true coefficients and the estimated coefficients. import matplotlib.pyplot as plt -plt.subplot(2, 1, 1) -plt.semilogx(alphas, train_errors, label="Train") -plt.semilogx(alphas, test_errors, label="Test") -plt.vlines( - alpha_optim, - plt.ylim()[0], - np.max(test_errors), - color="k", - linewidth=3, - label="Optimum on test", +fig, axs = plt.subplots(ncols=2, figsize=(12, 6), sharex=True, sharey=True) +for ax, coef, title in zip(axs, [true_coef, enet.coef_], ["True", "Model"]): + ax.stem(coef) + ax.set( + title=f"{title} Coefficients", + xlabel="Feature Index", + ylabel="Coefficient Value", + ) +fig.suptitle( + "Comparison of the coefficients of the true generative model and \n" + "the estimated elastic net coefficients" ) -plt.legend(loc="lower right") -plt.ylim([0, 1.2]) -plt.xlabel("Regularization parameter") -plt.ylabel("Performance") - -# Show estimated coef_ vs true coef -plt.subplot(2, 1, 2) -plt.plot(coef, label="True coef") -plt.plot(coef_, label="Estimated coef") -plt.legend() -plt.subplots_adjust(0.09, 0.04, 0.94, 0.94, 0.26, 0.26) + plt.show() + +# %% +# While the original coefficients are sparse, the estimated coefficients are not +# as sparse. The reason is that we fixed the `l1_ratio` parameter to 0.9. We could +# force the model to get a sparser solution by increasing the `l1_ratio` parameter. +# +# However, we observed that for the estimated coefficients that are close to zero in +# the true generative model, our model shrinks them towards zero. So we don't recover +# the true coefficients, but we get a sensible outcome in line with the performance +# obtained on the test set. diff --git a/examples/model_selection/plot_validation_curve.py b/examples/model_selection/plot_validation_curve.py deleted file mode 100644 index 44a382fed0c17..0000000000000 --- a/examples/model_selection/plot_validation_curve.py +++ /dev/null @@ -1,43 +0,0 @@ -""" -========================== -Plotting Validation Curves -========================== - -In this plot you can see the training scores and validation scores of an SVM -for different values of the kernel parameter gamma. For very low values of -gamma, you can see that both the training score and the validation score are -low. This is called underfitting. Medium values of gamma will result in high -values for both scores, i.e. the classifier is performing fairly well. If gamma -is too high, the classifier will overfit, which means that the training score -is good but the validation score is poor. - -""" - -# Authors: The scikit-learn developers -# SPDX-License-Identifier: BSD-3-Clause - -import matplotlib.pyplot as plt -import numpy as np - -from sklearn.datasets import load_digits -from sklearn.model_selection import ValidationCurveDisplay -from sklearn.svm import SVC - -X, y = load_digits(return_X_y=True) -subset_mask = np.isin(y, [1, 2]) # binary classification: 1 vs 2 -X, y = X[subset_mask], y[subset_mask] - -disp = ValidationCurveDisplay.from_estimator( - SVC(), - X, - y, - param_name="gamma", - param_range=np.logspace(-6, -1, 5), - score_type="both", - n_jobs=2, - score_name="Accuracy", -) -disp.ax_.set_title("Validation Curve for SVM with an RBF kernel") -disp.ax_.set_xlabel(r"gamma (inverse radius of the RBF kernel)") -disp.ax_.set_ylim(0.0, 1.1) -plt.show()