From ab9cf7f1dd0b166a3b0282dd1aa07e86c5ee855e Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Fri, 14 Apr 2023 17:30:43 -0400 Subject: [PATCH 1/4] DOC Adds target encoder example about internal cross validation --- doc/modules/preprocessing.rst | 1 + .../plot_target_encoder_cross_val.py | 159 ++++++++++++++++++ 2 files changed, 160 insertions(+) create mode 100644 examples/preprocessing/plot_target_encoder_cross_val.py diff --git a/doc/modules/preprocessing.rst b/doc/modules/preprocessing.rst index dc151871874d4..69045147d8af9 100644 --- a/doc/modules/preprocessing.rst +++ b/doc/modules/preprocessing.rst @@ -941,6 +941,7 @@ learned in :meth:`~TargetEncoder.fit_transform`. .. topic:: Examples: * :ref:`sphx_glr_auto_examples_preprocessing_plot_target_encoder.py` + * :ref:`sphx_glr_auto_examples_preprocessing_plot_target_encoder_cross_val.py` .. topic:: References diff --git a/examples/preprocessing/plot_target_encoder_cross_val.py b/examples/preprocessing/plot_target_encoder_cross_val.py new file mode 100644 index 0000000000000..46ea61bb2b3d5 --- /dev/null +++ b/examples/preprocessing/plot_target_encoder_cross_val.py @@ -0,0 +1,159 @@ +""" +========================================== +Target Encoder's Internal Cross Validation +========================================== + +.. currentmodule:: sklearn.preprocessing + +The :class:`TargetEnocoder` replaces each category of a categorical feature with +the mean of the target variable for that category. This method is useful +in cases where there is a strong relationship between the categorical feature +and the target. To prevent overfitting, :meth:`TargetEncoder.fit_transform` uses +interval cross validation to encode the training data to be used by a downstream +model. In this example, we demonstrate the importance of the cross validation +procedure to prevent overfitting. +""" + +# %% +# Create Synthetic Dataset +# ======================== +# For this example, we build a dataset with three categorical features: an informative +# and feature with medium cardinality, an uninformative feature with medium cardinality, +# and an uninformative feature with high cardinality. First, we generate the informative +# feature: +from sklearn.preprocessing import KBinsDiscretizer +import numpy as np + +n_samples = 50_000 + +rng = np.random.RandomState(42) +y = rng.randn(n_samples) +noise = 0.5 * rng.randn(n_samples) +n_categories = 100 + +kbins = KBinsDiscretizer( + n_bins=n_categories, encode="ordinal", strategy="uniform", random_state=rng +) +X_informative = kbins.fit_transform((y + noise).reshape(-1, 1)) + +# Permute the feature to remove the information from the ordering +permuted_categories = rng.permutation(n_categories) +X_informative = permuted_categories[X_informative.astype(np.int32)] + +# %% +# The uninformative feature with medium cardinality is generated by permuting the +# informative feature and removing the relationship with the target: +X_shuffled = rng.permutation(X_informative) + +# %% +# The uninformative feature with high cardinality is generated that is independent of +# the target variable. We will show that target encoding without cross validation will +# cause catastrophic overfitting for the downstream regressor. These high cardinality +# features are typically unique identifies for samples which should be generally be +# removed from machine learning dataset. In this example, We generate them to show how +# :class:`TargetEncoder`'s default cross validation behavior mitigates the overfitting +# issue automatically. +X_near_unique_categories = rng.choice( + int(0.9 * n_samples), size=n_samples, replace=True +).reshape(-1, 1) + +# %% +# Finally, we assemble the dataset and perform a train test split: +from sklearn.model_selection import train_test_split +import pandas as pd + +X = pd.DataFrame( + np.concatenate( + [X_informative, X_shuffled, X_near_unique_categories], + axis=1, + ), + columns=["informative", "shuffled", "near_unique"], +) +X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) + +# %% +# Training a Ridge Regressors +# =========================== +# In this section, we train ridge regressors on the dataset with and without +# encoding and explore the influence of target encoder with and without the +# interval cross validation. First, we see the Ridge model trained on the +# raw features will underfit, because the order of the informative feature is +# not informative: +from sklearn.linear_model import Ridge +import sklearn + +# Configure transformers to always output DataFrames +sklearn.set_config(transform_output="pandas") + +ridge = Ridge(alpha=1e-6, solver="lsqr", fit_intercept=False) + +raw_model = ridge.fit(X_train, y_train) +print("Raw Model score on training set: ", raw_model.score(X_train, y_train)) +print("Raw Model score on test set: ", raw_model.score(X_test, y_test)) + +# %% +# Next, we create a pipeline with the target encoder and ridge model. The pipeline +# uses :meth:`TargetEncoder.fit_transform` which uses cross validation. We see that +# the model fits the data well and generalizes to the test set: +from sklearn.pipeline import make_pipeline +from sklearn.preprocessing import TargetEncoder + +model_with_cv = make_pipeline(TargetEncoder(random_state=0), ridge) +model_with_cv.fit(X_train, y_train) +print("Model with CV on training set: ", model_with_cv.score(X_train, y_train)) +print("Model with CV on test set: ", model_with_cv.score(X_test, y_test)) + +# %% +# The coefficients of the linear model shows that most of the weight is on the +# feature at column index 0, which is the informative feature +import pandas as pd +import matplotlib.pyplot as plt + +plt.rcParams["figure.constrained_layout.use"] = True + +coefs_cv = pd.Series( + model_with_cv[-1].coef_, index=model_with_cv[-1].feature_names_in_ +).sort_values() +_ = coefs_cv.plot(kind="barh") + +# %% +# While :meth:`TargetEncoder.fit_transform` uses an interval cross validation, +# :meth:`TargetEncoder.transform` itself does not perform any cross validation. +# It uses the aggregation of the complete training set to transform the categorical +# features. Thus, we can use :meth:`TargetEncoder.fit` followed by +# :meth:`TargetEncoder.transform` to disable the cross validation. This encoding +# is then passed to the ridge model. +target_encoder = TargetEncoder(random_state=0) +target_encoder.fit(X_train, y_train) +X_train_no_cv_encoding = target_encoder.transform(X_train) +X_test_no_cv_encoding = target_encoder.transform(X_test) + +model_no_cv = ridge.fit(X_train_no_cv_encoding, y_train) + +# %% +# We evaluate the model on the non-cross validated encoding and see that it overfits: +print( + "Model without CV on training set: ", + model_no_cv.score(X_train_no_cv_encoding, y_train), +) +print( + "Model without CV on test set: ", model_no_cv.score(X_test_no_cv_encoding, y_test) +) + +# %% +# The ridge model overfits, because it assigns more weights to the extremely high +# cardinality feature relative to the informative feature. +coefs_no_cv = pd.Series( + model_no_cv.coef_, index=model_no_cv.feature_names_in_ +).sort_values() +_ = coefs_no_cv.plot(kind="barh") + +# %% +# Conclusion +# ========== +# This example demonstrates the importance of :class:`TargetEncoder`'s interval cross +# validation. It is important to use :meth:`TargetEncoder.fit_transform` to encode +# training data before passing it to a machine learning model. When a +# :class:`TargetEncoder` is a part of a :class:`~sklearn.pipeline.Pipeline` and the +# pipeline is fitted, the pipeline will correctly call +# :meth:`TargetEncoder.fit_transform` and pass the encoding along. From b9c53c0cc1d5315b804c8dfa8b23d347ddbfd7ae Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Tue, 18 Apr 2023 09:04:32 -0400 Subject: [PATCH 2/4] Apply suggestions from code review Co-authored-by: Tim Head --- .../preprocessing/plot_target_encoder_cross_val.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/preprocessing/plot_target_encoder_cross_val.py b/examples/preprocessing/plot_target_encoder_cross_val.py index 46ea61bb2b3d5..19e8f84570535 100644 --- a/examples/preprocessing/plot_target_encoder_cross_val.py +++ b/examples/preprocessing/plot_target_encoder_cross_val.py @@ -18,7 +18,7 @@ # Create Synthetic Dataset # ======================== # For this example, we build a dataset with three categorical features: an informative -# and feature with medium cardinality, an uninformative feature with medium cardinality, +# feature with medium cardinality, an uninformative feature with medium cardinality, # and an uninformative feature with high cardinality. First, we generate the informative # feature: from sklearn.preprocessing import KBinsDiscretizer @@ -46,11 +46,11 @@ X_shuffled = rng.permutation(X_informative) # %% -# The uninformative feature with high cardinality is generated that is independent of +# The uninformative feature with high cardinality is generated so that is independent of # the target variable. We will show that target encoding without cross validation will # cause catastrophic overfitting for the downstream regressor. These high cardinality -# features are typically unique identifies for samples which should be generally be -# removed from machine learning dataset. In this example, We generate them to show how +# features are basically unique identifiers for samples which should be generally be +# removed from machine learning dataset. In this example, we generate them to show how # :class:`TargetEncoder`'s default cross validation behavior mitigates the overfitting # issue automatically. X_near_unique_categories = rng.choice( @@ -72,7 +72,7 @@ X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) # %% -# Training a Ridge Regressors +# Training a Ridge Regressor # =========================== # In this section, we train ridge regressors on the dataset with and without # encoding and explore the influence of target encoder with and without the From 6755256f594df5345e5d7c69f3bc3ac41e446121 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Tue, 18 Apr 2023 09:10:50 -0400 Subject: [PATCH 3/4] DOC Address comments --- .../preprocessing/plot_target_encoder_cross_val.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/examples/preprocessing/plot_target_encoder_cross_val.py b/examples/preprocessing/plot_target_encoder_cross_val.py index 19e8f84570535..10f61b415b7b3 100644 --- a/examples/preprocessing/plot_target_encoder_cross_val.py +++ b/examples/preprocessing/plot_target_encoder_cross_val.py @@ -36,7 +36,8 @@ ) X_informative = kbins.fit_transform((y + noise).reshape(-1, 1)) -# Permute the feature to remove the information from the ordering +# Remove the linear relationship between y and the bin index by permuting the values of +# X_informative permuted_categories = rng.permutation(n_categories) X_informative = permuted_categories[X_informative.astype(np.int32)] @@ -73,12 +74,12 @@ # %% # Training a Ridge Regressor -# =========================== -# In this section, we train ridge regressors on the dataset with and without +# ========================== +# In this section, we train a ridge regressor on the dataset with and without # encoding and explore the influence of target encoder with and without the # interval cross validation. First, we see the Ridge model trained on the -# raw features will underfit, because the order of the informative feature is -# not informative: +# raw features will have low performance, because the order of the informative +# feature is not informative: from sklearn.linear_model import Ridge import sklearn From 392c9082abd71dffc8d54dc6174a664d6d9ab8c2 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Wed, 19 Apr 2023 11:10:18 -0400 Subject: [PATCH 4/4] Apply suggestions from code review Co-authored-by: Tim Head --- examples/preprocessing/plot_target_encoder_cross_val.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/preprocessing/plot_target_encoder_cross_val.py b/examples/preprocessing/plot_target_encoder_cross_val.py index 10f61b415b7b3..455625cc47460 100644 --- a/examples/preprocessing/plot_target_encoder_cross_val.py +++ b/examples/preprocessing/plot_target_encoder_cross_val.py @@ -50,7 +50,7 @@ # The uninformative feature with high cardinality is generated so that is independent of # the target variable. We will show that target encoding without cross validation will # cause catastrophic overfitting for the downstream regressor. These high cardinality -# features are basically unique identifiers for samples which should be generally be +# features are basically unique identifiers for samples which should generally be # removed from machine learning dataset. In this example, we generate them to show how # :class:`TargetEncoder`'s default cross validation behavior mitigates the overfitting # issue automatically. @@ -142,7 +142,7 @@ ) # %% -# The ridge model overfits, because it assigns more weights to the extremely high +# The ridge model overfits, because it assigns more weight to the extremely high # cardinality feature relative to the informative feature. coefs_no_cv = pd.Series( model_no_cv.coef_, index=model_no_cv.feature_names_in_