From 1a1c53b90de80f80ec1de456f54e63b9cb606bf3 Mon Sep 17 00:00:00 2001 From: Charlie-XIAO Date: Thu, 30 Mar 2023 04:05:51 +0800 Subject: [PATCH 01/20] add option to plot chance level line and customize rendering for PR curve --- .../metrics/_plot/precision_recall_curve.py | 95 ++++++++++++++++++- .../tests/test_precision_recall_display.py | 76 +++++++++++++++ 2 files changed, 169 insertions(+), 2 deletions(-) diff --git a/sklearn/metrics/_plot/precision_recall_curve.py b/sklearn/metrics/_plot/precision_recall_curve.py index 209f4dd0c3862..825556b08a02d 100644 --- a/sklearn/metrics/_plot/precision_recall_curve.py +++ b/sklearn/metrics/_plot/precision_recall_curve.py @@ -6,6 +6,8 @@ from ...utils import check_matplotlib_support from ...utils._response import _get_response_values_binary +from collections import Counter + class PrecisionRecallDisplay: """Precision Recall visualization. @@ -43,6 +45,11 @@ class PrecisionRecallDisplay: line_ : matplotlib Artist Precision recall curve. + chance_level_ : matplotlib Artist or None + The chance level line. It is `None` if the chance level is not plotted. + + .. versionadded:: 1.3 + ax_ : matplotlib Axes Axes with precision recall curve. @@ -107,7 +114,16 @@ def __init__( self.average_precision = average_precision self.pos_label = pos_label - def plot(self, ax=None, *, name=None, **kwargs): + def plot( + self, + ax=None, + *, + name=None, + plot_chance_level=False, + pos_prevalence=0, + chance_level_kwargs=None, + **kwargs, + ): """Plot visualization. Extra keyword arguments will be passed to matplotlib's `plot`. @@ -122,6 +138,23 @@ def plot(self, ax=None, *, name=None, **kwargs): Name of precision recall curve for labeling. If `None`, use `estimator_name` if not `None`, otherwise no labeling is shown. + plot_chance_level : bool, default=False + Whether to plot the chance level. + + .. versionadded:: 1.3 + + pos_prevalence : float, default=0 + The prevalence of the positive label. It is used for plotting the + chance level line. + + .. versionadded:: 1.3 + + chance_level_kwargs : dict, default=None + Keyword arguments to be passed to matplotlib's `plot` for rendering + the chance level line. + + .. versionadded:: 1.3 + **kwargs : dict Keyword arguments to be passed to matplotlib's `plot`. @@ -154,6 +187,15 @@ def plot(self, ax=None, *, name=None, **kwargs): line_kwargs["label"] = name line_kwargs.update(**kwargs) + chance_level_line_kwargs = { + "label": f"Chance level (AP = {pos_prevalence:0.2f})", + "color": "k", + "linestyle": "--", + } + + if chance_level_kwargs is not None: + chance_level_line_kwargs.update(chance_level_kwargs) + import matplotlib.pyplot as plt if ax is None: @@ -168,6 +210,15 @@ def plot(self, ax=None, *, name=None, **kwargs): ylabel = "Precision" + info_pos_label ax.set(xlabel=xlabel, ylabel=ylabel) + if plot_chance_level: + (self.chance_level_,) = ax.plot( + (0, 1), + (pos_prevalence, pos_prevalence), + **chance_level_line_kwargs, + ) + else: + self.chance_level_ = None + if "label" in line_kwargs: ax.legend(loc="lower left") @@ -188,6 +239,8 @@ def from_estimator( response_method="auto", name=None, ax=None, + plot_chance_level=False, + chance_level_kwargs=None, **kwargs, ): """Plot precision-recall curve given an estimator and some data. @@ -232,6 +285,17 @@ def from_estimator( ax : matplotlib axes, default=None Axes object to plot on. If `None`, a new figure and axes is created. + plot_chance_level : bool, default=False + Whether to plot the chance level. + + .. versionadded:: 1.3 + + chance_level_kwargs : dict, default=None + Keyword arguments to be passed to matplotlib's `plot` for rendering + the chance level line. + + .. versionadded:: 1.3 + **kwargs : dict Keyword arguments to be passed to matplotlib's `plot`. @@ -293,6 +357,8 @@ def from_estimator( pos_label=pos_label, drop_intermediate=drop_intermediate, ax=ax, + plot_chance_level=plot_chance_level, + chance_level_kwargs=chance_level_kwargs, **kwargs, ) @@ -307,6 +373,8 @@ def from_predictions( drop_intermediate=False, name=None, ax=None, + plot_chance_level=False, + chance_level_kwargs=None, **kwargs, ): """Plot precision-recall curve given binary class predictions. @@ -340,6 +408,17 @@ def from_predictions( ax : matplotlib axes, default=None Axes object to plot on. If `None`, a new figure and axes is created. + plot_chance_level : bool, default=False + Whether to plot the chance level. + + .. versionadded:: 1.3 + + chance_level_kwargs : dict, default=None + Keyword arguments to be passed to matplotlib's `plot` for rendering + the chance level line. + + .. versionadded:: 1.3 + **kwargs : dict Keyword arguments to be passed to matplotlib's `plot`. @@ -400,6 +479,11 @@ def from_predictions( name = name if name is not None else "Classifier" + if plot_chance_level: + pos_prevalence = Counter(y_true)[pos_label] / len(y_true) + else: + pos_prevalence = 0 + viz = PrecisionRecallDisplay( precision=precision, recall=recall, @@ -408,4 +492,11 @@ def from_predictions( pos_label=pos_label, ) - return viz.plot(ax=ax, name=name, **kwargs) + return viz.plot( + ax=ax, + name=name, + plot_chance_level=plot_chance_level, + pos_prevalence=pos_prevalence, + chance_level_kwargs=chance_level_kwargs, + **kwargs, + ) diff --git a/sklearn/metrics/_plot/tests/test_precision_recall_display.py b/sklearn/metrics/_plot/tests/test_precision_recall_display.py index e7e1917c79776..58acc15fa5204 100644 --- a/sklearn/metrics/_plot/tests/test_precision_recall_display.py +++ b/sklearn/metrics/_plot/tests/test_precision_recall_display.py @@ -1,6 +1,8 @@ import numpy as np import pytest +from collections import Counter + from sklearn.compose import make_column_transformer from sklearn.datasets import load_breast_cancer, make_classification from sklearn.exceptions import NotFittedError @@ -120,6 +122,80 @@ def test_precision_recall_display_plotting( assert display.line_.get_alpha() == pytest.approx(0.8) +@pytest.mark.parametrize("plot_chance_level", [True, False]) +@pytest.mark.parametrize( + "chance_level_kwargs", + [None, {"linewidth": 1, "color": "red", "label": "DummyEstimator"}], +) +@pytest.mark.parametrize( + "constructor_name", + ["from_estimator", "from_predictions"], +) +def test_precision_recall_chance_level_line( + pyplot, + plot_chance_level, + chance_level_kwargs, + constructor_name, +): + """Check the chance leve line plotting behaviour.""" + X, y = make_classification(n_classes=2, n_samples=50, random_state=0) + pos_prevalence = Counter(y)[1] / len(y) + + lr = LogisticRegression() + lr.fit(X, y) + + y_pred = getattr(lr, "predict_proba")(X) + y_pred = y_pred if y_pred.ndim == 1 else y_pred[:, 1] + + if constructor_name == "from_estimator": + display = PrecisionRecallDisplay.from_estimator( + lr, + X, + y, + alpha=0.8, + plot_chance_level=plot_chance_level, + chance_level_kwargs=chance_level_kwargs, + ) + else: + display = PrecisionRecallDisplay.from_predictions( + y, + y_pred, + alpha=0.8, + plot_chance_level=plot_chance_level, + chance_level_kwargs=chance_level_kwargs, + ) + + import matplotlib as mpl # noqal + + assert isinstance(display.line_, mpl.lines.Line2D) + assert display.line_.get_alpha() == 0.8 + assert isinstance(display.ax_, mpl.axes.Axes) + assert isinstance(display.figure_, mpl.figure.Figure) + + if plot_chance_level: + assert isinstance(display.chance_level_, mpl.lines.Line2D) + assert tuple(display.chance_level_.get_xdata()) == (0, 1) + assert tuple(display.chance_level_.get_ydata()) == ( + pos_prevalence, + pos_prevalence, + ) + + # Checking for chance level line styles + if plot_chance_level and chance_level_kwargs is None: + assert display.chance_level_.get_color() == "k" + assert display.chance_level_.get_linestyle() == "--" + assert ( + display.chance_level_.get_label() + == f"Chance level (AP = {pos_prevalence:0.2f})" + ) + elif plot_chance_level: + for k, v in chance_level_kwargs.items(): + if hasattr(display.chance_level_, "get_" + k): + assert getattr(display.chance_level_, "get_" + k)() == v + else: + assert display.chance_level_ is None + + @pytest.mark.parametrize( "constructor_name, default_label", [ From 161fd9fabfeac8586576664170bc38ff7f1ec1a6 Mon Sep 17 00:00:00 2001 From: Charlie-XIAO Date: Thu, 30 Mar 2023 04:17:23 +0800 Subject: [PATCH 02/20] changelog added --- doc/whats_new/v1.3.rst | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/doc/whats_new/v1.3.rst b/doc/whats_new/v1.3.rst index 07b581f2104b1..72e44ccde9cc8 100644 --- a/doc/whats_new/v1.3.rst +++ b/doc/whats_new/v1.3.rst @@ -321,6 +321,12 @@ Changelog curves. :pr:`24668` by :user:`dberenbaum`. +- |Enhancement| :meth:`metrics.PrecisionRecallDisplay.from_estimator` and + :meth:`metrics.PrecisionRecallDisplay.from_predictions` now accept two new + keywords, `plot_chance_level` and `chance_level_kwargs` to plot the baseline + chance level. This line is exposed in the `chance_level_` attribute. + :pr:`26019` by :user:`Yao Xiao `. + - |Fix| :func:`log_loss` raises a warning if the values of the parameter `y_pred` are not normalized, instead of actually normalizing them in the metric. Starting from 1.5 this will raise an error. :pr:`25299` by :user:`Omar Salman Date: Thu, 30 Mar 2023 20:48:40 +0800 Subject: [PATCH 03/20] changed chance_level_kwargs to chance_level_kw for consistency with other display --- .../metrics/_plot/precision_recall_curve.py | 24 +++++++++---------- .../tests/test_precision_recall_display.py | 12 +++++----- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/sklearn/metrics/_plot/precision_recall_curve.py b/sklearn/metrics/_plot/precision_recall_curve.py index 825556b08a02d..d507ef96b3d0d 100644 --- a/sklearn/metrics/_plot/precision_recall_curve.py +++ b/sklearn/metrics/_plot/precision_recall_curve.py @@ -121,7 +121,7 @@ def plot( name=None, plot_chance_level=False, pos_prevalence=0, - chance_level_kwargs=None, + chance_level_kw=None, **kwargs, ): """Plot visualization. @@ -149,7 +149,7 @@ def plot( .. versionadded:: 1.3 - chance_level_kwargs : dict, default=None + chance_level_kw : dict, default=None Keyword arguments to be passed to matplotlib's `plot` for rendering the chance level line. @@ -187,14 +187,14 @@ def plot( line_kwargs["label"] = name line_kwargs.update(**kwargs) - chance_level_line_kwargs = { + chance_level_line_kw = { "label": f"Chance level (AP = {pos_prevalence:0.2f})", "color": "k", "linestyle": "--", } - if chance_level_kwargs is not None: - chance_level_line_kwargs.update(chance_level_kwargs) + if chance_level_kw is not None: + chance_level_line_kw.update(chance_level_kw) import matplotlib.pyplot as plt @@ -214,7 +214,7 @@ def plot( (self.chance_level_,) = ax.plot( (0, 1), (pos_prevalence, pos_prevalence), - **chance_level_line_kwargs, + **chance_level_line_kw, ) else: self.chance_level_ = None @@ -240,7 +240,7 @@ def from_estimator( name=None, ax=None, plot_chance_level=False, - chance_level_kwargs=None, + chance_level_kw=None, **kwargs, ): """Plot precision-recall curve given an estimator and some data. @@ -290,7 +290,7 @@ def from_estimator( .. versionadded:: 1.3 - chance_level_kwargs : dict, default=None + chance_level_kw : dict, default=None Keyword arguments to be passed to matplotlib's `plot` for rendering the chance level line. @@ -358,7 +358,7 @@ def from_estimator( drop_intermediate=drop_intermediate, ax=ax, plot_chance_level=plot_chance_level, - chance_level_kwargs=chance_level_kwargs, + chance_level_kw=chance_level_kw, **kwargs, ) @@ -374,7 +374,7 @@ def from_predictions( name=None, ax=None, plot_chance_level=False, - chance_level_kwargs=None, + chance_level_kw=None, **kwargs, ): """Plot precision-recall curve given binary class predictions. @@ -413,7 +413,7 @@ def from_predictions( .. versionadded:: 1.3 - chance_level_kwargs : dict, default=None + chance_level_kw : dict, default=None Keyword arguments to be passed to matplotlib's `plot` for rendering the chance level line. @@ -497,6 +497,6 @@ def from_predictions( name=name, plot_chance_level=plot_chance_level, pos_prevalence=pos_prevalence, - chance_level_kwargs=chance_level_kwargs, + chance_level_kw=chance_level_kw, **kwargs, ) diff --git a/sklearn/metrics/_plot/tests/test_precision_recall_display.py b/sklearn/metrics/_plot/tests/test_precision_recall_display.py index 58acc15fa5204..6b1c484f0a950 100644 --- a/sklearn/metrics/_plot/tests/test_precision_recall_display.py +++ b/sklearn/metrics/_plot/tests/test_precision_recall_display.py @@ -124,7 +124,7 @@ def test_precision_recall_display_plotting( @pytest.mark.parametrize("plot_chance_level", [True, False]) @pytest.mark.parametrize( - "chance_level_kwargs", + "chance_level_kw", [None, {"linewidth": 1, "color": "red", "label": "DummyEstimator"}], ) @pytest.mark.parametrize( @@ -134,7 +134,7 @@ def test_precision_recall_display_plotting( def test_precision_recall_chance_level_line( pyplot, plot_chance_level, - chance_level_kwargs, + chance_level_kw, constructor_name, ): """Check the chance leve line plotting behaviour.""" @@ -154,7 +154,7 @@ def test_precision_recall_chance_level_line( y, alpha=0.8, plot_chance_level=plot_chance_level, - chance_level_kwargs=chance_level_kwargs, + chance_level_kw=chance_level_kw, ) else: display = PrecisionRecallDisplay.from_predictions( @@ -162,7 +162,7 @@ def test_precision_recall_chance_level_line( y_pred, alpha=0.8, plot_chance_level=plot_chance_level, - chance_level_kwargs=chance_level_kwargs, + chance_level_kw=chance_level_kw, ) import matplotlib as mpl # noqal @@ -181,7 +181,7 @@ def test_precision_recall_chance_level_line( ) # Checking for chance level line styles - if plot_chance_level and chance_level_kwargs is None: + if plot_chance_level and chance_level_kw is None: assert display.chance_level_.get_color() == "k" assert display.chance_level_.get_linestyle() == "--" assert ( @@ -189,7 +189,7 @@ def test_precision_recall_chance_level_line( == f"Chance level (AP = {pos_prevalence:0.2f})" ) elif plot_chance_level: - for k, v in chance_level_kwargs.items(): + for k, v in chance_level_kw.items(): if hasattr(display.chance_level_, "get_" + k): assert getattr(display.chance_level_, "get_" + k)() == v else: From d4e44d149cd8b959051acc14480e6d7de8d171b0 Mon Sep 17 00:00:00 2001 From: Charlie-XIAO Date: Thu, 30 Mar 2023 20:50:57 +0800 Subject: [PATCH 04/20] changelog updated --- doc/whats_new/v1.3.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/whats_new/v1.3.rst b/doc/whats_new/v1.3.rst index 72e44ccde9cc8..1def65cb8e672 100644 --- a/doc/whats_new/v1.3.rst +++ b/doc/whats_new/v1.3.rst @@ -323,7 +323,7 @@ Changelog - |Enhancement| :meth:`metrics.PrecisionRecallDisplay.from_estimator` and :meth:`metrics.PrecisionRecallDisplay.from_predictions` now accept two new - keywords, `plot_chance_level` and `chance_level_kwargs` to plot the baseline + keywords, `plot_chance_level` and `chance_level_kw` to plot the baseline chance level. This line is exposed in the `chance_level_` attribute. :pr:`26019` by :user:`Yao Xiao `. From b2df9e2a289e50915541c518b765273d36271ddf Mon Sep 17 00:00:00 2001 From: Charlie-XIAO Date: Wed, 12 Apr 2023 01:45:39 +0800 Subject: [PATCH 05/20] resolved conversations: default value of pos_prevalence changed to None --- .../metrics/_plot/precision_recall_curve.py | 32 ++++++++++++------- .../tests/test_precision_recall_display.py | 2 +- 2 files changed, 21 insertions(+), 13 deletions(-) diff --git a/sklearn/metrics/_plot/precision_recall_curve.py b/sklearn/metrics/_plot/precision_recall_curve.py index d507ef96b3d0d..d7ee009e2e7d3 100644 --- a/sklearn/metrics/_plot/precision_recall_curve.py +++ b/sklearn/metrics/_plot/precision_recall_curve.py @@ -120,7 +120,7 @@ def plot( *, name=None, plot_chance_level=False, - pos_prevalence=0, + pos_prevalence=None, chance_level_kw=None, **kwargs, ): @@ -143,9 +143,10 @@ def plot( .. versionadded:: 1.3 - pos_prevalence : float, default=0 + pos_prevalence : float, default=None The prevalence of the positive label. It is used for plotting the - chance level line. + chance level line. If `plot_chance_level=True`, it must be provided + as a float between 0 and 1. .. versionadded:: 1.3 @@ -187,14 +188,21 @@ def plot( line_kwargs["label"] = name line_kwargs.update(**kwargs) - chance_level_line_kw = { - "label": f"Chance level (AP = {pos_prevalence:0.2f})", - "color": "k", - "linestyle": "--", - } - - if chance_level_kw is not None: - chance_level_line_kw.update(chance_level_kw) + if plot_chance_level: + if pos_prevalence is None: + raise ValueError( + "pos_prevalence must be provided if plot_chance_level=True" + ) + elif pos_prevalence < 0 or pos_prevalence > 1: + raise ValueError("pos_prevalence has value outside [0, 1]") + else: + chance_level_line_kw = { + "label": f"Chance level (AP = {pos_prevalence:0.2f})", + "color": "k", + "linestyle": "--", + } + if chance_level_kw is not None: + chance_level_line_kw.update(chance_level_kw) import matplotlib.pyplot as plt @@ -482,7 +490,7 @@ def from_predictions( if plot_chance_level: pos_prevalence = Counter(y_true)[pos_label] / len(y_true) else: - pos_prevalence = 0 + pos_prevalence = None viz = PrecisionRecallDisplay( precision=precision, diff --git a/sklearn/metrics/_plot/tests/test_precision_recall_display.py b/sklearn/metrics/_plot/tests/test_precision_recall_display.py index 6b1c484f0a950..fd161aef876fe 100644 --- a/sklearn/metrics/_plot/tests/test_precision_recall_display.py +++ b/sklearn/metrics/_plot/tests/test_precision_recall_display.py @@ -137,7 +137,7 @@ def test_precision_recall_chance_level_line( chance_level_kw, constructor_name, ): - """Check the chance leve line plotting behaviour.""" + """Check the chance level line plotting behavior.""" X, y = make_classification(n_classes=2, n_samples=50, random_state=0) pos_prevalence = Counter(y)[1] / len(y) From 206a926fff4e8f038d65c5799362c3583436413d Mon Sep 17 00:00:00 2001 From: Charlie-XIAO Date: Wed, 12 Apr 2023 01:58:29 +0800 Subject: [PATCH 06/20] added check and solved linting issues --- sklearn/metrics/_plot/precision_recall_curve.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/sklearn/metrics/_plot/precision_recall_curve.py b/sklearn/metrics/_plot/precision_recall_curve.py index e3302bf71b241..4379a23621e49 100644 --- a/sklearn/metrics/_plot/precision_recall_curve.py +++ b/sklearn/metrics/_plot/precision_recall_curve.py @@ -3,6 +3,8 @@ from ...utils._plotting import _BinaryClassifierCurveDisplayMixin from collections import Counter +from numbers import Real + class PrecisionRecallDisplay(_BinaryClassifierCurveDisplayMixin): """Precision Recall visualization. @@ -183,9 +185,12 @@ def plot( if plot_chance_level: if pos_prevalence is None: - raise ValueError( - "pos_prevalence must be provided if plot_chance_level=True" + raise TypeError( + "pos_prevalence must be provided as a real number between " + "0 and 1 if plot_chance_level=True" ) + elif not isinstance(pos_prevalence, Real): + raise TypeError("pos_prevalence must be a real number between 0 and 1") elif pos_prevalence < 0 or pos_prevalence > 1: raise ValueError("pos_prevalence has value outside [0, 1]") else: @@ -472,7 +477,7 @@ def from_predictions( average_precision = average_precision_score( y_true, y_pred, pos_label=pos_label, sample_weight=sample_weight ) - + name = name if name is not None else "Classifier" if plot_chance_level: From edbbdc597964997867a0f8d401c2b655214e74f1 Mon Sep 17 00:00:00 2001 From: Charlie-XIAO Date: Wed, 12 Apr 2023 21:50:17 +0800 Subject: [PATCH 07/20] improved test coverage --- .../tests/test_precision_recall_display.py | 29 +++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/sklearn/metrics/_plot/tests/test_precision_recall_display.py b/sklearn/metrics/_plot/tests/test_precision_recall_display.py index 90e3f958772b6..1ea5fac04277d 100644 --- a/sklearn/metrics/_plot/tests/test_precision_recall_display.py +++ b/sklearn/metrics/_plot/tests/test_precision_recall_display.py @@ -11,6 +11,7 @@ from sklearn.model_selection import train_test_split from sklearn.pipeline import make_pipeline from sklearn.preprocessing import StandardScaler +from sklearn.svm import SVC from sklearn.utils import shuffle from sklearn.metrics import PrecisionRecallDisplay @@ -332,3 +333,31 @@ def test_plot_precision_recall_pos_label(pyplot, constructor_name, response_meth avg_prec_limit = 0.95 assert display.average_precision > avg_prec_limit assert -np.trapz(display.precision, display.recall) > avg_prec_limit + + +def test_precision_recall_pos_prevalence_error(pyplot): + # Check that when plot_chance_level is True + # If pos_prevalence is not given as a real number between 0 and 1 + # We raise the correct exceptions + X, y = make_classification(random_state=0) + X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) + clf = SVC(random_state=0) + clf.fit(X_train, y_train) + predictions = clf.predict(X_test) + precision, recall, _ = precision_recall_curve(y_test, predictions) + disp = PrecisionRecallDisplay(precision=precision, recall=recall) + + msg = ( + "pos_prevalence must be provided as a real number between 0 and 1 " + "if plot_chance_level=True" + ) + with pytest.raises(TypeError, match=msg): + disp.plot(plot_chance_level=True) + + msg = "pos_prevalence must be a real number between 0 and 1" + with pytest.raises(TypeError, match=msg): + disp.plot(plot_chance_level=True, pos_prevalence="0.5") + + msg = "pos_prevalence has value outside \\[0, 1\\]" + with pytest.raises(ValueError, match=msg): + disp.plot(plot_chance_level=True, pos_prevalence=1.5) From 54b5eae817012fcf310fe2fe6ae404cec36bafa8 Mon Sep 17 00:00:00 2001 From: Charlie-XIAO Date: Fri, 14 Apr 2023 14:20:12 +0800 Subject: [PATCH 08/20] resolved conversations --- .../metrics/_plot/precision_recall_curve.py | 64 ++++++-------- .../tests/test_precision_recall_display.py | 85 +++---------------- 2 files changed, 40 insertions(+), 109 deletions(-) diff --git a/sklearn/metrics/_plot/precision_recall_curve.py b/sklearn/metrics/_plot/precision_recall_curve.py index 4379a23621e49..6d779dc9a0d10 100644 --- a/sklearn/metrics/_plot/precision_recall_curve.py +++ b/sklearn/metrics/_plot/precision_recall_curve.py @@ -1,10 +1,9 @@ +from collections import Counter + from .. import average_precision_score from .. import precision_recall_curve from ...utils._plotting import _BinaryClassifierCurveDisplayMixin -from collections import Counter -from numbers import Real - class PrecisionRecallDisplay(_BinaryClassifierCurveDisplayMixin): """Precision Recall visualization. @@ -37,6 +36,13 @@ class PrecisionRecallDisplay(_BinaryClassifierCurveDisplayMixin): .. versionadded:: 0.24 + prevalence_pos_label : float, default=None + The prevalence of the positive label. It is used for plotting the + chance level line. If None, the chance level line will not be plotted + even if plot_chance_level is set to True when plotting. + + .. versionadded:: 1.3 + Attributes ---------- line_ : matplotlib Artist @@ -104,12 +110,14 @@ def __init__( average_precision=None, estimator_name=None, pos_label=None, + prevalence_pos_label=None, ): self.estimator_name = estimator_name self.precision = precision self.recall = recall self.average_precision = average_precision self.pos_label = pos_label + self.prevalence_pos_label = prevalence_pos_label def plot( self, @@ -117,7 +125,6 @@ def plot( *, name=None, plot_chance_level=False, - pos_prevalence=None, chance_level_kw=None, **kwargs, ): @@ -140,13 +147,6 @@ def plot( .. versionadded:: 1.3 - pos_prevalence : float, default=None - The prevalence of the positive label. It is used for plotting the - chance level line. If `plot_chance_level=True`, it must be provided - as a float between 0 and 1. - - .. versionadded:: 1.3 - chance_level_kw : dict, default=None Keyword arguments to be passed to matplotlib's `plot` for rendering the chance level line. @@ -183,29 +183,19 @@ def plot( line_kwargs["label"] = name line_kwargs.update(**kwargs) + # If prevalence_pos_label is not provided, even if plot_chance_level + # is set to True, we do not plot the chance level line + if self.prevalence_pos_label is None: + plot_chance_level = False + if plot_chance_level: - if pos_prevalence is None: - raise TypeError( - "pos_prevalence must be provided as a real number between " - "0 and 1 if plot_chance_level=True" - ) - elif not isinstance(pos_prevalence, Real): - raise TypeError("pos_prevalence must be a real number between 0 and 1") - elif pos_prevalence < 0 or pos_prevalence > 1: - raise ValueError("pos_prevalence has value outside [0, 1]") - else: - chance_level_line_kw = { - "label": f"Chance level (AP = {pos_prevalence:0.2f})", - "color": "k", - "linestyle": "--", - } - if chance_level_kw is not None: - chance_level_line_kw.update(chance_level_kw) - - import matplotlib.pyplot as plt - - if ax is None: - fig, ax = plt.subplots() + chance_level_line_kw = { + "label": f"Chance level (AP = {self.prevalence_pos_label:0.2f})", + "color": "k", + "linestyle": "--", + } + if chance_level_kw is not None: + chance_level_line_kw.update(chance_level_kw) (self.line_,) = self.ax_.plot(self.recall, self.precision, **line_kwargs) @@ -218,9 +208,9 @@ def plot( self.ax_.set(xlabel=xlabel, ylabel=ylabel) if plot_chance_level: - (self.chance_level_,) = ax.plot( + (self.chance_level_,) = self.ax_.plot( (0, 1), - (pos_prevalence, pos_prevalence), + (self.prevalence_pos_label, self.prevalence_pos_label), **chance_level_line_kw, ) else: @@ -478,8 +468,6 @@ def from_predictions( y_true, y_pred, pos_label=pos_label, sample_weight=sample_weight ) - name = name if name is not None else "Classifier" - if plot_chance_level: pos_prevalence = Counter(y_true)[pos_label] / len(y_true) else: @@ -491,13 +479,13 @@ def from_predictions( average_precision=average_precision, estimator_name=name, pos_label=pos_label, + prevalence_pos_label=pos_prevalence, ) return viz.plot( ax=ax, name=name, plot_chance_level=plot_chance_level, - pos_prevalence=pos_prevalence, chance_level_kw=chance_level_kw, **kwargs, ) diff --git a/sklearn/metrics/_plot/tests/test_precision_recall_display.py b/sklearn/metrics/_plot/tests/test_precision_recall_display.py index 1ea5fac04277d..4ad59b690cd46 100644 --- a/sklearn/metrics/_plot/tests/test_precision_recall_display.py +++ b/sklearn/metrics/_plot/tests/test_precision_recall_display.py @@ -11,7 +11,6 @@ from sklearn.model_selection import train_test_split from sklearn.pipeline import make_pipeline from sklearn.preprocessing import StandardScaler -from sklearn.svm import SVC from sklearn.utils import shuffle from sklearn.metrics import PrecisionRecallDisplay @@ -79,19 +78,14 @@ def test_precision_recall_display_plotting( assert display.line_.get_label() == expected_label assert display.line_.get_alpha() == pytest.approx(0.8) + # Check that the chance level line is not plotted + assert display.chance_level_ is None -@pytest.mark.parametrize("plot_chance_level", [True, False]) -@pytest.mark.parametrize( - "chance_level_kw", - [None, {"linewidth": 1, "color": "red", "label": "DummyEstimator"}], -) -@pytest.mark.parametrize( - "constructor_name", - ["from_estimator", "from_predictions"], -) + +@pytest.mark.parametrize("chance_level_kw", [None, {"color": "r"}]) +@pytest.mark.parametrize("constructor_name", ["from_estimator", "from_predictions"]) def test_precision_recall_chance_level_line( pyplot, - plot_chance_level, chance_level_kw, constructor_name, ): @@ -100,58 +94,35 @@ def test_precision_recall_chance_level_line( pos_prevalence = Counter(y)[1] / len(y) lr = LogisticRegression() - lr.fit(X, y) - - y_pred = getattr(lr, "predict_proba")(X) - y_pred = y_pred if y_pred.ndim == 1 else y_pred[:, 1] + y_pred = lr.fit(X, y).predict_proba(X)[:, 1] if constructor_name == "from_estimator": display = PrecisionRecallDisplay.from_estimator( lr, X, y, - alpha=0.8, - plot_chance_level=plot_chance_level, + plot_chance_level=True, chance_level_kw=chance_level_kw, ) else: display = PrecisionRecallDisplay.from_predictions( y, y_pred, - alpha=0.8, - plot_chance_level=plot_chance_level, + plot_chance_level=True, chance_level_kw=chance_level_kw, ) - import matplotlib as mpl # noqal + import matplotlib as mpl # noqa - assert isinstance(display.line_, mpl.lines.Line2D) - assert display.line_.get_alpha() == 0.8 - assert isinstance(display.ax_, mpl.axes.Axes) - assert isinstance(display.figure_, mpl.figure.Figure) - - if plot_chance_level: - assert isinstance(display.chance_level_, mpl.lines.Line2D) - assert tuple(display.chance_level_.get_xdata()) == (0, 1) - assert tuple(display.chance_level_.get_ydata()) == ( - pos_prevalence, - pos_prevalence, - ) + assert isinstance(display.chance_level_, mpl.lines.Line2D) + assert tuple(display.chance_level_.get_xdata()) == (0, 1) + assert tuple(display.chance_level_.get_ydata()) == (pos_prevalence, pos_prevalence) # Checking for chance level line styles - if plot_chance_level and chance_level_kw is None: + if chance_level_kw is None: assert display.chance_level_.get_color() == "k" - assert display.chance_level_.get_linestyle() == "--" - assert ( - display.chance_level_.get_label() - == f"Chance level (AP = {pos_prevalence:0.2f})" - ) - elif plot_chance_level: - for k, v in chance_level_kw.items(): - if hasattr(display.chance_level_, "get_" + k): - assert getattr(display.chance_level_, "get_" + k)() == v else: - assert display.chance_level_ is None + assert display.chance_level_.get_color() == "r" @pytest.mark.parametrize( @@ -333,31 +304,3 @@ def test_plot_precision_recall_pos_label(pyplot, constructor_name, response_meth avg_prec_limit = 0.95 assert display.average_precision > avg_prec_limit assert -np.trapz(display.precision, display.recall) > avg_prec_limit - - -def test_precision_recall_pos_prevalence_error(pyplot): - # Check that when plot_chance_level is True - # If pos_prevalence is not given as a real number between 0 and 1 - # We raise the correct exceptions - X, y = make_classification(random_state=0) - X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) - clf = SVC(random_state=0) - clf.fit(X_train, y_train) - predictions = clf.predict(X_test) - precision, recall, _ = precision_recall_curve(y_test, predictions) - disp = PrecisionRecallDisplay(precision=precision, recall=recall) - - msg = ( - "pos_prevalence must be provided as a real number between 0 and 1 " - "if plot_chance_level=True" - ) - with pytest.raises(TypeError, match=msg): - disp.plot(plot_chance_level=True) - - msg = "pos_prevalence must be a real number between 0 and 1" - with pytest.raises(TypeError, match=msg): - disp.plot(plot_chance_level=True, pos_prevalence="0.5") - - msg = "pos_prevalence has value outside \\[0, 1\\]" - with pytest.raises(ValueError, match=msg): - disp.plot(plot_chance_level=True, pos_prevalence=1.5) From 39ff08e5ad952b174cf7209ed4034c210776178c Mon Sep 17 00:00:00 2001 From: Charlie-XIAO Date: Fri, 14 Apr 2023 20:50:33 +0800 Subject: [PATCH 09/20] partially resolved conversations, the rest TBD soon --- .../metrics/_plot/precision_recall_curve.py | 44 +++++++++---------- .../tests/test_precision_recall_display.py | 6 +-- 2 files changed, 23 insertions(+), 27 deletions(-) diff --git a/sklearn/metrics/_plot/precision_recall_curve.py b/sklearn/metrics/_plot/precision_recall_curve.py index 6d779dc9a0d10..92836e8071057 100644 --- a/sklearn/metrics/_plot/precision_recall_curve.py +++ b/sklearn/metrics/_plot/precision_recall_curve.py @@ -39,7 +39,7 @@ class PrecisionRecallDisplay(_BinaryClassifierCurveDisplayMixin): prevalence_pos_label : float, default=None The prevalence of the positive label. It is used for plotting the chance level line. If None, the chance level line will not be plotted - even if plot_chance_level is set to True when plotting. + even if `plot_chance_level` is set to True when plotting. .. versionadded:: 1.3 @@ -143,7 +143,9 @@ def plot( `estimator_name` if not `None`, otherwise no labeling is shown. plot_chance_level : bool, default=False - Whether to plot the chance level. + Whether to plot the chance level. The chance level is the prevalence + of the positive label computed from the data passed during + :meth:`from_estimator` or :meth:`from_predictions` call. .. versionadded:: 1.3 @@ -183,20 +185,6 @@ def plot( line_kwargs["label"] = name line_kwargs.update(**kwargs) - # If prevalence_pos_label is not provided, even if plot_chance_level - # is set to True, we do not plot the chance level line - if self.prevalence_pos_label is None: - plot_chance_level = False - - if plot_chance_level: - chance_level_line_kw = { - "label": f"Chance level (AP = {self.prevalence_pos_label:0.2f})", - "color": "k", - "linestyle": "--", - } - if chance_level_kw is not None: - chance_level_line_kw.update(chance_level_kw) - (self.line_,) = self.ax_.plot(self.recall, self.precision, **line_kwargs) info_pos_label = ( @@ -207,7 +195,14 @@ def plot( ylabel = "Precision" + info_pos_label self.ax_.set(xlabel=xlabel, ylabel=ylabel) - if plot_chance_level: + if plot_chance_level and self.prevalence is not None: + chance_level_line_kw = { + "label": f"Chance level (AP = {self.prevalence_pos_label:0.2f})", + "color": "k", + "linestyle": "--", + } + if chance_level_kw is not None: + chance_level_line_kw.update(chance_level_kw) (self.chance_level_,) = self.ax_.plot( (0, 1), (self.prevalence_pos_label, self.prevalence_pos_label), @@ -281,7 +276,9 @@ def from_estimator( Axes object to plot on. If `None`, a new figure and axes is created. plot_chance_level : bool, default=False - Whether to plot the chance level. + Whether to plot the chance level. The chance level is the prevalence + of the positive label computed from the data passed during + :meth:`from_estimator` or :meth:`from_predictions` call. .. versionadded:: 1.3 @@ -401,7 +398,9 @@ def from_predictions( Axes object to plot on. If `None`, a new figure and axes is created. plot_chance_level : bool, default=False - Whether to plot the chance level. + Whether to plot the chance level. The chance level is the prevalence + of the positive label computed from the data passed during + :meth:`from_estimator` or :meth:`from_predictions` call. .. versionadded:: 1.3 @@ -468,10 +467,7 @@ def from_predictions( y_true, y_pred, pos_label=pos_label, sample_weight=sample_weight ) - if plot_chance_level: - pos_prevalence = Counter(y_true)[pos_label] / len(y_true) - else: - pos_prevalence = None + prevalence_pos_label = Counter(y_true)[pos_label] / len(y_true) viz = PrecisionRecallDisplay( precision=precision, @@ -479,7 +475,7 @@ def from_predictions( average_precision=average_precision, estimator_name=name, pos_label=pos_label, - prevalence_pos_label=pos_prevalence, + prevalence_pos_label=prevalence_pos_label, ) return viz.plot( diff --git a/sklearn/metrics/_plot/tests/test_precision_recall_display.py b/sklearn/metrics/_plot/tests/test_precision_recall_display.py index 4ad59b690cd46..aa3b1acbddcf4 100644 --- a/sklearn/metrics/_plot/tests/test_precision_recall_display.py +++ b/sklearn/metrics/_plot/tests/test_precision_recall_display.py @@ -1,8 +1,8 @@ +from collections import Counter + import numpy as np import pytest -from collections import Counter - from sklearn.compose import make_column_transformer from sklearn.datasets import load_breast_cancer, make_classification from sklearn.exceptions import NotFittedError @@ -78,7 +78,7 @@ def test_precision_recall_display_plotting( assert display.line_.get_label() == expected_label assert display.line_.get_alpha() == pytest.approx(0.8) - # Check that the chance level line is not plotted + # Check that the chance level line is not plotted by default assert display.chance_level_ is None From 46a06797ec3e1f50fd157169ec03b0c03c266b76 Mon Sep 17 00:00:00 2001 From: Charlie-XIAO Date: Fri, 14 Apr 2023 21:30:09 +0800 Subject: [PATCH 10/20] fixed typo in attribute name --- sklearn/metrics/_plot/precision_recall_curve.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/metrics/_plot/precision_recall_curve.py b/sklearn/metrics/_plot/precision_recall_curve.py index 92836e8071057..c55130f49cae9 100644 --- a/sklearn/metrics/_plot/precision_recall_curve.py +++ b/sklearn/metrics/_plot/precision_recall_curve.py @@ -195,7 +195,7 @@ def plot( ylabel = "Precision" + info_pos_label self.ax_.set(xlabel=xlabel, ylabel=ylabel) - if plot_chance_level and self.prevalence is not None: + if plot_chance_level and self.prevalence_pos_label is not None: chance_level_line_kw = { "label": f"Chance level (AP = {self.prevalence_pos_label:0.2f})", "color": "k", From c6460e58869cbfd964303766d807190bd84b083d Mon Sep 17 00:00:00 2001 From: Charlie-XIAO Date: Fri, 14 Apr 2023 21:50:38 +0800 Subject: [PATCH 11/20] added test to check that prevalence_pos_label is reusable via plot method --- .../tests/test_precision_recall_display.py | 29 +++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/sklearn/metrics/_plot/tests/test_precision_recall_display.py b/sklearn/metrics/_plot/tests/test_precision_recall_display.py index aa3b1acbddcf4..a6696fcba4912 100644 --- a/sklearn/metrics/_plot/tests/test_precision_recall_display.py +++ b/sklearn/metrics/_plot/tests/test_precision_recall_display.py @@ -304,3 +304,32 @@ def test_plot_precision_recall_pos_label(pyplot, constructor_name, response_meth avg_prec_limit = 0.95 assert display.average_precision > avg_prec_limit assert -np.trapz(display.precision, display.recall) > avg_prec_limit + + +@pytest.mark.parametrize("constructor_name", ["from_estimator", "from_predictions"]) +def test_precision_recall_prevalence_pos_label_reusable(pyplot, constructor_name): + # Check that even if one passes plot_chance_level=False the first time + # one can still call disp.plot with plot_chance_level=True and get the + # chance level line + X, y = make_classification(n_classes=2, n_samples=50, random_state=0) + + lr = LogisticRegression() + y_pred = lr.fit(X, y).predict_proba(X)[:, 1] + + if constructor_name == "from_estimator": + display = PrecisionRecallDisplay.from_estimator( + lr, X, y, plot_chance_level=False + ) + else: + display = PrecisionRecallDisplay.from_predictions( + y, y_pred, plot_chance_level=False + ) + assert display.chance_level_ is None + + import matplotlib as mpl # noqa + + # When calling from_estimator or from_predictions, + # prevalence_pos_label should have been set, so that directly + # calling plot_chance_level=True should plot the chance level line + display.plot(plot_chance_level=True) + assert isinstance(display.chance_level_, mpl.lines.Line2D) From d729fcb5a0bf8180f6943004f6033bc885f44cd4 Mon Sep 17 00:00:00 2001 From: Charlie-XIAO Date: Fri, 14 Apr 2023 22:22:45 +0800 Subject: [PATCH 12/20] added example --- examples/model_selection/plot_precision_recall.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/examples/model_selection/plot_precision_recall.py b/examples/model_selection/plot_precision_recall.py index 4d9ebcdc4abe2..a80af50db6429 100644 --- a/examples/model_selection/plot_precision_recall.py +++ b/examples/model_selection/plot_precision_recall.py @@ -142,7 +142,7 @@ from sklearn.metrics import PrecisionRecallDisplay display = PrecisionRecallDisplay.from_estimator( - classifier, X_test, y_test, name="LinearSVC" + classifier, X_test, y_test, name="LinearSVC", plot_chance_level=True ) _ = display.ax_.set_title("2-class Precision-Recall curve") @@ -152,7 +152,9 @@ # :func:`~sklearn.metrics.PrecisionRecallDisplay.from_predictions`. y_score = classifier.decision_function(X_test) -display = PrecisionRecallDisplay.from_predictions(y_test, y_score, name="LinearSVC") +display = PrecisionRecallDisplay.from_predictions( + y_test, y_score, name="LinearSVC", plot_chance_level=True +) _ = display.ax_.set_title("2-class Precision-Recall curve") # %% @@ -214,12 +216,15 @@ # %% # Plot the micro-averaged Precision-Recall curve # .............................................. +from collections import Counter + display = PrecisionRecallDisplay( recall=recall["micro"], precision=precision["micro"], average_precision=average_precision["micro"], + prevalence_pos_label=Counter(Y_test.ravel())[1] / Y_test.size, ) -display.plot() +display.plot(plot_chance_level=True) _ = display.ax_.set_title("Micro-averaged over all classes") # %% From 46807c3cf4a3a7bfb13de4b2b1bfd0bb7059d47e Mon Sep 17 00:00:00 2001 From: Charlie-XIAO Date: Sat, 15 Apr 2023 00:26:33 +0800 Subject: [PATCH 13/20] 'secretly' fix a changelog typo in my previous contributions --- doc/whats_new/v1.3.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/whats_new/v1.3.rst b/doc/whats_new/v1.3.rst index 301b146e0058f..d5d97153adbf0 100644 --- a/doc/whats_new/v1.3.rst +++ b/doc/whats_new/v1.3.rst @@ -158,7 +158,7 @@ Changelog a DataFrame's dtype when transformed. :pr:`25102` by `Thomas Fan`_. - |Fix| :class:`feature_selection.SequentialFeatureSelector`'s `cv` parameter - now supports generators. :pr:`25973` by `Yao Xiao `. + now supports generators. :pr:`25973` by :user:`Yao Xiao `. :mod:`sklearn.base` ................... From ab8873cf89764d11b9af92cf89e3d405da43e452 Mon Sep 17 00:00:00 2001 From: Charlie-XIAO Date: Thu, 4 May 2023 21:18:22 +0800 Subject: [PATCH 14/20] reverted suspicious additions --- doc/whats_new/v1.3.rst | 9 --------- 1 file changed, 9 deletions(-) diff --git a/doc/whats_new/v1.3.rst b/doc/whats_new/v1.3.rst index a3d385ef8c353..fdc97c1801064 100644 --- a/doc/whats_new/v1.3.rst +++ b/doc/whats_new/v1.3.rst @@ -151,15 +151,6 @@ Changelog :pr:`123456` by :user:`Joe Bloggs `. where 123456 is the *pull request* number, not the issue number. -:mod:`sklearn.feature_selection` -................................ - -- |Enhancement| All selectors in :mod:`sklearn.feature_selection` will preserve - a DataFrame's dtype when transformed. :pr:`25102` by `Thomas Fan`_. - -- |Fix| :class:`feature_selection.SequentialFeatureSelector`'s `cv` parameter - now supports generators. :pr:`25973` by :user:`Yao Xiao `. - :mod:`sklearn.base` ................... From a825b92637ea418c79168212c253ca0dce8fa38e Mon Sep 17 00:00:00 2001 From: Charlie-XIAO Date: Thu, 4 May 2023 21:18:42 +0800 Subject: [PATCH 15/20] resolved conversations --- sklearn/metrics/_plot/precision_recall_curve.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sklearn/metrics/_plot/precision_recall_curve.py b/sklearn/metrics/_plot/precision_recall_curve.py index 07289d3703065..c5802bd20b611 100644 --- a/sklearn/metrics/_plot/precision_recall_curve.py +++ b/sklearn/metrics/_plot/precision_recall_curve.py @@ -467,7 +467,8 @@ def from_predictions( y_true, y_pred, pos_label=pos_label, sample_weight=sample_weight ) - prevalence_pos_label = Counter(y_true)[pos_label] / len(y_true) + class_count = Counter(y_true) + prevalence_pos_label = class_count[pos_label] / class_count.total() viz = PrecisionRecallDisplay( precision=precision, From 2afede70ee7bff74083180c304295f555a47c6c6 Mon Sep 17 00:00:00 2001 From: Charlie-XIAO Date: Thu, 4 May 2023 21:35:12 +0800 Subject: [PATCH 16/20] counter object does not have total() --- sklearn/metrics/_plot/precision_recall_curve.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/metrics/_plot/precision_recall_curve.py b/sklearn/metrics/_plot/precision_recall_curve.py index c5802bd20b611..f3b4f167913bc 100644 --- a/sklearn/metrics/_plot/precision_recall_curve.py +++ b/sklearn/metrics/_plot/precision_recall_curve.py @@ -468,7 +468,7 @@ def from_predictions( ) class_count = Counter(y_true) - prevalence_pos_label = class_count[pos_label] / class_count.total() + prevalence_pos_label = class_count[pos_label] / sum(class_count.values()) viz = PrecisionRecallDisplay( precision=precision, From 7b21accb47bc0ac47197d08fd3b756ed3d40f1f0 Mon Sep 17 00:00:00 2001 From: Charlie-XIAO Date: Thu, 4 May 2023 21:37:03 +0800 Subject: [PATCH 17/20] reverted unnecessary change --- doc/whats_new/v1.3.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/whats_new/v1.3.rst b/doc/whats_new/v1.3.rst index fdc97c1801064..1d40a1941d421 100644 --- a/doc/whats_new/v1.3.rst +++ b/doc/whats_new/v1.3.rst @@ -151,6 +151,7 @@ Changelog :pr:`123456` by :user:`Joe Bloggs `. where 123456 is the *pull request* number, not the issue number. + :mod:`sklearn.base` ................... From 2457e9012eecee4e67771563ad1411d1fbe20dcc Mon Sep 17 00:00:00 2001 From: Charlie-XIAO Date: Mon, 15 May 2023 03:16:15 +0800 Subject: [PATCH 18/20] minor modification --- sklearn/metrics/_plot/precision_recall_curve.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/metrics/_plot/precision_recall_curve.py b/sklearn/metrics/_plot/precision_recall_curve.py index f3b4f167913bc..5e6249b1a1982 100644 --- a/sklearn/metrics/_plot/precision_recall_curve.py +++ b/sklearn/metrics/_plot/precision_recall_curve.py @@ -211,7 +211,7 @@ def plot( else: self.chance_level_ = None - if "label" in line_kwargs: + if "label" in line_kwargs or "label" in chance_level_line_kw: self.ax_.legend(loc="lower left") return self From e93a2437e26db700f5d96c444c9e5970905659f1 Mon Sep 17 00:00:00 2001 From: Charlie-XIAO Date: Mon, 15 May 2023 04:02:19 +0800 Subject: [PATCH 19/20] minor fix --- sklearn/metrics/_plot/precision_recall_curve.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/metrics/_plot/precision_recall_curve.py b/sklearn/metrics/_plot/precision_recall_curve.py index 5e6249b1a1982..45429a314ed23 100644 --- a/sklearn/metrics/_plot/precision_recall_curve.py +++ b/sklearn/metrics/_plot/precision_recall_curve.py @@ -211,7 +211,7 @@ def plot( else: self.chance_level_ = None - if "label" in line_kwargs or "label" in chance_level_line_kw: + if "label" in line_kwargs or plot_chance_level: self.ax_.legend(loc="lower left") return self From 833852140ebf6ef021ebb65f59d7e922bbf6c1f4 Mon Sep 17 00:00:00 2001 From: Charlie-XIAO Date: Wed, 17 May 2023 00:32:14 +0800 Subject: [PATCH 20/20] raises when plotting chance level but no prevalence level is given --- .../metrics/_plot/precision_recall_curve.py | 13 +++++++++++- .../tests/test_precision_recall_display.py | 20 +++++++++++++++++++ 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/sklearn/metrics/_plot/precision_recall_curve.py b/sklearn/metrics/_plot/precision_recall_curve.py index 45429a314ed23..5df70aa75b5fb 100644 --- a/sklearn/metrics/_plot/precision_recall_curve.py +++ b/sklearn/metrics/_plot/precision_recall_curve.py @@ -195,7 +195,17 @@ def plot( ylabel = "Precision" + info_pos_label self.ax_.set(xlabel=xlabel, ylabel=ylabel) - if plot_chance_level and self.prevalence_pos_label is not None: + if plot_chance_level: + if self.prevalence_pos_label is None: + raise ValueError( + "You must provide prevalence_pos_label when constructing the " + "PrecisionRecallDisplay object in order to plot the chance " + "level line. Alternatively, you may use " + "PrecisionRecallDisplay.from_estimator or " + "PrecisionRecallDisplay.from_predictions " + "to automatically set prevalence_pos_label" + ) + chance_level_line_kw = { "label": f"Chance level (AP = {self.prevalence_pos_label:0.2f})", "color": "k", @@ -203,6 +213,7 @@ def plot( } if chance_level_kw is not None: chance_level_line_kw.update(chance_level_kw) + (self.chance_level_,) = self.ax_.plot( (0, 1), (self.prevalence_pos_label, self.prevalence_pos_label), diff --git a/sklearn/metrics/_plot/tests/test_precision_recall_display.py b/sklearn/metrics/_plot/tests/test_precision_recall_display.py index a6696fcba4912..0bb6501dec89a 100644 --- a/sklearn/metrics/_plot/tests/test_precision_recall_display.py +++ b/sklearn/metrics/_plot/tests/test_precision_recall_display.py @@ -333,3 +333,23 @@ def test_precision_recall_prevalence_pos_label_reusable(pyplot, constructor_name # calling plot_chance_level=True should plot the chance level line display.plot(plot_chance_level=True) assert isinstance(display.chance_level_, mpl.lines.Line2D) + + +def test_precision_recall_raise_no_prevalence(pyplot): + # Check that raises correctly when plotting chance level with + # no prvelance_pos_label is provided + precision = np.array([1, 0.5, 0]) + recall = np.array([0, 0.5, 1]) + display = PrecisionRecallDisplay(precision, recall) + + msg = ( + "You must provide prevalence_pos_label when constructing the " + "PrecisionRecallDisplay object in order to plot the chance " + "level line. Alternatively, you may use " + "PrecisionRecallDisplay.from_estimator or " + "PrecisionRecallDisplay.from_predictions " + "to automatically set prevalence_pos_label" + ) + + with pytest.raises(ValueError, match=msg): + display.plot(plot_chance_level=True)