From 825400e1ba418460f9813b1e20756a6feebcd246 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Mon, 7 Nov 2022 19:43:35 +0100 Subject: [PATCH 01/24] Add support for feature names in monotonic_cst --- .../gradient_boosting.py | 21 ++++- .../_hist_gradient_boosting/grower.py | 19 ++--- .../tests/test_monotonic_contraints.py | 7 +- sklearn/utils/validation.py | 81 +++++++++++++++++++ 4 files changed, 106 insertions(+), 22 deletions(-) diff --git a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py index 470020dbd492b..427e882e5ea4d 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py +++ b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py @@ -23,6 +23,7 @@ check_is_fitted, check_consistent_length, _check_sample_weight, + _check_monotonic_cst, ) from ...utils._param_validation import Interval, StrOptions from ...utils._openmp_helpers import _openmp_effective_n_threads @@ -91,7 +92,7 @@ class BaseHistGradientBoosting(BaseEstimator, ABC): "max_depth": [Interval(Integral, 1, None, closed="left"), None], "min_samples_leaf": [Interval(Integral, 1, None, closed="left")], "l2_regularization": [Interval(Real, 0, None, closed="left")], - "monotonic_cst": ["array-like", None], + "monotonic_cst": ["array-like", dict, None], "interaction_cst": [Iterable, None], "n_iter_no_change": [Interval(Integral, 1, None, closed="left")], "validation_fraction": [ @@ -342,6 +343,7 @@ def fit(self, X, y, sample_weight=None): self._random_seed = rng.randint(np.iinfo(np.uint32).max, dtype="u8") self._validate_parameters() + monotonic_cst = _check_monotonic_cst(self, self.monotonic_cst) # used for validation in predict n_samples, self._n_features = X.shape @@ -637,7 +639,7 @@ def fit(self, X, y, sample_weight=None): n_bins_non_missing=self._bin_mapper.n_bins_non_missing_, has_missing_values=has_missing_values, is_categorical=self.is_categorical_, - monotonic_cst=self.monotonic_cst, + monotonic_cst=monotonic_cst, interaction_cst=interaction_cst, max_leaf_nodes=self.max_leaf_nodes, max_depth=self.max_depth, @@ -1559,8 +1561,15 @@ class HistGradientBoostingClassifier(ClassifierMixin, BaseHistGradientBoosting): .. versionadded:: 0.24 - monotonic_cst : array-like of int of shape (n_features), default=None - Indicates the monotonic constraint to enforce on each feature. + monotonic_cst : dict of str, array-like of int of shape (n_features), \ + default=None + + If a dict with str keys, map feature names to monotonic constraints by + feature names. If an array, the feature are mapped to constraints by + position. + + Monotonic constraint to enforce on each feature are specified using the + following integer values: - 1: monotonic increase - 0: no constraint - -1: monotonic decrease @@ -1571,6 +1580,10 @@ class HistGradientBoostingClassifier(ClassifierMixin, BaseHistGradientBoosting): .. versionadded:: 0.23 + .. versionchanged:: 1.2 + + Accept dict of constraints with feature names as keys. + interaction_cst : iterable of iterables of int, default=None Specify interaction constraints, i.e. sets of features which can only interact with each other in child nodes splits. diff --git a/sklearn/ensemble/_hist_gradient_boosting/grower.py b/sklearn/ensemble/_hist_gradient_boosting/grower.py index 5e3010fa4a509..db0522647f74e 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/grower.py +++ b/sklearn/ensemble/_hist_gradient_boosting/grower.py @@ -264,28 +264,19 @@ def __init__( has_missing_values = [has_missing_values] * X_binned.shape[1] has_missing_values = np.asarray(has_missing_values, dtype=np.uint8) + # Shallow validation of monotonic_cst to make TreeGrower easier to + # test. A more complete validation is done in _validate_monotonic_cst + # at the estimator level and therefore the following should not be + # needed when using the public API. if monotonic_cst is None: - self.with_monotonic_cst = False monotonic_cst = np.full( shape=X_binned.shape[1], fill_value=MonotonicConstraint.NO_CST, dtype=np.int8, ) else: - self.with_monotonic_cst = True monotonic_cst = np.asarray(monotonic_cst, dtype=np.int8) - - if monotonic_cst.shape[0] != X_binned.shape[1]: - raise ValueError( - "monotonic_cst has shape {} but the input data " - "X has {} features.".format( - monotonic_cst.shape[0], X_binned.shape[1] - ) - ) - if np.any(monotonic_cst < -1) or np.any(monotonic_cst > 1): - raise ValueError( - "monotonic_cst must be None or an array-like of -1, 0 or 1." - ) + self.with_monotonic_cst = np.any(monotonic_cst != MonotonicConstraint.NO_CST) if is_categorical is None: is_categorical = np.zeros(shape=X_binned.shape[1], dtype=np.uint8) diff --git a/sklearn/ensemble/_hist_gradient_boosting/tests/test_monotonic_contraints.py b/sklearn/ensemble/_hist_gradient_boosting/tests/test_monotonic_contraints.py index 4ab65c55a8620..1d38eee5f7e7f 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/tests/test_monotonic_contraints.py +++ b/sklearn/ensemble/_hist_gradient_boosting/tests/test_monotonic_contraints.py @@ -1,3 +1,4 @@ +import re import numpy as np import pytest @@ -258,15 +259,13 @@ def test_input_error(): gbdt = HistGradientBoostingRegressor(monotonic_cst=[1, 0, -1]) with pytest.raises( - ValueError, match="monotonic_cst has shape 3 but the input data" + ValueError, match=re.escape("monotonic_cst has shape (3,) but the input data") ): gbdt.fit(X, y) for monotonic_cst in ([1, 3], [1, -3]): gbdt = HistGradientBoostingRegressor(monotonic_cst=monotonic_cst) - with pytest.raises( - ValueError, match="must be None or an array-like of -1, 0 or 1" - ): + with pytest.raises(ValueError, match="must be an array-like of -1, 0 or 1"): gbdt.fit(X, y) gbdt = HistGradientBoostingClassifier(monotonic_cst=[0, 1]) diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index d204b332eef8c..72cdf1ab2a19c 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -1979,3 +1979,84 @@ def _generate_get_feature_names_out(estimator, n_features_out, input_features=No return np.asarray( [f"{estimator_name}{i}" for i in range(n_features_out)], dtype=object ) + + +def _check_monotonic_cst(estimator, monotonic_cst=None): + """Check the monotonic constraints and return the corresponding array. + + This helper function should be used in the `fit` method of an estimator + that supports monotonic constraints and called after the estimator has + introspected input data to set the `n_features_in_` and optionally the + `feature_names_in_` attributes. + + .. versionadded:: 1.2 + + Parameters + ---------- + estimator : estimator instance + + monotonic_cst : array-like of int, dict of str or None, default=None + Monotonic constraints for the features. + + - If array-like, then it should contain only -1, 0 or 1. Each value + will be checked to be in [-1, 0, 1]. If a value is -1, then the + corresponding feature is required to be monotonically decreasing. + - If dict, then it the keys should be the feature names occurring in + `estimator.feature_names_in_` and the values should be -1, 0 or 1. + - If None, then an array of 0s will be allocated. + + Returns + ------- + monotonic_cst : ndarray of int + Monotonic constraints for each feature. + """ + original_monotonic_cst = monotonic_cst + if monotonic_cst is None or isinstance(monotonic_cst, dict): + monotonic_cst = np.full( + shape=estimator.n_features_in_, + fill_value=0, + dtype=np.int8, + ) + if isinstance(original_monotonic_cst, dict): + if not hasattr(estimator, "feature_names_in_"): + raise ValueError( + f"{estimator.__class__.__name__} was not fitted on data " + "with feature names. Pass monotonic_cst as an integer " + "array instead." + ) + unexpected_feature_names = list( + set(original_monotonic_cst) - set(estimator.feature_names_in_) + ) + unexpected_feature_names.sort() # deterministic error message + if unexpected_feature_names: + if len(unexpected_feature_names) > 5: + unexpected_feature_names = unexpected_feature_names[:5] + unexpected_feature_names.append("...") + raise ValueError( + "monotonic_cst contains unexpected feature names: " + f"{unexpected_feature_names}" + ) + for feature_idx, feature_name in estimator.feature_names_in_: + if feature_name in original_monotonic_cst: + cst = original_monotonic_cst[feature_name] + if cst not in [-1, 0, 1]: + raise ValueError( + f"monotonic_cst[{feature_name}] must be either " + f"-1, 0 or 1. Got {cst}." + ) + monotonic_cst[feature_idx] = cst + else: + monotonic_cst = np.asarray(monotonic_cst, dtype=np.int8) + + if monotonic_cst.shape[0] != estimator.n_features_in_: + raise ValueError( + f"monotonic_cst has shape {monotonic_cst.shape} but the input data " + f"X has {estimator.n_features_in_} features." + ) + unexpected_cst = np.setdiff1d(monotonic_cst, [-1, 0, 1]) + if unexpected_cst.shape[0]: + raise ValueError( + "monotonic_cst must be an array-like of -1, 0 or 1. Observed " + f"values: {unexpected_cst.tolist()}." + ) + return monotonic_cst From e660969fca504dfe55e0aea45f0c055af725a399 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Tue, 8 Nov 2022 00:16:29 +0100 Subject: [PATCH 02/24] docstring format --- sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py index 427e882e5ea4d..78cae96d7c409 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py +++ b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py @@ -1581,8 +1581,7 @@ class HistGradientBoostingClassifier(ClassifierMixin, BaseHistGradientBoosting): .. versionadded:: 0.23 .. versionchanged:: 1.2 - - Accept dict of constraints with feature names as keys. + Accept dict of constraints with feature names as keys. interaction_cst : iterable of iterables of int, default=None Specify interaction constraints, i.e. sets of features which can From 3b15660a878a0af7b92f0b429fbf903114cfdc23 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Tue, 8 Nov 2022 08:22:33 +0100 Subject: [PATCH 03/24] docstring format --- sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py index 78cae96d7c409..eb5049f159b2e 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py +++ b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py @@ -1561,9 +1561,7 @@ class HistGradientBoostingClassifier(ClassifierMixin, BaseHistGradientBoosting): .. versionadded:: 0.24 - monotonic_cst : dict of str, array-like of int of shape (n_features), \ - default=None - + monotonic_cst : dict, array-like of int of shape (n_features), default=None If a dict with str keys, map feature names to monotonic constraints by feature names. If an array, the feature are mapped to constraints by position. From cecf92ee4c15e465e8cb5a671f83a72f4c5d7fc7 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Tue, 8 Nov 2022 10:06:47 +0100 Subject: [PATCH 04/24] Fix indentation in docstring? --- sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py index eb5049f159b2e..40eb0f4934fad 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py +++ b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py @@ -1579,7 +1579,7 @@ class HistGradientBoostingClassifier(ClassifierMixin, BaseHistGradientBoosting): .. versionadded:: 0.23 .. versionchanged:: 1.2 - Accept dict of constraints with feature names as keys. + Accept dict of constraints with feature names as keys. interaction_cst : iterable of iterables of int, default=None Specify interaction constraints, i.e. sets of features which can From 1cd97d00100ddfb45ce07921e754cc29f2515f48 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Tue, 8 Nov 2022 10:29:05 +0100 Subject: [PATCH 05/24] More docstring tweaking --- sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py index 40eb0f4934fad..93fabafdc8717 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py +++ b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py @@ -1577,9 +1577,10 @@ class HistGradientBoostingClassifier(ClassifierMixin, BaseHistGradientBoosting): Read more in the :ref:`User Guide `. .. versionadded:: 0.23 + Support for monotonic constraints via array of integers. .. versionchanged:: 1.2 - Accept dict of constraints with feature names as keys. + Accept dict of constraints with feature names as keys. interaction_cst : iterable of iterables of int, default=None Specify interaction constraints, i.e. sets of features which can From 44c33d394bfa3ff98f21f55b7e451cfa7b3ed6f5 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Tue, 8 Nov 2022 11:19:48 +0100 Subject: [PATCH 06/24] Add a test for the nominal case --- .../tests/test_monotonic_contraints.py | 17 +++++++++++++---- sklearn/utils/validation.py | 2 +- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/sklearn/ensemble/_hist_gradient_boosting/tests/test_monotonic_contraints.py b/sklearn/ensemble/_hist_gradient_boosting/tests/test_monotonic_contraints.py index 1d38eee5f7e7f..3b3fa2d09e832 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/tests/test_monotonic_contraints.py +++ b/sklearn/ensemble/_hist_gradient_boosting/tests/test_monotonic_contraints.py @@ -198,24 +198,33 @@ def test_nodes_values(monotonic_cst, seed): assert_leaves_values_monotonic(predictor, monotonic_cst) -@pytest.mark.parametrize("seed", range(3)) -def test_predictions(seed): +@pytest.mark.parametrize("use_feature_names", (True, False)) +def test_predictions(global_random_seed, use_feature_names): # Train a model with a POS constraint on the first feature and a NEG # constraint on the second feature, and make sure the constraints are # respected by checking the predictions. # test adapted from lightgbm's test_monotone_constraint(), itself inspired # by https://xgboost.readthedocs.io/en/latest/tutorials/monotonic.html - rng = np.random.RandomState(seed) + rng = np.random.RandomState(global_random_seed) n_samples = 1000 f_0 = rng.rand(n_samples) # positive correlation with y f_1 = rng.rand(n_samples) # negative correslation with y X = np.c_[f_0, f_1] + if use_feature_names: + pd = pytest.importorskip("pandas") + X = pd.DataFrame(X, columns=["f_0", "f_1"]) + noise = rng.normal(loc=0.0, scale=0.01, size=n_samples) y = 5 * f_0 + np.sin(10 * np.pi * f_0) - 5 * f_1 - np.cos(10 * np.pi * f_1) + noise - gbdt = HistGradientBoostingRegressor(monotonic_cst=[1, -1]) + if use_feature_names: + monotonic_cst = {"f_0": +1, "f_1": -1} + else: + monotonic_cst = [+1, -1] + + gbdt = HistGradientBoostingRegressor(monotonic_cst=monotonic_cst) gbdt.fit(X, y) linspace = np.linspace(0, 1, 100) diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index 72cdf1ab2a19c..d83063578a21e 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -2036,7 +2036,7 @@ def _check_monotonic_cst(estimator, monotonic_cst=None): "monotonic_cst contains unexpected feature names: " f"{unexpected_feature_names}" ) - for feature_idx, feature_name in estimator.feature_names_in_: + for feature_idx, feature_name in enumerate(estimator.feature_names_in_): if feature_name in original_monotonic_cst: cst = original_monotonic_cst[feature_name] if cst not in [-1, 0, 1]: From 73d8a3769fc835db8a90d7e8be4e7128c6a5048c Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Tue, 8 Nov 2022 12:15:59 +0100 Subject: [PATCH 07/24] Test error messages --- .../tests/test_monotonic_contraints.py | 29 +++++++++++++++++++ sklearn/utils/validation.py | 4 +-- 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/sklearn/ensemble/_hist_gradient_boosting/tests/test_monotonic_contraints.py b/sklearn/ensemble/_hist_gradient_boosting/tests/test_monotonic_contraints.py index 3b3fa2d09e832..4379812e4b42a 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/tests/test_monotonic_contraints.py +++ b/sklearn/ensemble/_hist_gradient_boosting/tests/test_monotonic_contraints.py @@ -285,6 +285,35 @@ def test_input_error(): gbdt.fit(X, y) +def test_input_error_related_to_feature_names(): + pd = pytest.importorskip("pandas") + X = pd.DataFrame({"a": [0, 1, 2], "b": [0, 1, 2]}) + y = np.array([0, 1, 0]) + + monotonic_cst = {"d": 1, "a": 1, "c": -1} + gbdt = HistGradientBoostingRegressor(monotonic_cst=monotonic_cst) + expected_msg = re.escape( + "monotonic_cst contains unexpected feature names: ['c', 'd']." + ) + with pytest.raises(ValueError, match=expected_msg): + gbdt.fit(X, y) + + monotonic_cst = {"a": 1} + gbdt = HistGradientBoostingRegressor(monotonic_cst=monotonic_cst) + expected_msg = re.escape( + "HistGradientBoostingRegressor was not fitted on data with feature " + "names. Pass monotonic_cst as an integer array instead." + ) + with pytest.raises(ValueError, match=expected_msg): + gbdt.fit(X.values, y) + + monotonic_cst = {"b": -1, "a": "+"} + gbdt = HistGradientBoostingRegressor(monotonic_cst=monotonic_cst) + expected_msg = re.escape("monotonic_cst[a] must be either -1, 0 or 1. Got '+'.") + with pytest.raises(ValueError, match=expected_msg): + gbdt.fit(X, y) + + def test_bounded_value_min_gain_to_split(): # The purpose of this test is to show that when computing the gain at a # given split, the value of the current node should be properly bounded to diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index d83063578a21e..4e0cb3e808fdc 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -2034,7 +2034,7 @@ def _check_monotonic_cst(estimator, monotonic_cst=None): unexpected_feature_names.append("...") raise ValueError( "monotonic_cst contains unexpected feature names: " - f"{unexpected_feature_names}" + f"{unexpected_feature_names}." ) for feature_idx, feature_name in enumerate(estimator.feature_names_in_): if feature_name in original_monotonic_cst: @@ -2042,7 +2042,7 @@ def _check_monotonic_cst(estimator, monotonic_cst=None): if cst not in [-1, 0, 1]: raise ValueError( f"monotonic_cst[{feature_name}] must be either " - f"-1, 0 or 1. Got {cst}." + f"-1, 0 or 1. Got {cst!r}." ) monotonic_cst[feature_idx] = cst else: From 7ed48f5b82222cb568614b8d8c8052a7c4f3e652 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Tue, 8 Nov 2022 14:35:49 +0100 Subject: [PATCH 08/24] Changelog entry --- doc/whats_new/v1.2.rst | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/doc/whats_new/v1.2.rst b/doc/whats_new/v1.2.rst index 1427e61c03385..b7e1eac472150 100644 --- a/doc/whats_new/v1.2.rst +++ b/doc/whats_new/v1.2.rst @@ -301,6 +301,13 @@ Changelog :user:`Nestor Navarro `, :user:`Nati Tomattis `, and :user:`Vincent Maladiere `. +- |Enhancement| :class:`ensemble.HistGradientBoostingClassifier` and + :class:`ensemble.HistGradientBoostingClassifier` now accept their + `monotonic_cst` parameter to be passed as a dictionary with feature names as + keys in addition the previously supported format that used an array of + ternary integers to specify monotonicity constraints by feature position. + :pr:`24855` by :user:`Olivier Grisel `. + - |Fix| Fixed the issue where :class:`ensemble.AdaBoostClassifier` outputs NaN in feature importance when fitted with very small sample weight. :pr:`20415` by :user:`Zhehao Liu `. From 498ca43d42018bebbde7bf11be97965383fb3baa Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Thu, 10 Nov 2022 10:33:34 +0100 Subject: [PATCH 09/24] Update example --- .../ensemble/plot_monotonic_constraints.py | 49 +++++++++++++------ 1 file changed, 35 insertions(+), 14 deletions(-) diff --git a/examples/ensemble/plot_monotonic_constraints.py b/examples/ensemble/plot_monotonic_constraints.py index 8b0aff204a584..1c25915de3d6d 100644 --- a/examples/ensemble/plot_monotonic_constraints.py +++ b/examples/ensemble/plot_monotonic_constraints.py @@ -19,7 +19,7 @@ `_. """ - +# %% from sklearn.ensemble import HistGradientBoostingRegressor from sklearn.inspection import PartialDependenceDisplay import numpy as np @@ -28,7 +28,7 @@ rng = np.random.RandomState(0) -n_samples = 5000 +n_samples = 1000 f_0 = rng.rand(n_samples) f_1 = rng.rand(n_samples) X = np.c_[f_0, f_1] @@ -37,14 +37,23 @@ # y is positively correlated with f_0, and negatively correlated with f_1 y = 5 * f_0 + np.sin(10 * np.pi * f_0) - 5 * f_1 - np.cos(10 * np.pi * f_1) + noise -fig, ax = plt.subplots() + +# %% +# Fit a first model on this dataset without any constraint +gbdt_no_cst = HistGradientBoostingRegressor() +gbdt_no_cst.fit(X, y) + +# %% +# With monotonic increase (1) and a monotonic decrease (-1) constraints, respectively. +gbdt_monotonic = HistGradientBoostingRegressor(monotonic_cst=[1, -1]) +gbdt_monotonic.fit(X, y) -# Without any constraint -gbdt = HistGradientBoostingRegressor() -gbdt.fit(X, y) +# %% +# Let's display the partial dependence of the predictions on the two features. +fig, ax = plt.subplots() disp = PartialDependenceDisplay.from_estimator( - gbdt, + gbdt_no_cst, X, features=[0, 1], feature_names=( @@ -54,13 +63,8 @@ line_kw={"linewidth": 4, "label": "unconstrained", "color": "tab:blue"}, ax=ax, ) - -# With monotonic increase (1) and a monotonic decrease (-1) constraints, respectively. -gbdt = HistGradientBoostingRegressor(monotonic_cst=[1, -1]) -gbdt.fit(X, y) - PartialDependenceDisplay.from_estimator( - gbdt, + gbdt_monotonic, X, features=[0, 1], line_kw={"linewidth": 4, "label": "constrained", "color": "tab:orange"}, @@ -75,5 +79,22 @@ plt.legend() fig.suptitle("Monotonic constraints effect on partial dependences") - plt.show() + +# %% +# We can see that the predictions of the unconstrained model capture the +# oscillations of the the data while the constrained model follows the general +# trend and ignores the local variations. + +# %% +# Note that if the training data has feature names, it's possible to specifiy the +# monotonic constraints by passing a dictionary: +import pandas as pd + +X_df = pd.DataFrame(X, columns=["f_0", "f_1"]) + +gbdt_monotonic_df = HistGradientBoostingRegressor( + monotonic_cst={"f_0": 1, "f_1": -1} +).fit(X_df, y) + +np.allclose(gbdt_monotonic_df.predict(X_df), gbdt_monotonic.predict(X)) From 9f37dd77f077afbbb77fb09de14a081a26cb439d Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Thu, 10 Nov 2022 15:20:18 +0100 Subject: [PATCH 10/24] Docstring tweak for sphinx? --- .../_hist_gradient_boosting/gradient_boosting.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py index eb1d4c5b5fada..1b885964571e6 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py +++ b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py @@ -1561,17 +1561,17 @@ class HistGradientBoostingClassifier(ClassifierMixin, BaseHistGradientBoosting): .. versionadded:: 0.24 - monotonic_cst : dict, array-like of int of shape (n_features), default=None - If a dict with str keys, map feature names to monotonic constraints by - feature names. If an array, the feature are mapped to constraints by - position. - + monotonic_cst : array-like of int of shape (n_features) or dict, default=None Monotonic constraint to enforce on each feature are specified using the following integer values: - 1: monotonic increase - 0: no constraint - -1: monotonic decrease + If a dict with str keys, map feature names to monotonic constraints by + feature names. If an array, the feature are mapped to constraints by + position. + The constraints are only valid for binary classifications and hold over the probability of the positive class. Read more in the :ref:`User Guide `. From c754b803785a1d805e6d9374288adb67bac2ab80 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Thu, 10 Nov 2022 16:01:35 +0100 Subject: [PATCH 11/24] More indentation tweaking --- .../ensemble/_hist_gradient_boosting/gradient_boosting.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py index 1b885964571e6..9a77589d172ff 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py +++ b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py @@ -1564,9 +1564,9 @@ class HistGradientBoostingClassifier(ClassifierMixin, BaseHistGradientBoosting): monotonic_cst : array-like of int of shape (n_features) or dict, default=None Monotonic constraint to enforce on each feature are specified using the following integer values: - - 1: monotonic increase - - 0: no constraint - - -1: monotonic decrease + - 1: monotonic increase + - 0: no constraint + - -1: monotonic decrease If a dict with str keys, map feature names to monotonic constraints by feature names. If an array, the feature are mapped to constraints by From 5eb8174c18c7a3582d34224567489d5c6fe64cea Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Thu, 10 Nov 2022 16:21:00 +0100 Subject: [PATCH 12/24] Update the regressors' docstring --- .../gradient_boosting.py | 23 +++++++++++++++---- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py index 9a77589d172ff..5947372a528c9 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py +++ b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py @@ -1229,15 +1229,27 @@ class HistGradientBoostingRegressor(RegressorMixin, BaseHistGradientBoosting): .. versionadded:: 0.24 - monotonic_cst : array-like of int of shape (n_features), default=None - Indicates the monotonic constraint to enforce on each feature. - - 1: monotonic increase - - 0: no constraint - - -1: monotonic decrease + monotonic_cst : array-like of int of shape (n_features) or dict, default=None + Monotonic constraint to enforce on each feature are specified using the + following integer values: + + - 1: monotonic increase + - 0: no constraint + - -1: monotonic decrease + + If a dict with str keys, map feature names to monotonic constraints by + feature names. If an array, the feature are mapped to constraints by + position. + The constraints are only valid for binary classifications and hold + over the probability of the positive class. Read more in the :ref:`User Guide `. .. versionadded:: 0.23 + Support for monotonic constraints via array of integers. + + .. versionchanged:: 1.2 + Accept dict of constraints with feature names as keys. interaction_cst : iterable of iterables of int, default=None Specify interaction constraints, the sets of features which can @@ -1564,6 +1576,7 @@ class HistGradientBoostingClassifier(ClassifierMixin, BaseHistGradientBoosting): monotonic_cst : array-like of int of shape (n_features) or dict, default=None Monotonic constraint to enforce on each feature are specified using the following integer values: + - 1: monotonic increase - 0: no constraint - -1: monotonic decrease From a523aaf4686b39db1ec91a26f12e88426e8f880b Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Thu, 10 Nov 2022 16:56:11 +0100 Subject: [PATCH 13/24] Fix docstring formating and phrasing --- .../_hist_gradient_boosting/gradient_boosting.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py index 5947372a528c9..680b255de77e9 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py +++ b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py @@ -1237,16 +1237,14 @@ class HistGradientBoostingRegressor(RegressorMixin, BaseHistGradientBoosting): - 0: no constraint - -1: monotonic decrease - If a dict with str keys, map feature names to monotonic constraints by - feature names. If an array, the feature are mapped to constraints by - position. + If a dict with str keys, map feature to monotonic constraints by name. + If an array, the feature are mapped to constraints by position. The constraints are only valid for binary classifications and hold over the probability of the positive class. Read more in the :ref:`User Guide `. .. versionadded:: 0.23 - Support for monotonic constraints via array of integers. .. versionchanged:: 1.2 Accept dict of constraints with feature names as keys. @@ -1581,16 +1579,14 @@ class HistGradientBoostingClassifier(ClassifierMixin, BaseHistGradientBoosting): - 0: no constraint - -1: monotonic decrease - If a dict with str keys, map feature names to monotonic constraints by - feature names. If an array, the feature are mapped to constraints by - position. + If a dict with str keys, map feature to monotonic constraints by name. + If an array, the feature are mapped to constraints by position. The constraints are only valid for binary classifications and hold over the probability of the positive class. Read more in the :ref:`User Guide `. .. versionadded:: 0.23 - Support for monotonic constraints via array of integers. .. versionchanged:: 1.2 Accept dict of constraints with feature names as keys. From f929848fc27318e05abaf31befada18723930359 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Thu, 10 Nov 2022 21:12:02 +0100 Subject: [PATCH 14/24] Apply suggestions from code review Co-authored-by: Julien Jerphanion --- doc/whats_new/v1.2.rst | 7 ++++--- examples/ensemble/plot_monotonic_constraints.py | 17 +++++++++-------- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/doc/whats_new/v1.2.rst b/doc/whats_new/v1.2.rst index 76cff308c5d49..27f0800654488 100644 --- a/doc/whats_new/v1.2.rst +++ b/doc/whats_new/v1.2.rst @@ -307,9 +307,10 @@ Changelog - |Enhancement| :class:`ensemble.HistGradientBoostingClassifier` and :class:`ensemble.HistGradientBoostingClassifier` now accept their - `monotonic_cst` parameter to be passed as a dictionary with feature names as - keys in addition the previously supported format that used an array of - ternary integers to specify monotonicity constraints by feature position. + `monotonic_cst` parameter to be passed as a dictionary in addition + to the previously supported array-like format. + Such dictionary have feature names as keys and one of `-1`, `0`, `1` + as value to specify monotonicity constraints for each feature. :pr:`24855` by :user:`Olivier Grisel `. - |Fix| Fixed the issue where :class:`ensemble.AdaBoostClassifier` outputs diff --git a/examples/ensemble/plot_monotonic_constraints.py b/examples/ensemble/plot_monotonic_constraints.py index 1c25915de3d6d..f75bf68c77098 100644 --- a/examples/ensemble/plot_monotonic_constraints.py +++ b/examples/ensemble/plot_monotonic_constraints.py @@ -39,14 +39,15 @@ # %% -# Fit a first model on this dataset without any constraint +# Fit a first model on this dataset without any constraints. gbdt_no_cst = HistGradientBoostingRegressor() gbdt_no_cst.fit(X, y) # %% -# With monotonic increase (1) and a monotonic decrease (-1) constraints, respectively. -gbdt_monotonic = HistGradientBoostingRegressor(monotonic_cst=[1, -1]) -gbdt_monotonic.fit(X, y) +# Fit a second model on this dataset with monotonic increase (1) +# and a monotonic decrease (-1) constraints, respectively. +gbdt_with_monotonic_cst = HistGradientBoostingRegressor(monotonic_cst=[1, -1]) +gbdt_with_monotonic_cst.fit(X, y) # %% @@ -64,7 +65,7 @@ ax=ax, ) PartialDependenceDisplay.from_estimator( - gbdt_monotonic, + gbdt_with_monotonic_cst, X, features=[0, 1], line_kw={"linewidth": 4, "label": "constrained", "color": "tab:orange"}, @@ -83,7 +84,7 @@ # %% # We can see that the predictions of the unconstrained model capture the -# oscillations of the the data while the constrained model follows the general +# oscillations of the data while the constrained model follows the general # trend and ignores the local variations. # %% @@ -93,8 +94,8 @@ X_df = pd.DataFrame(X, columns=["f_0", "f_1"]) -gbdt_monotonic_df = HistGradientBoostingRegressor( +gbdt_with_monotonic_cst_df = HistGradientBoostingRegressor( monotonic_cst={"f_0": 1, "f_1": -1} ).fit(X_df, y) -np.allclose(gbdt_monotonic_df.predict(X_df), gbdt_monotonic.predict(X)) +np.allclose(gbdt_with_monotonic_cst_df.predict(X_df), gbdt_monotonic.predict(X)) From 86c4cc79d2448c5019b9e91c995d003b60dbeaa7 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Thu, 10 Nov 2022 21:39:18 +0100 Subject: [PATCH 15/24] Fix undefined variable --- examples/ensemble/plot_monotonic_constraints.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/ensemble/plot_monotonic_constraints.py b/examples/ensemble/plot_monotonic_constraints.py index f75bf68c77098..db284300d88e5 100644 --- a/examples/ensemble/plot_monotonic_constraints.py +++ b/examples/ensemble/plot_monotonic_constraints.py @@ -98,4 +98,6 @@ monotonic_cst={"f_0": 1, "f_1": -1} ).fit(X_df, y) -np.allclose(gbdt_with_monotonic_cst_df.predict(X_df), gbdt_monotonic.predict(X)) +np.allclose( + gbdt_with_monotonic_cst_df.predict(X_df), gbdt_with_monotonic_cst.predict(X) +) From 5ac2617b648eef5376ffe4b33f9a243fa3dbdb85 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Thu, 10 Nov 2022 21:56:13 +0100 Subject: [PATCH 16/24] Exclude invalid values in ]-1, 1[ --- .../tests/test_monotonic_contraints.py | 7 +++++-- sklearn/utils/validation.py | 14 +++++++------- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/sklearn/ensemble/_hist_gradient_boosting/tests/test_monotonic_contraints.py b/sklearn/ensemble/_hist_gradient_boosting/tests/test_monotonic_contraints.py index 4379812e4b42a..40b96176674f0 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/tests/test_monotonic_contraints.py +++ b/sklearn/ensemble/_hist_gradient_boosting/tests/test_monotonic_contraints.py @@ -272,9 +272,12 @@ def test_input_error(): ): gbdt.fit(X, y) - for monotonic_cst in ([1, 3], [1, -3]): + for monotonic_cst in ([1, 3], [1, -3], [0.3, -0.7]): gbdt = HistGradientBoostingRegressor(monotonic_cst=monotonic_cst) - with pytest.raises(ValueError, match="must be an array-like of -1, 0 or 1"): + expected_msg = re.escape( + "must be an array-like of -1, 0 or 1. Observed values:" + ) + with pytest.raises(ValueError, match=expected_msg): gbdt.fit(X, y) gbdt = HistGradientBoostingClassifier(monotonic_cst=[0, 1]) diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index 4e0cb3e808fdc..a450b91f03f63 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -2046,17 +2046,17 @@ def _check_monotonic_cst(estimator, monotonic_cst=None): ) monotonic_cst[feature_idx] = cst else: - monotonic_cst = np.asarray(monotonic_cst, dtype=np.int8) - - if monotonic_cst.shape[0] != estimator.n_features_in_: - raise ValueError( - f"monotonic_cst has shape {monotonic_cst.shape} but the input data " - f"X has {estimator.n_features_in_} features." - ) unexpected_cst = np.setdiff1d(monotonic_cst, [-1, 0, 1]) if unexpected_cst.shape[0]: raise ValueError( "monotonic_cst must be an array-like of -1, 0 or 1. Observed " f"values: {unexpected_cst.tolist()}." ) + + monotonic_cst = np.asarray(monotonic_cst, dtype=np.int8) + if monotonic_cst.shape[0] != estimator.n_features_in_: + raise ValueError( + f"monotonic_cst has shape {monotonic_cst.shape} but the input data " + f"X has {estimator.n_features_in_} features." + ) return monotonic_cst From 36dc2b755648db0629343722953bc613bc9f19bd Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Thu, 10 Nov 2022 22:01:46 +0100 Subject: [PATCH 17/24] Report number of unexpected feature names --- .../tests/test_monotonic_contraints.py | 2 +- sklearn/utils/validation.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/sklearn/ensemble/_hist_gradient_boosting/tests/test_monotonic_contraints.py b/sklearn/ensemble/_hist_gradient_boosting/tests/test_monotonic_contraints.py index 40b96176674f0..8504e99a922eb 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/tests/test_monotonic_contraints.py +++ b/sklearn/ensemble/_hist_gradient_boosting/tests/test_monotonic_contraints.py @@ -296,7 +296,7 @@ def test_input_error_related_to_feature_names(): monotonic_cst = {"d": 1, "a": 1, "c": -1} gbdt = HistGradientBoostingRegressor(monotonic_cst=monotonic_cst) expected_msg = re.escape( - "monotonic_cst contains unexpected feature names: ['c', 'd']." + "monotonic_cst contains 2 unexpected feature names: ['c', 'd']." ) with pytest.raises(ValueError, match=expected_msg): gbdt.fit(X, y) diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index a450b91f03f63..3611f096f5df8 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -2028,13 +2028,14 @@ def _check_monotonic_cst(estimator, monotonic_cst=None): set(original_monotonic_cst) - set(estimator.feature_names_in_) ) unexpected_feature_names.sort() # deterministic error message + n_unexpeced = len(unexpected_feature_names) if unexpected_feature_names: if len(unexpected_feature_names) > 5: unexpected_feature_names = unexpected_feature_names[:5] unexpected_feature_names.append("...") raise ValueError( - "monotonic_cst contains unexpected feature names: " - f"{unexpected_feature_names}." + f"monotonic_cst contains {n_unexpeced} unexpected feature " + f"names: {unexpected_feature_names}." ) for feature_idx, feature_name in enumerate(estimator.feature_names_in_): if feature_name in original_monotonic_cst: From 766d1f8601e4c04a35a0b42316c209836f5dc4c1 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Fri, 11 Nov 2022 06:17:56 +0100 Subject: [PATCH 18/24] Update sklearn/ensemble/_hist_gradient_boosting/grower.py Co-authored-by: Thomas J. Fan --- sklearn/ensemble/_hist_gradient_boosting/grower.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sklearn/ensemble/_hist_gradient_boosting/grower.py b/sklearn/ensemble/_hist_gradient_boosting/grower.py index db0522647f74e..1ad6dba661552 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/grower.py +++ b/sklearn/ensemble/_hist_gradient_boosting/grower.py @@ -264,8 +264,7 @@ def __init__( has_missing_values = [has_missing_values] * X_binned.shape[1] has_missing_values = np.asarray(has_missing_values, dtype=np.uint8) - # Shallow validation of monotonic_cst to make TreeGrower easier to - # test. A more complete validation is done in _validate_monotonic_cst + # `monotonic_cst` validation is done in _validate_monotonic_cst # at the estimator level and therefore the following should not be # needed when using the public API. if monotonic_cst is None: From 1925f0be3b65f5ad5e3d13203dad5d5b56a8bda2 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Fri, 11 Nov 2022 06:23:13 +0100 Subject: [PATCH 19/24] Link to example from docstring --- examples/ensemble/plot_monotonic_constraints.py | 5 +++++ .../ensemble/_hist_gradient_boosting/gradient_boosting.py | 6 ++++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/examples/ensemble/plot_monotonic_constraints.py b/examples/ensemble/plot_monotonic_constraints.py index db284300d88e5..45f0d2ccdf08d 100644 --- a/examples/ensemble/plot_monotonic_constraints.py +++ b/examples/ensemble/plot_monotonic_constraints.py @@ -88,6 +88,11 @@ # trend and ignores the local variations. # %% +# .. monotonic_cst_features_names: +# +# Using feature names to specify monotonic constraints +# ---------------------------------------------------- +# # Note that if the training data has feature names, it's possible to specifiy the # monotonic constraints by passing a dictionary: import pandas as pd diff --git a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py index 680b255de77e9..395bf260ecfa8 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py +++ b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py @@ -1238,7 +1238,8 @@ class HistGradientBoostingRegressor(RegressorMixin, BaseHistGradientBoosting): - -1: monotonic decrease If a dict with str keys, map feature to monotonic constraints by name. - If an array, the feature are mapped to constraints by position. + If an array, the feature are mapped to constraints by position. See + example usage :ref:`monotonic_cst_features_names`. The constraints are only valid for binary classifications and hold over the probability of the positive class. @@ -1580,7 +1581,8 @@ class HistGradientBoostingClassifier(ClassifierMixin, BaseHistGradientBoosting): - -1: monotonic decrease If a dict with str keys, map feature to monotonic constraints by name. - If an array, the feature are mapped to constraints by position. + If an array, the feature are mapped to constraints by position. See + example usage in :ref:`monotonic_cst_features_names`. The constraints are only valid for binary classifications and hold over the probability of the positive class. From 7b2b3c39cc4bd1b6de2870d544220033e7f1b4cd Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Fri, 11 Nov 2022 06:26:11 +0100 Subject: [PATCH 20/24] Add missing test case to increase coverage --- .../tests/test_monotonic_contraints.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/sklearn/ensemble/_hist_gradient_boosting/tests/test_monotonic_contraints.py b/sklearn/ensemble/_hist_gradient_boosting/tests/test_monotonic_contraints.py index 8504e99a922eb..d02e24957087c 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/tests/test_monotonic_contraints.py +++ b/sklearn/ensemble/_hist_gradient_boosting/tests/test_monotonic_contraints.py @@ -301,6 +301,15 @@ def test_input_error_related_to_feature_names(): with pytest.raises(ValueError, match=expected_msg): gbdt.fit(X, y) + monotonic_cst = {k: 1 for k in list("abcdefghijklmnopqrstuvwxyz")} + gbdt = HistGradientBoostingRegressor(monotonic_cst=monotonic_cst) + expected_msg = re.escape( + "monotonic_cst contains 24 unexpected feature names: " + "['c', 'd', 'e', 'f', 'g', '...']." + ) + with pytest.raises(ValueError, match=expected_msg): + gbdt.fit(X, y) + monotonic_cst = {"a": 1} gbdt = HistGradientBoostingRegressor(monotonic_cst=monotonic_cst) expected_msg = re.escape( From fa814d06f809ff06a1f6aa12fe48f2218e9f25f8 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Fri, 11 Nov 2022 07:00:49 +0100 Subject: [PATCH 21/24] Fix ref to example section --- examples/ensemble/plot_monotonic_constraints.py | 2 +- sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/ensemble/plot_monotonic_constraints.py b/examples/ensemble/plot_monotonic_constraints.py index 45f0d2ccdf08d..7e9e271256fa9 100644 --- a/examples/ensemble/plot_monotonic_constraints.py +++ b/examples/ensemble/plot_monotonic_constraints.py @@ -88,7 +88,7 @@ # trend and ignores the local variations. # %% -# .. monotonic_cst_features_names: +# .. _monotonic_cst_features_names: # # Using feature names to specify monotonic constraints # ---------------------------------------------------- diff --git a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py index 395bf260ecfa8..d238e204e0558 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py +++ b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py @@ -1239,7 +1239,7 @@ class HistGradientBoostingRegressor(RegressorMixin, BaseHistGradientBoosting): If a dict with str keys, map feature to monotonic constraints by name. If an array, the feature are mapped to constraints by position. See - example usage :ref:`monotonic_cst_features_names`. + :ref:`monotonic_cst_features_names` for a usage example. The constraints are only valid for binary classifications and hold over the probability of the positive class. @@ -1582,7 +1582,7 @@ class HistGradientBoostingClassifier(ClassifierMixin, BaseHistGradientBoosting): If a dict with str keys, map feature to monotonic constraints by name. If an array, the feature are mapped to constraints by position. See - example usage in :ref:`monotonic_cst_features_names`. + :ref:`monotonic_cst_features_names` for a usage example. The constraints are only valid for binary classifications and hold over the probability of the positive class. From 161e87cc8e39116dbc09723892f50cc51fbe04be Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Mon, 14 Nov 2022 09:12:00 +0100 Subject: [PATCH 22/24] Update sklearn/ensemble/_hist_gradient_boosting/tests/test_monotonic_contraints.py Co-authored-by: Thomas J. Fan --- .../_hist_gradient_boosting/tests/test_monotonic_contraints.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/ensemble/_hist_gradient_boosting/tests/test_monotonic_contraints.py b/sklearn/ensemble/_hist_gradient_boosting/tests/test_monotonic_contraints.py index d02e24957087c..296f961ded60a 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/tests/test_monotonic_contraints.py +++ b/sklearn/ensemble/_hist_gradient_boosting/tests/test_monotonic_contraints.py @@ -301,7 +301,7 @@ def test_input_error_related_to_feature_names(): with pytest.raises(ValueError, match=expected_msg): gbdt.fit(X, y) - monotonic_cst = {k: 1 for k in list("abcdefghijklmnopqrstuvwxyz")} + monotonic_cst = {k: 1 for k in "abcdefghijklmnopqrstuvwxyz"} gbdt = HistGradientBoostingRegressor(monotonic_cst=monotonic_cst) expected_msg = re.escape( "monotonic_cst contains 24 unexpected feature names: " From 5c6a7ee277f183656970ef6cfbeef06e59a5ccb7 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Mon, 14 Nov 2022 09:16:19 +0100 Subject: [PATCH 23/24] Cosmetic change in error message --- .../_hist_gradient_boosting/tests/test_monotonic_contraints.py | 2 +- sklearn/utils/validation.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/ensemble/_hist_gradient_boosting/tests/test_monotonic_contraints.py b/sklearn/ensemble/_hist_gradient_boosting/tests/test_monotonic_contraints.py index 296f961ded60a..9456b9d9934b1 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/tests/test_monotonic_contraints.py +++ b/sklearn/ensemble/_hist_gradient_boosting/tests/test_monotonic_contraints.py @@ -321,7 +321,7 @@ def test_input_error_related_to_feature_names(): monotonic_cst = {"b": -1, "a": "+"} gbdt = HistGradientBoostingRegressor(monotonic_cst=monotonic_cst) - expected_msg = re.escape("monotonic_cst[a] must be either -1, 0 or 1. Got '+'.") + expected_msg = re.escape("monotonic_cst['a'] must be either -1, 0 or 1. Got '+'.") with pytest.raises(ValueError, match=expected_msg): gbdt.fit(X, y) diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index 3611f096f5df8..aeb3a8814be22 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -2042,7 +2042,7 @@ def _check_monotonic_cst(estimator, monotonic_cst=None): cst = original_monotonic_cst[feature_name] if cst not in [-1, 0, 1]: raise ValueError( - f"monotonic_cst[{feature_name}] must be either " + f"monotonic_cst['{feature_name}'] must be either " f"-1, 0 or 1. Got {cst!r}." ) monotonic_cst[feature_idx] = cst From 90060dab8c4b0fa4023287486635df26ca2e625f Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Tue, 15 Nov 2022 15:25:45 +0100 Subject: [PATCH 24/24] Apply suggestions from code review Co-authored-by: Tim Head --- sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py index 58f735db43cdf..bdf9ea414c210 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py +++ b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py @@ -1270,7 +1270,7 @@ class HistGradientBoostingRegressor(RegressorMixin, BaseHistGradientBoosting): - -1: monotonic decrease If a dict with str keys, map feature to monotonic constraints by name. - If an array, the feature are mapped to constraints by position. See + If an array, the features are mapped to constraints by position. See :ref:`monotonic_cst_features_names` for a usage example. The constraints are only valid for binary classifications and hold @@ -1618,7 +1618,7 @@ class HistGradientBoostingClassifier(ClassifierMixin, BaseHistGradientBoosting): - -1: monotonic decrease If a dict with str keys, map feature to monotonic constraints by name. - If an array, the feature are mapped to constraints by position. See + If an array, the features are mapped to constraints by position. See :ref:`monotonic_cst_features_names` for a usage example. The constraints are only valid for binary classifications and hold