Skip to content
Merged
5 changes: 5 additions & 0 deletions doc/whats_new/v1.2.rst
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,11 @@ Changelog
- |Efficiency| Improve runtime performance of :class:`ensemble.IsolationForest`
by avoiding data copies. :pr:`23252` by :user:`Zhehao Liu <MaxwellLZH>`.

- |Enhancement| Make it possible to pass the `categorical_features` parameter
of :class:`ensemble.HistGradientBoostingClassifier` and
:class:`ensemble.HistGradientBoostingRegressor` as feature names.
:pr:`24889` by :user:`Olivier Grisel <ogrisel>`.

- |Enhancement| :class:`ensemble.StackingClassifier` now supports
multilabel-indicator target
:pr:`24146` by :user:`Nicolas Peretti <nicoperetti>`,
Expand Down
10 changes: 7 additions & 3 deletions examples/applications/plot_cyclical_feature_engineering.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,11 +209,15 @@
("categorical", ordinal_encoder, categorical_columns),
],
remainder="passthrough",
# Use short feature names to make it easier to specify the categorical
# variables in the HistGradientBoostingRegressor in the next
# step of the pipeline.
verbose_feature_names_out=False,
),
HistGradientBoostingRegressor(
categorical_features=range(4),
categorical_features=categorical_columns,
),
)
).set_output(transform="pandas")

# %%
#
Expand Down Expand Up @@ -263,7 +267,7 @@ def evaluate(model, X, y, cv):
import numpy as np


one_hot_encoder = OneHotEncoder(handle_unknown="ignore", sparse=False)
one_hot_encoder = OneHotEncoder(handle_unknown="ignore", sparse_output=False)
alphas = np.logspace(-6, 6, 25)
naive_linear_pipeline = make_pipeline(
ColumnTransformer(
Expand Down
15 changes: 10 additions & 5 deletions examples/ensemble/plot_gradient_boosting_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@
X = X[categorical_columns_subset + numerical_columns_subset]
X[categorical_columns_subset] = X[categorical_columns_subset].astype("category")

n_categorical_features = X.select_dtypes(include="category").shape[1]
categorical_columns = X.select_dtypes(include="category").columns
n_categorical_features = len(categorical_columns)
n_numerical_features = X.select_dtypes(include="number").shape[1]

print(f"Number of samples: {X.shape[0]}")
Expand Down Expand Up @@ -96,7 +97,7 @@

one_hot_encoder = make_column_transformer(
(
OneHotEncoder(sparse=False, handle_unknown="ignore"),
OneHotEncoder(sparse_output=False, handle_unknown="ignore"),
make_column_selector(dtype_include="category"),
),
remainder="passthrough",
Expand All @@ -122,6 +123,10 @@
make_column_selector(dtype_include="category"),
),
remainder="passthrough",
# Use short feature names to make it easier to specify the categorical
# variables in the HistGradientBoostingRegressor in the next step
# of the pipeline.
verbose_feature_names_out=False,
)

hist_ordinal = make_pipeline(
Expand All @@ -146,13 +151,13 @@
# The ordinal encoder will first output the categorical features, and then the
# continuous (passed-through) features

categorical_mask = [True] * n_categorical_features + [False] * n_numerical_features
hist_native = make_pipeline(
ordinal_encoder,
HistGradientBoostingRegressor(
random_state=42, categorical_features=categorical_mask
random_state=42,
categorical_features=categorical_columns,
),
)
).set_output(transform="pandas")

# %%
# Model comparison
Expand Down
51 changes: 44 additions & 7 deletions sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,16 +193,43 @@ def _check_categories(self, X):
if categorical_features.size == 0:
return None, None

if categorical_features.dtype.kind not in ("i", "b"):
if categorical_features.dtype.kind not in ("i", "b", "U", "O"):
raise ValueError(
"categorical_features must be an array-like of "
"bools or array-like of ints."
"categorical_features must be an array-like of bool, int or "
f"str, got: {categorical_features.dtype.name}."
)

if categorical_features.dtype.kind == "O":
types = set(type(f) for f in categorical_features)
if types != {str}:
raise ValueError(
"categorical_features must be an array-like of bool, int or "
f"str, got: {', '.join(sorted(t.__name__ for t in types))}."
)

n_features = X.shape[1]

# check for categorical features as indices
if categorical_features.dtype.kind == "i":
if categorical_features.dtype.kind in ("U", "O"):
# check for feature names
if not hasattr(self, "feature_names_in_"):
raise ValueError(
"categorical_features should be passed as an array of "
"integers or as a boolean mask when the model is fitted "
"on data without feature names."
)
is_categorical = np.zeros(n_features, dtype=bool)
feature_names = self.feature_names_in_.tolist()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this conversion to a list necessary?

Copy link
Member Author

@ogrisel ogrisel Nov 13, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Arrays do not have the index method. Not sure how to implement this while staying in numpy and making it easy to raise the error message timely at the same time.

Copy link
Member Author

@ogrisel ogrisel Nov 13, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, the feature names list should never be to long (few hundred values) for HGBDT models in practice because those models tend to perform poorly when n_features >> n_samples.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation.

for feature_name in categorical_features:
try:
is_categorical[feature_names.index(feature_name)] = True
except ValueError as e:
raise ValueError(
f"categorical_features has a item value '{feature_name}' "
"which is not a valid feature name of the training "
f"data. Observed feature names: {feature_names}"
) from e
elif categorical_features.dtype.kind == "i":
# check for categorical features as indices
if (
np.max(categorical_features) >= n_features
or np.min(categorical_features) < 0
Expand Down Expand Up @@ -1209,14 +1236,16 @@ class HistGradientBoostingRegressor(RegressorMixin, BaseHistGradientBoosting):
Features with a small number of unique values may use less than
``max_bins`` bins. In addition to the ``max_bins`` bins, one more bin
is always reserved for missing values. Must be no larger than 255.
categorical_features : array-like of {bool, int} of shape (n_features) \
categorical_features : array-like of {bool, int, str} of shape (n_features) \
or shape (n_categorical_features,), default=None
Indicates the categorical features.

- None : no feature will be considered categorical.
- boolean array-like : boolean mask indicating categorical features.
- integer array-like : integer indices indicating categorical
features.
- str array-like: names of categorical features (assuming the training
data has feature names).

For each categorical feature, there must be at most `max_bins` unique
categories, and each categorical value must be in [0, max_bins -1].
Expand All @@ -1227,6 +1256,9 @@ class HistGradientBoostingRegressor(RegressorMixin, BaseHistGradientBoosting):

.. versionadded:: 0.24

.. versionchanged:: 1.2
Added support for feature names.

monotonic_cst : array-like of int of shape (n_features), default=None
Indicates the monotonic constraint to enforce on each feature.
- 1: monotonic increase
Expand Down Expand Up @@ -1541,14 +1573,16 @@ class HistGradientBoostingClassifier(ClassifierMixin, BaseHistGradientBoosting):
Features with a small number of unique values may use less than
``max_bins`` bins. In addition to the ``max_bins`` bins, one more bin
is always reserved for missing values. Must be no larger than 255.
categorical_features : array-like of {bool, int} of shape (n_features) \
categorical_features : array-like of {bool, int, str} of shape (n_features) \
or shape (n_categorical_features,), default=None
Indicates the categorical features.

- None : no feature will be considered categorical.
- boolean array-like : boolean mask indicating categorical features.
- integer array-like : integer indices indicating categorical
features.
- str array-like: names of categorical features (assuming the training
data has feature names).

For each categorical feature, there must be at most `max_bins` unique
categories, and each categorical value must be in [0, max_bins -1].
Expand All @@ -1559,6 +1593,9 @@ class HistGradientBoostingClassifier(ClassifierMixin, BaseHistGradientBoosting):

.. versionadded:: 0.24

.. versionchanged:: 1.2
Added support for feature names.

monotonic_cst : array-like of int of shape (n_features), default=None
Indicates the monotonic constraint to enforce on each feature.
- 1: monotonic increase
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import warnings

import re
import numpy as np
import pytest
from numpy.testing import assert_allclose, assert_array_equal
Expand Down Expand Up @@ -979,13 +980,26 @@ def test_categorical_encoding_strategies():
# influence predictions too much with max_iter = 1
assert 0.49 < y.mean() < 0.51

clf_cat = HistGradientBoostingClassifier(
max_iter=1, max_depth=1, categorical_features=[False, True]
)
native_cat_specs = [
[False, True],
[1],
]
try:
import pandas as pd

# Using native categorical encoding, we get perfect predictions with just
# one split
assert cross_val_score(clf_cat, X, y).mean() == 1
X = pd.DataFrame(X, columns=["f_0", "f_1"])
native_cat_specs.append(["f_1"])
except ImportError:
pass

for native_cat_spec in native_cat_specs:
clf_cat = HistGradientBoostingClassifier(
max_iter=1, max_depth=1, categorical_features=native_cat_spec
)

# Using native categorical encoding, we get perfect predictions with just
# one split
assert cross_val_score(clf_cat, X, y).mean() == 1

# quick sanity check for the bitset: 0, 2, 4 = 2**0 + 2**2 + 2**4 = 21
expected_left_bitset = [21, 0, 0, 0, 0, 0, 0, 0]
Expand Down Expand Up @@ -1022,24 +1036,36 @@ def test_categorical_encoding_strategies():
"categorical_features, monotonic_cst, expected_msg",
[
(
["hello", "world"],
[b"hello", b"world"],
None,
"categorical_features must be an array-like of bools or array-like of "
"ints.",
re.escape(
"categorical_features must be an array-like of bool, int or str, "
"got: bytes40."
),
),
(
np.array([b"hello", 1.3], dtype=object),
None,
re.escape(
"categorical_features must be an array-like of bool, int or str, "
"got: bytes, float."
),
),
(
[0, -1],
None,
(
r"categorical_features set as integer indices must be in "
r"\[0, n_features - 1\]"
re.escape(
"categorical_features set as integer indices must be in "
"[0, n_features - 1]"
),
),
(
[True, True, False, False, True],
None,
r"categorical_features set as a boolean mask must have shape "
r"\(n_features,\)",
re.escape(
"categorical_features set as a boolean mask must have shape "
"(n_features,)"
),
),
(
[True, True, False, False],
Expand All @@ -1063,6 +1089,39 @@ def test_categorical_spec_errors(
est.fit(X, y)


@pytest.mark.parametrize(
"Est", (HistGradientBoostingClassifier, HistGradientBoostingRegressor)
)
def test_categorical_spec_errors_with_feature_names(Est):
pd = pytest.importorskip("pandas")
n_samples = 10
X = pd.DataFrame(
{
"f0": range(n_samples),
"f1": range(n_samples),
"f2": [1.0] * n_samples,
}
)
y = [0, 1] * (n_samples // 2)

est = Est(categorical_features=["f0", "f1", "f3"])
expected_msg = re.escape(
"categorical_features has a item value 'f3' which is not a valid "
"feature name of the training data."
)
with pytest.raises(ValueError, match=expected_msg):
est.fit(X, y)

est = Est(categorical_features=["f0", "f1"])
expected_msg = re.escape(
"categorical_features should be passed as an array of integers or "
"as a boolean mask when the model is fitted on data without feature "
"names."
)
with pytest.raises(ValueError, match=expected_msg):
est.fit(X.to_numpy(), y)


@pytest.mark.parametrize(
"Est", (HistGradientBoostingClassifier, HistGradientBoostingRegressor)
)
Expand Down