-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
[MRG] Adds plot_precision_recall_curve #14936
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
qinhanmin2014
merged 30 commits into
scikit-learn:master
from
thomasjpfan:plot_precision_recall
Nov 11, 2019
Merged
Changes from all commits
Commits
Show all changes
30 commits
Select commit
Hold shift + click to select a range
720f9ac
WIP
thomasjpfan 8ac4469
DOC Uses plot_precision_recall in example
thomasjpfan 0b81383
DOC Adds to userguide
thomasjpfan e9c8131
DOC style
thomasjpfan 241845f
Merge remote-tracking branch 'upstream/master' into plot_precision_re…
thomasjpfan 3d86867
DOC Better docs
thomasjpfan c7029a6
Merge remote-tracking branch 'upstream/master' into plot_precision_re…
thomasjpfan 9bf152b
Merge remote-tracking branch 'upstream/master' into plot_precision_re…
thomasjpfan 66deaac
CLN
thomasjpfan 10dc97e
CLN Address @glemaitre comments
thomasjpfan affec16
CLN Address @glemaitre comments
thomasjpfan 99b18ed
Merge remote-tracking branch 'upstream/master' into plot_precision_re…
thomasjpfan 1796020
DOC Remove whatsnew
thomasjpfan ee62183
Merge remote-tracking branch 'upstream/master' into plot_precision_re…
thomasjpfan d7d448f
DOC Style
thomasjpfan 294c29a
CLN Addresses @amuller comments
thomasjpfan dbe9a3a
CLN Addresses @amuller comments
thomasjpfan 3fade4b
Merge remote-tracking branch 'upstream/master' into plot_precision_re…
thomasjpfan fd1cc44
TST Clearier error messages
thomasjpfan fdb60ae
TST Modify test name
thomasjpfan 7cbb1ac
Merge remote-tracking branch 'upstream/master' into plot_precision_re…
thomasjpfan abbbb9d
BUG Quick fix
thomasjpfan a589f92
BUG Fix test
thomasjpfan c9b3d60
ENH Better error message
thomasjpfan bfd5634
CLN Address comments
thomasjpfan dbae9f8
Merge remote-tracking branch 'upstream/master' into plot_precision_re…
thomasjpfan 2c3d78d
CLN Address comments
thomasjpfan 7736f77
CLN Move to base
thomasjpfan 91f0d05
CLN Unify response detection
thomasjpfan a559342
CLN Removes unneeded check
thomasjpfan File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
def _check_classifer_response_method(estimator, response_method): | ||
"""Return prediction method from the response_method | ||
|
||
Parameters | ||
---------- | ||
estimator: object | ||
Classifier to check | ||
|
||
response_method: {'auto', 'predict_proba', 'decision_function'} | ||
Specifies whether to use :term:`predict_proba` or | ||
:term:`decision_function` as the target response. If set to 'auto', | ||
:term:`predict_proba` is tried first and if it does not exist | ||
:term:`decision_function` is tried next. | ||
|
||
Returns | ||
------- | ||
prediction_method: callable | ||
prediction method of estimator | ||
""" | ||
|
||
if response_method not in ("predict_proba", "decision_function", "auto"): | ||
raise ValueError("response_method must be 'predict_proba', " | ||
"'decision_function' or 'auto'") | ||
|
||
error_msg = "response method {} is not defined in {}" | ||
if response_method != "auto": | ||
prediction_method = getattr(estimator, response_method, None) | ||
if prediction_method is None: | ||
raise ValueError(error_msg.format(response_method, | ||
estimator.__class__.__name__)) | ||
else: | ||
predict_proba = getattr(estimator, 'predict_proba', None) | ||
decision_function = getattr(estimator, 'decision_function', None) | ||
prediction_method = predict_proba or decision_function | ||
if prediction_method is None: | ||
raise ValueError(error_msg.format( | ||
"decision_function or predict_proba", | ||
estimator.__class__.__name__)) | ||
|
||
return prediction_method |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,169 @@ | ||
from .base import _check_classifer_response_method | ||
|
||
from .. import average_precision_score | ||
from .. import precision_recall_curve | ||
|
||
from ...utils import check_matplotlib_support | ||
from ...utils.validation import check_is_fitted | ||
from ...base import is_classifier | ||
|
||
|
||
class PrecisionRecallDisplay: | ||
"""Precision Recall visualization. | ||
|
||
It is recommend to use :func:`~sklearn.metrics.plot_precision_recall_curve` | ||
to create a visualizer. All parameters are stored as attributes. | ||
|
||
Read more in the :ref:`User Guide <visualizations>`. | ||
|
||
Parameters | ||
----------- | ||
precision : ndarray | ||
Precision values. | ||
|
||
recall : ndarray | ||
Recall values. | ||
|
||
average_precision : float | ||
Average precision. | ||
|
||
estimator_name : str | ||
Name of estimator. | ||
|
||
Attributes | ||
---------- | ||
line_ : matplotlib Artist | ||
Precision recall curve. | ||
|
||
ax_ : matplotlib Axes | ||
Axes with precision recall curve. | ||
|
||
figure_ : matplotlib Figure | ||
Figure containing the curve. | ||
""" | ||
|
||
def __init__(self, precision, recall, average_precision, estimator_name): | ||
self.precision = precision | ||
self.recall = recall | ||
self.average_precision = average_precision | ||
self.estimator_name = estimator_name | ||
|
||
def plot(self, ax=None, name=None, **kwargs): | ||
"""Plot visualization. | ||
|
||
Extra keyword arguments will be passed to matplotlib's `plot`. | ||
|
||
Parameters | ||
---------- | ||
ax : Matplotlib Axes, default=None | ||
Axes object to plot on. If `None`, a new figure and axes is | ||
created. | ||
|
||
name : str, default=None | ||
Name of precision recall curve for labeling. If `None`, use the | ||
name of the estimator. | ||
|
||
**kwargs : dict | ||
Keyword arguments to be passed to matplotlib's `plot`. | ||
|
||
Returns | ||
------- | ||
display : :class:`~sklearn.metrics.PrecisionRecallDisplay` | ||
Object that stores computed values. | ||
""" | ||
check_matplotlib_support("PrecisionRecallDisplay.plot") | ||
import matplotlib.pyplot as plt | ||
|
||
if ax is None: | ||
fig, ax = plt.subplots() | ||
|
||
name = self.estimator_name if name is None else name | ||
|
||
line_kwargs = { | ||
"label": "{} (AP = {:0.2f})".format(name, | ||
self.average_precision), | ||
"drawstyle": "steps-post" | ||
} | ||
line_kwargs.update(**kwargs) | ||
|
||
self.line_, = ax.plot(self.recall, self.precision, **line_kwargs) | ||
ax.set(xlabel="Recall", ylabel="Precision") | ||
ax.legend(loc='lower left') | ||
|
||
self.ax_ = ax | ||
self.figure_ = ax.figure | ||
return self | ||
|
||
|
||
def plot_precision_recall_curve(estimator, X, y, | ||
sample_weight=None, response_method="auto", | ||
name=None, ax=None, **kwargs): | ||
"""Plot Precision Recall Curve for binary classifers. | ||
|
||
Extra keyword arguments will be passed to matplotlib's `plot`. | ||
|
||
Read more in the :ref:`User Guide <precision_recall_f_measure_metrics>`. | ||
|
||
Parameters | ||
---------- | ||
estimator : estimator instance | ||
Trained classifier. | ||
|
||
X : {array-like, sparse matrix} of shape (n_samples, n_features) | ||
Input values. | ||
|
||
y : array-like of shape (n_samples,) | ||
Binary target values. | ||
|
||
sample_weight : array-like of shape (n_samples,), default=None | ||
Sample weights. | ||
|
||
response_method : {'predict_proba', 'decision_function', 'auto'}, \ | ||
default='auto' | ||
Specifies whether to use :term:`predict_proba` or | ||
:term:`decision_function` as the target response. If set to 'auto', | ||
:term:`predict_proba` is tried first and if it does not exist | ||
:term:`decision_function` is tried next. | ||
|
||
name : str, default=None | ||
Name for labeling curve. If `None`, the name of the | ||
estimator is used. | ||
|
||
ax : matplotlib axes, default=None | ||
Axes object to plot on. If `None`, a new figure and axes is created. | ||
|
||
**kwargs : dict | ||
Keyword arguments to be passed to matplotlib's `plot`. | ||
|
||
Returns | ||
------- | ||
display : :class:`~sklearn.metrics.PrecisionRecallDisplay` | ||
Object that stores computed values. | ||
""" | ||
check_matplotlib_support("plot_precision_recall_curve") | ||
check_is_fitted(estimator) | ||
|
||
classificaiton_error = ("{} should be a binary classifer".format( | ||
estimator.__class__.__name__)) | ||
if is_classifier(estimator): | ||
if len(estimator.classes_) != 2: | ||
raise ValueError(classificaiton_error) | ||
pos_label = estimator.classes_[1] | ||
else: | ||
raise ValueError(classificaiton_error) | ||
|
||
prediction_method = _check_classifer_response_method(estimator, | ||
response_method) | ||
y_pred = prediction_method(X) | ||
|
||
if y_pred.ndim != 1: | ||
y_pred = y_pred[:, 1] | ||
|
||
precision, recall, _ = precision_recall_curve(y, y_pred, | ||
pos_label=pos_label, | ||
sample_weight=sample_weight) | ||
average_precision = average_precision_score(y, y_pred, | ||
sample_weight=sample_weight) | ||
viz = PrecisionRecallDisplay(precision, recall, average_precision, | ||
estimator.__class__.__name__) | ||
return viz.plot(ax=ax, name=name, **kwargs) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also need to remove following things above