Skip to content

ENH PrecisionRecallDisplay add option to plot chance level #26019

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 31 commits into from
May 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
1a1c53b
add option to plot chance level line and customize rendering for PR c…
Charlie-XIAO Mar 29, 2023
161fd9f
changelog added
Charlie-XIAO Mar 29, 2023
2863ddf
changed chance_level_kwargs to chance_level_kw for consistency with o…
Charlie-XIAO Mar 30, 2023
d4e44d1
changelog updated
Charlie-XIAO Mar 30, 2023
b2df9e2
resolved conversations: default value of pos_prevalence changed to None
Charlie-XIAO Apr 11, 2023
a570ba8
Merge branch 'main' into pr-vis-enh
Charlie-XIAO Apr 11, 2023
206a926
added check and solved linting issues
Charlie-XIAO Apr 11, 2023
edbbdc5
improved test coverage
Charlie-XIAO Apr 12, 2023
54b5eae
resolved conversations
Charlie-XIAO Apr 14, 2023
39ff08e
partially resolved conversations, the rest TBD soon
Charlie-XIAO Apr 14, 2023
46a0679
fixed typo in attribute name
Charlie-XIAO Apr 14, 2023
c6460e5
added test to check that prevalence_pos_label is reusable via plot me…
Charlie-XIAO Apr 14, 2023
d729fcb
added example
Charlie-XIAO Apr 14, 2023
2eb5bec
Merge branch 'main' into pr-vis-enh
Charlie-XIAO Apr 14, 2023
46807c3
'secretly' fix a changelog typo in my previous contributions
Charlie-XIAO Apr 14, 2023
401a370
Merge branch 'main' into pr-vis-enh
Charlie-XIAO Apr 20, 2023
21becf3
Merge remote-tracking branch 'upstream/main' into pr-vis-enh
Charlie-XIAO Apr 21, 2023
5f3e114
Merge remote-tracking branch 'upstream/main' into pr-vis-enh
Charlie-XIAO Apr 25, 2023
4503504
Merge remote-tracking branch 'upstream/main' into pr-vis-enh
Charlie-XIAO Apr 26, 2023
d96210e
Merge branch 'main' into pr-vis-enh
glemaitre May 4, 2023
ab8873c
reverted suspicious additions
Charlie-XIAO May 4, 2023
a825b92
resolved conversations
Charlie-XIAO May 4, 2023
2afede7
counter object does not have total()
Charlie-XIAO May 4, 2023
7b21acc
reverted unnecessary change
Charlie-XIAO May 4, 2023
3de2722
Merge branch 'main' into pr-vis-enh
Charlie-XIAO May 4, 2023
2457e90
minor modification
Charlie-XIAO May 14, 2023
8e11c09
Merge remote-tracking branch 'upstream/main' into pr-vis-enh
Charlie-XIAO May 14, 2023
e93a243
minor fix
Charlie-XIAO May 14, 2023
1e46101
Merge branch 'main' into pr-vis-enh
Charlie-XIAO May 16, 2023
8338521
raises when plotting chance level but no prevalence level is given
Charlie-XIAO May 16, 2023
0c91fc0
Merge branch 'main' into pr-vis-enh
Charlie-XIAO May 16, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions doc/whats_new/v1.3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,12 @@ Changelog
level. This line is exposed in the `chance_level_` attribute.
:pr:`25987` by :user:`Yao Xiao <Charlie-XIAO>`.

- |Enhancement| :meth:`metrics.PrecisionRecallDisplay.from_estimator` and
:meth:`metrics.PrecisionRecallDisplay.from_predictions` now accept two new
keywords, `plot_chance_level` and `chance_level_kw` to plot the baseline
chance level. This line is exposed in the `chance_level_` attribute.
:pr:`26019` by :user:`Yao Xiao <Charlie-XIAO>`.

- |Fix| :func:`log_loss` raises a warning if the values of the parameter `y_pred` are
not normalized, instead of actually normalizing them in the metric. Starting from
1.5 this will raise an error. :pr:`25299` by :user:`Omar Salman <OmarManzoor`.
Expand Down
11 changes: 8 additions & 3 deletions examples/model_selection/plot_precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@
from sklearn.metrics import PrecisionRecallDisplay

display = PrecisionRecallDisplay.from_estimator(
classifier, X_test, y_test, name="LinearSVC"
classifier, X_test, y_test, name="LinearSVC", plot_chance_level=True
)
_ = display.ax_.set_title("2-class Precision-Recall curve")

Expand All @@ -152,7 +152,9 @@
# :func:`~sklearn.metrics.PrecisionRecallDisplay.from_predictions`.
y_score = classifier.decision_function(X_test)

display = PrecisionRecallDisplay.from_predictions(y_test, y_score, name="LinearSVC")
display = PrecisionRecallDisplay.from_predictions(
y_test, y_score, name="LinearSVC", plot_chance_level=True
)
_ = display.ax_.set_title("2-class Precision-Recall curve")

# %%
Expand Down Expand Up @@ -214,12 +216,15 @@
# %%
# Plot the micro-averaged Precision-Recall curve
# ..............................................
from collections import Counter

display = PrecisionRecallDisplay(
recall=recall["micro"],
precision=precision["micro"],
average_precision=average_precision["micro"],
prevalence_pos_label=Counter(Y_test.ravel())[1] / Y_test.size,
)
display.plot()
display.plot(plot_chance_level=True)
_ = display.ax_.set_title("Micro-averaged over all classes")

# %%
Expand Down
113 changes: 110 additions & 3 deletions sklearn/metrics/_plot/precision_recall_curve.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from collections import Counter

from .. import average_precision_score
from .. import precision_recall_curve
from ...utils._plotting import _BinaryClassifierCurveDisplayMixin
Expand Down Expand Up @@ -34,11 +36,23 @@ class PrecisionRecallDisplay(_BinaryClassifierCurveDisplayMixin):

.. versionadded:: 0.24

prevalence_pos_label : float, default=None
The prevalence of the positive label. It is used for plotting the
chance level line. If None, the chance level line will not be plotted
even if `plot_chance_level` is set to True when plotting.

.. versionadded:: 1.3

Attributes
----------
line_ : matplotlib Artist
Precision recall curve.

chance_level_ : matplotlib Artist or None
The chance level line. It is `None` if the chance level is not plotted.

.. versionadded:: 1.3

ax_ : matplotlib Axes
Axes with precision recall curve.

Expand Down Expand Up @@ -96,14 +110,24 @@ def __init__(
average_precision=None,
estimator_name=None,
pos_label=None,
prevalence_pos_label=None,
):
self.estimator_name = estimator_name
self.precision = precision
self.recall = recall
self.average_precision = average_precision
self.pos_label = pos_label
self.prevalence_pos_label = prevalence_pos_label

def plot(self, ax=None, *, name=None, **kwargs):
def plot(
self,
ax=None,
*,
name=None,
plot_chance_level=False,
chance_level_kw=None,
**kwargs,
):
"""Plot visualization.

Extra keyword arguments will be passed to matplotlib's `plot`.
Expand All @@ -118,6 +142,19 @@ def plot(self, ax=None, *, name=None, **kwargs):
Name of precision recall curve for labeling. If `None`, use
`estimator_name` if not `None`, otherwise no labeling is shown.

plot_chance_level : bool, default=False
Whether to plot the chance level. The chance level is the prevalence
of the positive label computed from the data passed during
:meth:`from_estimator` or :meth:`from_predictions` call.

.. versionadded:: 1.3

chance_level_kw : dict, default=None
Keyword arguments to be passed to matplotlib's `plot` for rendering
the chance level line.

.. versionadded:: 1.3

**kwargs : dict
Keyword arguments to be passed to matplotlib's `plot`.

Expand Down Expand Up @@ -149,6 +186,7 @@ def plot(self, ax=None, *, name=None, **kwargs):
line_kwargs.update(**kwargs)

(self.line_,) = self.ax_.plot(self.recall, self.precision, **line_kwargs)

info_pos_label = (
f" (Positive label: {self.pos_label})" if self.pos_label is not None else ""
)
Expand All @@ -157,7 +195,34 @@ def plot(self, ax=None, *, name=None, **kwargs):
ylabel = "Precision" + info_pos_label
self.ax_.set(xlabel=xlabel, ylabel=ylabel)

if "label" in line_kwargs:
if plot_chance_level:
if self.prevalence_pos_label is None:
raise ValueError(
"You must provide prevalence_pos_label when constructing the "
"PrecisionRecallDisplay object in order to plot the chance "
"level line. Alternatively, you may use "
"PrecisionRecallDisplay.from_estimator or "
"PrecisionRecallDisplay.from_predictions "
"to automatically set prevalence_pos_label"
)

chance_level_line_kw = {
"label": f"Chance level (AP = {self.prevalence_pos_label:0.2f})",
"color": "k",
"linestyle": "--",
}
if chance_level_kw is not None:
chance_level_line_kw.update(chance_level_kw)

(self.chance_level_,) = self.ax_.plot(
(0, 1),
(self.prevalence_pos_label, self.prevalence_pos_label),
**chance_level_line_kw,
)
else:
self.chance_level_ = None

if "label" in line_kwargs or plot_chance_level:
self.ax_.legend(loc="lower left")

return self
Expand All @@ -175,6 +240,8 @@ def from_estimator(
response_method="auto",
name=None,
ax=None,
plot_chance_level=False,
chance_level_kw=None,
**kwargs,
):
"""Plot precision-recall curve given an estimator and some data.
Expand Down Expand Up @@ -219,6 +286,19 @@ def from_estimator(
ax : matplotlib axes, default=None
Axes object to plot on. If `None`, a new figure and axes is created.

plot_chance_level : bool, default=False
Whether to plot the chance level. The chance level is the prevalence
of the positive label computed from the data passed during
:meth:`from_estimator` or :meth:`from_predictions` call.

.. versionadded:: 1.3

chance_level_kw : dict, default=None
Keyword arguments to be passed to matplotlib's `plot` for rendering
the chance level line.

.. versionadded:: 1.3

**kwargs : dict
Keyword arguments to be passed to matplotlib's `plot`.

Expand Down Expand Up @@ -277,6 +357,8 @@ def from_estimator(
pos_label=pos_label,
drop_intermediate=drop_intermediate,
ax=ax,
plot_chance_level=plot_chance_level,
chance_level_kw=chance_level_kw,
**kwargs,
)

Expand All @@ -291,6 +373,8 @@ def from_predictions(
drop_intermediate=False,
name=None,
ax=None,
plot_chance_level=False,
chance_level_kw=None,
**kwargs,
):
"""Plot precision-recall curve given binary class predictions.
Expand Down Expand Up @@ -324,6 +408,19 @@ def from_predictions(
ax : matplotlib axes, default=None
Axes object to plot on. If `None`, a new figure and axes is created.

plot_chance_level : bool, default=False
Whether to plot the chance level. The chance level is the prevalence
of the positive label computed from the data passed during
:meth:`from_estimator` or :meth:`from_predictions` call.

.. versionadded:: 1.3

chance_level_kw : dict, default=None
Keyword arguments to be passed to matplotlib's `plot` for rendering
the chance level line.

.. versionadded:: 1.3

**kwargs : dict
Keyword arguments to be passed to matplotlib's `plot`.

Expand Down Expand Up @@ -381,12 +478,22 @@ def from_predictions(
y_true, y_pred, pos_label=pos_label, sample_weight=sample_weight
)

class_count = Counter(y_true)
prevalence_pos_label = class_count[pos_label] / sum(class_count.values())

viz = PrecisionRecallDisplay(
precision=precision,
recall=recall,
average_precision=average_precision,
estimator_name=name,
pos_label=pos_label,
prevalence_pos_label=prevalence_pos_label,
)

return viz.plot(ax=ax, name=name, **kwargs)
return viz.plot(
ax=ax,
name=name,
plot_chance_level=plot_chance_level,
chance_level_kw=chance_level_kw,
**kwargs,
)
97 changes: 97 additions & 0 deletions sklearn/metrics/_plot/tests/test_precision_recall_display.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from collections import Counter

import numpy as np
import pytest

Expand Down Expand Up @@ -76,6 +78,52 @@ def test_precision_recall_display_plotting(
assert display.line_.get_label() == expected_label
assert display.line_.get_alpha() == pytest.approx(0.8)

# Check that the chance level line is not plotted by default
assert display.chance_level_ is None


@pytest.mark.parametrize("chance_level_kw", [None, {"color": "r"}])
@pytest.mark.parametrize("constructor_name", ["from_estimator", "from_predictions"])
def test_precision_recall_chance_level_line(
pyplot,
chance_level_kw,
constructor_name,
):
"""Check the chance level line plotting behavior."""
X, y = make_classification(n_classes=2, n_samples=50, random_state=0)
pos_prevalence = Counter(y)[1] / len(y)

lr = LogisticRegression()
y_pred = lr.fit(X, y).predict_proba(X)[:, 1]

if constructor_name == "from_estimator":
display = PrecisionRecallDisplay.from_estimator(
lr,
X,
y,
plot_chance_level=True,
chance_level_kw=chance_level_kw,
)
else:
display = PrecisionRecallDisplay.from_predictions(
y,
y_pred,
plot_chance_level=True,
chance_level_kw=chance_level_kw,
)

import matplotlib as mpl # noqa

assert isinstance(display.chance_level_, mpl.lines.Line2D)
assert tuple(display.chance_level_.get_xdata()) == (0, 1)
assert tuple(display.chance_level_.get_ydata()) == (pos_prevalence, pos_prevalence)

# Checking for chance level line styles
if chance_level_kw is None:
assert display.chance_level_.get_color() == "k"
else:
assert display.chance_level_.get_color() == "r"


@pytest.mark.parametrize(
"constructor_name, default_label",
Expand Down Expand Up @@ -256,3 +304,52 @@ def test_plot_precision_recall_pos_label(pyplot, constructor_name, response_meth
avg_prec_limit = 0.95
assert display.average_precision > avg_prec_limit
assert -np.trapz(display.precision, display.recall) > avg_prec_limit


@pytest.mark.parametrize("constructor_name", ["from_estimator", "from_predictions"])
def test_precision_recall_prevalence_pos_label_reusable(pyplot, constructor_name):
# Check that even if one passes plot_chance_level=False the first time
# one can still call disp.plot with plot_chance_level=True and get the
# chance level line
X, y = make_classification(n_classes=2, n_samples=50, random_state=0)

lr = LogisticRegression()
y_pred = lr.fit(X, y).predict_proba(X)[:, 1]

if constructor_name == "from_estimator":
display = PrecisionRecallDisplay.from_estimator(
lr, X, y, plot_chance_level=False
)
else:
display = PrecisionRecallDisplay.from_predictions(
y, y_pred, plot_chance_level=False
)
assert display.chance_level_ is None

import matplotlib as mpl # noqa

# When calling from_estimator or from_predictions,
# prevalence_pos_label should have been set, so that directly
# calling plot_chance_level=True should plot the chance level line
display.plot(plot_chance_level=True)
assert isinstance(display.chance_level_, mpl.lines.Line2D)


def test_precision_recall_raise_no_prevalence(pyplot):
# Check that raises correctly when plotting chance level with
# no prvelance_pos_label is provided
precision = np.array([1, 0.5, 0])
recall = np.array([0, 0.5, 1])
display = PrecisionRecallDisplay(precision, recall)

msg = (
"You must provide prevalence_pos_label when constructing the "
"PrecisionRecallDisplay object in order to plot the chance "
"level line. Alternatively, you may use "
"PrecisionRecallDisplay.from_estimator or "
"PrecisionRecallDisplay.from_predictions "
"to automatically set prevalence_pos_label"
)

with pytest.raises(ValueError, match=msg):
display.plot(plot_chance_level=True)