-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
ENH Remove unnecessary OOB computation when n_more_estimators == 0 #26318
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the PR @choo8 !
Do you have any suggestions on how I can go about writing the test for this?
As for testing, I think we can use a Python's Mock object with a callable as the side_effect
, pass it to self.oob_score
, and assert the number of times it is called. With this PR, self.oob_score
should be called less than on main
.
Thanks for the suggestion and comments @thomasjpfan! I've commited my test, please let me know if you have any comments. I decided to add on to the "check_warm_start_oob" test function as I thought the test logic should belong there. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the update!
Please add an entry to the change log at doc/whats_new/v1.3.rst
with tag |Efficiency|. Like the other entries there, please reference this pull request with :pr:
and credit yourself (and other contributors if applicable) with :user:
.
|
||
est_3.set_params(oob_score=True) | ||
ignore_warnings(est_3.fit)(X, y) | ||
# Patch _set_oob_score_and_attributes() to track OOB computation |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think extending this test complicates the original test a little too much. I prefer adding another test:
def test_oob_not_computed_twice(monkeypatch):
"""Check that oob_score is not computed twice when warm_start=True."""
est = RandomForestClassifier(n_estimators=10, oob_score=True, warm_start=True)
mock = Mock(side_effect=est._set_oob_score_and_attributes)
monkeypatch.setattr(est, "_set_oob_score_and_attributes", mock)
est.fit(X, y)
with pytest.warns(UserWarning, match="Warm-start fitting without increasing"):
est.fit(X, y)
mock.assert_called_once()
where from unittest.mock import Mock
.
sklearn/ensemble/_forest.py
Outdated
" expected. Please change the shape of y to " | ||
"(n_samples,), for example using ravel()." | ||
), | ||
"A column-vector y was passed when a 1d array was" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few months ago we updated to black==23.3.0
. May you update to black==23.3.0
and run the linter again?
Thank you for your comments @thomasjpfan. I've created a separate test case as suggested, updated the changelog and also updated the version of the linter. |
There seems to be an issue with the merge because the diff is now at git fetch upstream main
git rebase upstream/main and do a force push? |
Hi @thomasjpfan, I've fixed the diff with the commands you suggested. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the update!
doc/whats_new/v1.3.rst
Outdated
- |Efficiency| :class:`ensemble.BaseForest` now only recomputes out-of-bag scores | ||
if `n_more_estimators > 0` in subsequent `fit` calls. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BaseForest
is not a public class and n_more_estimators
is not a public parameter. The user guide is public facing, so it's best to state the change in terms of the public API:
- |Efficiency| :class:`ensemble.BaseForest` now only recomputes out-of-bag scores | |
if `n_more_estimators > 0` in subsequent `fit` calls. | |
- |Efficiency| :class:`ensemble.RandomForestClassifier` and | |
:class:`ensemble.RandomForestRegressor` with `warm_start=True` now only | |
recomputes out-of-bag scores when there are actually more `n_estimators` | |
in subsequent `fit` calls. |
@@ -1470,6 +1471,31 @@ def test_warm_start_oob(name): | |||
check_warm_start_oob(name) | |||
|
|||
|
|||
def check_oob_not_computed_twice(name): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can see we there is a convention in this file of having a separate function that the test calls, but I think that adds more indirection. Can you place the body of the function directly in the test_oob_not_computed_twice
test?
Hi @thomasjpfan, I've updated the PR with regards to your latest comments. |
@@ -1466,8 +1467,29 @@ def check_warm_start_oob(name): | |||
|
|||
|
|||
@pytest.mark.parametrize("name", FOREST_CLASSIFIERS_REGRESSORS) | |||
def test_warm_start_oob(name): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like something went wrong with the last patch.
Can you add test_warm_start_oob
back and move the body of the new function into test_oob_not_computed_twice
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I mistakenly deleted the wrong function. I've ran pytest to make sure and pushed the corrected edit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the update! LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will merge these two nitpicks.
Thanks @choo8 |
@thomasjpfan thanks for patiently guiding me through this PR! |
…cikit-learn#26318) Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
Reference Issues/PRs
Fixes #20435, based on the comments by @NicolasHug (#20435 (comment))
What does this implement/fix? Explain your changes.
Removes unnecessary OOB computation when n_more_estimators == 0
Any other comments?
I understand from our previous conversation (#24579 (comment)) that @glemaitre would like some unit tests for this change.
I noticed that the unit test below will check for the case where oob_score is toggled from
False
toTrue
scikit-learn/sklearn/ensemble/tests/test_forest.py
Lines 1418 to 1465 in c5f10c8
However, I am not sure how I should go about checking that OOB computation (call to
self._set_oob_score_and_attributes()
atscikit-learn/sklearn/ensemble/_forest.py
Lines 492 to 494 in c5f10c8
scikit-learn/sklearn/ensemble/_forest.py
Line 496 in c5f10c8
n_more_estimators == 0
.I thought of checking on some object attribute that might be changed by
self._set_oob_score_and_attributes()
but I couldn't find any candidates. I did some research online and it seems like others create a mock function if they want to check for functions being called in runtime.Do you have any suggestions on how I can go about writing the test for this?