Skip to content

SLEP006: CalibratedClassifierCV #24126

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

BenjaminBossan
Copy link
Contributor

Reference Issues/PRs

#22893

What does this implement/fix? Explain your changes.

This PR adds metadata routing to CalibratedClassifierCV (CCV). CCV uses
a subestimator to create (out of sample) probabilities, which are in
turn used to calibrate the probabilities.

The metaestimator uses sample_weight. The subestimator may or may not
use sample_weight and additional metadata. So far, it was checked if the
subestimator has sample_weight in its signature and then they were
routed, otherwise not. This is, however, not always ideal, e.g. when the
subestimator is itself a pipeline (#21134). With routing, this problem disappears.

Any other comments?

The majority of the work here was done pair-programming with @adrinjalali.
Therefore, having a fresh set of eyes to review would be appreciated.

In addition to these changes, the tests in
test_metaestimator_metadata_routing.py have been amended to make them
more generic, as right now, they are specific to multioutput.

A current limitation of the generic tests is that check_recorded_metadata cannot
be performed for CCV. The reason is that CCV internally creates a slice of the
metadata before passing them to the subestimator. So exact equality fails in this
case. The possibility was discussed to check for exact equality or for the passed
data being a subset; this would work in this case but not in others, e.g. when sample
weights are normalized. Therefore, the solution for now is that in the tests, it can
be declared that this specific metaestimator opts out of check_recorded_metadata.

@adrinjalali I still don't use the exact values in "warns_on", please let me know
how to use them exactly. I thought it's easier to discuss this with the code out.

This PR adds metadata routing to CalibratedClassifierCV (CCV). CCV uses
a subestimator to create (out of sample) probabilities, which are in
turn used to calibrate the probabilities.

The metaestimator uses sample_weight. The subestimator may or may not
use sample_weight and additional metadata. So far, it was checked if the
subestimator has sample_weight in its signature and then they were
routed, otherwise not. This is, however, not always ideal, e.g. when the
subestimator is itself a
pipeline (scikit-learn#21134).
With routing, this problem disappears.

In addition to these changes, the tests in
test_metaestimator_metadata_routing.py have been amended to make them
more generic, as right now, they are specific to multioutput.
@BenjaminBossan
Copy link
Contributor Author

Linting errors seem to be unrelated to this PR

@thomasjpfan
Copy link
Member

Syncing with main should help. The linting error was recently fixed in #24065

@BenjaminBossan
Copy link
Contributor Author

Syncing with main should help. The linting error was recently fixed in #24065

I think since this PR is against sample-props, the sample-props branch needs to be synced first, right?

@adrinjalali
Copy link
Member

Yeah I'll sync that branch with main today.

@adrinjalali
Copy link
Member

sample-props branch is now synced with main.

@BenjaminBossan
Copy link
Contributor Author

I updated but the linting still fails for unrelated reasons. It also fails on the sample-props branch itself because of what looks like changes introduced by the multioutput PR:

https://dev.azure.com/scikit-learn/scikit-learn/_build/results?buildId=45465&view=logs&j=32e2e1bb-a28f-5b18-6cfc-3f01273f5609&t=8a54543f-0728-5134-6642-bedd98e03dd0

I think this PR can thus be reviewed despite the linting errors.

Copy link
Member

@adrinjalali adrinjalali left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @BenjaminBossan , this looks great!

@@ -259,6 +259,31 @@ def __init__(
self.ensemble = ensemble
self.base_estimator = base_estimator

def _get_estimator(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note to other reviewers: this is only a refactoring. Used in fit and get_metadata_routing.

Comment on lines -339 to -352
# sample_weight checks
fit_parameters = signature(estimator.fit).parameters
supports_sw = "sample_weight" in fit_parameters
if sample_weight is not None and not supports_sw:
estimator_name = type(estimator).__name__
warnings.warn(
f"Since {estimator_name} does not appear to accept sample_weight, "
"sample weights will only be used for the calibration itself. This "
"can be caused by a limitation of the current scikit-learn API. "
"See the following issue for more details: "
"https://github.com/scikit-learn/scikit-learn/issues/21134. Be "
"warned that the result of the calibration is likely to be "
"incorrect."
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: metadata routing removes the need for this warning. The user will get the right warnings / errors if the metadata is not requested properly.

@@ -380,20 +378,14 @@ def fit(self, X, y, sample_weight=None, **fit_params):
test=test,
method=self.method,
classes=self.classes_,
supports_sw=supports_sw,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: we don't need this parameter since routing will know what to route and what not.

"""
router = (
MetadataRouter(owner=self.__class__.__name__)
.add_self(self)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: self is added since this CCV is both a consumer and a router. One can do weighted CCV but unweighted fit for the underlying estimator.

Comment on lines +53 to +54
def _weighted(estimator):
return estimator.set_fit_request(sample_weight=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: only fit can be weighted in CCV, hence only requesting sample_weight for fit.

Comment on lines 230 to 231
method = getattr(instance, method_name)
method(X, y, sample_weight=sample_weight, metadata=metadata)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this assumes only fit and partial_fit route things around, but other methods could do the same, like transform, score etc.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean the assumption is implicitly made because of how the method is called? I.e. we need a more generic way to call the method? If so, what needs to be generic: calling with y, calling with sample_weight?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the way the method is called is fine, but if the method being called is not fit or partial_fit, it'll raise an exception that the estimator has not been fit.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I see. I have no really good solution for this that would work in all cases. Just something OTOH:

# before calling method
if not "fit" in method_name:
    instance.fit(X, y)  # <= would probably still fail on some transformers
...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah that would work I think.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment on lines +1054 to +1056
validate_keys : bool, default=True
Whether to check if the requested parameters fit the actual parameters
of the method.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: this was added so that we could simply add an instance of this descriptor to CheckingClassifier

- Add __copy__ method to Registry
- Fix parameter docstrings
- Don't repeat metaestimator ids code
- Check_recorded_metadata runs for all registered estimators
- More fine-grained check for warnings, so as _not_ to error on
  unrelated warnings
Copy link
Contributor Author

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I addressed some of your comments and had some questions on others, please take a look.

Comment on lines 230 to 231
method = getattr(instance, method_name)
method(X, y, sample_weight=sample_weight, metadata=metadata)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean the assumption is implicitly made because of how the method is called? I.e. we need a more generic way to call the method? If so, what needs to be generic: calling with y, calling with sample_weight?

Change around structure of the generic test to make more sense. Use the
values to check for specific arguments that should be warned on, instead
of all arguments at once.
Copy link
Member

@adrinjalali adrinjalali left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

Comment on lines 83 to 85
# only check keys whose value was explicitly passed
expected_keys = {key for key, val in records.items() if val is not None}
assert set(kwargs.keys()) == expected_keys
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thinking more about this, the change is making the test weaker, and it makes it weaker everywhere.

I think a safer option would be to leave this function as is, and change record_metadata to accept a record_none arg, which is True by default, and in our Consumer estimators in this PR we can set that arg to False.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, I changed it as you suggested.

Comment on lines 271 to 273
if method_name in warns_on:
# this method is expected to warn, not raise
continue
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as the above test, if the method's name is in warns_on, it doesn't mean it always warns, it means it warns only for those attributes which are listed there. So we need to test for the other attributes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, the error test now works analogously to the warning test.

There is now an option for record_metadata to not store None values from
kwargs. This is used in the tests now.
There was no error because it's not being used right now.
Analogous to the warning test, we want to check each argument in
isolation for the error case.
The could be methods like "score" that require a fitted metaestimator.
Therefore, it is fitted before calling the tested method, except of the
tested method is a fitting method.

Note that right now, this never happens, so in a way that code path is
untested.
Comment on lines 76 to 78
If record_none is False, kwargs whose values are None are skipped. This is
so that checks on keyword arguments whose default was not changed are
skipped.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see that record_metadata and check_recorded_metadata is only used for testing.

It is strange how some test such as test_simple_metadata_routing expects None to be recorded while test in test_metaestimators_metadata_routing.py expects them not to be recorded.

Can we assume that None is never recorded all the time?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We record and check None to make sure if a metadata is not requested, it is not passed, not even as None.

The tests which don't record None, work because the user is not passing any metadata as None, and we need to ignore them cause the default value is None and explicitly set in those sub-estimator methods.

Alternatively, we could change the default to "default", and let record_metadata know what the default value is, and only ignore that. It might be cleaner.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think about always recording None and have a flag in check_recorded_metadata to switch between the two modes of checking? Concretely:

def check_recorded_metadata(obj, method, strict=True, **kwargs):
    """Check whether the expected metadata is passed to the object's method."""
    records = getattr(obj, "_records", dict()).get(method, dict())

    cmp = operator.eq if strict else operator.le
    assert cmp(set(kwargs.keys()), set(records.keys()))
    for key, value in kwargs.items():
        assert records[key] is value

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I kinda prefer the current implementation because it's easier for the test to test exactly what it needs to. What you have here kind of a subset of what the current implementation does, as in, there's no way to test if the router has routed None explicitly or not, and that's something we need to check. I might be missing something here though.

Copy link
Member

@adrinjalali adrinjalali left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Other than this (#24126 (comment)), which kinda makes the tests more robust, I'm happy with the PR.

BenjaminBossan and others added 4 commits August 16, 2022 16:14
Give option to not check if the routed metadata are the literal string
"default" (instead of checking for None).
@adrinjalali
Copy link
Member

Why did it not fail when @thomasjpfan pushes? 🤯

Copy link
Member

@adrinjalali adrinjalali left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't mind the custom group splitter either way. LGTM.

Comment on lines 1073 to 1078
class MyGroupKFold(GroupKFold):
"""Custom Splitter that checks that the values of groups are correct"""

def split(self, X, y=None, groups=None):
assert (groups == split_groups).all()
return super().split(X, y=y, groups=groups)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GroupKFold raises if groups is None anyway, you won't need this custom class to test.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was more about checking that the correct groups data is being passed. If you think it's not necessary, I'd rather remove it to simplify the test.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we have to check for the correctness of the groups. Those mechanisms are tested elsewhere.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed the custom class.

Could you or someone else please push an empty commit for CI? It seems like CircleCI hasn't solved the issue yet.

BenjaminBossan and others added 2 commits August 17, 2022 15:23
This is already covererd by the general routing tests.
Copy link
Member

@thomasjpfan thomasjpfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@thomasjpfan thomasjpfan merged commit 0afaa63 into scikit-learn:sample-props Aug 17, 2022
@BenjaminBossan BenjaminBossan deleted the slep006/calibratedclassifiercv branch August 18, 2022 07:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants