Skip to content

Commit 8e1e87c

Browse files
samronsinpat-oreillydsleoogriselthomasjpfan
authored andcommitted
ENH Monotonic Contraints for Tree-based models (scikit-learn#13649)
Co-authored-by: Pat O'Reilly <patrick.oreilly256@gmail.com> Co-authored-by: dsleo <leooleds@gmail.com> Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org> Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com> Co-authored-by: Julien Jerphanion <git@jjerphan.xyz>
1 parent 3afc8dc commit 8e1e87c

File tree

11 files changed

+1181
-45
lines changed

11 files changed

+1181
-45
lines changed

doc/whats_new/v1.4.rst

+25
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,31 @@ TODO: update at the time of the release.
5959
passed to the ``fit`` method of the the estimator. :pr:`26506` by `Adrin
6060
Jalali`_.
6161

62+
63+
:mod:`sklearn.ensemble`
64+
.......................
65+
66+
- |Feature| :class:`ensemble.RandomForestClassifier`,
67+
:class:`ensemble.RandomForestRegressor`, :class:`ensemble.ExtraTreesClassifier`
68+
and :class:`ensemble.ExtraTreesRegressor` now support monotonic constraints,
69+
useful when features are supposed to have a positive/negative effect on the target.
70+
Missing values in the train data and multi-output targets are not supported.
71+
:pr:`13649` by :user:`Samuel Ronsin <samronsin>`,
72+
initiated by :user:`Patrick O'Reilly <pat-oreilly>`.
73+
74+
75+
:mod:`sklearn.tree`
76+
...................
77+
78+
- |Feature| :class:`tree.DecisionTreeClassifier`, :class:`tree.DecisionTreeRegressor`,
79+
:class:`tree.ExtraTreeClassifier` and :class:`tree.ExtraTreeRegressor` now support
80+
monotonic constraints, useful when features are supposed to have a positive/negative
81+
effect on the target. Missing values in the train data and multi-output targets are
82+
not supported.
83+
:pr:`13649` by :user:`Samuel Ronsin <samronsin>`, initiated by
84+
:user:`Patrick O'Reilly <pat-oreilly>`.
85+
86+
6287
:mod:`sklearn.decomposition`
6388
............................
6489

sklearn/ensemble/_forest.py

+83-1
Original file line numberDiff line numberDiff line change
@@ -1273,6 +1273,25 @@ class RandomForestClassifier(ForestClassifier):
12731273
12741274
.. versionadded:: 0.22
12751275
1276+
monotonic_cst : array-like of int of shape (n_features), default=None
1277+
Indicates the monotonicity constraint to enforce on each feature.
1278+
- 1: monotonic increase
1279+
- 0: no constraint
1280+
- -1: monotonic decrease
1281+
1282+
If monotonic_cst is None, no constraints are applied.
1283+
1284+
Monotonicity constraints are not supported for:
1285+
- multiclass classifications (i.e. when `n_classes > 2`),
1286+
- multioutput classifications (i.e. when `n_outputs_ > 1`),
1287+
- classifications trained on data with missing values.
1288+
1289+
The constraints hold over the probability of the positive class.
1290+
1291+
Read more in the :ref:`User Guide <monotonic_cst_gbdt>`.
1292+
1293+
.. versionadded:: 1.4
1294+
12761295
Attributes
12771296
----------
12781297
estimator_ : :class:`~sklearn.tree.DecisionTreeClassifier`
@@ -1413,6 +1432,7 @@ def __init__(
14131432
class_weight=None,
14141433
ccp_alpha=0.0,
14151434
max_samples=None,
1435+
monotonic_cst=None,
14161436
):
14171437
super().__init__(
14181438
estimator=DecisionTreeClassifier(),
@@ -1428,6 +1448,7 @@ def __init__(
14281448
"min_impurity_decrease",
14291449
"random_state",
14301450
"ccp_alpha",
1451+
"monotonic_cst",
14311452
),
14321453
bootstrap=bootstrap,
14331454
oob_score=oob_score,
@@ -1447,6 +1468,7 @@ def __init__(
14471468
self.max_features = max_features
14481469
self.max_leaf_nodes = max_leaf_nodes
14491470
self.min_impurity_decrease = min_impurity_decrease
1471+
self.monotonic_cst = monotonic_cst
14501472
self.ccp_alpha = ccp_alpha
14511473

14521474

@@ -1627,6 +1649,22 @@ class RandomForestRegressor(ForestRegressor):
16271649
16281650
.. versionadded:: 0.22
16291651
1652+
monotonic_cst : array-like of int of shape (n_features), default=None
1653+
Indicates the monotonicity constraint to enforce on each feature.
1654+
- 1: monotonically increasing
1655+
- 0: no constraint
1656+
- -1: monotonically decreasing
1657+
1658+
If monotonic_cst is None, no constraints are applied.
1659+
1660+
Monotonicity constraints are not supported for:
1661+
- multioutput regressions (i.e. when `n_outputs_ > 1`),
1662+
- regressions trained on data with missing values.
1663+
1664+
Read more in the :ref:`User Guide <monotonic_cst_gbdt>`.
1665+
1666+
.. versionadded:: 1.4
1667+
16301668
Attributes
16311669
----------
16321670
estimator_ : :class:`~sklearn.tree.DecisionTreeRegressor`
@@ -1754,6 +1792,7 @@ def __init__(
17541792
warm_start=False,
17551793
ccp_alpha=0.0,
17561794
max_samples=None,
1795+
monotonic_cst=None,
17571796
):
17581797
super().__init__(
17591798
estimator=DecisionTreeRegressor(),
@@ -1769,6 +1808,7 @@ def __init__(
17691808
"min_impurity_decrease",
17701809
"random_state",
17711810
"ccp_alpha",
1811+
"monotonic_cst",
17721812
),
17731813
bootstrap=bootstrap,
17741814
oob_score=oob_score,
@@ -1788,6 +1828,7 @@ def __init__(
17881828
self.max_leaf_nodes = max_leaf_nodes
17891829
self.min_impurity_decrease = min_impurity_decrease
17901830
self.ccp_alpha = ccp_alpha
1831+
self.monotonic_cst = monotonic_cst
17911832

17921833

17931834
class ExtraTreesClassifier(ForestClassifier):
@@ -1975,6 +2016,25 @@ class ExtraTreesClassifier(ForestClassifier):
19752016
19762017
.. versionadded:: 0.22
19772018
2019+
monotonic_cst : array-like of int of shape (n_features), default=None
2020+
Indicates the monotonicity constraint to enforce on each feature.
2021+
- 1: monotonically increasing
2022+
- 0: no constraint
2023+
- -1: monotonically decreasing
2024+
2025+
If monotonic_cst is None, no constraints are applied.
2026+
2027+
Monotonicity constraints are not supported for:
2028+
- multiclass classifications (i.e. when `n_classes > 2`),
2029+
- multioutput classifications (i.e. when `n_outputs_ > 1`),
2030+
- classifications trained on data with missing values.
2031+
2032+
The constraints hold over the probability of the positive class.
2033+
2034+
Read more in the :ref:`User Guide <monotonic_cst_gbdt>`.
2035+
2036+
.. versionadded:: 1.4
2037+
19782038
Attributes
19792039
----------
19802040
estimator_ : :class:`~sklearn.tree.ExtraTreesClassifier`
@@ -2104,6 +2164,7 @@ def __init__(
21042164
class_weight=None,
21052165
ccp_alpha=0.0,
21062166
max_samples=None,
2167+
monotonic_cst=None,
21072168
):
21082169
super().__init__(
21092170
estimator=ExtraTreeClassifier(),
@@ -2119,6 +2180,7 @@ def __init__(
21192180
"min_impurity_decrease",
21202181
"random_state",
21212182
"ccp_alpha",
2183+
"monotonic_cst",
21222184
),
21232185
bootstrap=bootstrap,
21242186
oob_score=oob_score,
@@ -2139,6 +2201,7 @@ def __init__(
21392201
self.max_leaf_nodes = max_leaf_nodes
21402202
self.min_impurity_decrease = min_impurity_decrease
21412203
self.ccp_alpha = ccp_alpha
2204+
self.monotonic_cst = monotonic_cst
21422205

21432206

21442207
class ExtraTreesRegressor(ForestRegressor):
@@ -2314,6 +2377,22 @@ class ExtraTreesRegressor(ForestRegressor):
23142377
23152378
.. versionadded:: 0.22
23162379
2380+
monotonic_cst : array-like of int of shape (n_features), default=None
2381+
Indicates the monotonicity constraint to enforce on each feature.
2382+
- 1: monotonically increasing
2383+
- 0: no constraint
2384+
- -1: monotonically decreasing
2385+
2386+
If monotonic_cst is None, no constraints are applied.
2387+
2388+
Monotonicity constraints are not supported for:
2389+
- multioutput regressions (i.e. when `n_outputs_ > 1`),
2390+
- regressions trained on data with missing values.
2391+
2392+
Read more in the :ref:`User Guide <monotonic_cst_gbdt>`.
2393+
2394+
.. versionadded:: 1.4
2395+
23172396
Attributes
23182397
----------
23192398
estimator_ : :class:`~sklearn.tree.ExtraTreeRegressor`
@@ -2426,6 +2505,7 @@ def __init__(
24262505
warm_start=False,
24272506
ccp_alpha=0.0,
24282507
max_samples=None,
2508+
monotonic_cst=None,
24292509
):
24302510
super().__init__(
24312511
estimator=ExtraTreeRegressor(),
@@ -2441,6 +2521,7 @@ def __init__(
24412521
"min_impurity_decrease",
24422522
"random_state",
24432523
"ccp_alpha",
2524+
"monotonic_cst",
24442525
),
24452526
bootstrap=bootstrap,
24462527
oob_score=oob_score,
@@ -2460,6 +2541,7 @@ def __init__(
24602541
self.max_leaf_nodes = max_leaf_nodes
24612542
self.min_impurity_decrease = min_impurity_decrease
24622543
self.ccp_alpha = ccp_alpha
2544+
self.monotonic_cst = monotonic_cst
24632545

24642546

24652547
class RandomTreesEmbedding(TransformerMixin, BaseForest):
@@ -2653,7 +2735,7 @@ class RandomTreesEmbedding(TransformerMixin, BaseForest):
26532735
**BaseDecisionTree._parameter_constraints,
26542736
"sparse_output": ["boolean"],
26552737
}
2656-
for param in ("max_features", "ccp_alpha", "splitter"):
2738+
for param in ("max_features", "ccp_alpha", "splitter", "monotonic_cst"):
26572739
_parameter_constraints.pop(param)
26582740

26592741
criterion = "squared_error"

sklearn/ensemble/_gb.py

+1
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ class BaseGradientBoosting(BaseEnsemble, metaclass=ABCMeta):
137137
"tol": [Interval(Real, 0.0, None, closed="left")],
138138
}
139139
_parameter_constraints.pop("splitter")
140+
_parameter_constraints.pop("monotonic_cst")
140141

141142
@abstractmethod
142143
def __init__(

0 commit comments

Comments
 (0)