Skip to content

ENH Remove unnecessary OOB computation when n_more_estimators == 0 #26318

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
May 30, 2023

Conversation

choo8
Copy link

@choo8 choo8 commented May 2, 2023

Reference Issues/PRs

Fixes #20435, based on the comments by @NicolasHug (#20435 (comment))

What does this implement/fix? Explain your changes.

Removes unnecessary OOB computation when n_more_estimators == 0

Any other comments?

I understand from our previous conversation (#24579 (comment)) that @glemaitre would like some unit tests for this change.

I noticed that the unit test below will check for the case where oob_score is toggled from False to True

def check_warm_start_oob(name):
# Test that the warm start computes oob score when asked.
X, y = hastie_X, hastie_y
ForestEstimator = FOREST_ESTIMATORS[name]
# Use 15 estimators to avoid 'some inputs do not have OOB scores' warning.
est = ForestEstimator(
n_estimators=15,
max_depth=3,
warm_start=False,
random_state=1,
bootstrap=True,
oob_score=True,
)
est.fit(X, y)
est_2 = ForestEstimator(
n_estimators=5,
max_depth=3,
warm_start=False,
random_state=1,
bootstrap=True,
oob_score=False,
)
est_2.fit(X, y)
est_2.set_params(warm_start=True, oob_score=True, n_estimators=15)
est_2.fit(X, y)
assert hasattr(est_2, "oob_score_")
assert est.oob_score_ == est_2.oob_score_
# Test that oob_score is computed even if we don't need to train
# additional trees.
est_3 = ForestEstimator(
n_estimators=15,
max_depth=3,
warm_start=True,
random_state=1,
bootstrap=True,
oob_score=False,
)
est_3.fit(X, y)
assert not hasattr(est_3, "oob_score_")
est_3.set_params(oob_score=True)
ignore_warnings(est_3.fit)(X, y)
assert est.oob_score_ == est_3.oob_score_

However, I am not sure how I should go about checking that OOB computation (call to self._set_oob_score_and_attributes() at

self._set_oob_score_and_attributes(
X, y, scoring_function=self.oob_score
)
or
self._set_oob_score_and_attributes(X, y)
) is not called when n_more_estimators == 0.

I thought of checking on some object attribute that might be changed by self._set_oob_score_and_attributes() but I couldn't find any candidates. I did some research online and it seems like others create a mock function if they want to check for functions being called in runtime.

Do you have any suggestions on how I can go about writing the test for this?

Copy link
Member

@thomasjpfan thomasjpfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the PR @choo8 !

Do you have any suggestions on how I can go about writing the test for this?

As for testing, I think we can use a Python's Mock object with a callable as the side_effect, pass it to self.oob_score, and assert the number of times it is called. With this PR, self.oob_score should be called less than on main.

@choo8
Copy link
Author

choo8 commented May 23, 2023

Thanks for the suggestion and comments @thomasjpfan! I've commited my test, please let me know if you have any comments. I decided to add on to the "check_warm_start_oob" test function as I thought the test logic should belong there.

Copy link
Member

@thomasjpfan thomasjpfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the update!

Please add an entry to the change log at doc/whats_new/v1.3.rst with tag |Efficiency|. Like the other entries there, please reference this pull request with :pr: and credit yourself (and other contributors if applicable) with :user:.


est_3.set_params(oob_score=True)
ignore_warnings(est_3.fit)(X, y)
# Patch _set_oob_score_and_attributes() to track OOB computation
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think extending this test complicates the original test a little too much. I prefer adding another test:

def test_oob_not_computed_twice(monkeypatch):
    """Check that oob_score is not computed twice when warm_start=True."""
    est = RandomForestClassifier(n_estimators=10, oob_score=True, warm_start=True)

    mock = Mock(side_effect=est._set_oob_score_and_attributes)
    monkeypatch.setattr(est, "_set_oob_score_and_attributes", mock)

    est.fit(X, y)

    with pytest.warns(UserWarning, match="Warm-start fitting without increasing"):
        est.fit(X, y)

    mock.assert_called_once()

where from unittest.mock import Mock.

" expected. Please change the shape of y to "
"(n_samples,), for example using ravel()."
),
"A column-vector y was passed when a 1d array was"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few months ago we updated to black==23.3.0. May you update to black==23.3.0 and run the linter again?

@choo8
Copy link
Author

choo8 commented May 27, 2023

Thank you for your comments @thomasjpfan. I've created a separate test case as suggested, updated the changelog and also updated the version of the linter.

@thomasjpfan
Copy link
Member

There seems to be an issue with the merge because the diff is now at +3,998 −1,450. Can you try running the following:

git fetch upstream main
git rebase upstream/main

and do a force push?

@choo8
Copy link
Author

choo8 commented May 28, 2023

Hi @thomasjpfan, I've fixed the diff with the commands you suggested.

@thomasjpfan thomasjpfan changed the title [WIP] Remove unnecessary OOB computation when n_more_estimators == 0 ENH Remove unnecessary OOB computation when n_more_estimators == 0 May 28, 2023
Copy link
Member

@thomasjpfan thomasjpfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the update!

Comment on lines 290 to 291
- |Efficiency| :class:`ensemble.BaseForest` now only recomputes out-of-bag scores
if `n_more_estimators > 0` in subsequent `fit` calls.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BaseForest is not a public class and n_more_estimators is not a public parameter. The user guide is public facing, so it's best to state the change in terms of the public API:

Suggested change
- |Efficiency| :class:`ensemble.BaseForest` now only recomputes out-of-bag scores
if `n_more_estimators > 0` in subsequent `fit` calls.
- |Efficiency| :class:`ensemble.RandomForestClassifier` and
:class:`ensemble.RandomForestRegressor` with `warm_start=True` now only
recomputes out-of-bag scores when there are actually more `n_estimators`
in subsequent `fit` calls.

@@ -1470,6 +1471,31 @@ def test_warm_start_oob(name):
check_warm_start_oob(name)


def check_oob_not_computed_twice(name):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can see we there is a convention in this file of having a separate function that the test calls, but I think that adds more indirection. Can you place the body of the function directly in the test_oob_not_computed_twice test?

@choo8
Copy link
Author

choo8 commented May 29, 2023

Hi @thomasjpfan, I've updated the PR with regards to your latest comments.

@@ -1466,8 +1467,29 @@ def check_warm_start_oob(name):


@pytest.mark.parametrize("name", FOREST_CLASSIFIERS_REGRESSORS)
def test_warm_start_oob(name):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like something went wrong with the last patch.

Can you add test_warm_start_oob back and move the body of the new function into test_oob_not_computed_twice?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I mistakenly deleted the wrong function. I've ran pytest to make sure and pushed the corrected edit.

Copy link
Member

@thomasjpfan thomasjpfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the update! LGTM

@thomasjpfan thomasjpfan added the Waiting for Second Reviewer First reviewer is done, need a second one! label May 29, 2023
@glemaitre glemaitre self-requested a review May 30, 2023 07:17
Copy link
Member

@glemaitre glemaitre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will merge these two nitpicks.

@glemaitre glemaitre merged commit 1415a28 into scikit-learn:main May 30, 2023
@glemaitre
Copy link
Member

Thanks @choo8

@choo8
Copy link
Author

choo8 commented May 31, 2023

@thomasjpfan thanks for patiently guiding me through this PR!

REDVM pushed a commit to REDVM/scikit-learn that referenced this pull request Nov 16, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module:ensemble Waiting for Second Reviewer First reviewer is done, need a second one!
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Incorrect documentation for warm_start behavior on BaseForest-derived classes
3 participants