-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
[MRG] Adds plot_precision_recall_curve #14936
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MRG] Adds plot_precision_recall_curve #14936
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, only a couple of changes.
|
||
Parameters | ||
----------- | ||
precision : ndarray of shape (n_thresholds + 1, ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
precision : ndarray of shape (n_thresholds + 1, ) | |
precision : ndarray of shape (n_thresholds + 1,) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can also add an entry in what's new
|
||
if y_pred.ndim != 1: | ||
if y_pred.shape[1] > 2: | ||
raise ValueError("Estimator should solve a " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
isn't it possible to use check_classification_targets
?
conflicts ;) |
|
||
y_pred = prediction_method(X) | ||
|
||
if is_predict_proba and y_pred.ndim != 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if is_predict_proba
y_pred.ndim
is never 1, right?
plot_precision_recall_curve(clf, X, y) | ||
|
||
msg = "Estimator should solve a binary classification problem" | ||
y_binary = y == 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand why this raises this error, both semantically and why that is what the code does. I thought the code checked y_pred
, which we're not changing here, right?
looks good apart from nitpicks |
The error raised does not match. |
CC @NicolasHug |
I agree that this is a blocker, but we need to figure out a solution for #15303 |
The user guide link of plot_precision_recall_curve is wrong: there's no point to link to the vizualization API UG. Also some of the links are broken |
see #15405 (comment) The easy fix is removing it and inferring it from the estimator. The better fix is to actually ensure to correctly slice predict_proba / decision_function |
Went with removing |
Then we should do the same for |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps it's better to keep consistent with plot_roc_curve/RocCurveDisplay (API and code)
doc/visualizations.rst
Outdated
@@ -71,6 +71,7 @@ Functions | |||
|
|||
.. autosummary:: | |||
|
|||
metrics.plot_precision_recall_curve |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
alphabetic order?
doc/visualizations.rst
Outdated
@@ -82,5 +83,6 @@ Display Objects | |||
|
|||
.. autosummary:: | |||
|
|||
metrics.PrecisionRecallDisplay |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
alphabetic order?
sklearn/metrics/__init__.py
Outdated
@@ -79,6 +79,8 @@ | |||
|
|||
from ._plot.roc_curve import plot_roc_curve | |||
from ._plot.roc_curve import RocCurveDisplay | |||
from ._plot.precision_recall import plot_precision_recall_curve |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rename the file to precision_recall_curve.py
?
Axes object to plot on. If `None`, a new figure and axes is | ||
created. | ||
|
||
label_name : str, default=None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is it different from RocCurveDisplay?
|
||
Parameters | ||
----------- | ||
precision : ndarray of shape (n_thresholds + 1,) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
n_thresholds is not defined in this context.
line_kwargs.update(**kwargs) | ||
|
||
self.line_, = ax.plot(self.recall, self.precision, **line_kwargs) | ||
ax.set(xlabel="Recall", ylabel="Precision", ylim=[0.0, 1.05], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rely on default xlim/ylim?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
precision : ndarray of shape (n_thresholds + 1,) | ||
Precision values. | ||
|
||
recall : ndarray of shape (n_thresholds + 1,) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
n_thresholds is not defined in this context.
:term:`predict_proba` is tried first and if it does not exist | ||
:term:`decision_function` is tried next. | ||
|
||
label_name : str, default=None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not consistent with plot_roc_curve
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed this to name
to be consistent with plot_roc_curve
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good but need to link to UG with small updates
|
||
It is recommend to use :func:`~sklearn.metrics.plot_precision_recall_curve` | ||
to create a visualizer. All parameters are stored as attributes. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add link to Visualization UG
"""Plot Precision Recall Curve for binary classifers. | ||
|
||
Extra keyword arguments will be passed to matplotlib's `plot`. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and link include plot_precision_recall_curve
in the UG there
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure why tests are failing but LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think tests are failing because we no longer set xlim and ylim manually but we don't update the test.
I feel a little uncomfortable that plot_roc_curve and plot_precision_recall_curve are written in different way, e.g., we introduce is_predict_proba in plot_precision_recall_curve, but do not introduce it in plot_roc_auc_score. If we keep these two functions consistent, it will be much easier to maintain, but prehaps it's not so important.
I refactored the response method checking into a |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we rename the files to _precision_recall_curve.py and _roc_curve.py
sklearn/metrics/_plot/__init__.py
Outdated
@@ -0,0 +1,40 @@ | |||
def _check_classifer_response_method(estimator, response_method): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it good to include things in init.py? Perhaps base.py?
Let's update plot_roc_curve in this PR?
Since they are both in
Done |
@@ -180,18 +181,8 @@ def plot_roc_curve(estimator, X, y, sample_weight=None, | |||
else: | |||
raise ValueError(classification_error) | |||
|
|||
if response_method != "auto": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also need to remove following things above
if response_method not in ("predict_proba", "decision_function", "auto"):
raise ValueError("response_method must be 'predict_proba', "
"'decision_function' or 'auto'")
Reference Issues/PRs
Related to #7116
What does this implement/fix? Explain your changes.
This PR adds
plot_precision_recall_curve
.Any other comments?
Only supports binary classifiers.