diff --git a/doc/whats_new/v1.2.rst b/doc/whats_new/v1.2.rst index 7265d0e8529f0..101cda43c8747 100644 --- a/doc/whats_new/v1.2.rst +++ b/doc/whats_new/v1.2.rst @@ -310,6 +310,14 @@ Changelog :user:`Nestor Navarro `, :user:`Nati Tomattis `, and :user:`Vincent Maladiere `. +- |Enhancement| :class:`ensemble.HistGradientBoostingClassifier` and + :class:`ensemble.HistGradientBoostingClassifier` now accept their + `monotonic_cst` parameter to be passed as a dictionary in addition + to the previously supported array-like format. + Such dictionary have feature names as keys and one of `-1`, `0`, `1` + as value to specify monotonicity constraints for each feature. + :pr:`24855` by :user:`Olivier Grisel `. + - |Fix| Fixed the issue where :class:`ensemble.AdaBoostClassifier` outputs NaN in feature importance when fitted with very small sample weight. :pr:`20415` by :user:`Zhehao Liu `. diff --git a/examples/ensemble/plot_monotonic_constraints.py b/examples/ensemble/plot_monotonic_constraints.py index 8b0aff204a584..7e9e271256fa9 100644 --- a/examples/ensemble/plot_monotonic_constraints.py +++ b/examples/ensemble/plot_monotonic_constraints.py @@ -19,7 +19,7 @@ `_. """ - +# %% from sklearn.ensemble import HistGradientBoostingRegressor from sklearn.inspection import PartialDependenceDisplay import numpy as np @@ -28,7 +28,7 @@ rng = np.random.RandomState(0) -n_samples = 5000 +n_samples = 1000 f_0 = rng.rand(n_samples) f_1 = rng.rand(n_samples) X = np.c_[f_0, f_1] @@ -37,14 +37,24 @@ # y is positively correlated with f_0, and negatively correlated with f_1 y = 5 * f_0 + np.sin(10 * np.pi * f_0) - 5 * f_1 - np.cos(10 * np.pi * f_1) + noise -fig, ax = plt.subplots() +# %% +# Fit a first model on this dataset without any constraints. +gbdt_no_cst = HistGradientBoostingRegressor() +gbdt_no_cst.fit(X, y) + +# %% +# Fit a second model on this dataset with monotonic increase (1) +# and a monotonic decrease (-1) constraints, respectively. +gbdt_with_monotonic_cst = HistGradientBoostingRegressor(monotonic_cst=[1, -1]) +gbdt_with_monotonic_cst.fit(X, y) -# Without any constraint -gbdt = HistGradientBoostingRegressor() -gbdt.fit(X, y) + +# %% +# Let's display the partial dependence of the predictions on the two features. +fig, ax = plt.subplots() disp = PartialDependenceDisplay.from_estimator( - gbdt, + gbdt_no_cst, X, features=[0, 1], feature_names=( @@ -54,13 +64,8 @@ line_kw={"linewidth": 4, "label": "unconstrained", "color": "tab:blue"}, ax=ax, ) - -# With monotonic increase (1) and a monotonic decrease (-1) constraints, respectively. -gbdt = HistGradientBoostingRegressor(monotonic_cst=[1, -1]) -gbdt.fit(X, y) - PartialDependenceDisplay.from_estimator( - gbdt, + gbdt_with_monotonic_cst, X, features=[0, 1], line_kw={"linewidth": 4, "label": "constrained", "color": "tab:orange"}, @@ -75,5 +80,29 @@ plt.legend() fig.suptitle("Monotonic constraints effect on partial dependences") - plt.show() + +# %% +# We can see that the predictions of the unconstrained model capture the +# oscillations of the data while the constrained model follows the general +# trend and ignores the local variations. + +# %% +# .. _monotonic_cst_features_names: +# +# Using feature names to specify monotonic constraints +# ---------------------------------------------------- +# +# Note that if the training data has feature names, it's possible to specifiy the +# monotonic constraints by passing a dictionary: +import pandas as pd + +X_df = pd.DataFrame(X, columns=["f_0", "f_1"]) + +gbdt_with_monotonic_cst_df = HistGradientBoostingRegressor( + monotonic_cst={"f_0": 1, "f_1": -1} +).fit(X_df, y) + +np.allclose( + gbdt_with_monotonic_cst_df.predict(X_df), gbdt_with_monotonic_cst.predict(X) +) diff --git a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py index 4f5419dccd8cb..bdf9ea414c210 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py +++ b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py @@ -23,6 +23,7 @@ check_is_fitted, check_consistent_length, _check_sample_weight, + _check_monotonic_cst, ) from ...utils._param_validation import Interval, StrOptions from ...utils._openmp_helpers import _openmp_effective_n_threads @@ -91,7 +92,7 @@ class BaseHistGradientBoosting(BaseEstimator, ABC): "max_depth": [Interval(Integral, 1, None, closed="left"), None], "min_samples_leaf": [Interval(Integral, 1, None, closed="left")], "l2_regularization": [Interval(Real, 0, None, closed="left")], - "monotonic_cst": ["array-like", None], + "monotonic_cst": ["array-like", dict, None], "interaction_cst": [Iterable, None], "n_iter_no_change": [Interval(Integral, 1, None, closed="left")], "validation_fraction": [ @@ -369,6 +370,7 @@ def fit(self, X, y, sample_weight=None): self._random_seed = rng.randint(np.iinfo(np.uint32).max, dtype="u8") self._validate_parameters() + monotonic_cst = _check_monotonic_cst(self, self.monotonic_cst) # used for validation in predict n_samples, self._n_features = X.shape @@ -664,7 +666,7 @@ def fit(self, X, y, sample_weight=None): n_bins_non_missing=self._bin_mapper.n_bins_non_missing_, has_missing_values=has_missing_values, is_categorical=self.is_categorical_, - monotonic_cst=self.monotonic_cst, + monotonic_cst=monotonic_cst, interaction_cst=interaction_cst, max_leaf_nodes=self.max_leaf_nodes, max_depth=self.max_depth, @@ -1259,16 +1261,27 @@ class HistGradientBoostingRegressor(RegressorMixin, BaseHistGradientBoosting): .. versionchanged:: 1.2 Added support for feature names. - monotonic_cst : array-like of int of shape (n_features), default=None - Indicates the monotonic constraint to enforce on each feature. - - 1: monotonic increase - - 0: no constraint - - -1: monotonic decrease + monotonic_cst : array-like of int of shape (n_features) or dict, default=None + Monotonic constraint to enforce on each feature are specified using the + following integer values: + - 1: monotonic increase + - 0: no constraint + - -1: monotonic decrease + + If a dict with str keys, map feature to monotonic constraints by name. + If an array, the features are mapped to constraints by position. See + :ref:`monotonic_cst_features_names` for a usage example. + + The constraints are only valid for binary classifications and hold + over the probability of the positive class. Read more in the :ref:`User Guide `. .. versionadded:: 0.23 + .. versionchanged:: 1.2 + Accept dict of constraints with feature names as keys. + interaction_cst : iterable of iterables of int, default=None Specify interaction constraints, the sets of features which can interact with each other in child node splits. @@ -1596,11 +1609,17 @@ class HistGradientBoostingClassifier(ClassifierMixin, BaseHistGradientBoosting): .. versionchanged:: 1.2 Added support for feature names. - monotonic_cst : array-like of int of shape (n_features), default=None - Indicates the monotonic constraint to enforce on each feature. - - 1: monotonic increase - - 0: no constraint - - -1: monotonic decrease + monotonic_cst : array-like of int of shape (n_features) or dict, default=None + Monotonic constraint to enforce on each feature are specified using the + following integer values: + + - 1: monotonic increase + - 0: no constraint + - -1: monotonic decrease + + If a dict with str keys, map feature to monotonic constraints by name. + If an array, the features are mapped to constraints by position. See + :ref:`monotonic_cst_features_names` for a usage example. The constraints are only valid for binary classifications and hold over the probability of the positive class. @@ -1608,6 +1627,9 @@ class HistGradientBoostingClassifier(ClassifierMixin, BaseHistGradientBoosting): .. versionadded:: 0.23 + .. versionchanged:: 1.2 + Accept dict of constraints with feature names as keys. + interaction_cst : iterable of iterables of int, default=None Specify interaction constraints, the sets of features which can interact with each other in child node splits. diff --git a/sklearn/ensemble/_hist_gradient_boosting/grower.py b/sklearn/ensemble/_hist_gradient_boosting/grower.py index 5e3010fa4a509..1ad6dba661552 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/grower.py +++ b/sklearn/ensemble/_hist_gradient_boosting/grower.py @@ -264,28 +264,18 @@ def __init__( has_missing_values = [has_missing_values] * X_binned.shape[1] has_missing_values = np.asarray(has_missing_values, dtype=np.uint8) + # `monotonic_cst` validation is done in _validate_monotonic_cst + # at the estimator level and therefore the following should not be + # needed when using the public API. if monotonic_cst is None: - self.with_monotonic_cst = False monotonic_cst = np.full( shape=X_binned.shape[1], fill_value=MonotonicConstraint.NO_CST, dtype=np.int8, ) else: - self.with_monotonic_cst = True monotonic_cst = np.asarray(monotonic_cst, dtype=np.int8) - - if monotonic_cst.shape[0] != X_binned.shape[1]: - raise ValueError( - "monotonic_cst has shape {} but the input data " - "X has {} features.".format( - monotonic_cst.shape[0], X_binned.shape[1] - ) - ) - if np.any(monotonic_cst < -1) or np.any(monotonic_cst > 1): - raise ValueError( - "monotonic_cst must be None or an array-like of -1, 0 or 1." - ) + self.with_monotonic_cst = np.any(monotonic_cst != MonotonicConstraint.NO_CST) if is_categorical is None: is_categorical = np.zeros(shape=X_binned.shape[1], dtype=np.uint8) diff --git a/sklearn/ensemble/_hist_gradient_boosting/tests/test_monotonic_contraints.py b/sklearn/ensemble/_hist_gradient_boosting/tests/test_monotonic_contraints.py index 4ab65c55a8620..9456b9d9934b1 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/tests/test_monotonic_contraints.py +++ b/sklearn/ensemble/_hist_gradient_boosting/tests/test_monotonic_contraints.py @@ -1,3 +1,4 @@ +import re import numpy as np import pytest @@ -197,24 +198,33 @@ def test_nodes_values(monotonic_cst, seed): assert_leaves_values_monotonic(predictor, monotonic_cst) -@pytest.mark.parametrize("seed", range(3)) -def test_predictions(seed): +@pytest.mark.parametrize("use_feature_names", (True, False)) +def test_predictions(global_random_seed, use_feature_names): # Train a model with a POS constraint on the first feature and a NEG # constraint on the second feature, and make sure the constraints are # respected by checking the predictions. # test adapted from lightgbm's test_monotone_constraint(), itself inspired # by https://xgboost.readthedocs.io/en/latest/tutorials/monotonic.html - rng = np.random.RandomState(seed) + rng = np.random.RandomState(global_random_seed) n_samples = 1000 f_0 = rng.rand(n_samples) # positive correlation with y f_1 = rng.rand(n_samples) # negative correslation with y X = np.c_[f_0, f_1] + if use_feature_names: + pd = pytest.importorskip("pandas") + X = pd.DataFrame(X, columns=["f_0", "f_1"]) + noise = rng.normal(loc=0.0, scale=0.01, size=n_samples) y = 5 * f_0 + np.sin(10 * np.pi * f_0) - 5 * f_1 - np.cos(10 * np.pi * f_1) + noise - gbdt = HistGradientBoostingRegressor(monotonic_cst=[1, -1]) + if use_feature_names: + monotonic_cst = {"f_0": +1, "f_1": -1} + else: + monotonic_cst = [+1, -1] + + gbdt = HistGradientBoostingRegressor(monotonic_cst=monotonic_cst) gbdt.fit(X, y) linspace = np.linspace(0, 1, 100) @@ -258,15 +268,16 @@ def test_input_error(): gbdt = HistGradientBoostingRegressor(monotonic_cst=[1, 0, -1]) with pytest.raises( - ValueError, match="monotonic_cst has shape 3 but the input data" + ValueError, match=re.escape("monotonic_cst has shape (3,) but the input data") ): gbdt.fit(X, y) - for monotonic_cst in ([1, 3], [1, -3]): + for monotonic_cst in ([1, 3], [1, -3], [0.3, -0.7]): gbdt = HistGradientBoostingRegressor(monotonic_cst=monotonic_cst) - with pytest.raises( - ValueError, match="must be None or an array-like of -1, 0 or 1" - ): + expected_msg = re.escape( + "must be an array-like of -1, 0 or 1. Observed values:" + ) + with pytest.raises(ValueError, match=expected_msg): gbdt.fit(X, y) gbdt = HistGradientBoostingClassifier(monotonic_cst=[0, 1]) @@ -277,6 +288,44 @@ def test_input_error(): gbdt.fit(X, y) +def test_input_error_related_to_feature_names(): + pd = pytest.importorskip("pandas") + X = pd.DataFrame({"a": [0, 1, 2], "b": [0, 1, 2]}) + y = np.array([0, 1, 0]) + + monotonic_cst = {"d": 1, "a": 1, "c": -1} + gbdt = HistGradientBoostingRegressor(monotonic_cst=monotonic_cst) + expected_msg = re.escape( + "monotonic_cst contains 2 unexpected feature names: ['c', 'd']." + ) + with pytest.raises(ValueError, match=expected_msg): + gbdt.fit(X, y) + + monotonic_cst = {k: 1 for k in "abcdefghijklmnopqrstuvwxyz"} + gbdt = HistGradientBoostingRegressor(monotonic_cst=monotonic_cst) + expected_msg = re.escape( + "monotonic_cst contains 24 unexpected feature names: " + "['c', 'd', 'e', 'f', 'g', '...']." + ) + with pytest.raises(ValueError, match=expected_msg): + gbdt.fit(X, y) + + monotonic_cst = {"a": 1} + gbdt = HistGradientBoostingRegressor(monotonic_cst=monotonic_cst) + expected_msg = re.escape( + "HistGradientBoostingRegressor was not fitted on data with feature " + "names. Pass monotonic_cst as an integer array instead." + ) + with pytest.raises(ValueError, match=expected_msg): + gbdt.fit(X.values, y) + + monotonic_cst = {"b": -1, "a": "+"} + gbdt = HistGradientBoostingRegressor(monotonic_cst=monotonic_cst) + expected_msg = re.escape("monotonic_cst['a'] must be either -1, 0 or 1. Got '+'.") + with pytest.raises(ValueError, match=expected_msg): + gbdt.fit(X, y) + + def test_bounded_value_min_gain_to_split(): # The purpose of this test is to show that when computing the gain at a # given split, the value of the current node should be properly bounded to diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index d204b332eef8c..aeb3a8814be22 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -1979,3 +1979,85 @@ def _generate_get_feature_names_out(estimator, n_features_out, input_features=No return np.asarray( [f"{estimator_name}{i}" for i in range(n_features_out)], dtype=object ) + + +def _check_monotonic_cst(estimator, monotonic_cst=None): + """Check the monotonic constraints and return the corresponding array. + + This helper function should be used in the `fit` method of an estimator + that supports monotonic constraints and called after the estimator has + introspected input data to set the `n_features_in_` and optionally the + `feature_names_in_` attributes. + + .. versionadded:: 1.2 + + Parameters + ---------- + estimator : estimator instance + + monotonic_cst : array-like of int, dict of str or None, default=None + Monotonic constraints for the features. + + - If array-like, then it should contain only -1, 0 or 1. Each value + will be checked to be in [-1, 0, 1]. If a value is -1, then the + corresponding feature is required to be monotonically decreasing. + - If dict, then it the keys should be the feature names occurring in + `estimator.feature_names_in_` and the values should be -1, 0 or 1. + - If None, then an array of 0s will be allocated. + + Returns + ------- + monotonic_cst : ndarray of int + Monotonic constraints for each feature. + """ + original_monotonic_cst = monotonic_cst + if monotonic_cst is None or isinstance(monotonic_cst, dict): + monotonic_cst = np.full( + shape=estimator.n_features_in_, + fill_value=0, + dtype=np.int8, + ) + if isinstance(original_monotonic_cst, dict): + if not hasattr(estimator, "feature_names_in_"): + raise ValueError( + f"{estimator.__class__.__name__} was not fitted on data " + "with feature names. Pass monotonic_cst as an integer " + "array instead." + ) + unexpected_feature_names = list( + set(original_monotonic_cst) - set(estimator.feature_names_in_) + ) + unexpected_feature_names.sort() # deterministic error message + n_unexpeced = len(unexpected_feature_names) + if unexpected_feature_names: + if len(unexpected_feature_names) > 5: + unexpected_feature_names = unexpected_feature_names[:5] + unexpected_feature_names.append("...") + raise ValueError( + f"monotonic_cst contains {n_unexpeced} unexpected feature " + f"names: {unexpected_feature_names}." + ) + for feature_idx, feature_name in enumerate(estimator.feature_names_in_): + if feature_name in original_monotonic_cst: + cst = original_monotonic_cst[feature_name] + if cst not in [-1, 0, 1]: + raise ValueError( + f"monotonic_cst['{feature_name}'] must be either " + f"-1, 0 or 1. Got {cst!r}." + ) + monotonic_cst[feature_idx] = cst + else: + unexpected_cst = np.setdiff1d(monotonic_cst, [-1, 0, 1]) + if unexpected_cst.shape[0]: + raise ValueError( + "monotonic_cst must be an array-like of -1, 0 or 1. Observed " + f"values: {unexpected_cst.tolist()}." + ) + + monotonic_cst = np.asarray(monotonic_cst, dtype=np.int8) + if monotonic_cst.shape[0] != estimator.n_features_in_: + raise ValueError( + f"monotonic_cst has shape {monotonic_cst.shape} but the input data " + f"X has {estimator.n_features_in_} features." + ) + return monotonic_cst