Skip to content

ENH Add CalibrationDisplay plotting class #17443

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 151 commits into from
Aug 31, 2021
Merged

Conversation

lucyleeow
Copy link
Member

@lucyleeow lucyleeow commented Jun 4, 2020

Reference Issues/PRs

closes #8425

What does this implement/fix? Explain your changes.

Adds CalibrationDisplay for binary classifiers with visualization API

Any other comments?

Plot currently looks like this:
image

Yet to add tests and update examples as I am unsure about API:

Should a histogram be added? If so should the histogram be a separate plot (e.g., here) or on the same plot (as suggested by @amueller: #8425 (comment)).
If we put it on the same plot I'm worried it will be too crowded. If we want 2 separate plots, the API becomes difficult as we should have 2 different plot **kwargs parameters, for both plots so people can amend them separately. You also couldn't use ax = plt.gca() to get the current axis when you want to add lines to an existing plot (like for the current plots, e.g., here). I think you could use CalibrationDisplay.ax_ though.

Copy link
Member

@NicolasHug NicolasHug left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @lucyleeow , made a quick pass but looks good

Regarding the histograms: since it's just a simple call the plt.hist, I think we should let users call that and rely on the prob_pred attribute of the Vizualiser. Since this would be illustrated in the examples, it's fine IMO.

@glemaitre glemaitre self-assigned this Jun 4, 2020
@glemaitre glemaitre removed their assignment Jun 4, 2020
Copy link
Member

@glemaitre glemaitre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I post these comments. I see this is still WIP.

self.prob_pred = prob_pred
self.estimator_name = estimator_name

@_deprecate_positional_args
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We would need this deprecation since this is a new class and new function?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should I move the * to not allow any positional args..? (bit confused about this)

Copy link
Member Author

@lucyleeow lucyleeow Jun 4, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Specifically here def plot(self, ax=None, *, name=None, ref_line=True, **kwargs): as ax arguably should/could be keyword.

@lucyleeow
Copy link
Member Author

Thanks for the reviews guys! I would like to update the examples plot_calibration_curve.py and plot_compare_calibration.py to notebook style and add links, as well as amending them to use plot_calibration_curve.

Can I do the non-plot_calibration_curve changes in this PR or should I do it in a different PR?

@NicolasHug
Copy link
Member

I'm fine with doing this here so we can see it in action.

Regarding the module, I'm not sure metric is the right choice here. I think we should either use sklearn.inspection, or transform calibration into a sub-package?

@lucyleeow
Copy link
Member Author

Good point, it's not really a metric. sklearn.inspection sounds reasonable, from the user guide:

The sklearn.inspection module provides tools to help understand the predictions from a model and what affects them. This can be used to evaluate assumptions and biases of a model, design a better model, or to diagnose issues with model performance.

I'll wait a bit to see if there are any objections and if there are none, I'll move it there.

@lucyleeow
Copy link
Member Author

lucyleeow commented Jun 5, 2020

@NicolasHug would it be appropriate to put this inside sklearn/calibration.py and sklearn/tests/test_calibration.py or should this plotting code be in it's own files?

@lucyleeow
Copy link
Member Author

lucyleeow commented Jul 22, 2021

ping @glemaitre changes made, thanks! (and the CIs are greeeeen!)

@glemaitre
Copy link
Member

LGTM apart from the small comment where I would check what failing message do we get.

@glemaitre
Copy link
Member

I will not merge right now. I would like to know if @ogrisel +1 is still standing after the changes.

@lucyleeow
Copy link
Member Author

ping @ogrisel...?

@lorentzenchr
Copy link
Member

@ogrisel Can we merge?

@lorentzenchr lorentzenchr added this to the 1.0 milestone Aug 23, 2021
@lorentzenchr
Copy link
Member

@adrinjalali I added this PR to the 1.0 milestone as I think this should really be in. Already +2 and just waiting for a final OK of @ogrisel as there have been some changes since his approval.

@lucyleeow
Copy link
Member Author

Yes please!

Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is now a failing test, not sure why:

In sklearn/calibration.py at line 650:

        if y_pred.ndim != 1:  # `predict_proba`
            if y_pred.shape[1] != 2:
>               raise ValueError(classification_error)
E               ValueError: Expected 'estimator' to be a binary classifier, but got DecisionTreeClassifier

Beyond fixing the test failure, I think the error message should be improved to report the observed shape of y_pred.

@glemaitre
Copy link
Member

The failure is due to the fact that we already improved the previous error message in another PR.
Here, we could maybe catch this message and add the part that @ogrisel would like about y_pred.

Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another detail to fix below.

Other than that and the failing test, I confirm that this PR looks good to me. Thanks again @lucyleeow.

Comment on lines +107 to +108
f"{classification_error} fit on multiclass ({y_pred_shape} classes)"
" data"
Copy link
Member Author

@lucyleeow lucyleeow Aug 31, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy to change this message if you had something different in mind @ogrisel.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, the message looks good!

@lucyleeow
Copy link
Member Author

Thank you @ogrisel @glemaitre, changes made!

@ogrisel ogrisel merged commit da36f72 into scikit-learn:main Aug 31, 2021
@ogrisel
Copy link
Member

ogrisel commented Aug 31, 2021

Merged! Thank you very much @lucyleeow!

@lucyleeow lucyleeow deleted the IS/8425 branch August 31, 2021 07:44
@lucyleeow
Copy link
Member Author

Thanks! I'm so happy! :D

@glemaitre
Copy link
Member

Nice. Thanks @lucyleeow

samronsin pushed a commit to samronsin/scikit-learn that referenced this pull request Nov 30, 2021
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add calibration curve to plotting module
9 participants