Skip to content

EHN Provide a pos_label parameter in plot_precision_recall_curve #17569

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jun 24, 2020
5 changes: 5 additions & 0 deletions doc/whats_new/v0.24.rst
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,11 @@ Changelog
:func:`metrics.median_absolute_error`. :pr:`17225` by
:user:`Lucy Liu <lucyleeow>`.

- |Enhancement| Add `pos_label` parameter in
:func:`metrics.plot_precision_recall_curve` in order to specify the positive
class to be used when computing the precision and recall statistics.
:pr:`17569` by :user:`Guillaume Lemaitre <glemaitre>`.

- |Enhancement| Add `pos_label` parameter to :func:`roc_auc_score`.
:pr:`17594` by :user:`Guillaume Lemaitre <glemaitre>`.

Expand Down
51 changes: 43 additions & 8 deletions sklearn/metrics/_plot/precision_recall_curve.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import numpy as np

from .base import _check_classifier_response_method

from .. import average_precision_score
Expand Down Expand Up @@ -30,6 +32,12 @@ class PrecisionRecallDisplay:
estimator_name : str, default=None
Name of estimator. If None, then the estimator name is not shown.

pos_label : str or int, default=None
The class considered as the positive class. If None, the class will not
be shown in the legend.

.. versionadded:: 0.24

Attributes
----------
line_ : matplotlib Artist
Expand Down Expand Up @@ -60,11 +68,12 @@ class PrecisionRecallDisplay:
>>> disp.plot() # doctest: +SKIP
"""
def __init__(self, precision, recall, *,
average_precision=None, estimator_name=None):
average_precision=None, estimator_name=None, pos_label=None):
self.precision = precision
self.recall = recall
self.average_precision = average_precision
self.estimator_name = estimator_name
self.pos_label = pos_label

@_deprecate_positional_args
def plot(self, ax=None, *, name=None, **kwargs):
Expand Down Expand Up @@ -110,7 +119,11 @@ def plot(self, ax=None, *, name=None, **kwargs):
line_kwargs.update(**kwargs)

self.line_, = ax.plot(self.recall, self.precision, **line_kwargs)
ax.set(xlabel="Recall", ylabel="Precision")
info_pos_label = (f" (Positive label: {self.pos_label})"
if self.pos_label is not None else "")
xlabel = "Recall" + info_pos_label
ylabel = "Precision" + info_pos_label
ax.set(xlabel=xlabel, ylabel=ylabel)

if "label" in line_kwargs:
ax.legend(loc='lower left')
Expand All @@ -123,7 +136,7 @@ def plot(self, ax=None, *, name=None, **kwargs):
@_deprecate_positional_args
def plot_precision_recall_curve(estimator, X, y, *,
sample_weight=None, response_method="auto",
name=None, ax=None, **kwargs):
name=None, ax=None, pos_label=None, **kwargs):
"""Plot Precision Recall Curve for binary classifiers.

Extra keyword arguments will be passed to matplotlib's `plot`.
Expand Down Expand Up @@ -159,6 +172,13 @@ def plot_precision_recall_curve(estimator, X, y, *,
ax : matplotlib axes, default=None
Axes object to plot on. If `None`, a new figure and axes is created.

pos_label : str or int, default=None
The class considered as the positive class when computing the precision
and recall metrics. By default, `estimators.classes_[1]` is considered
as the positive class.

.. versionadded:: 0.24

**kwargs : dict
Keyword arguments to be passed to matplotlib's `plot`.

Expand All @@ -180,16 +200,30 @@ def plot_precision_recall_curve(estimator, X, y, *,
raise ValueError(classification_error)

prediction_method = _check_classifier_response_method(estimator,
response_method)
response_method)
y_pred = prediction_method(X)

if y_pred.ndim != 1:
if pos_label is not None and pos_label not in estimator.classes_:
raise ValueError(
f"The class provided by 'pos_label' is unknown. Got "
f"{pos_label} instead of one of {estimator.classes_}"
)

if y_pred.ndim != 1: # `predict_proba`
if y_pred.shape[1] != 2:
raise ValueError(classification_error)
else:
if pos_label is None:
pos_label = estimator.classes_[1]
y_pred = y_pred[:, 1]
else:
class_idx = np.flatnonzero(estimator.classes_ == pos_label)
y_pred = y_pred[:, class_idx]
else: # `decision_function`
if pos_label is None:
pos_label = estimator.classes_[1]
elif pos_label == estimator.classes_[0]:
y_pred *= -1

pos_label = estimator.classes_[1]
precision, recall, _ = precision_recall_curve(y, y_pred,
pos_label=pos_label,
sample_weight=sample_weight)
Expand All @@ -199,6 +233,7 @@ def plot_precision_recall_curve(estimator, X, y, *,
name = name if name is not None else estimator.__class__.__name__
viz = PrecisionRecallDisplay(
precision=precision, recall=recall,
average_precision=average_precision, estimator_name=name
average_precision=average_precision, estimator_name=name,
pos_label=pos_label,
)
return viz.plot(ax=ax, name=name, **kwargs)
53 changes: 51 additions & 2 deletions sklearn/metrics/_plot/tests/test_plot_precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@
from sklearn.datasets import load_breast_cancer
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.exceptions import NotFittedError
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.utils import shuffle
from sklearn.compose import make_column_transformer


Expand Down Expand Up @@ -114,8 +116,8 @@ def test_plot_precision_recall(pyplot, response_method, with_sample_weight):

expected_label = "LogisticRegression (AP = {:0.2f})".format(avg_prec)
assert disp.line_.get_label() == expected_label
assert disp.ax_.get_xlabel() == "Recall"
assert disp.ax_.get_ylabel() == "Precision"
assert disp.ax_.get_xlabel() == "Recall (Positive label: 1)"
assert disp.ax_.get_ylabel() == "Precision (Positive label: 1)"

# draw again with another label
disp.plot(name="MySpecialEstimator")
Expand Down Expand Up @@ -190,3 +192,50 @@ def test_default_labels(pyplot, average_precision, estimator_name,
estimator_name=estimator_name)
disp.plot()
assert disp.line_.get_label() == expected_label


@pytest.mark.parametrize(
"response_method", ["predict_proba", "decision_function"]
)
def test_plot_precision_recall_pos_label(pyplot, response_method):
# check that we can provide the positive label and display the proper
# statistics
X, y = load_breast_cancer(return_X_y=True)
# create an highly imbalanced version of the breast cancer dataset
idx_positive = np.flatnonzero(y == 1)
idx_negative = np.flatnonzero(y == 0)
idx_selected = np.hstack([idx_negative, idx_positive[:25]])
X, y = X[idx_selected], y[idx_selected]
X, y = shuffle(X, y, random_state=42)
# only use 2 features to make the problem even harder
X = X[:, :2]
y = np.array(
["cancer" if c == 1 else "not cancer" for c in y], dtype=object
)
X_train, X_test, y_train, y_test = train_test_split(
X, y, stratify=y, random_state=0,
)

classifier = LogisticRegression()
classifier.fit(X_train, y_train)

# sanity check to be sure the positive class is classes_[0] and that we
# are betrayed by the class imbalance
assert classifier.classes_.tolist() == ["cancer", "not cancer"]

disp = plot_precision_recall_curve(
classifier, X_test, y_test, pos_label="cancer",
response_method=response_method
)
# we should obtain the statistics of the "cancer" class
avg_prec_limit = 0.65
assert disp.average_precision < avg_prec_limit
assert -np.trapz(disp.precision, disp.recall) < avg_prec_limit

# otherwise we should obtain the statistics of the "not cancer" class
disp = plot_precision_recall_curve(
classifier, X_test, y_test, response_method=response_method,
)
avg_prec_limit = 0.95
assert disp.average_precision > avg_prec_limit
assert -np.trapz(disp.precision, disp.recall) > avg_prec_limit