diff --git a/doc/whats_new/v1.5.rst b/doc/whats_new/v1.5.rst index aee5f247c2a98..fdff62e3bce93 100644 --- a/doc/whats_new/v1.5.rst +++ b/doc/whats_new/v1.5.rst @@ -42,6 +42,13 @@ Changelog :mod:`sklearn.ensemble` ....................... +- |MajorFeature| :class:`ensemble.HistGradientBoostingClassifier` and + :class:`ensemble.HistGradientBoostingRegressor` now has a + `on_high_cardinality_categories="bin_infrequent"`. option to automatically + encode high cardinality categories. For categories with cardinality higher + than `max_bins`, the infrequent categories are grouped together such there are + only `max_bins` categories. :pr:`26268` by `Thomas Fan`_. + - |Efficiency| Improves runtime of `predict` of :class:`ensemble.HistGradientBoostingClassifier` by avoiding to call `predict_proba`. :pr:`27844` by :user:`Christian Lorentzen `. diff --git a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py index 69ae0090b1fb8..fb7360c7e2dcf 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py +++ b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py @@ -169,6 +169,7 @@ class BaseHistGradientBoosting(BaseEstimator, ABC): Hidden(StrOptions({"warn"})), None, ], + "on_high_cardinality_categories": [StrOptions({"error", "bin_infrequent"})], "warm_start": ["boolean"], "early_stopping": [StrOptions({"auto"}), "boolean"], "scoring": [str, callable, None], @@ -190,6 +191,7 @@ def __init__( max_features, max_bins, categorical_features, + on_high_cardinality_categories, monotonic_cst, interaction_cst, warm_start, @@ -213,6 +215,7 @@ def __init__( self.monotonic_cst = monotonic_cst self.interaction_cst = interaction_cst self.categorical_features = categorical_features + self.on_high_cardinality_categories = on_high_cardinality_categories self.warm_start = warm_start self.early_stopping = early_stopping self.scoring = scoring @@ -277,12 +280,18 @@ def _preprocess_X(self, X, *, reset): X = self._validate_data(X, **check_X_kwargs) return X, None + if self.on_high_cardinality_categories == "bin_infrequent": + max_categories = self.max_bins + else: + max_categories = None + n_features = X.shape[1] ordinal_encoder = OrdinalEncoder( categories="auto", handle_unknown="use_encoded_value", unknown_value=np.nan, encoded_missing_value=np.nan, + max_categories=max_categories, dtype=X_DTYPE, ) @@ -333,15 +342,23 @@ def _check_categories(self): categorical_column_indices = np.arange(self._preprocessor.n_features_in_)[ self._preprocessor.output_indices_["encoder"] ] - for feature_idx, categories in zip( - categorical_column_indices, encoder.categories_ + bin_infrequent = self.on_high_cardinality_categories == "bin_infrequent" + try: + n_infrequent_categories = [ + 0 if cat is None else len(cat) for cat in encoder.infrequent_categories_ + ] + except AttributeError: + n_infrequent_categories = [0] * len(encoder.categories_) + + for feature_idx, categories, n_infrequent in zip( + categorical_column_indices, encoder.categories_, n_infrequent_categories ): # OrdinalEncoder always puts np.nan as the last category if the # training data has missing values. Here we remove it because it is # already added by the _BinMapper. if len(categories) and is_scalar_nan(categories[-1]): categories = categories[:-1] - if categories.size > self.max_bins: + if not bin_infrequent and categories.size > self.max_bins: try: feature_name = repr(encoder.feature_names_in_[feature_idx]) except AttributeError: @@ -351,7 +368,10 @@ def _check_categories(self): f"have a cardinality <= {self.max_bins} but actually " f"has a cardinality of {categories.size}." ) - known_categories[feature_idx] = np.arange(len(categories), dtype=X_DTYPE) + + # infrequent categories are grouped into one category + total_categories = len(categories) - n_infrequent + 1 + known_categories[feature_idx] = np.arange(total_categories, dtype=X_DTYPE) return known_categories def _check_categorical_features(self, X): @@ -1527,6 +1547,17 @@ class HistGradientBoostingRegressor(RegressorMixin, BaseHistGradientBoosting): Added `"from_dtype"` option. The default will change to `"from_dtype"` in v1.6. + on_high_cardinality_categories : {"error", "bin_infrequent"}, default="error" + Whether to raise an error or to bin together the least frequent categorical + features. + + - `"error"`: Raises an error when the cardinality of a categorical feature + is higher than `max_bins` or is encoded with a value greater than `max_bins`. + - `"bin_infrequent"`: Bins the least frequent categorical features + such that there is no more than `max_bins` categories. + + .. versionadded:: 1.5 + monotonic_cst : array-like of int of shape (n_features) or dict, default=None Monotonic constraint to enforce on each feature are specified using the following integer values: @@ -1698,6 +1729,7 @@ def __init__( max_features=1.0, max_bins=255, categorical_features="warn", + on_high_cardinality_categories="error", monotonic_cst=None, interaction_cst=None, warm_start=False, @@ -1722,6 +1754,7 @@ def __init__( monotonic_cst=monotonic_cst, interaction_cst=interaction_cst, categorical_features=categorical_features, + on_high_cardinality_categories=on_high_cardinality_categories, early_stopping=early_stopping, warm_start=warm_start, scoring=scoring, @@ -1903,6 +1936,17 @@ class HistGradientBoostingClassifier(ClassifierMixin, BaseHistGradientBoosting): Added `"from_dtype"` option. The default will change to `"from_dtype"` in v1.6. + on_high_cardinality_categories : {"error", "bin_infrequent"}, default="error" + Whether to raise an error or to bin together the least frequent categorical + features. + + - `"error"`: Raises an error when the cardinality of a categorical feature + is higher than `max_bins` or is encoded with a value greater than `max_bins`. + - `"bin_infrequent"`: Bins the least frequent categorical features + such that there is no more than `max_bins` categories. + + .. versionadded:: 1.5 + monotonic_cst : array-like of int of shape (n_features) or dict, default=None Monotonic constraint to enforce on each feature are specified using the following integer values: @@ -2076,6 +2120,7 @@ def __init__( max_features=1.0, max_bins=255, categorical_features="warn", + on_high_cardinality_categories="error", monotonic_cst=None, interaction_cst=None, warm_start=False, @@ -2099,6 +2144,7 @@ def __init__( max_features=max_features, max_bins=max_bins, categorical_features=categorical_features, + on_high_cardinality_categories=on_high_cardinality_categories, monotonic_cst=monotonic_cst, interaction_cst=interaction_cst, warm_start=warm_start, diff --git a/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py b/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py index bdc85eccd6607..9ef009b6c9a43 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py +++ b/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py @@ -34,7 +34,12 @@ from sklearn.metrics import get_scorer, mean_gamma_deviance, mean_poisson_deviance from sklearn.model_selection import cross_val_score, train_test_split from sklearn.pipeline import make_pipeline -from sklearn.preprocessing import KBinsDiscretizer, MinMaxScaler, OneHotEncoder +from sklearn.preprocessing import ( + KBinsDiscretizer, + MinMaxScaler, + OneHotEncoder, + OrdinalEncoder, +) from sklearn.utils import _IS_32BIT, shuffle from sklearn.utils._openmp_helpers import _openmp_effective_n_threads from sklearn.utils._testing import _convert_container @@ -1447,6 +1452,124 @@ def test_unknown_category_that_are_negative(): assert_allclose(hist.predict(X_test_neg), hist.predict(X_test_nan)) +@pytest.mark.parametrize( + "Hist", [HistGradientBoostingClassifier, HistGradientBoostingRegressor] +) +def test_categorical_cardinality_higher_than_n_bins(Hist): + """Check categorical works when the cardinality is greater than max_bins.""" + + rng = np.random.RandomState(42) + n_samples = 5_000 + n_cardinality = 100 + max_bins = 64 + f_num = rng.rand(n_samples) + f_cat = rng.randint(n_cardinality, size=n_samples) + # f_cat is an informative feature + y = f_cat % 3 == 0 + categorical_features = np.array([True, False]) + + X = np.c_[f_cat, f_num] + + X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) + + hist_kwargs = dict(max_iter=1, max_bins=max_bins, random_state=0) + hist_native = Hist( + categorical_features=categorical_features, + on_high_cardinality_categories="bin_infrequent", + **hist_kwargs, + ) + hist_native.fit(X_train, y_train) + + # Using a preprocessor with max_categories=max_bins should give the same model + # as the native implementation which uses the same preprocessing strategy + column_transformer = make_column_transformer( + ( + OrdinalEncoder( + categories="auto", + handle_unknown="use_encoded_value", + unknown_value=np.nan, + encoded_missing_value=np.nan, + max_categories=max_bins, + dtype=np.float64, + ), + categorical_features, + ), + ("passthrough", ~categorical_features), + ) + hist_with_prep = make_pipeline( + column_transformer, + Hist(categorical_features=categorical_features, **hist_kwargs), + ) + hist_with_prep.fit(X_train, y_train) + + # Check that preprocessors returns the same transformed data + assert_allclose( + hist_native._preprocessor.transform(X_train), + hist_with_prep[0].transform(X_train), + ) + + # Check that the trees are the same and have the same performance + assert len(hist_native._predictors) == len(hist_with_prep[-1]._predictors) + for predictor_1, predictor_2 in zip( + hist_native._predictors, hist_with_prep[-1]._predictors + ): + assert len(predictor_1[0].nodes) == len(predictor_2[0].nodes) + + score_native = hist_native.score(X_test, y_test) + score_with_prep = hist_with_prep.score(X_test, y_test) + assert score_with_prep == pytest.approx(score_native) + + assert_allclose(hist_native.predict(X_test), hist_with_prep.predict(X_test)) + + +@pytest.mark.parametrize( + "Hist", [HistGradientBoostingClassifier, HistGradientBoostingRegressor] +) +def test_categorical_encoding_higher_than_n_bins(Hist): + """Check that categorical encoding can be greater than n_bins.""" + + rng = np.random.RandomState(42) + n_samples = 5_000 + n_cardinality = 4 + max_bins = 10 + f_num = rng.rand(n_samples) + f_cat = rng.randint(n_cardinality, size=n_samples) + # f_cat is an informative feature + y = f_cat % 3 == 0 + X1 = np.c_[f_cat, f_num] + categorical_features = [True, False] + + # Categorical feature above max_bins + f_cat_ = f_cat.copy() + f_cat_[f_cat_ == 3] = max_bins + 1 + X2 = np.c_[f_cat_, f_num] + + X1_train, X1_test, X2_train, X2_test, y_train, y_test = train_test_split( + X1, X2, y, random_state=0 + ) + + hist_kwargs = dict(max_iter=10, max_bins=max_bins, random_state=0) + hist_in_bounds = Hist(categorical_features=categorical_features, **hist_kwargs) + hist_in_bounds.fit(X1_train, y_train) + score_in_bounds = hist_in_bounds.score(X1_test, y_test) + + hist_out_of_bounds = Hist( + categorical_features=categorical_features, + on_high_cardinality_categories="bin_infrequent", + **hist_kwargs, + ) + hist_out_of_bounds.fit(X2_train, y_train) + score_out_of_bounds = hist_out_of_bounds.score(X2_test, y_test) + + assert len(hist_in_bounds._predictors) == len(hist_out_of_bounds._predictors) + for predictor_1, predictor_2 in zip( + hist_in_bounds._predictors, hist_out_of_bounds._predictors + ): + assert len(predictor_1[0].nodes) == len(predictor_2[0].nodes) + + assert score_in_bounds == pytest.approx(score_out_of_bounds) + + @pytest.mark.parametrize("dataframe_lib", ["pandas", "polars"]) @pytest.mark.parametrize( "HistGradientBoosting",