Skip to content

[MRG] Adds plot_precision_recall_curve #14936

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 30 commits into from
Nov 11, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
720f9ac
WIP
thomasjpfan Aug 20, 2019
8ac4469
DOC Uses plot_precision_recall in example
thomasjpfan Aug 22, 2019
0b81383
DOC Adds to userguide
thomasjpfan Aug 22, 2019
e9c8131
DOC style
thomasjpfan Sep 3, 2019
241845f
Merge remote-tracking branch 'upstream/master' into plot_precision_re…
thomasjpfan Sep 3, 2019
3d86867
DOC Better docs
thomasjpfan Sep 5, 2019
c7029a6
Merge remote-tracking branch 'upstream/master' into plot_precision_re…
thomasjpfan Sep 5, 2019
9bf152b
Merge remote-tracking branch 'upstream/master' into plot_precision_re…
thomasjpfan Sep 9, 2019
66deaac
CLN
thomasjpfan Sep 9, 2019
10dc97e
CLN Address @glemaitre comments
thomasjpfan Sep 19, 2019
affec16
CLN Address @glemaitre comments
thomasjpfan Sep 20, 2019
99b18ed
Merge remote-tracking branch 'upstream/master' into plot_precision_re…
thomasjpfan Sep 20, 2019
1796020
DOC Remove whatsnew
thomasjpfan Sep 24, 2019
ee62183
Merge remote-tracking branch 'upstream/master' into plot_precision_re…
thomasjpfan Sep 24, 2019
d7d448f
DOC Style
thomasjpfan Sep 24, 2019
294c29a
CLN Addresses @amuller comments
thomasjpfan Sep 25, 2019
dbe9a3a
CLN Addresses @amuller comments
thomasjpfan Sep 25, 2019
3fade4b
Merge remote-tracking branch 'upstream/master' into plot_precision_re…
thomasjpfan Oct 2, 2019
fd1cc44
TST Clearier error messages
thomasjpfan Oct 2, 2019
fdb60ae
TST Modify test name
thomasjpfan Oct 2, 2019
7cbb1ac
Merge remote-tracking branch 'upstream/master' into plot_precision_re…
thomasjpfan Nov 6, 2019
abbbb9d
BUG Quick fix
thomasjpfan Nov 6, 2019
a589f92
BUG Fix test
thomasjpfan Nov 6, 2019
c9b3d60
ENH Better error message
thomasjpfan Nov 6, 2019
bfd5634
CLN Address comments
thomasjpfan Nov 7, 2019
dbae9f8
Merge remote-tracking branch 'upstream/master' into plot_precision_re…
thomasjpfan Nov 8, 2019
2c3d78d
CLN Address comments
thomasjpfan Nov 8, 2019
7736f77
CLN Move to base
thomasjpfan Nov 10, 2019
91f0d05
CLN Unify response detection
thomasjpfan Nov 10, 2019
a559342
CLN Removes unneeded check
thomasjpfan Nov 10, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/modules/classes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1082,12 +1082,14 @@ See the :ref:`visualizations` section of the user guide for further details.
:toctree: generated/
:template: function.rst

metrics.plot_precision_recall_curve
metrics.plot_roc_curve

.. autosummary::
:toctree: generated/
:template: class.rst

metrics.PrecisionRecallDisplay
metrics.RocCurveDisplay


Expand Down
10 changes: 8 additions & 2 deletions doc/modules/model_evaluation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -744,8 +744,14 @@ score:

Note that the :func:`precision_recall_curve` function is restricted to the
binary case. The :func:`average_precision_score` function works only in
binary classification and multilabel indicator format.

binary classification and multilabel indicator format. The
:func:`plot_precision_recall_curve` function plots the precision recall as
follows.

.. image:: ../auto_examples/model_selection/images/sphx_glr_plot_precision_recall_001.png
:target: ../auto_examples/model_selection/plot_precision_recall.html#plot-the-precision-recall-curve
:scale: 75
:align: center

.. topic:: Examples:

Expand Down
2 changes: 2 additions & 0 deletions doc/visualizations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ Functions
.. autosummary::

inspection.plot_partial_dependence
metrics.plot_precision_recall_curve
metrics.plot_roc_curve


Expand All @@ -83,4 +84,5 @@ Display Objects
.. autosummary::

inspection.PartialDependenceDisplay
metrics.PrecisionRecallDisplay
metrics.RocCurveDisplay
3 changes: 3 additions & 0 deletions doc/whats_new/v0.22.rst
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,9 @@ Changelog
Gain and Normalized Discounted Cumulative Gain. :pr:`9951` by :user:`Jérôme
Dockès <jeromedockes>`.

- |Feature| :func:`metrics.plot_precision_recall_curve` has been added to plot
precision recall curves. :pr:`14936` by `Thomas Fan`_.

- |Feature| Added multiclass support to :func:`metrics.roc_auc_score` with
corresponding scorers `'roc_auc_ovr'`, `'roc_auc_ovo'`,
`'roc_auc_ovr_weighted'`, and `'roc_auc_ovo_weighted'`.
Expand Down
26 changes: 5 additions & 21 deletions examples/model_selection/plot_precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,25 +134,12 @@
# Plot the Precision-Recall curve
# ................................
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import plot_precision_recall_curve
import matplotlib.pyplot as plt
from inspect import signature

precision, recall, _ = precision_recall_curve(y_test, y_score)

# In matplotlib < 1.5, plt.fill_between does not have a 'step' argument
step_kwargs = ({'step': 'post'}
if 'step' in signature(plt.fill_between).parameters
else {})
plt.step(recall, precision, color='b', alpha=0.2,
where='post')
plt.fill_between(recall, precision, alpha=0.2, color='b', **step_kwargs)

plt.xlabel('Recall')
plt.ylabel('Precision')
plt.ylim([0.0, 1.05])
plt.xlim([0.0, 1.0])
plt.title('2-class Precision-Recall curve: AP={0:0.2f}'.format(
average_precision))
disp = plot_precision_recall_curve(classifier, X_test, y_test)
disp.ax_.set_title('2-class Precision-Recall curve: '
'AP={0:0.2f}'.format(average_precision))

###############################################################################
# In multi-label settings
Expand Down Expand Up @@ -212,10 +199,7 @@
#

plt.figure()
plt.step(recall['micro'], precision['micro'], color='b', alpha=0.2,
where='post')
plt.fill_between(recall["micro"], precision["micro"], alpha=0.2, color='b',
**step_kwargs)
plt.step(recall['micro'], precision['micro'], where='post')

plt.xlabel('Recall')
plt.ylabel('Precision')
Expand Down
4 changes: 4 additions & 0 deletions sklearn/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@

from ._plot.roc_curve import plot_roc_curve
from ._plot.roc_curve import RocCurveDisplay
from ._plot.precision_recall_curve import plot_precision_recall_curve
from ._plot.precision_recall_curve import PrecisionRecallDisplay


__all__ = [
Expand Down Expand Up @@ -135,7 +137,9 @@
'pairwise_distances_argmin_min',
'pairwise_distances_chunked',
'pairwise_kernels',
'plot_precision_recall_curve',
'plot_roc_curve',
'PrecisionRecallDisplay',
'precision_recall_curve',
'precision_recall_fscore_support',
'precision_score',
Expand Down
40 changes: 40 additions & 0 deletions sklearn/metrics/_plot/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
def _check_classifer_response_method(estimator, response_method):
"""Return prediction method from the response_method

Parameters
----------
estimator: object
Classifier to check

response_method: {'auto', 'predict_proba', 'decision_function'}
Specifies whether to use :term:`predict_proba` or
:term:`decision_function` as the target response. If set to 'auto',
:term:`predict_proba` is tried first and if it does not exist
:term:`decision_function` is tried next.

Returns
-------
prediction_method: callable
prediction method of estimator
"""

if response_method not in ("predict_proba", "decision_function", "auto"):
raise ValueError("response_method must be 'predict_proba', "
"'decision_function' or 'auto'")

error_msg = "response method {} is not defined in {}"
if response_method != "auto":
prediction_method = getattr(estimator, response_method, None)
if prediction_method is None:
raise ValueError(error_msg.format(response_method,
estimator.__class__.__name__))
else:
predict_proba = getattr(estimator, 'predict_proba', None)
decision_function = getattr(estimator, 'decision_function', None)
prediction_method = predict_proba or decision_function
if prediction_method is None:
raise ValueError(error_msg.format(
"decision_function or predict_proba",
estimator.__class__.__name__))

return prediction_method
169 changes: 169 additions & 0 deletions sklearn/metrics/_plot/precision_recall_curve.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
from .base import _check_classifer_response_method

from .. import average_precision_score
from .. import precision_recall_curve

from ...utils import check_matplotlib_support
from ...utils.validation import check_is_fitted
from ...base import is_classifier


class PrecisionRecallDisplay:
"""Precision Recall visualization.

It is recommend to use :func:`~sklearn.metrics.plot_precision_recall_curve`
to create a visualizer. All parameters are stored as attributes.

Read more in the :ref:`User Guide <visualizations>`.

Parameters
-----------
precision : ndarray
Precision values.

recall : ndarray
Recall values.

average_precision : float
Average precision.

estimator_name : str
Name of estimator.

Attributes
----------
line_ : matplotlib Artist
Precision recall curve.

ax_ : matplotlib Axes
Axes with precision recall curve.

figure_ : matplotlib Figure
Figure containing the curve.
"""

def __init__(self, precision, recall, average_precision, estimator_name):
self.precision = precision
self.recall = recall
self.average_precision = average_precision
self.estimator_name = estimator_name

def plot(self, ax=None, name=None, **kwargs):
"""Plot visualization.

Extra keyword arguments will be passed to matplotlib's `plot`.

Parameters
----------
ax : Matplotlib Axes, default=None
Axes object to plot on. If `None`, a new figure and axes is
created.

name : str, default=None
Name of precision recall curve for labeling. If `None`, use the
name of the estimator.

**kwargs : dict
Keyword arguments to be passed to matplotlib's `plot`.

Returns
-------
display : :class:`~sklearn.metrics.PrecisionRecallDisplay`
Object that stores computed values.
"""
check_matplotlib_support("PrecisionRecallDisplay.plot")
import matplotlib.pyplot as plt

if ax is None:
fig, ax = plt.subplots()

name = self.estimator_name if name is None else name

line_kwargs = {
"label": "{} (AP = {:0.2f})".format(name,
self.average_precision),
"drawstyle": "steps-post"
}
line_kwargs.update(**kwargs)

self.line_, = ax.plot(self.recall, self.precision, **line_kwargs)
ax.set(xlabel="Recall", ylabel="Precision")
ax.legend(loc='lower left')

self.ax_ = ax
self.figure_ = ax.figure
return self


def plot_precision_recall_curve(estimator, X, y,
sample_weight=None, response_method="auto",
name=None, ax=None, **kwargs):
"""Plot Precision Recall Curve for binary classifers.

Extra keyword arguments will be passed to matplotlib's `plot`.

Read more in the :ref:`User Guide <precision_recall_f_measure_metrics>`.

Parameters
----------
estimator : estimator instance
Trained classifier.

X : {array-like, sparse matrix} of shape (n_samples, n_features)
Input values.

y : array-like of shape (n_samples,)
Binary target values.

sample_weight : array-like of shape (n_samples,), default=None
Sample weights.

response_method : {'predict_proba', 'decision_function', 'auto'}, \
default='auto'
Specifies whether to use :term:`predict_proba` or
:term:`decision_function` as the target response. If set to 'auto',
:term:`predict_proba` is tried first and if it does not exist
:term:`decision_function` is tried next.

name : str, default=None
Name for labeling curve. If `None`, the name of the
estimator is used.

ax : matplotlib axes, default=None
Axes object to plot on. If `None`, a new figure and axes is created.

**kwargs : dict
Keyword arguments to be passed to matplotlib's `plot`.

Returns
-------
display : :class:`~sklearn.metrics.PrecisionRecallDisplay`
Object that stores computed values.
"""
check_matplotlib_support("plot_precision_recall_curve")
check_is_fitted(estimator)

classificaiton_error = ("{} should be a binary classifer".format(
estimator.__class__.__name__))
if is_classifier(estimator):
if len(estimator.classes_) != 2:
raise ValueError(classificaiton_error)
pos_label = estimator.classes_[1]
else:
raise ValueError(classificaiton_error)

prediction_method = _check_classifer_response_method(estimator,
response_method)
y_pred = prediction_method(X)

if y_pred.ndim != 1:
y_pred = y_pred[:, 1]

precision, recall, _ = precision_recall_curve(y, y_pred,
pos_label=pos_label,
sample_weight=sample_weight)
average_precision = average_precision_score(y, y_pred,
sample_weight=sample_weight)
viz = PrecisionRecallDisplay(precision, recall, average_precision,
estimator.__class__.__name__)
return viz.plot(ax=ax, name=name, **kwargs)
19 changes: 3 additions & 16 deletions sklearn/metrics/_plot/roc_curve.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .. import auc
from .. import roc_curve

from .base import _check_classifer_response_method
from ...utils import check_matplotlib_support
from ...base import is_classifier
from ...utils.validation import check_is_fitted
Expand Down Expand Up @@ -166,10 +167,6 @@ def plot_roc_curve(estimator, X, y, sample_weight=None,
check_matplotlib_support('plot_roc_curve')
check_is_fitted(estimator)

if response_method not in ("predict_proba", "decision_function", "auto"):
raise ValueError("response_method must be 'predict_proba', "
"'decision_function' or 'auto'")

classification_error = ("{} should be a binary classifer".format(
estimator.__class__.__name__))

Expand All @@ -180,18 +177,8 @@ def plot_roc_curve(estimator, X, y, sample_weight=None,
else:
raise ValueError(classification_error)

if response_method != "auto":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also need to remove following things above

if response_method not in ("predict_proba", "decision_function", "auto"):
        raise ValueError("response_method must be 'predict_proba', "
                         "'decision_function' or 'auto'")

prediction_method = getattr(estimator, response_method, None)
if prediction_method is None:
raise ValueError(
"response method {} is not defined".format(response_method))
else:
predict_proba = getattr(estimator, 'predict_proba', None)
decision_function = getattr(estimator, 'decision_function', None)
prediction_method = predict_proba or decision_function

if prediction_method is None:
raise ValueError('response methods not defined')
prediction_method = _check_classifer_response_method(estimator,
response_method)

y_pred = prediction_method(X)

Expand Down
Loading