Skip to content

[MRG] Adds plot_precision_recall_curve #14936

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 30 commits into from
Nov 11, 2019

Conversation

thomasjpfan
Copy link
Member

Reference Issues/PRs

Related to #7116

What does this implement/fix? Explain your changes.

This PR adds plot_precision_recall_curve.

Any other comments?

Only supports binary classifiers.

Copy link
Member

@glemaitre glemaitre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, only a couple of changes.


Parameters
-----------
precision : ndarray of shape (n_thresholds + 1, )
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
precision : ndarray of shape (n_thresholds + 1, )
precision : ndarray of shape (n_thresholds + 1,)

Copy link
Member

@glemaitre glemaitre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can also add an entry in what's new


if y_pred.ndim != 1:
if y_pred.shape[1] > 2:
raise ValueError("Estimator should solve a "
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't it possible to use check_classification_targets?

@amueller
Copy link
Member

conflicts ;)


y_pred = prediction_method(X)

if is_predict_proba and y_pred.ndim != 1:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if is_predict_proba y_pred.ndim is never 1, right?

plot_precision_recall_curve(clf, X, y)

msg = "Estimator should solve a binary classification problem"
y_binary = y == 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why this raises this error, both semantically and why that is what the code does. I thought the code checked y_pred, which we're not changing here, right?

@amueller
Copy link
Member

looks good apart from nitpicks

@glemaitre
Copy link
Member

The error raised does not match.

@thomasjpfan
Copy link
Member Author

CC @NicolasHug

@qinhanmin2014
Copy link
Member

I agree that this is a blocker, but we need to figure out a solution for #15303

@NicolasHug
Copy link
Member

The user guide link of plot_precision_recall_curve is wrong: there's no point to link to the vizualization API UG. Also some of the links are broken

@thomasjpfan

@amueller
Copy link
Member

amueller commented Nov 6, 2019

see #15405 (comment)

The easy fix is removing it and inferring it from the estimator. The better fix is to actually ensure to correctly slice predict_proba / decision_function

@thomasjpfan
Copy link
Member Author

Went with removing pos_label and infering it from the estimator.

@amueller
Copy link
Member

amueller commented Nov 6, 2019

Then we should do the same for plot_roc_curve and open an issue to do the fix for the next release?

Copy link
Member

@qinhanmin2014 qinhanmin2014 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps it's better to keep consistent with plot_roc_curve/RocCurveDisplay (API and code)

@@ -71,6 +71,7 @@ Functions

.. autosummary::

metrics.plot_precision_recall_curve
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

alphabetic order?

@@ -82,5 +83,6 @@ Display Objects

.. autosummary::

metrics.PrecisionRecallDisplay
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

alphabetic order?

@@ -79,6 +79,8 @@

from ._plot.roc_curve import plot_roc_curve
from ._plot.roc_curve import RocCurveDisplay
from ._plot.precision_recall import plot_precision_recall_curve
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rename the file to precision_recall_curve.py?

Axes object to plot on. If `None`, a new figure and axes is
created.

label_name : str, default=None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is it different from RocCurveDisplay?


Parameters
-----------
precision : ndarray of shape (n_thresholds + 1,)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

n_thresholds is not defined in this context.

line_kwargs.update(**kwargs)

self.line_, = ax.plot(self.recall, self.precision, **line_kwargs)
ax.set(xlabel="Recall", ylabel="Precision", ylim=[0.0, 1.05],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rely on default xlim/ylim?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the, x, going without the I think for the y it is kind of important because of the scaling:

With ylim explicit set:
set

Not set:
not_set

precision : ndarray of shape (n_thresholds + 1,)
Precision values.

recall : ndarray of shape (n_thresholds + 1,)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

n_thresholds is not defined in this context.

:term:`predict_proba` is tried first and if it does not exist
:term:`decision_function` is tried next.

label_name : str, default=None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not consistent with plot_roc_curve

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed this to name to be consistent with plot_roc_curve.

Copy link
Member

@NicolasHug NicolasHug left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good but need to link to UG with small updates


It is recommend to use :func:`~sklearn.metrics.plot_precision_recall_curve`
to create a visualizer. All parameters are stored as attributes.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add link to Visualization UG

"""Plot Precision Recall Curve for binary classifers.

Extra keyword arguments will be passed to matplotlib's `plot`.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

@NicolasHug NicolasHug left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure why tests are failing but LGTM

Copy link
Member

@qinhanmin2014 qinhanmin2014 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think tests are failing because we no longer set xlim and ylim manually but we don't update the test.

I feel a little uncomfortable that plot_roc_curve and plot_precision_recall_curve are written in different way, e.g., we introduce is_predict_proba in plot_precision_recall_curve, but do not introduce it in plot_roc_auc_score. If we keep these two functions consistent, it will be much easier to maintain, but prehaps it's not so important.

@thomasjpfan
Copy link
Member Author

If we keep these two functions consistent, it will be much easier to maintain, but prehaps it's not so important.

I refactored the response method checking into a _check_classifer_response_method that can be used by plot_roc_auc_curve. We can have a follow up PR to have plot_roc_auc_curve use it as well, to keep the error messages and code consistent.

Copy link
Member

@qinhanmin2014 qinhanmin2014 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we rename the files to _precision_recall_curve.py and _roc_curve.py

@@ -0,0 +1,40 @@
def _check_classifer_response_method(estimator, response_method):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it good to include things in init.py? Perhaps base.py?
Let's update plot_roc_curve in this PR?

@thomasjpfan
Copy link
Member Author

Should we rename the files to _precision_recall_curve.py and _roc_curve.py

Since they are both in _plot, either way works for me.

Let's update plot_roc_curve in this PR?

Done

@@ -180,18 +181,8 @@ def plot_roc_curve(estimator, X, y, sample_weight=None,
else:
raise ValueError(classification_error)

if response_method != "auto":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also need to remove following things above

if response_method not in ("predict_proba", "decision_function", "auto"):
        raise ValueError("response_method must be 'predict_proba', "
                         "'decision_function' or 'auto'")

@qinhanmin2014 qinhanmin2014 merged commit 968252d into scikit-learn:master Nov 11, 2019
panpiort8 pushed a commit to panpiort8/scikit-learn that referenced this pull request Mar 3, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants