Skip to content

Commit 02dc9ed

Browse files
adrinjalaliamueller
authored andcommitted
Fix max_depth overshoot in BFS expansion of trees (#12344)
* fix the issue with max_depth and BestFirstTreeBuilder * fix the test * fix max_depth overshoot in BFS expansion * fix forest tests * remove the warning, add whats_new entry * remove extra line * add affected classes to changed classes * add other affected estimators to the whats_new changed models * shorten whats_new changed models entry
1 parent d0e99db commit 02dc9ed

File tree

5 files changed

+18
-6
lines changed

5 files changed

+18
-6
lines changed

doc/whats_new/v0.21.rst

+13
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ occurs due to changes in the modelling logic (bug fixes or enhancements), or in
1818
random sampling procedures.
1919

2020
- please add class and reason here (see version 0.20 what's new)
21+
- Decision trees and derived ensembles when both `max_depth` and
22+
`max_leaf_nodes` are set. (bug fix)
2123

2224
Details are listed in the changelog below.
2325

@@ -125,6 +127,17 @@ Support for Python 3.4 and below has been officially dropped.
125127
and :class:`tree.ExtraTreeRegressor`.
126128
:issue:`12300` by :user:`Adrin Jalali <adrinjalali>`.
127129

130+
- |Fix| Fixed an issue with :class:`tree.BaseDecisionTree`
131+
and consequently all estimators based
132+
on it, including :class:`tree.DecisionTreeClassifier`,
133+
:class:`tree.DecisionTreeRegressor`, :class:`tree.ExtraTreeClassifier`,
134+
and :class:`tree.ExtraTreeRegressor`, where they used to exceed the given
135+
``max_depth`` by 1 while expanding the tree if ``max_leaf_nodes`` and
136+
``max_depth`` were both specified by the user. Please note that this also
137+
affects all ensemble methods using decision trees.
138+
:pr:`12344` by :user:`Adrin Jalali <adrinjalali>`.
139+
140+
128141
Multiple modules
129142
................
130143

sklearn/ensemble/tests/test_forest.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -711,11 +711,11 @@ def check_max_leaf_nodes_max_depth(name):
711711
ForestEstimator = FOREST_ESTIMATORS[name]
712712
est = ForestEstimator(max_depth=1, max_leaf_nodes=4,
713713
n_estimators=1, random_state=0).fit(X, y)
714-
assert_greater(est.estimators_[0].tree_.max_depth, 1)
714+
assert_equal(est.estimators_[0].get_depth(), 1)
715715

716716
est = ForestEstimator(max_depth=1, n_estimators=1,
717717
random_state=0).fit(X, y)
718-
assert_equal(est.estimators_[0].tree_.max_depth, 1)
718+
assert_equal(est.estimators_[0].get_depth(), 1)
719719

720720

721721
@pytest.mark.parametrize('name', FOREST_ESTIMATORS)

sklearn/ensemble/tests/test_gradient_boosting.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1103,7 +1103,7 @@ def test_max_leaf_nodes_max_depth(GBEstimator):
11031103

11041104
est = GBEstimator(max_depth=1, max_leaf_nodes=k).fit(X, y)
11051105
tree = est.estimators_[0, 0].tree_
1106-
assert_greater(tree.max_depth, 1)
1106+
assert_equal(tree.max_depth, 1)
11071107

11081108
est = GBEstimator(max_depth=1).fit(X, y)
11091109
tree = est.estimators_[0, 0].tree_

sklearn/tree/_tree.pyx

+1-1
Original file line numberDiff line numberDiff line change
@@ -450,7 +450,7 @@ cdef class BestFirstTreeBuilder(TreeBuilder):
450450
impurity = splitter.node_impurity()
451451

452452
n_node_samples = end - start
453-
is_leaf = (depth > self.max_depth or
453+
is_leaf = (depth >= self.max_depth or
454454
n_node_samples < self.min_samples_split or
455455
n_node_samples < 2 * self.min_samples_leaf or
456456
weighted_n_node_samples < 2 * self.min_weight_leaf or

sklearn/tree/tests/test_tree.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -1210,7 +1210,6 @@ def test_class_weight_errors(name):
12101210

12111211
def test_max_leaf_nodes():
12121212
# Test greedy trees with max_depth + 1 leafs.
1213-
from sklearn.tree._tree import TREE_LEAF
12141213
X, y = datasets.make_hastie_10_2(n_samples=100, random_state=1)
12151214
k = 4
12161215
for name, TreeEstimator in ALL_TREES.items():
@@ -1232,7 +1231,7 @@ def test_max_leaf_nodes_max_depth():
12321231
k = 4
12331232
for name, TreeEstimator in ALL_TREES.items():
12341233
est = TreeEstimator(max_depth=1, max_leaf_nodes=k).fit(X, y)
1235-
assert_greater(est.get_depth(), 1)
1234+
assert_equal(est.get_depth(), 1)
12361235

12371236

12381237
def test_arrays_persist():

0 commit comments

Comments
 (0)