Skip to content

Fix max_depth overshoot in BFS expansion of trees #12344

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

adrinjalali
Copy link
Member

This fixes an issue with BFS expansion of trees, which overshoots the max_depth of a tree by 1. The output of the following two cases should be the same, but isn't:

>>> from sklearn.datasets import load_iris
>>> from sklearn.tree import DecisionTreeClassifier
>>> X, y = load_iris(return_X_y=True)
>>> 
>>> clf = DecisionTreeClassifier(random_state=0, max_depth=1, max_leaf_nodes=100)
>>> clf = clf.fit(X, y)
>>> clf.get_depth()
2
>>> clf.get_n_leaves()
3
>>> 
>>> clf = DecisionTreeClassifier(random_state=0, max_depth=1)
>>> clf = clf.fit(X, y)
>>> clf.get_depth()
1
>>> clf.get_n_leaves()
2

This PR fixes the issue, and warns the user of the changed behavior if the code reaches that point of change.

I'm not sure about the warning, but right now, this would be the output after this PR:

>>> from sklearn.datasets import load_iris
>>> from sklearn.tree import DecisionTreeClassifier
>>> X, y = load_iris(return_X_y=True)
>>> 
>>> clf = DecisionTreeClassifier(random_state=0, max_depth=1, max_leaf_nodes=100)
>>> clf = clf.fit(X, y)
.../tree.py:380: UserWarning: Due to a bugfix in v0.21 the maximum depth of a tree now does not pass the given max_depth!
  builder.build(self.tree_, X, y, sample_weight, X_idx_sorted)
>>> clf.get_depth()
1
>>> clf.get_n_leaves()
2
>>> 
>>> clf = DecisionTreeClassifier(random_state=0, max_depth=1)
>>> clf = clf.fit(X, y)
>>> clf.get_depth()
1
>>> clf.get_n_leaves()
2

@amueller
Copy link
Member

I'm -1 on the warning. I don't like warnings that the user can't disable / avoid.

Copy link
Member

@jnothman jnothman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm okay with no warning. We need to note it clearly in what's new, though

impurity <= min_impurity_split)):
with gil:
warnings.warn("Due to a bugfix in v0.21 the maximum depth of a"
" tree now does not pass the given max_depth!",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pass -> exceed

@jnothman
Copy link
Member

Although: all warnings can be disabled if you try hard enough

@adrinjalali
Copy link
Member Author

The warning is removed now.

Copy link
Member

@jnothman jnothman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please also add an entry under "Changed models" in what's new.

@adrinjalali
Copy link
Member Author

I'm not sure if I should also add the ensemble methods which use trees to the changed classes and the Fix part of the whats_new though. Their tests had to be modified as well, so their behavior has changed the same way.

@jnothman
Copy link
Member

Yes, it doesn't hurt to mention the forests etc

@@ -18,6 +18,22 @@ occurs due to changes in the modelling logic (bug fixes or enhancements), or in
random sampling procedures.

- please add class and reason here (see version 0.20 what's new)
- :class:`ensemble.AdaBoostClassifier` (bug fix)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay. This is excessive. Firstly, the user can't specify max_depth in adaboost (or bagging) without explicitly constructing a decision tree.
Secondly, I think here we are trying to keep things succinct. Saying decision trees and derived ensembles are affected should be sufficient. It might also be appropriate here to say "with max_depth and max_leaf_nodes set" here to not cause undue panic.

@amueller amueller merged commit 02dc9ed into scikit-learn:master Nov 13, 2018
@adrinjalali adrinjalali deleted the bug/tree/maxdepthandleafnodes branch November 13, 2018 18:39
thoo added a commit to thoo/scikit-learn that referenced this pull request Nov 13, 2018
…ybutton

* upstream/master:
  Fix max_depth overshoot in BFS expansion of trees (scikit-learn#12344)
  TST don't test utils.fixes docstrings (scikit-learn#12576)
  DOC Fix typo (scikit-learn#12563)
  FIX Workaround limitation of cloudpickle under PyPy (scikit-learn#12566)
  MNT bare asserts (scikit-learn#12571)
  FIX incorrect error when OneHotEncoder.transform called prior to fit (scikit-learn#12443)
thoo pushed a commit to thoo/scikit-learn that referenced this pull request Nov 13, 2018
* fix the issue with max_depth and BestFirstTreeBuilder

* fix the test

* fix max_depth overshoot in BFS expansion

* fix forest tests

* remove the warning, add whats_new entry

* remove extra line

* add affected classes to changed classes

* add other affected estimators to the whats_new changed models

* shorten whats_new changed models entry
thoo added a commit to thoo/scikit-learn that referenced this pull request Nov 13, 2018
…ikit-learn into add_codeblock_copybutton

* 'add_codeblock_copybutton' of https://github.com/thoo/scikit-learn:
  Move an extension under sphinx_copybutton/
  Move css/js file under sphinxext/
  Fix max_depth overshoot in BFS expansion of trees (scikit-learn#12344)
  TST don't test utils.fixes docstrings (scikit-learn#12576)
  DOC Fix typo (scikit-learn#12563)
  FIX Workaround limitation of cloudpickle under PyPy (scikit-learn#12566)
  MNT bare asserts (scikit-learn#12571)
  FIX incorrect error when OneHotEncoder.transform called prior to fit (scikit-learn#12443)
  Retrigger travis:max time limit error
  DOC: Clarify `cv` parameter description in `GridSearchCV` (scikit-learn#12495)
  FIX remove FutureWarning in _object_dtype_isnan and add test (scikit-learn#12567)
  DOC Add 's' to "correspond" in docs for Hamming Loss. (scikit-learn#12565)
  EXA Fix comment in plot-iris-logistic example (scikit-learn#12564)
  FIX stop words validation in text vectorizers with custom preprocessors / tokenizers (scikit-learn#12393)
  DOC Add skorch to related projects (scikit-learn#12561)
  MNT Don't change self.n_values in OneHotEncoder.fit (scikit-learn#12286)
  MNT Remove unused assert_true imports (scikit-learn#12560)
  TST autoreplace assert_true(...==...) with plain assert (scikit-learn#12547)
  DOC: add a testimonial from JP Morgan (scikit-learn#12555)
thoo pushed a commit to thoo/scikit-learn that referenced this pull request Nov 14, 2018
* fix the issue with max_depth and BestFirstTreeBuilder

* fix the test

* fix max_depth overshoot in BFS expansion

* fix forest tests

* remove the warning, add whats_new entry

* remove extra line

* add affected classes to changed classes

* add other affected estimators to the whats_new changed models

* shorten whats_new changed models entry
xhluca pushed a commit to xhluca/scikit-learn that referenced this pull request Apr 28, 2019
* fix the issue with max_depth and BestFirstTreeBuilder

* fix the test

* fix max_depth overshoot in BFS expansion

* fix forest tests

* remove the warning, add whats_new entry

* remove extra line

* add affected classes to changed classes

* add other affected estimators to the whats_new changed models

* shorten whats_new changed models entry
xhluca pushed a commit to xhluca/scikit-learn that referenced this pull request Apr 28, 2019
xhluca pushed a commit to xhluca/scikit-learn that referenced this pull request Apr 28, 2019
koenvandevelde pushed a commit to koenvandevelde/scikit-learn that referenced this pull request Jul 12, 2019
* fix the issue with max_depth and BestFirstTreeBuilder

* fix the test

* fix max_depth overshoot in BFS expansion

* fix forest tests

* remove the warning, add whats_new entry

* remove extra line

* add affected classes to changed classes

* add other affected estimators to the whats_new changed models

* shorten whats_new changed models entry
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants