-
-
Notifications
You must be signed in to change notification settings - Fork 26k
fix mixin inheritance order, allow overwriting tags #14884
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have not checked that changes include all classes, but the included changes LGTM.
I would really prefer using super but this is a mostly backward-compatible change and seems less controversial.
OK, but I think in either case these changes with mixin reordering would be necessary so it's good to have this no?
cc @NicolasHug in case you have bandwidth for a second quick review :)
Yes these changes here are are prerequisite so I'm happy to merge them. |
If we want to simulate super we should go down the MRO, not up. from sklearn.base import BaseEstimator
class A(BaseEstimator):
def _more_tags(self):
return {'allow_nan': True,
'multioutput': False}
class B(A):
def _more_tags(self):
return {'multioutput': True}
class C(B):
def _more_tags(self):
return {'allow_nan': False}
class D(C):
# No _more_tags(), yet self._more_tags() exists and it's C's method
pass
# "correctly" resolves to C's allow_nan only because C has _more_tags().
print(D()._get_tags()['allow_nan']) # False, OK
# resolves to A's multioutput instead of B's since C's _more_tags does not
# specifies multioutput, and because we're going up the mro instead of down the
# mro.
print(D()._get_tags()['multioutput']) # False, KO |
I need to make sure I have the tests from the other PR, will fix this later. |
sklearn/tests/test_base.py
Outdated
|
||
diamond_tag_est = DiamondOverwriteTag() | ||
with pytest.raises(TypeError, match="Inconsistent values for tag"): | ||
diamond_tag_est._get_tags() | ||
assert diamond_tag_est._get_tags() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lol this is missing an ['allow_nan']
I'm confused why the current tests pass ... |
tests were passing because the logic was to complicated and only wrong in complex cases. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM when all goes green!
Thanks! |
In a follow up, we should document how the order of the inheritance matters now when using tags. |
* see scikit-learn/scikit-learn#14884 * see https://twitter.com/dabeaz/status/809084586487664641 * BaseEnsemble already inherits from MetaEstimatorMixin
Incremental change from #14644, keeping the old logic manually looping over the MRO.
I would really prefer using
super
but this is a mostly backward-compatible change and seems less controversial.