From 8a88513ae2813a6e8603f754fd7910fbaca416d0 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Sun, 23 Apr 2023 20:30:13 -0400 Subject: [PATCH 01/13] ENH Support categories higher than max_bins in HistGradientBoosting --- doc/whats_new/v1.3.rst | 7 + .../gradient_boosting.py | 133 ++++++++++---- .../tests/test_gradient_boosting.py | 172 +++++++++++++----- 3 files changed, 230 insertions(+), 82 deletions(-) diff --git a/doc/whats_new/v1.3.rst b/doc/whats_new/v1.3.rst index bb245aa466152..d55dc57fbcb89 100644 --- a/doc/whats_new/v1.3.rst +++ b/doc/whats_new/v1.3.rst @@ -255,6 +255,13 @@ Changelog out-of-bag scores via the `oob_scores_` or `oob_score_` attributes. :pr:`24882` by :user:`Ashwin Mathur `. +- |Feature| :class:`ensemble.HistGradientBoostingClassifier` and + :class:`ensemble.HistGradientBoostingRegressor` supports categories with + cardinality greater than `max_bins` or encoded with values greater than + `max_bins`. For categories with cardinality higher than `max_bins`, the + infrequent categories are grouped together such there are only `max_bins` + categories. :pr:`xxxxx` by `Thomas Fan`_. + - |Efficiency| :class:`ensemble.IsolationForest` predict time is now faster (typically by a factor of 8 or more). Internally, the estimator now precomputes decision path lengths per tree at `fit` time. It is therefore not possible diff --git a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py index 976335ea684d0..4792f3017a1d6 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py +++ b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py @@ -18,12 +18,16 @@ PinballLoss, ) from ...base import BaseEstimator, RegressorMixin, ClassifierMixin, is_classifier +from ...compose import ColumnTransformer +from ...preprocessing import OrdinalEncoder +from ...preprocessing import FunctionTransformer from ...utils import check_random_state, resample, compute_sample_weight from ...utils.validation import ( check_is_fitted, check_consistent_length, _check_sample_weight, _check_monotonic_cst, + _check_y, ) from ...utils._param_validation import Interval, StrOptions from ...utils._param_validation import RealNotInt @@ -176,6 +180,63 @@ class weights. """ return sample_weight + def _check_X(self, X, *, reset): + X = self._validate_data(X, dtype=[X_DTYPE], force_all_finite=False, reset=reset) + + if not reset: + return self._preprocessor.transform(X) + + self.is_categorical_, known_categories, requires_encoder = ( + self._check_categories(X) + ) + n_features = X.shape[1] + + if not requires_encoder: + self._preprocessor = FunctionTransformer().set_output(transform="default") + self._is_categorical_remapped = self.is_categorical_ + return X, known_categories + + # Create categories to pass into ordinal_encoder based on known_categories + categories_ = [c for c in known_categories if c is not None] + + ordinal_encoder = OrdinalEncoder( + categories=categories_, + handle_unknown="use_encoded_value", + unknown_value=np.nan, + encoded_missing_value=np.nan, + max_categories=self.max_bins, + dtype=X_DTYPE, + ) + + self._preprocessor = ColumnTransformer( + [ + ("numerical", "passthrough", ~self.is_categorical_), + ("encoder", ordinal_encoder, self.is_categorical_), + ] + ) + self._preprocessor.set_output(transform="default") + X = self._preprocessor.fit_transform(X) + + # Column Transformer places the categorical features at the end. + categorical_remapped = np.zeros(n_features, dtype=bool) + n_categorical = self.is_categorical_.sum() + categorical_remapped[-n_categorical:] = True + + self._is_categorical_remapped = categorical_remapped + + # OrdinalEncoder will map categories to [0,..., cardinality - 1] + # If categories are not grouped into infrequent categories, then OrdinalEncoder + # will map categories to [0, ..., cardinality - 1] + # If there are infrequent categories, then OrdinalEncoder will map categories to + # [0, ..., max_bins - 1]. + renamed_categories = [ + np.arange(min(len(c), self.max_bins), dtype=X_DTYPE) for c in categories_ + ] + + numerical_features = n_features - n_categorical + known_categories = [None] * numerical_features + renamed_categories + return X, known_categories + def _check_categories(self, X): """Check and validate categorical features in X @@ -189,14 +250,16 @@ def _check_categories(self, X): - an array of shape (n_categories,) with the unique cat values - None if the feature is not categorical None if no feature is categorical. + requires_encoding : bool + True if categorical features require a column transformer to encode """ if self.categorical_features is None: - return None, None + return None, None, False categorical_features = np.asarray(self.categorical_features) if categorical_features.size == 0: - return None, None + return None, None, False if categorical_features.dtype.kind not in ("i", "b", "U", "O"): raise ValueError( @@ -255,12 +318,13 @@ def _check_categories(self, X): is_categorical = categorical_features if not np.any(is_categorical): - return None, None + return None, None, False # Compute the known categories in the training data. We cannot do this # in the BinMapper because it only gets a fraction of the training data # when early stopping is enabled. known_categories = [] + requires_encoding = False for f_idx in range(n_features): if is_categorical[f_idx]: @@ -274,30 +338,17 @@ def _check_categories(self, X): if negative_categories.any(): categories = categories[~negative_categories] - if hasattr(self, "feature_names_in_"): - feature_name = f"'{self.feature_names_in_[f_idx]}'" - else: - feature_name = f"at index {f_idx}" - - if categories.size > self.max_bins: - raise ValueError( - f"Categorical feature {feature_name} is expected to " - f"have a cardinality <= {self.max_bins} but actually " - f"has a cardinality of {categories.size}." - ) - - if (categories >= self.max_bins).any(): - raise ValueError( - f"Categorical feature {feature_name} is expected to " - f"be encoded with values < {self.max_bins} but the " - "largest value for the encoded categories is " - f"{categories.max()}." + if not requires_encoding: + is_numerical = categories.dtype.kind in {"i", "u", "f"} + requires_encoding = is_numerical and ( + categories.size > self.max_bins + or (categories >= self.max_bins).any() ) else: categories = None known_categories.append(categories) - return is_categorical, known_categories + return is_categorical, known_categories, requires_encoding def _check_interaction_cst(self, n_features): """Check and validation for interaction constraints.""" @@ -365,8 +416,8 @@ def fit(self, X, y, sample_weight=None): acc_compute_hist_time = 0.0 # time spent computing histograms # time spent predicting X for gradient and hessians update acc_prediction_time = 0.0 - X, y = self._validate_data(X, y, dtype=[X_DTYPE], force_all_finite=False) - y = self._encode_y(y) + X, known_categories = self._check_X(X, reset=True) + y = self._encode_y(_check_y(y, estimator=self)) check_consistent_length(X, y) # Do not create unit sample weights by default to later skip some # computation @@ -391,8 +442,6 @@ def fit(self, X, y, sample_weight=None): # used for validation in predict n_samples, self._n_features = X.shape - self.is_categorical_, known_categories = self._check_categories(X) - # Encode constraints into a list of sets of features indices (integers). interaction_cst = self._check_interaction_cst(self._n_features) @@ -473,7 +522,7 @@ def fit(self, X, y, sample_weight=None): n_bins = self.max_bins + 1 # + 1 for missing values self._bin_mapper = _BinMapper( n_bins=n_bins, - is_categorical=self.is_categorical_, + is_categorical=self._is_categorical_remapped, known_categories=known_categories, random_state=self._random_seed, n_threads=n_threads, @@ -680,7 +729,7 @@ def fit(self, X, y, sample_weight=None): n_bins=n_bins, n_bins_non_missing=self._bin_mapper.n_bins_non_missing_, has_missing_values=has_missing_values, - is_categorical=self.is_categorical_, + is_categorical=self._is_categorical_remapped, monotonic_cst=monotonic_cst, interaction_cst=interaction_cst, max_leaf_nodes=self.max_leaf_nodes, @@ -1023,12 +1072,10 @@ def _raw_predict(self, X, n_threads=None): raw_predictions : array, shape (n_samples, n_trees_per_iteration) The raw predicted values. """ + check_is_fitted(self) is_binned = getattr(self, "_in_fit", False) if not is_binned: - X = self._validate_data( - X, dtype=X_DTYPE, force_all_finite=False, reset=False - ) - check_is_fitted(self) + X = self._check_X(X, reset=False) if X.shape[1] != self._n_features: raise ValueError( "X has {} features but this estimator was trained with " @@ -1094,8 +1141,8 @@ def _staged_raw_predict(self, X): The raw predictions of the input samples. The order of the classes corresponds to that in the attribute :term:`classes_`. """ - X = self._validate_data(X, dtype=X_DTYPE, force_all_finite=False, reset=False) check_is_fitted(self) + X = self._check_X(X, reset=False) if X.shape[1] != self._n_features: raise ValueError( "X has {} features but this estimator was trained with " @@ -1268,9 +1315,10 @@ class HistGradientBoostingRegressor(RegressorMixin, BaseHistGradientBoosting): - str array-like: names of categorical features (assuming the training data has feature names). - For each categorical feature, there must be at most `max_bins` unique - categories, and each categorical value must be less then `max_bins - 1`. - Negative values for categorical features are treated as missing values. + For categories with cardinality higher than `max_bins`, the + infrequent categories are grouped together such there are only `max_bins` + categories. Negative values for categorical features are treated as + missing values. Read more in the :ref:`User Guide `. @@ -1279,6 +1327,9 @@ class HistGradientBoostingRegressor(RegressorMixin, BaseHistGradientBoosting): .. versionchanged:: 1.2 Added support for feature names. + .. versionchanged:: 1.3 + Support categories with cardinality higher than `max_bins`. + monotonic_cst : array-like of int of shape (n_features) or dict, default=None Monotonic constraint to enforce on each feature are specified using the following integer values: @@ -1625,9 +1676,10 @@ class HistGradientBoostingClassifier(ClassifierMixin, BaseHistGradientBoosting): - str array-like: names of categorical features (assuming the training data has feature names). - For each categorical feature, there must be at most `max_bins` unique - categories, and each categorical value must be less then `max_bins - 1`. - Negative values for categorical features are treated as missing values. + For categories with cardinality higher than `max_bins`, the + infrequent categories are grouped together such there are only `max_bins` + categories. Negative values for categorical features are treated as + missing values. Read more in the :ref:`User Guide `. @@ -1636,6 +1688,9 @@ class HistGradientBoostingClassifier(ClassifierMixin, BaseHistGradientBoosting): .. versionchanged:: 1.2 Added support for feature names. + .. versionchanged:: 1.3 + Support categories with cardinality higher than `max_bins`. + monotonic_cst : array-like of int of shape (n_features) or dict, default=None Monotonic constraint to enforce on each feature are specified using the following integer values: diff --git a/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py b/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py index 94d8960b6e813..2a249743b2171 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py +++ b/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py @@ -13,6 +13,7 @@ from sklearn.datasets import make_classification, make_regression from sklearn.datasets import make_low_rank_matrix from sklearn.preprocessing import KBinsDiscretizer, MinMaxScaler, OneHotEncoder +from sklearn.preprocessing import OrdinalEncoder from sklearn.model_selection import train_test_split, cross_val_score from sklearn.base import clone, BaseEstimator, TransformerMixin from sklearn.base import is_regressor @@ -1162,49 +1163,6 @@ def test_categorical_spec_no_categories(Est, categorical_features, as_array): assert est.is_categorical_ is None -@pytest.mark.parametrize( - "Est", (HistGradientBoostingClassifier, HistGradientBoostingRegressor) -) -@pytest.mark.parametrize( - "use_pandas, feature_name", [(False, "at index 0"), (True, "'f0'")] -) -def test_categorical_bad_encoding_errors(Est, use_pandas, feature_name): - # Test errors when categories are encoded incorrectly - - gb = Est(categorical_features=[True], max_bins=2) - - if use_pandas: - pd = pytest.importorskip("pandas") - X = pd.DataFrame({"f0": [0, 1, 2]}) - else: - X = np.array([[0, 1, 2]]).T - y = np.arange(3) - msg = ( - f"Categorical feature {feature_name} is expected to have a " - "cardinality <= 2 but actually has a cardinality of 3." - ) - with pytest.raises(ValueError, match=msg): - gb.fit(X, y) - - if use_pandas: - X = pd.DataFrame({"f0": [0, 2]}) - else: - X = np.array([[0, 2]]).T - y = np.arange(2) - msg = ( - f"Categorical feature {feature_name} is expected to be encoded " - "with values < 2 but the largest value for the encoded categories " - "is 2.0." - ) - with pytest.raises(ValueError, match=msg): - gb.fit(X, y) - - # nans are ignored in the counts - X = np.array([[0, 1, np.nan]]).T - y = np.arange(3) - gb.fit(X, y) - - @pytest.mark.parametrize( "Est", (HistGradientBoostingClassifier, HistGradientBoostingRegressor) ) @@ -1388,3 +1346,131 @@ def test_unknown_category_that_are_negative(): X_test_nan = np.asarray([[1, np.nan], [3, np.nan]]) assert_allclose(hist.predict(X_test_neg), hist.predict(X_test_nan)) + + +@pytest.mark.parametrize( + "Hist", [HistGradientBoostingClassifier, HistGradientBoostingRegressor] +) +def test_categorical_cardinality_higher_than_n_bins(Hist): + """Check categorical works when the cardinality is greater than max_bins.""" + + rng = np.random.RandomState(42) + n_samples = 5_000 + n_cardinality = 100 + max_bins = 64 + f_num = rng.rand(n_samples) + f_cat = rng.randint(n_cardinality, size=n_samples) + # f_cat is an informative feature + y = f_cat % 3 == 0 + categorical_features = np.asarray([False, True]) + + X = np.c_[f_num, f_cat] + + X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) + + hist_kwargs = dict(max_iter=10, max_bins=max_bins, random_state=0) + hist_native = Hist(categorical_features=categorical_features, **hist_kwargs) + hist_native.fit(X_train, y_train) + + # Use a preprocessor with an ordinal encoder should that gives the same model + column_transformer = make_column_transformer( + ("passthrough", ~categorical_features), + ( + OrdinalEncoder( + handle_unknown="use_encoded_value", + unknown_value=np.nan, + encoded_missing_value=np.nan, + max_categories=max_bins, + dtype=np.float64, + ), + categorical_features, + ), + ) + hist_with_prep = make_pipeline( + column_transformer, + Hist(categorical_features=categorical_features, **hist_kwargs), + ) + hist_with_prep.fit(X_train, y_train) + + assert len(hist_native._predictors) == len(hist_with_prep[-1]._predictors) + for predictor_1, predictor_2 in zip( + hist_native._predictors, hist_with_prep[-1]._predictors + ): + assert len(predictor_1[0].nodes) == len(predictor_2[0].nodes) + + score_native = hist_native.score(X_test, y_test) + score_with_prep = hist_with_prep.score(X_test, y_test) + assert score_with_prep == pytest.approx(score_native) + + +@pytest.mark.parametrize( + "Hist", [HistGradientBoostingClassifier, HistGradientBoostingRegressor] +) +def test_categorical_encoding_higher_than_n_bins(Hist): + """Check that categorical encoding can be greater than n_bins.""" + + rng = np.random.RandomState(42) + n_samples = 5_000 + n_cardinality = 4 + max_bins = 10 + f_num = rng.rand(n_samples) + f_cat = rng.randint(n_cardinality, size=n_samples) + # f_cat is an informative feature + y = f_cat % 3 == 0 + X1 = np.c_[f_num, f_cat] + categorical_features = [False, True] + + # Categorical feature above max_bins + f_cat_ = f_cat.copy() + f_cat_[f_cat_ == 3] = max_bins + 1 + X2 = np.c_[f_num, f_cat_] + + X1_train, X1_test, X2_train, X2_test, y_train, y_test = train_test_split( + X1, X2, y, random_state=0 + ) + + hist_kwargs = dict(max_iter=10, max_bins=max_bins, random_state=0) + hist_in_bounds = Hist(categorical_features=categorical_features, **hist_kwargs) + hist_in_bounds.fit(X1_train, y_train) + score_in_bounds = hist_in_bounds.score(X1_test, y_test) + + hist_out_of_bounds = Hist(categorical_features=categorical_features, **hist_kwargs) + hist_out_of_bounds.fit(X2_train, y_train) + score_out_of_bounds = hist_out_of_bounds.score(X2_test, y_test) + + assert len(hist_in_bounds._predictors) == len(hist_out_of_bounds._predictors) + for predictor_1, predictor_2 in zip( + hist_in_bounds._predictors, hist_out_of_bounds._predictors + ): + assert len(predictor_1[0].nodes) == len(predictor_2[0].nodes) + + assert score_in_bounds == pytest.approx(score_out_of_bounds) + + +def test_categorical_category_first(): + """Check that categorical features gives correct result as the first feature.""" + rng = np.random.RandomState(42) + n_samples = 5_000 + n_cardinality = 12 + max_bins = 10 + f_num = rng.rand(n_samples) + f_cat = rng.randint(n_cardinality, size=n_samples) + + # f_cat is an informative feature + y = f_cat % 3 == 0 + X = np.c_[f_cat, f_num] + categorical_features = [True, False] + + X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) + + hist_kwargs = dict(max_iter=20, max_bins=max_bins, random_state=0) + # Without categorical features we get lower performance + hist_no_cat = HistGradientBoostingRegressor(**hist_kwargs) + hist_no_cat.fit(X_train, y_train) + assert hist_no_cat.score(X_test, y_test) <= 0.65 + + hist_with_cat = HistGradientBoostingRegressor( + categorical_features=categorical_features, **hist_kwargs + ) + hist_with_cat.fit(X_train, y_train) + assert hist_with_cat.score(X_test, y_test) >= 0.95 From f60c263d8b81bd048201941b75df26c150b38624 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Sun, 23 Apr 2023 20:43:04 -0400 Subject: [PATCH 02/13] DOC Adds pr number --- doc/whats_new/v1.3.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/whats_new/v1.3.rst b/doc/whats_new/v1.3.rst index d55dc57fbcb89..fadad395437d4 100644 --- a/doc/whats_new/v1.3.rst +++ b/doc/whats_new/v1.3.rst @@ -260,7 +260,7 @@ Changelog cardinality greater than `max_bins` or encoded with values greater than `max_bins`. For categories with cardinality higher than `max_bins`, the infrequent categories are grouped together such there are only `max_bins` - categories. :pr:`xxxxx` by `Thomas Fan`_. + categories. :pr:`26268` by `Thomas Fan`_. - |Efficiency| :class:`ensemble.IsolationForest` predict time is now faster (typically by a factor of 8 or more). Internally, the estimator now precomputes From f01f2a5649857f38b0ffc942688e9ff19dad4f42 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Tue, 25 Apr 2023 21:16:27 -0400 Subject: [PATCH 03/13] Apply suggestions from code review Co-authored-by: Olivier Grisel Co-authored-by: Christian Lorentzen --- .../gradient_boosting.py | 17 +++++++++-------- .../tests/test_gradient_boosting.py | 2 +- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py index 4792f3017a1d6..d62d0ce007840 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py +++ b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py @@ -224,10 +224,10 @@ def _check_X(self, X, *, reset): self._is_categorical_remapped = categorical_remapped - # OrdinalEncoder will map categories to [0,..., cardinality - 1] - # If categories are not grouped into infrequent categories, then OrdinalEncoder + # If the cardinality is lower than max_bins then OrdinalEncoder # will map categories to [0, ..., cardinality - 1] - # If there are infrequent categories, then OrdinalEncoder will map categories to + # Otherwise, the most infrequent categories are binned together by + # OrdinalEncoder such that all values are mapped to an index in: # [0, ..., max_bins - 1]. renamed_categories = [ np.arange(min(len(c), self.max_bins), dtype=X_DTYPE) for c in categories_ @@ -1328,7 +1328,8 @@ class HistGradientBoostingRegressor(RegressorMixin, BaseHistGradientBoosting): Added support for feature names. .. versionchanged:: 1.3 - Support categories with cardinality higher than `max_bins`. + Support categories with cardinality higher than `max_bins` by + collapsing the most infrequent categories in a dedicated bin. monotonic_cst : array-like of int of shape (n_features) or dict, default=None Monotonic constraint to enforce on each feature are specified using the @@ -1676,10 +1677,10 @@ class HistGradientBoostingClassifier(ClassifierMixin, BaseHistGradientBoosting): - str array-like: names of categorical features (assuming the training data has feature names). - For categories with cardinality higher than `max_bins`, the - infrequent categories are grouped together such there are only `max_bins` - categories. Negative values for categorical features are treated as - missing values. + For categories with cardinality higher than `max_bins`, the most + infrequent categories are grouped together such that there are only + `max_bins` categories. Negative values for categorical features are + treated as missing values. Read more in the :ref:`User Guide `. diff --git a/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py b/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py index 2a249743b2171..6739cc84acbe0 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py +++ b/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py @@ -1362,7 +1362,7 @@ def test_categorical_cardinality_higher_than_n_bins(Hist): f_cat = rng.randint(n_cardinality, size=n_samples) # f_cat is an informative feature y = f_cat % 3 == 0 - categorical_features = np.asarray([False, True]) + categorical_features = np.array([False, True]) X = np.c_[f_num, f_cat] From afacb6001ec170d05edd981aad1d4221a62549a9 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Tue, 25 Apr 2023 21:23:47 -0400 Subject: [PATCH 04/13] CLN Address comments --- .../gradient_boosting.py | 106 +++++++++++++++--- .../tests/test_gradient_boosting.py | 58 ++++------ 2 files changed, 110 insertions(+), 54 deletions(-) diff --git a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py index d62d0ce007840..b1a44889658d9 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py +++ b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py @@ -20,7 +20,6 @@ from ...base import BaseEstimator, RegressorMixin, ClassifierMixin, is_classifier from ...compose import ColumnTransformer from ...preprocessing import OrdinalEncoder -from ...preprocessing import FunctionTransformer from ...utils import check_random_state, resample, compute_sample_weight from ...utils.validation import ( check_is_fitted, @@ -111,6 +110,7 @@ class BaseHistGradientBoosting(BaseEstimator, ABC): "tol": [Interval(Real, 0, None, closed="left")], "max_bins": [Interval(Integral, 2, 255, closed="both")], "categorical_features": ["array-like", None], + "on_high_cardinality_categories": [StrOptions({"error", "bin_least_frequent"})], "warm_start": ["boolean"], "early_stopping": [StrOptions({"auto"}), "boolean"], "scoring": [str, callable, None], @@ -131,6 +131,7 @@ def __init__( l2_regularization, max_bins, categorical_features, + on_high_cardinality_categories, monotonic_cst, interaction_cst, warm_start, @@ -153,6 +154,7 @@ def __init__( self.monotonic_cst = monotonic_cst self.interaction_cst = interaction_cst self.categorical_features = categorical_features + self.on_high_cardinality_categories = on_high_cardinality_categories self.warm_start = warm_start self.early_stopping = early_stopping self.scoring = scoring @@ -180,22 +182,42 @@ class weights. """ return sample_weight - def _check_X(self, X, *, reset): + def _preprocess_X(self, X, *, reset): + """Preprocess and validate X. + + Parameters + ---------- + X : {array-like, pandas DataFrame} of shape (n_samples, n_features) + Input data. + + reset : bool + Whether to reset the `n_features_in_` and `feature_names_in_ attributes. + + Returns + ------- + X : ndarray of shape (n_samples, n_features) + Validated input data. + + known_categories : list of ndarray of shape (n_categories,) + List of known categories for each categorical feature. + """ X = self._validate_data(X, dtype=[X_DTYPE], force_all_finite=False, reset=reset) if not reset: + if self._preprocessor is None: + return X return self._preprocessor.transform(X) self.is_categorical_, known_categories, requires_encoder = ( self._check_categories(X) ) - n_features = X.shape[1] - if not requires_encoder: - self._preprocessor = FunctionTransformer().set_output(transform="default") + self._preprocessor = None self._is_categorical_remapped = self.is_categorical_ return X, known_categories + n_features = X.shape[1] + # Create categories to pass into ordinal_encoder based on known_categories categories_ = [c for c in known_categories if c is not None] @@ -210,17 +232,18 @@ def _check_X(self, X, *, reset): self._preprocessor = ColumnTransformer( [ - ("numerical", "passthrough", ~self.is_categorical_), ("encoder", ordinal_encoder, self.is_categorical_), + ("numerical", "passthrough", ~self.is_categorical_), ] ) self._preprocessor.set_output(transform="default") X = self._preprocessor.fit_transform(X) - # Column Transformer places the categorical features at the end. + # The ColumnTransformer's output places the categorical features at the + # beginning categorical_remapped = np.zeros(n_features, dtype=bool) n_categorical = self.is_categorical_.sum() - categorical_remapped[-n_categorical:] = True + categorical_remapped[:n_categorical] = True self._is_categorical_remapped = categorical_remapped @@ -234,7 +257,7 @@ def _check_X(self, X, *, reset): ] numerical_features = n_features - n_categorical - known_categories = [None] * numerical_features + renamed_categories + known_categories = renamed_categories + [None] * numerical_features return X, known_categories def _check_categories(self, X): @@ -338,9 +361,33 @@ def _check_categories(self, X): if negative_categories.any(): categories = categories[~negative_categories] - if not requires_encoding: - is_numerical = categories.dtype.kind in {"i", "u", "f"} - requires_encoding = is_numerical and ( + if self.on_high_cardinality_categories == "error": + if hasattr(self, "feature_names_in_"): + feature_name = f"'{self.feature_names_in_[f_idx]}'" + else: + feature_name = f"at index {f_idx}" + + if categories.size > self.max_bins: + raise ValueError( + f"Categorical feature {feature_name} is expected to have a" + f" cardinality <= {self.max_bins} but actually has a" + f" cardinality of {categories.size}. Consider using" + " `on_high_cardinality_categories=`bin_least_frequent`," + " preprocess the feature using TargetEncoder, or expanding" + " the feature into many low cardinality categorical" + " features." + ) + if (categories >= self.max_bins).any(): + raise ValueError( + f"Categorical feature {feature_name} is expected to be" + f" encoded with values < {self.max_bins} but the largest" + f" value for the encoded categories is {categories.max()}." + " Consider using" + " `on_high_cardinality_categories=`bin_least_frequent` or" + " preprocess the data." + ) + elif not requires_encoding: + requires_encoding = ( categories.size > self.max_bins or (categories >= self.max_bins).any() ) @@ -416,8 +463,9 @@ def fit(self, X, y, sample_weight=None): acc_compute_hist_time = 0.0 # time spent computing histograms # time spent predicting X for gradient and hessians update acc_prediction_time = 0.0 - X, known_categories = self._check_X(X, reset=True) - y = self._encode_y(_check_y(y, estimator=self)) + X, known_categories = self._preprocess_X(X, reset=True) + y = _check_y(y, estimator=self) + y = self._encode_y(y) check_consistent_length(X, y) # Do not create unit sample weights by default to later skip some # computation @@ -1075,7 +1123,7 @@ def _raw_predict(self, X, n_threads=None): check_is_fitted(self) is_binned = getattr(self, "_in_fit", False) if not is_binned: - X = self._check_X(X, reset=False) + X = self._preprocess_X(X, reset=False) if X.shape[1] != self._n_features: raise ValueError( "X has {} features but this estimator was trained with " @@ -1142,7 +1190,7 @@ def _staged_raw_predict(self, X): classes corresponds to that in the attribute :term:`classes_`. """ check_is_fitted(self) - X = self._check_X(X, reset=False) + X = self._preprocess_X(X, reset=False) if X.shape[1] != self._n_features: raise ValueError( "X has {} features but this estimator was trained with " @@ -1331,6 +1379,17 @@ class HistGradientBoostingRegressor(RegressorMixin, BaseHistGradientBoosting): Support categories with cardinality higher than `max_bins` by collapsing the most infrequent categories in a dedicated bin. + on_high_cardinality_categories : {"error", "bin_least_frequent"}, default="error" + Whether to raise an error or to bin together the least frequent categorical + features. + + - `"error"`: Raises an error when the cardinality of a categorical feature + is higher than `max_bins` or is encoded with a value greater than `max_bins`. + - `"bin_least_frequent"`: Bins the least frequent categorical features + such that there is no more than `max_bins` categories. + + .. versionadded:: 1.3 + monotonic_cst : array-like of int of shape (n_features) or dict, default=None Monotonic constraint to enforce on each feature are specified using the following integer values: @@ -1503,6 +1562,7 @@ def __init__( l2_regularization=0.0, max_bins=255, categorical_features=None, + on_high_cardinality_categories="error", monotonic_cst=None, interaction_cst=None, warm_start=False, @@ -1526,6 +1586,7 @@ def __init__( monotonic_cst=monotonic_cst, interaction_cst=interaction_cst, categorical_features=categorical_features, + on_high_cardinality_categories=on_high_cardinality_categories, early_stopping=early_stopping, warm_start=warm_start, scoring=scoring, @@ -1692,6 +1753,17 @@ class HistGradientBoostingClassifier(ClassifierMixin, BaseHistGradientBoosting): .. versionchanged:: 1.3 Support categories with cardinality higher than `max_bins`. + on_high_cardinality_categories : {"error", "bin_least_frequent"}, default="error" + Whether to raise an error or to bin together the least frequent categorical + features. + + - `"error"`: Raises an error when the cardinality of a categorical feature + is higher than `max_bins` or is encoded with a value greater than `max_bins`. + - `"bin_least_frequent"`: Bins the least frequent categorical features + such that there is no more than `max_bins` categories. + + .. versionadded:: 1.3 + monotonic_cst : array-like of int of shape (n_features) or dict, default=None Monotonic constraint to enforce on each feature are specified using the following integer values: @@ -1864,6 +1936,7 @@ def __init__( l2_regularization=0.0, max_bins=255, categorical_features=None, + on_high_cardinality_categories="error", monotonic_cst=None, interaction_cst=None, warm_start=False, @@ -1886,6 +1959,7 @@ def __init__( l2_regularization=l2_regularization, max_bins=max_bins, categorical_features=categorical_features, + on_high_cardinality_categories=on_high_cardinality_categories, monotonic_cst=monotonic_cst, interaction_cst=interaction_cst, warm_start=warm_start, diff --git a/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py b/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py index 6739cc84acbe0..f7645b1aaf6e5 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py +++ b/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py @@ -1362,19 +1362,23 @@ def test_categorical_cardinality_higher_than_n_bins(Hist): f_cat = rng.randint(n_cardinality, size=n_samples) # f_cat is an informative feature y = f_cat % 3 == 0 - categorical_features = np.array([False, True]) + categorical_features = np.array([True, False]) - X = np.c_[f_num, f_cat] + X = np.c_[f_cat, f_num] X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) hist_kwargs = dict(max_iter=10, max_bins=max_bins, random_state=0) - hist_native = Hist(categorical_features=categorical_features, **hist_kwargs) + hist_native = Hist( + categorical_features=categorical_features, + on_high_cardinality_categories="bin_least_frequent", + **hist_kwargs, + ) hist_native.fit(X_train, y_train) - # Use a preprocessor with an ordinal encoder should that gives the same model + # Using a preprocessor with max_categories=max_bins should give the same model + # as the native implementation which uses the same preprocessing strategy column_transformer = make_column_transformer( - ("passthrough", ~categorical_features), ( OrdinalEncoder( handle_unknown="use_encoded_value", @@ -1385,6 +1389,7 @@ def test_categorical_cardinality_higher_than_n_bins(Hist): ), categorical_features, ), + ("passthrough", ~categorical_features), ) hist_with_prep = make_pipeline( column_transformer, @@ -1402,6 +1407,8 @@ def test_categorical_cardinality_higher_than_n_bins(Hist): score_with_prep = hist_with_prep.score(X_test, y_test) assert score_with_prep == pytest.approx(score_native) + assert_allclose(hist_native.predict(X_test), hist_with_prep.predict(X_test)) + @pytest.mark.parametrize( "Hist", [HistGradientBoostingClassifier, HistGradientBoostingRegressor] @@ -1417,13 +1424,13 @@ def test_categorical_encoding_higher_than_n_bins(Hist): f_cat = rng.randint(n_cardinality, size=n_samples) # f_cat is an informative feature y = f_cat % 3 == 0 - X1 = np.c_[f_num, f_cat] - categorical_features = [False, True] + X1 = np.c_[f_cat, f_num] + categorical_features = [True, False] # Categorical feature above max_bins f_cat_ = f_cat.copy() f_cat_[f_cat_ == 3] = max_bins + 1 - X2 = np.c_[f_num, f_cat_] + X2 = np.c_[f_cat_, f_num] X1_train, X1_test, X2_train, X2_test, y_train, y_test = train_test_split( X1, X2, y, random_state=0 @@ -1434,7 +1441,11 @@ def test_categorical_encoding_higher_than_n_bins(Hist): hist_in_bounds.fit(X1_train, y_train) score_in_bounds = hist_in_bounds.score(X1_test, y_test) - hist_out_of_bounds = Hist(categorical_features=categorical_features, **hist_kwargs) + hist_out_of_bounds = Hist( + categorical_features=categorical_features, + on_high_cardinality_categories="bin_least_frequent", + **hist_kwargs, + ) hist_out_of_bounds.fit(X2_train, y_train) score_out_of_bounds = hist_out_of_bounds.score(X2_test, y_test) @@ -1445,32 +1456,3 @@ def test_categorical_encoding_higher_than_n_bins(Hist): assert len(predictor_1[0].nodes) == len(predictor_2[0].nodes) assert score_in_bounds == pytest.approx(score_out_of_bounds) - - -def test_categorical_category_first(): - """Check that categorical features gives correct result as the first feature.""" - rng = np.random.RandomState(42) - n_samples = 5_000 - n_cardinality = 12 - max_bins = 10 - f_num = rng.rand(n_samples) - f_cat = rng.randint(n_cardinality, size=n_samples) - - # f_cat is an informative feature - y = f_cat % 3 == 0 - X = np.c_[f_cat, f_num] - categorical_features = [True, False] - - X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) - - hist_kwargs = dict(max_iter=20, max_bins=max_bins, random_state=0) - # Without categorical features we get lower performance - hist_no_cat = HistGradientBoostingRegressor(**hist_kwargs) - hist_no_cat.fit(X_train, y_train) - assert hist_no_cat.score(X_test, y_test) <= 0.65 - - hist_with_cat = HistGradientBoostingRegressor( - categorical_features=categorical_features, **hist_kwargs - ) - hist_with_cat.fit(X_train, y_train) - assert hist_with_cat.score(X_test, y_test) >= 0.95 From c1ec3be7419a48e759d1127bf8ba3748f2f60492 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Tue, 25 Apr 2023 21:50:29 -0400 Subject: [PATCH 05/13] DOC Adjust docs --- .../_hist_gradient_boosting/gradient_boosting.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py index b1a44889658d9..4d3d6f4013a91 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py +++ b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py @@ -1363,10 +1363,7 @@ class HistGradientBoostingRegressor(RegressorMixin, BaseHistGradientBoosting): - str array-like: names of categorical features (assuming the training data has feature names). - For categories with cardinality higher than `max_bins`, the - infrequent categories are grouped together such there are only `max_bins` - categories. Negative values for categorical features are treated as - missing values. + Negative values for categorical features are treated as missing values. Read more in the :ref:`User Guide `. @@ -1738,10 +1735,7 @@ class HistGradientBoostingClassifier(ClassifierMixin, BaseHistGradientBoosting): - str array-like: names of categorical features (assuming the training data has feature names). - For categories with cardinality higher than `max_bins`, the most - infrequent categories are grouped together such that there are only - `max_bins` categories. Negative values for categorical features are - treated as missing values. + Negative values for categorical features are treated as missing values. Read more in the :ref:`User Guide `. From 85295ed90a63db24e68b3fda05c0f717388f70b8 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Wed, 26 Apr 2023 08:34:36 -0400 Subject: [PATCH 06/13] TST Improves coverage --- .../gradient_boosting.py | 6 ++--- .../tests/test_gradient_boosting.py | 22 +++++++++++++++++++ 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py index 4d3d6f4013a91..96813433a058e 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py +++ b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py @@ -256,8 +256,8 @@ def _preprocess_X(self, X, *, reset): np.arange(min(len(c), self.max_bins), dtype=X_DTYPE) for c in categories_ ] - numerical_features = n_features - n_categorical - known_categories = renamed_categories + [None] * numerical_features + n_numerical = n_features - n_categorical + known_categories = renamed_categories + [None] * n_numerical return X, known_categories def _check_categories(self, X): @@ -369,7 +369,7 @@ def _check_categories(self, X): if categories.size > self.max_bins: raise ValueError( - f"Categorical feature {feature_name} is expected to have a" + f"Categorical feature {feature_name} is expected to have" f" cardinality <= {self.max_bins} but actually has a" f" cardinality of {categories.size}. Consider using" " `on_high_cardinality_categories=`bin_least_frequent`," diff --git a/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py b/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py index f7645b1aaf6e5..e916feae550f6 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py +++ b/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py @@ -1456,3 +1456,25 @@ def test_categorical_encoding_higher_than_n_bins(Hist): assert len(predictor_1[0].nodes) == len(predictor_2[0].nodes) assert score_in_bounds == pytest.approx(score_out_of_bounds) + + +@pytest.mark.parametrize( + "Hist", [HistGradientBoostingClassifier, HistGradientBoostingRegressor] +) +def test_categorical_errors(Hist): + """Check errors are raised for invalid categorical features.""" + max_bins = 5 + X = np.array([[max_bins + 1, 0, 2, 3, 1, 2, 0]]).T + y = [0, 1, 0, 1, 0, 1, 0] + + hist = Hist(max_bins=max_bins, random_state=0, categorical_features=[0]) + + msg = "Categorical feature at index 0 is expected to be encoded with values < 5" + with pytest.raises(ValueError, match=msg): + hist.fit(X, y) + + msg = "Categorical feature at index 0 is expected to have cardinality <= 5" + X = np.array([[0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 5]]).T + y = [0] * 10 + [1] * 11 + with pytest.raises(ValueError, match=msg): + hist.fit(X, y) From 45377b4e46ccb6aa65ef79cd8040fe9ecb4140cc Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Wed, 26 Apr 2023 10:20:06 -0400 Subject: [PATCH 07/13] DOC Update whats new --- doc/whats_new/v1.3.rst | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/doc/whats_new/v1.3.rst b/doc/whats_new/v1.3.rst index fadad395437d4..25330bba965a3 100644 --- a/doc/whats_new/v1.3.rst +++ b/doc/whats_new/v1.3.rst @@ -256,11 +256,11 @@ Changelog :pr:`24882` by :user:`Ashwin Mathur `. - |Feature| :class:`ensemble.HistGradientBoostingClassifier` and - :class:`ensemble.HistGradientBoostingRegressor` supports categories with - cardinality greater than `max_bins` or encoded with values greater than - `max_bins`. For categories with cardinality higher than `max_bins`, the - infrequent categories are grouped together such there are only `max_bins` - categories. :pr:`26268` by `Thomas Fan`_. + :class:`ensemble.HistGradientBoostingRegressor` now has a + `on_high_cardinality_categories="bin_least_frequent"`. option to automatically + encode high cardinality categories. For categories with cardinality higher + than `max_bins`, the infrequent categories are grouped together such there are + only `max_bins` categories. :pr:`26268` by `Thomas Fan`_. - |Efficiency| :class:`ensemble.IsolationForest` predict time is now faster (typically by a factor of 8 or more). Internally, the estimator now precomputes From d5ef2e0920b0a421ac7757b399c754ecce0641ca Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Wed, 20 Dec 2023 09:48:14 -0500 Subject: [PATCH 08/13] FIX Fixes bugs from merge --- .../gradient_boosting.py | 25 ++++++++++++--- .../tests/test_gradient_boosting.py | 32 ++++++------------- 2 files changed, 30 insertions(+), 27 deletions(-) diff --git a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py index 7bd62c497e73e..432f65744dfd6 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py +++ b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py @@ -280,12 +280,18 @@ def _preprocess_X(self, X, *, reset): X = self._validate_data(X, **check_X_kwargs) return X, None + if self.on_high_cardinality_categories == "bin_least_frequent": + max_categories = self.max_bins + else: + max_categories = None + n_features = X.shape[1] ordinal_encoder = OrdinalEncoder( categories="auto", handle_unknown="use_encoded_value", unknown_value=np.nan, encoded_missing_value=np.nan, + max_categories=max_categories, dtype=X_DTYPE, ) @@ -336,15 +342,23 @@ def _check_categories(self): categorical_column_indices = np.arange(self._preprocessor.n_features_in_)[ self._preprocessor.output_indices_["encoder"] ] - for feature_idx, categories in zip( - categorical_column_indices, encoder.categories_ + bin_least_frequent = self.on_high_cardinality_categories == "bin_least_frequent" + try: + n_infrequent_categories = [ + 0 if cat is None else len(cat) for cat in encoder.infrequent_categories_ + ] + except AttributeError: + n_infrequent_categories = [0] * len(encoder.categories_) + + for feature_idx, categories, n_infrequent in zip( + categorical_column_indices, encoder.categories_, n_infrequent_categories ): # OrdinalEncoder always puts np.nan as the last category if the # training data has missing values. Here we remove it because it is # already added by the _BinMapper. if len(categories) and is_scalar_nan(categories[-1]): categories = categories[:-1] - if categories.size > self.max_bins: + if not bin_least_frequent and categories.size > self.max_bins: try: feature_name = repr(encoder.feature_names_in_[feature_idx]) except AttributeError: @@ -354,7 +368,10 @@ def _check_categories(self): f"have a cardinality <= {self.max_bins} but actually " f"has a cardinality of {categories.size}." ) - known_categories[feature_idx] = np.arange(len(categories), dtype=X_DTYPE) + + # infrequent categories are grouped into one category + total_categories = len(categories) - n_infrequent + 1 + known_categories[feature_idx] = np.arange(total_categories, dtype=X_DTYPE) return known_categories def _check_categorical_features(self, X): diff --git a/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py b/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py index 3e30e70df0867..6c82cbe543aa3 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py +++ b/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py @@ -1466,7 +1466,7 @@ def test_categorical_cardinality_higher_than_n_bins(Hist): X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) - hist_kwargs = dict(max_iter=10, max_bins=max_bins, random_state=0) + hist_kwargs = dict(max_iter=1, max_bins=max_bins, random_state=0) hist_native = Hist( categorical_features=categorical_features, on_high_cardinality_categories="bin_least_frequent", @@ -1479,6 +1479,7 @@ def test_categorical_cardinality_higher_than_n_bins(Hist): column_transformer = make_column_transformer( ( OrdinalEncoder( + categories="auto", handle_unknown="use_encoded_value", unknown_value=np.nan, encoded_missing_value=np.nan, @@ -1495,6 +1496,13 @@ def test_categorical_cardinality_higher_than_n_bins(Hist): ) hist_with_prep.fit(X_train, y_train) + # Check that preprocessors returns the same transformed data + assert_allclose( + hist_native._preprocessor.transform(X_train), + hist_with_prep[0].transform(X_train), + ) + + # Check that the trees are the same and have the same performance assert len(hist_native._predictors) == len(hist_with_prep[-1]._predictors) for predictor_1, predictor_2 in zip( hist_native._predictors, hist_with_prep[-1]._predictors @@ -1556,28 +1564,6 @@ def test_categorical_encoding_higher_than_n_bins(Hist): assert score_in_bounds == pytest.approx(score_out_of_bounds) -@pytest.mark.parametrize( - "Hist", [HistGradientBoostingClassifier, HistGradientBoostingRegressor] -) -def test_categorical_errors(Hist): - """Check errors are raised for invalid categorical features.""" - max_bins = 5 - X = np.array([[max_bins + 1, 0, 2, 3, 1, 2, 0]]).T - y = [0, 1, 0, 1, 0, 1, 0] - - hist = Hist(max_bins=max_bins, random_state=0, categorical_features=[0]) - - msg = "Categorical feature at index 0 is expected to be encoded with values < 5" - with pytest.raises(ValueError, match=msg): - hist.fit(X, y) - - msg = "Categorical feature at index 0 is expected to have cardinality <= 5" - X = np.array([[0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 5]]).T - y = [0] * 10 + [1] * 11 - with pytest.raises(ValueError, match=msg): - hist.fit(X, y) - - @pytest.mark.parametrize("dataframe_lib", ["pandas", "polars"]) @pytest.mark.parametrize( "HistGradientBoosting", From 7f84de7a52ca14127f8aca8884094d3a5f11cd15 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Wed, 20 Dec 2023 12:52:26 -0500 Subject: [PATCH 09/13] REV Less diff --- .../ensemble/_hist_gradient_boosting/gradient_boosting.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py index 432f65744dfd6..6a3dccb3554d7 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py +++ b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py @@ -1282,6 +1282,7 @@ def _raw_predict(self, X, n_threads=None): is_binned = getattr(self, "_in_fit", False) if not is_binned: X = self._preprocess_X(X, reset=False) + n_samples = X.shape[0] raw_predictions = np.zeros( shape=(n_samples, self.n_trees_per_iteration_), @@ -1528,7 +1529,11 @@ class HistGradientBoostingRegressor(RegressorMixin, BaseHistGradientBoosting): exposing a ``__dataframe__`` method such as pandas or polars DataFrames to use this feature. - Negative values for categorical features are treated as missing values. + For each categorical feature, there must be at most `max_bins` unique + categories. Negative values for categorical features encoded as numeric + dtypes are treated as missing values. All categorical values are + converted to floating point numbers. This means that categorical values + of 1.0 and 1 are treated as the same category. Read more in the :ref:`User Guide `. From ac182b2fc63cf9ab2de6905c9aae45ed459ebb46 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Wed, 20 Dec 2023 12:55:36 -0500 Subject: [PATCH 10/13] Trigger CI From 547449f2beaa5bccec3c456bee15bcc550be5087 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Fri, 22 Dec 2023 12:07:06 -0500 Subject: [PATCH 11/13] API Change to bin_infrequent --- doc/whats_new/v1.3.rst | 2 +- .../_hist_gradient_boosting/gradient_boosting.py | 16 ++++++++-------- .../tests/test_gradient_boosting.py | 4 ++-- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/doc/whats_new/v1.3.rst b/doc/whats_new/v1.3.rst index 2a549534cbcfd..62a1c67b832ec 100644 --- a/doc/whats_new/v1.3.rst +++ b/doc/whats_new/v1.3.rst @@ -529,7 +529,7 @@ Changelog - |Feature| :class:`ensemble.HistGradientBoostingClassifier` and :class:`ensemble.HistGradientBoostingRegressor` now has a - `on_high_cardinality_categories="bin_least_frequent"`. option to automatically + `on_high_cardinality_categories="bin_infrequent"`. option to automatically encode high cardinality categories. For categories with cardinality higher than `max_bins`, the infrequent categories are grouped together such there are only `max_bins` categories. :pr:`26268` by `Thomas Fan`_. diff --git a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py index 6a3dccb3554d7..94f594b7f83c4 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py +++ b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py @@ -169,7 +169,7 @@ class BaseHistGradientBoosting(BaseEstimator, ABC): Hidden(StrOptions({"warn"})), None, ], - "on_high_cardinality_categories": [StrOptions({"error", "bin_least_frequent"})], + "on_high_cardinality_categories": [StrOptions({"error", "bin_infrequent"})], "warm_start": ["boolean"], "early_stopping": [StrOptions({"auto"}), "boolean"], "scoring": [str, callable, None], @@ -280,7 +280,7 @@ def _preprocess_X(self, X, *, reset): X = self._validate_data(X, **check_X_kwargs) return X, None - if self.on_high_cardinality_categories == "bin_least_frequent": + if self.on_high_cardinality_categories == "bin_infrequent": max_categories = self.max_bins else: max_categories = None @@ -342,7 +342,7 @@ def _check_categories(self): categorical_column_indices = np.arange(self._preprocessor.n_features_in_)[ self._preprocessor.output_indices_["encoder"] ] - bin_least_frequent = self.on_high_cardinality_categories == "bin_least_frequent" + bin_infrequent = self.on_high_cardinality_categories == "bin_infrequent" try: n_infrequent_categories = [ 0 if cat is None else len(cat) for cat in encoder.infrequent_categories_ @@ -358,7 +358,7 @@ def _check_categories(self): # already added by the _BinMapper. if len(categories) and is_scalar_nan(categories[-1]): categories = categories[:-1] - if not bin_least_frequent and categories.size > self.max_bins: + if not bin_infrequent and categories.size > self.max_bins: try: feature_name = repr(encoder.feature_names_in_[feature_idx]) except AttributeError: @@ -1546,13 +1546,13 @@ class HistGradientBoostingRegressor(RegressorMixin, BaseHistGradientBoosting): Added `"from_dtype"` option. The default will change to `"from_dtype"` in v1.6. - on_high_cardinality_categories : {"error", "bin_least_frequent"}, default="error" + on_high_cardinality_categories : {"error", "bin_infrequent"}, default="error" Whether to raise an error or to bin together the least frequent categorical features. - `"error"`: Raises an error when the cardinality of a categorical feature is higher than `max_bins` or is encoded with a value greater than `max_bins`. - - `"bin_least_frequent"`: Bins the least frequent categorical features + - `"bin_infrequent"`: Bins the least frequent categorical features such that there is no more than `max_bins` categories. .. versionadded:: 1.5 @@ -1935,13 +1935,13 @@ class HistGradientBoostingClassifier(ClassifierMixin, BaseHistGradientBoosting): Added `"from_dtype"` option. The default will change to `"from_dtype"` in v1.6. - on_high_cardinality_categories : {"error", "bin_least_frequent"}, default="error" + on_high_cardinality_categories : {"error", "bin_infrequent"}, default="error" Whether to raise an error or to bin together the least frequent categorical features. - `"error"`: Raises an error when the cardinality of a categorical feature is higher than `max_bins` or is encoded with a value greater than `max_bins`. - - `"bin_least_frequent"`: Bins the least frequent categorical features + - `"bin_infrequent"`: Bins the least frequent categorical features such that there is no more than `max_bins` categories. .. versionadded:: 1.5 diff --git a/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py b/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py index 6c82cbe543aa3..5d42fdbeef624 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py +++ b/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py @@ -1469,7 +1469,7 @@ def test_categorical_cardinality_higher_than_n_bins(Hist): hist_kwargs = dict(max_iter=1, max_bins=max_bins, random_state=0) hist_native = Hist( categorical_features=categorical_features, - on_high_cardinality_categories="bin_least_frequent", + on_high_cardinality_categories="bin_infrequent", **hist_kwargs, ) hist_native.fit(X_train, y_train) @@ -1549,7 +1549,7 @@ def test_categorical_encoding_higher_than_n_bins(Hist): hist_out_of_bounds = Hist( categorical_features=categorical_features, - on_high_cardinality_categories="bin_least_frequent", + on_high_cardinality_categories="bin_infrequent", **hist_kwargs, ) hist_out_of_bounds.fit(X2_train, y_train) From b1d1278dbcfde506542988e21b20f334d57eb27f Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Fri, 22 Dec 2023 12:07:43 -0500 Subject: [PATCH 12/13] DOC Move to 1.4 --- doc/whats_new/v1.3.rst | 7 ------- doc/whats_new/v1.4.rst | 7 +++++++ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/doc/whats_new/v1.3.rst b/doc/whats_new/v1.3.rst index 62a1c67b832ec..b711eadf572f5 100644 --- a/doc/whats_new/v1.3.rst +++ b/doc/whats_new/v1.3.rst @@ -527,13 +527,6 @@ Changelog out-of-bag scores via the `oob_scores_` or `oob_score_` attributes. :pr:`24882` by :user:`Ashwin Mathur `. -- |Feature| :class:`ensemble.HistGradientBoostingClassifier` and - :class:`ensemble.HistGradientBoostingRegressor` now has a - `on_high_cardinality_categories="bin_infrequent"`. option to automatically - encode high cardinality categories. For categories with cardinality higher - than `max_bins`, the infrequent categories are grouped together such there are - only `max_bins` categories. :pr:`26268` by `Thomas Fan`_. - - |Efficiency| :class:`ensemble.IsolationForest` predict time is now faster (typically by a factor of 8 or more). Internally, the estimator now precomputes decision path lengths per tree at `fit` time. It is therefore not possible diff --git a/doc/whats_new/v1.4.rst b/doc/whats_new/v1.4.rst index d2de5ee433f94..f48ec69283db3 100644 --- a/doc/whats_new/v1.4.rst +++ b/doc/whats_new/v1.4.rst @@ -408,6 +408,13 @@ Changelog in each split. :pr:`27139` by :user:`Christian Lorentzen `. +- |MajorFeature| :class:`ensemble.HistGradientBoostingClassifier` and + :class:`ensemble.HistGradientBoostingRegressor` now has a + `on_high_cardinality_categories="bin_infrequent"`. option to automatically + encode high cardinality categories. For categories with cardinality higher + than `max_bins`, the infrequent categories are grouped together such there are + only `max_bins` categories. :pr:`26268` by `Thomas Fan`_. + - |Feature| :class:`ensemble.RandomForestClassifier`, :class:`ensemble.RandomForestRegressor`, :class:`ensemble.ExtraTreesClassifier` and :class:`ensemble.ExtraTreesRegressor` now support monotonic constraints, From dc821aae2fd8ecf974df0ac05c1e4c3383221072 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Sun, 21 Jan 2024 15:44:16 -0500 Subject: [PATCH 13/13] DOC Move whats new to 1.5 --- doc/whats_new/v1.4.rst | 7 ------- doc/whats_new/v1.5.rst | 10 ++++++++++ 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/doc/whats_new/v1.4.rst b/doc/whats_new/v1.4.rst index a450d43abb0cc..d832e4b508359 100644 --- a/doc/whats_new/v1.4.rst +++ b/doc/whats_new/v1.4.rst @@ -457,13 +457,6 @@ Changelog in each split. :pr:`27139` by :user:`Christian Lorentzen `. -- |MajorFeature| :class:`ensemble.HistGradientBoostingClassifier` and - :class:`ensemble.HistGradientBoostingRegressor` now has a - `on_high_cardinality_categories="bin_infrequent"`. option to automatically - encode high cardinality categories. For categories with cardinality higher - than `max_bins`, the infrequent categories are grouped together such there are - only `max_bins` categories. :pr:`26268` by `Thomas Fan`_. - - |Feature| :class:`ensemble.RandomForestClassifier`, :class:`ensemble.RandomForestRegressor`, :class:`ensemble.ExtraTreesClassifier` and :class:`ensemble.ExtraTreesRegressor` now support monotonic constraints, diff --git a/doc/whats_new/v1.5.rst b/doc/whats_new/v1.5.rst index 159b8029c9137..2c27c59db0c2c 100644 --- a/doc/whats_new/v1.5.rst +++ b/doc/whats_new/v1.5.rst @@ -39,6 +39,16 @@ Changelog have the `n_features_in_` and `feature_names_in_` attributes after `fit`. :pr:`27937` by :user:`Marco vd Boom `. +:mod:`sklearn.ensemble` +....................... + +- |MajorFeature| :class:`ensemble.HistGradientBoostingClassifier` and + :class:`ensemble.HistGradientBoostingRegressor` now has a + `on_high_cardinality_categories="bin_infrequent"`. option to automatically + encode high cardinality categories. For categories with cardinality higher + than `max_bins`, the infrequent categories are grouped together such there are + only `max_bins` categories. :pr:`26268` by `Thomas Fan`_. + :mod:`sklearn.feature_extraction` .................................