-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
SLEP006: CalibratedClassifierCV #24126
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
SLEP006: CalibratedClassifierCV #24126
Conversation
This PR adds metadata routing to CalibratedClassifierCV (CCV). CCV uses a subestimator to create (out of sample) probabilities, which are in turn used to calibrate the probabilities. The metaestimator uses sample_weight. The subestimator may or may not use sample_weight and additional metadata. So far, it was checked if the subestimator has sample_weight in its signature and then they were routed, otherwise not. This is, however, not always ideal, e.g. when the subestimator is itself a pipeline (scikit-learn#21134). With routing, this problem disappears. In addition to these changes, the tests in test_metaestimator_metadata_routing.py have been amended to make them more generic, as right now, they are specific to multioutput.
Linting errors seem to be unrelated to this PR |
Syncing with |
I think since this PR is against |
Yeah I'll sync that branch with |
|
I updated but the linting still fails for unrelated reasons. It also fails on the I think this PR can thus be reviewed despite the linting errors. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @BenjaminBossan , this looks great!
@@ -259,6 +259,31 @@ def __init__( | |||
self.ensemble = ensemble | |||
self.base_estimator = base_estimator | |||
|
|||
def _get_estimator(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
note to other reviewers: this is only a refactoring. Used in fit
and get_metadata_routing
.
# sample_weight checks | ||
fit_parameters = signature(estimator.fit).parameters | ||
supports_sw = "sample_weight" in fit_parameters | ||
if sample_weight is not None and not supports_sw: | ||
estimator_name = type(estimator).__name__ | ||
warnings.warn( | ||
f"Since {estimator_name} does not appear to accept sample_weight, " | ||
"sample weights will only be used for the calibration itself. This " | ||
"can be caused by a limitation of the current scikit-learn API. " | ||
"See the following issue for more details: " | ||
"https://github.com/scikit-learn/scikit-learn/issues/21134. Be " | ||
"warned that the result of the calibration is likely to be " | ||
"incorrect." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
note: metadata routing removes the need for this warning. The user will get the right warnings / errors if the metadata is not requested properly.
@@ -380,20 +378,14 @@ def fit(self, X, y, sample_weight=None, **fit_params): | |||
test=test, | |||
method=self.method, | |||
classes=self.classes_, | |||
supports_sw=supports_sw, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
note: we don't need this parameter since routing will know what to route and what not.
""" | ||
router = ( | ||
MetadataRouter(owner=self.__class__.__name__) | ||
.add_self(self) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
note: self is added since this CCV is both a consumer and a router. One can do weighted CCV but unweighted fit for the underlying estimator.
def _weighted(estimator): | ||
return estimator.set_fit_request(sample_weight=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
note: only fit
can be weighted in CCV, hence only requesting sample_weight for fit
.
method = getattr(instance, method_name) | ||
method(X, y, sample_weight=sample_weight, metadata=metadata) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this assumes only fit
and partial_fit
route things around, but other methods could do the same, like transform
, score
etc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean the assumption is implicitly made because of how the method is called? I.e. we need a more generic way to call the method? If so, what needs to be generic: calling with y
, calling with sample_weight
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the way the method is called is fine, but if the method being called is not fit
or partial_fit
, it'll raise an exception that the estimator has not been fit
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, I see. I have no really good solution for this that would work in all cases. Just something OTOH:
# before calling method
if not "fit" in method_name:
instance.fit(X, y) # <= would probably still fail on some transformers
...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah that would work I think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
validate_keys : bool, default=True | ||
Whether to check if the requested parameters fit the actual parameters | ||
of the method. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
note: this was added so that we could simply add an instance of this descriptor to CheckingClassifier
- Add __copy__ method to Registry - Fix parameter docstrings - Don't repeat metaestimator ids code - Check_recorded_metadata runs for all registered estimators - More fine-grained check for warnings, so as _not_ to error on unrelated warnings
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I addressed some of your comments and had some questions on others, please take a look.
method = getattr(instance, method_name) | ||
method(X, y, sample_weight=sample_weight, metadata=metadata) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean the assumption is implicitly made because of how the method is called? I.e. we need a more generic way to call the method? If so, what needs to be generic: calling with y
, calling with sample_weight
?
Change around structure of the generic test to make more sense. Use the values to check for specific arguments that should be warned on, instead of all arguments at once.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
# only check keys whose value was explicitly passed | ||
expected_keys = {key for key, val in records.items() if val is not None} | ||
assert set(kwargs.keys()) == expected_keys |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thinking more about this, the change is making the test weaker, and it makes it weaker everywhere.
I think a safer option would be to leave this function as is, and change record_metadata
to accept a record_none
arg, which is True
by default, and in our Consumer estimators in this PR we can set that arg to False
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense, I changed it as you suggested.
if method_name in warns_on: | ||
# this method is expected to warn, not raise | ||
continue |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as the above test, if the method's name is in warns_on
, it doesn't mean it always warns, it means it warns only for those attributes which are listed there. So we need to test for the other attributes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're right, the error test now works analogously to the warning test.
There is now an option for record_metadata to not store None values from kwargs. This is used in the tests now.
There was no error because it's not being used right now.
Analogous to the warning test, we want to check each argument in isolation for the error case.
The could be methods like "score" that require a fitted metaestimator. Therefore, it is fitted before calling the tested method, except of the tested method is a fitting method. Note that right now, this never happens, so in a way that code path is untested.
If record_none is False, kwargs whose values are None are skipped. This is | ||
so that checks on keyword arguments whose default was not changed are | ||
skipped. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see that record_metadata
and check_recorded_metadata
is only used for testing.
It is strange how some test such as test_simple_metadata_routing
expects None
to be recorded while test in test_metaestimators_metadata_routing.py
expects them not to be recorded.
Can we assume that None
is never recorded all the time?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We record and check None to make sure if a metadata is not requested, it is not passed, not even as None.
The tests which don't record None, work because the user is not passing any metadata as None, and we need to ignore them cause the default value is None
and explicitly set in those sub-estimator methods.
Alternatively, we could change the default to "default"
, and let record_metadata
know what the default value is, and only ignore that. It might be cleaner.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you think about always recording None
and have a flag in check_recorded_metadata
to switch between the two modes of checking? Concretely:
def check_recorded_metadata(obj, method, strict=True, **kwargs):
"""Check whether the expected metadata is passed to the object's method."""
records = getattr(obj, "_records", dict()).get(method, dict())
cmp = operator.eq if strict else operator.le
assert cmp(set(kwargs.keys()), set(records.keys()))
for key, value in kwargs.items():
assert records[key] is value
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I kinda prefer the current implementation because it's easier for the test to test exactly what it needs to. What you have here kind of a subset of what the current implementation does, as in, there's no way to test if the router has routed None explicitly or not, and that's something we need to check. I might be missing something here though.
Specifically, overriding set_fit_request.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Other than this (#24126 (comment)), which kinda makes the tests more robust, I'm happy with the PR.
Give option to not check if the routed metadata are the literal string "default" (instead of checking for None).
Why did it not fail when @thomasjpfan pushes? 🤯 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't mind the custom group splitter either way. LGTM.
sklearn/tests/test_calibration.py
Outdated
class MyGroupKFold(GroupKFold): | ||
"""Custom Splitter that checks that the values of groups are correct""" | ||
|
||
def split(self, X, y=None, groups=None): | ||
assert (groups == split_groups).all() | ||
return super().split(X, y=y, groups=groups) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
GroupKFold
raises if groups
is None anyway, you won't need this custom class to test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was more about checking that the correct groups data is being passed. If you think it's not necessary, I'd rather remove it to simplify the test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we have to check for the correctness of the groups. Those mechanisms are tested elsewhere.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I removed the custom class.
Could you or someone else please push an empty commit for CI? It seems like CircleCI hasn't solved the issue yet.
This is already covererd by the general routing tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Reference Issues/PRs
#22893
What does this implement/fix? Explain your changes.
This PR adds metadata routing to
CalibratedClassifierCV
(CCV). CCV usesa subestimator to create (out of sample) probabilities, which are in
turn used to calibrate the probabilities.
The metaestimator uses
sample_weight
. The subestimator may or may notuse
sample_weight
and additional metadata. So far, it was checked if thesubestimator has
sample_weight
in its signature and then they wererouted, otherwise not. This is, however, not always ideal, e.g. when the
subestimator is itself a pipeline (#21134). With routing, this problem disappears.
Any other comments?
The majority of the work here was done pair-programming with @adrinjalali.
Therefore, having a fresh set of eyes to review would be appreciated.
In addition to these changes, the tests in
test_metaestimator_metadata_routing.py
have been amended to make themmore generic, as right now, they are specific to multioutput.
A current limitation of the generic tests is that
check_recorded_metadata
cannotbe performed for CCV. The reason is that CCV internally creates a slice of the
metadata before passing them to the subestimator. So exact equality fails in this
case. The possibility was discussed to check for exact equality or for the passed
data being a subset; this would work in this case but not in others, e.g. when sample
weights are normalized. Therefore, the solution for now is that in the tests, it can
be declared that this specific metaestimator opts out of
check_recorded_metadata
.@adrinjalali I still don't use the exact values in
"warns_on"
, please let me knowhow to use them exactly. I thought it's easier to discuss this with the code out.