Skip to content

FIX CalibratedClassifierCV should not ignore sample_weight if estimator does not support it #21143

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from

Conversation

glemaitre
Copy link
Member

Partially addressed #21134

Forces to raise an error with Pipeline included in meta-estimators to not ignore silently sample_weight.
In the future, we should address #18159 and the test should not raise an error and delegate the weights to the right estimator in the Pipeline.

@glemaitre glemaitre added this to the 1.0.1 milestone Sep 24, 2021
@glemaitre glemaitre marked this pull request as draft September 24, 2021 15:25
@glemaitre
Copy link
Member Author

Now that I am looking at the calibration code, it seems that we intend to raise a Warning (which I ignore while doing my primary tests). @lucyleeow do you remember why was not it controversial to not fit a model discarding sample_weight and only using it for calibration?

@lucyleeow
Copy link
Member

lucyleeow commented Sep 25, 2021

I think the intention when I refactored this function was to keep all functionality the same, and any fixes/changes to be done afterwards, separately. (not that I can remember any fixes..!)

Looking through git blame, it seems that the warning about when sample weight is ignored is added here: 70d49de

And it seems this ignoring of sample weight originates from the start? ecfc93d:

            for train, test in cv:
                this_estimator = clone(self.base_estimator)
                if sample_weight is not None and \
                   "sample_weight" in inspect.getargspec(
                        this_estimator.fit)[0]:
                    this_estimator.fit(X[train], y[train],
                                       sample_weight[train])
                else:
                    this_estimator.fit(X[train], y[train])

@glemaitre
Copy link
Member Author

And it seems this ignoring of sample weight originates from the start?

Thanks @lucyleeow for the insights. I am doubting that this is a good strategy, though. I will raise this issue in the next dev meeting then.

@glemaitre glemaitre changed the title TST add common tests for meta-estimators FIX CalibratedClassifierCV should not ignore sample_weight if estimator does not support it Sep 27, 2021
@ogrisel
Copy link
Member

ogrisel commented Sep 27, 2021

I am not sure if passing sample_weight both to the calibrator and the base_estimator is a form of "double-accounting" or not. Because we do a cross-val split, I think not but I am not 100% sure.

Maybe the best way to do would be to consider 2 datasets:

  • X, y, sample_weight where all samples have weight values of 1. except the last one that has a weight value of 2.
  • X, y and sample_weight=None with the same samples except for the last one that is duplicated.

Intuitively we would like that calling CallibratedClassifierCV(some_estimator) on those two cases to yield exactly the same decision function (same predict_proba values), in expectation.

  • When ensemble=True it seems that it can be the case only of both the calibrators and the base estimators are being propagated the weights. Which means that the current situation is probably yielding bad results if the base estimator does not accept sample weights.
  • When ensemble=False, I am not sure...

We could have a similar test to check that that dropping a sample is equivalent to setting it a weight of 0. There is a common test for this latter semantics but it is XFAILing for CalibratedClassifierCV:

https://github.com/scikit-learn/scikit-learn/blob/main/sklearn/calibration.py#L437-L440

So I think we should at least add dedicated tests for CalibratedClassifierCV to check for some semantics and document cases that work as expected.

For cases that are not working, we should make sure that we include them as part of the PR prototype for SLEP006 on meta-data routing, e.g. as part of #20350.

Whether or not we should raise ValueError or a warning with a stronger message in the mean time, while waiting for SLEP0006 ... I don't have a strong opinion. Maybe a warning is enough if we reword it.

@ogrisel
Copy link
Member

ogrisel commented Sep 28, 2021

Intuitively we would like that calling CallibratedClassifierCV(some_estimator) on those two cases to yield exactly the same decision function (same predict_proba values), in expectation.

Thinking a bit more about this this might be challenging to test with a limited computational budget because the cross-validation strategy might be non-deterministic and the "in expectation" would require a statistical test.

To simplifiy the problem we could make sure that we run this test with a simplistic, deterministic CV loop (simple 3 or 5-Fold CV without shuffling or stratification) and put the duplicated samples and sample with weight 2 in the same position (e.g. in the last CV fold in both cases.

Same strategy could be adapted to to check the 0 weight / sample drop equivalence.

@glemaitre
Copy link
Member Author

To simplifiy the problem we could make sure that we run this test with a simplistic, deterministic CV loop (simple 3 or 5-Fold CV without shuffling or stratification) and put the duplicated samples and sample with weight 2 in the same position (e.g. in the last CV fold in both cases.

I was indeed starting to make a 2-fold cross-validation with iris (only 2 first class) where it would be easy to check the underlying weight of the classifier and the parameters of the calibrator to understand exactly what we are doing with the weight.

@glemaitre
Copy link
Member Author

We can postpone this PR until we have proper dispatching with sample props

@glemaitre glemaitre modified the milestones: 1.0.1, 1.1 Oct 20, 2021
@jeremiedbb jeremiedbb modified the milestones: 1.1, 1.2 Apr 7, 2022
@glemaitre glemaitre modified the milestones: 1.2, 1.3 Nov 16, 2022
@jeremiedbb jeremiedbb modified the milestones: 1.3, 1.4 Jun 8, 2023
@glemaitre
Copy link
Member Author

Solved by using meta-data routing.

@glemaitre glemaitre closed this Dec 7, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants