Skip to content

FIX Fix error when using Calibrated with Voting #20087

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 4, 2021

Conversation

cfauchereau
Copy link
Contributor

Reference Issues/PRs

Fixes #20053

What does this implement/fix? Explain your changes.

The commit #17856 changed the way CalibratedClassifierCV internally works. Due to weird implementation of VotingClassifier.predict_proba, it broke the compatibility when VotingClassifier is used as base_estimator in CalibratedClassifier.
This is a simple fix to restore compatibility.

Any other comments?

The real issue is the way VotingClassifier.predict_proba is implemented. However, it seems to me that it can't be resolved without breaking changes.
The issue is that predict_proba is not a method but an attribute which holds another method. The goal of this trick was to do polymorphism and implement predict_proba only when voting="soft" using a getter.
I think VotingClassifier should be an abstract class and we should implement SoftVotingClassifier and HardVotingClassifier since they don't implement the same methods. It is however a big API change.

It would be simpler to only raise an error if voting="hard" but if I understand correctly it is assumed that if predict_proba exists then it must work. It would therefore leads to other incompatibilities.

I am not familiar enough with the code base to know how it is usually dealt with. Anyway, I think it needs further discussion.

@cfauchereau cfauchereau force-pushed the calibrate_and_voting branch 3 times, most recently from bc43de1 to 502b8a7 Compare May 13, 2021 14:10
Copy link
Member

@thomasjpfan thomasjpfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for working on this issue @Clement-F !

Comment on lines 511 to 512
elif method_name == 'predict_proba' or method_name == '_predict_proba':
# The `_predict_proba` option is needed for `VotingClassifier`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not think special casing VotingClassifier is great. I think it would be better to get _get_prediction_method to return a tuple (callable, method_name) and then use the method_name here.

(_get_prediction_method is always called before _compute_predictions)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. I have updated the code.

@cfauchereau cfauchereau force-pushed the calibrate_and_voting branch from 502b8a7 to 34322ee Compare May 22, 2021 16:43
Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for the fix. Just a few minor suggestions:

@cfauchereau cfauchereau force-pushed the calibrate_and_voting branch from 34322ee to 8eef25f Compare June 4, 2021 10:10
@thomasjpfan thomasjpfan changed the title Fix error when using Calibrated with Voting FIX Fix error when using Calibrated with Voting Jun 4, 2021
Copy link
Member

@thomasjpfan thomasjpfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@thomasjpfan thomasjpfan merged commit 7b965c7 into scikit-learn:main Jun 4, 2021
@cfauchereau cfauchereau deleted the calibrate_and_voting branch June 4, 2021 19:04
thomasjpfan added a commit to thomasjpfan/scikit-learn that referenced this pull request Jun 8, 2021
* TST enable test docstring params for feature extraction module (scikit-learn#20188)

* DOC fix a reference in sklearn.ensemble.GradientBoostingRegressor (scikit-learn#20198)

* FIX mcc zero divsion  (scikit-learn#19977)

* TST Add TransformedTargetRegressor to test_meta_estimators_delegate_data_validation (scikit-learn#20175)

Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>

* TST enable n_feature_in_ test for feature_extraction module

* FIX Uses points instead of pixels in plot_tree (scikit-learn#20023)

* MNT n_features_in through the multiclass module (scikit-learn#20193)

* CI Removes python 3.6 builds from wheel building (scikit-learn#20184)

* FIX Fix typo in error message in `fetch_openml` (scikit-learn#20201)

* FIX Fix error when using Calibrated with Voting (scikit-learn#20087)

* FIX Fix RandomForestRegressor doesn't accept max_samples=1.0 (scikit-learn#20159)

Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>

* ENH Adds Poisson criterion in RandomForestRegressor (scikit-learn#19836)

Co-authored-by: Christian Lorentzen <lorentzen.ch@gmail.com>
Co-authored-by: Alihan Zihna <alihanz@gmail.com>
Co-authored-by: Alihan Zihna <a.zihna@ckhgbdp.onmicrosoft.com>
Co-authored-by: Chiara Marmo <cmarmo@users.noreply.github.com>
Co-authored-by: Olivier Grisel <olivier.grisel@gmail.com>
Co-authored-by: naozin555 <37050583+naozin555@users.noreply.github.com>
Co-authored-by: Venkatachalam N <venky.yuvy@gmail.com>
Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>

* TST Replace assert_warns from decomposition/tests (scikit-learn#20214)

* TST check n_features_in_ in pipeline module (scikit-learn#20192)

Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>
Co-authored-by: Olivier Grisel <olivier.grisel@gmail.com>

* Allow `n_knots=None` if knots are explicitly specified in `SplineTransformer` (scikit-learn#20191)


Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>

* FIX make check_complex_data deterministic (scikit-learn#20221)

* TST test_fit_docstring_attributes include properties (scikit-learn#20190)

* FIX Uses the color max for colormap in ConfusionMatrixDisplay (scikit-learn#19784)

* STY Changing .format method to f-string formatting (scikit-learn#20215)

* CI Adds permissions for label action

Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>
Co-authored-by: tsuga <2888173+tsuga@users.noreply.github.com>
Co-authored-by: Conner Shen <connershen98@hotmail.com>
Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
Co-authored-by: mlondschien <61679398+mlondschien@users.noreply.github.com>
Co-authored-by: Clément Fauchereau <clement.fauchereau@ensta-bretagne.org>
Co-authored-by: murata-yu <67666318+murata-yu@users.noreply.github.com>
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
Co-authored-by: Brian Sun <52805678+bsun94@users.noreply.github.com>
Co-authored-by: Christian Lorentzen <lorentzen.ch@gmail.com>
Co-authored-by: Alihan Zihna <alihanz@gmail.com>
Co-authored-by: Alihan Zihna <a.zihna@ckhgbdp.onmicrosoft.com>
Co-authored-by: Chiara Marmo <cmarmo@users.noreply.github.com>
Co-authored-by: Olivier Grisel <olivier.grisel@gmail.com>
Co-authored-by: naozin555 <37050583+naozin555@users.noreply.github.com>
Co-authored-by: Venkatachalam N <venky.yuvy@gmail.com>
Co-authored-by: Nanshan Li <nanshanli@dsaid.gov.sg>
Co-authored-by: solosilence <abhishekkr23rs@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

CalibratedClassifierCV Invalid prediction method: _predict_proba
4 participants