From 720f9ace581b8a9c9a0af203a0078f1b5eb76574 Mon Sep 17 00:00:00 2001 From: Thomas Fan Date: Tue, 20 Aug 2019 16:07:23 -0400 Subject: [PATCH 01/22] WIP --- sklearn/metrics/__init__.py | 4 + sklearn/metrics/_plot/precision_recall.py | 166 ++++++++++++++++++ .../_plot/tests/test_plot_precision_recall.py | 93 ++++++++++ 3 files changed, 263 insertions(+) create mode 100644 sklearn/metrics/_plot/precision_recall.py create mode 100644 sklearn/metrics/_plot/tests/test_plot_precision_recall.py diff --git a/sklearn/metrics/__init__.py b/sklearn/metrics/__init__.py index d0b65ad1f4cfa..a338fa9337e66 100644 --- a/sklearn/metrics/__init__.py +++ b/sklearn/metrics/__init__.py @@ -78,6 +78,8 @@ from ._plot.roc_curve import plot_roc_curve from ._plot.roc_curve import RocCurveDisplay +from ._plot.precision_recall import plot_precision_recall_curve +from ._plot.precision_recall import PrecisionRecallDisplay __all__ = [ @@ -133,7 +135,9 @@ 'pairwise_distances_argmin_min', 'pairwise_distances_chunked', 'pairwise_kernels', + 'plot_precision_recall_curve', 'plot_roc_curve', + 'PrecisionRecallDisplay', 'precision_recall_curve', 'precision_recall_fscore_support', 'precision_score', diff --git a/sklearn/metrics/_plot/precision_recall.py b/sklearn/metrics/_plot/precision_recall.py new file mode 100644 index 0000000000000..b14ec294925b9 --- /dev/null +++ b/sklearn/metrics/_plot/precision_recall.py @@ -0,0 +1,166 @@ +from .. import average_precision_score +from .. import precision_recall_curve + +from ...utils import check_matplotlib_support + + +class PrecisionRecallDisplay: + """Precision Recall visualization. + + It is recommend to use `sklearn.metrics.plot_precision_recall_curve` to + create a visualizer. All parameters are stored as attributes. + + Read more in the :ref:`User Guide `. + + Parameters + ----------- + precision : ndarray + Precision values. + + recall : ndarray + Recall values. + + average_precision : float + Average precision. + + estimator_name : str + Name of estimator. + + Attributes + ---------- + line_ : matplotlib Artist + Precision recall curve. + + ax_ : matplotlib Axes + Axes with precision recall curve. + + figure_ : matplotlib Figure + Figure containing the curve. + """ + + def __init__(self, precision, recall, average_precision, estimator_name): + self.precision = precision + self.recall = recall + self.average_precision = average_precision + self.estimator_name = estimator_name + + def plot(self, ax=None, name=None, **kwargs): + """Plot visualization + + Extra keyword arguments will be passed to matplotlib's ``plot``. + + Parameters + ---------- + ax : Matplotlib Axes, default=None + Axes object to plot on. If `None`, a new figure and axes is + created. + + name : str, default=None + Name of precision recall curve for labeling. If `None`, use the + name of the estimator. + """ + check_matplotlib_support("PrecisionRecallDisplay.plot") + import matplotlib.pyplot as plt + + if ax is None: + fig, ax = plt.subplots() + + name = self.estimator_name if name is None else name + + line_kwargs = { + "label": "{} (AP = {:0.2f})".format(name, self.average_precision), + "drawstyle": "steps-post" + } + line_kwargs.update(**kwargs) + + self.line_ = ax.plot(self.recall, self.precision, **line_kwargs)[0] + ax.set(xlabel="Recall", ylabel="Precision", ylim=[0.0, 1.05], + xlim=[0.0, 1.0]) + ax.legend(loc='lower left') + + self.ax_ = ax + self.figure_ = ax.figure + return self + + +def plot_precision_recall_curve(estimator, X, y, pos_label=None, + sample_weight=None, response_method="auto", + name=None, ax=None, **kwargs): + """Plot Precision Recall Curve. + + Extra keyword arguments will be passed to matplotlib's ``plot``. + + Read more in the :ref:`User Guide `. + + Parameters + ---------- + estimator : estimator instance + Trained classifier. + + X : {array-like, sparse matrix}, shape (n_samples, n_features) + Input values. + + y : array-like, shape (n_samples, ) + Target values. + + pos_label : int or str, default=None + The label of the positive class. + When `pos_label=None`, if y_true is in {-1, 1} or {0, 1}, + `pos_label` is set to 1, otherwise an error will be raised. + + sample_weight : array-like, shape (n_samples, ), default=None + Sample weights. + + response_method : {'predict_proba', 'decision_function', 'auto'} \ + default='auto' + Specifies whether to use `predict_proba` or `decision_function` as the + target response. If set to 'auto', `predict_proba` is tried first + and if it does not exist `decision_function` is tried next. + + name : str, default=None + Name for labeling curve. If `None`, the name of the + estimator is used. + + ax : matplotlib axes, default=None + Axes object to plot on. If `None`, a new figure and axes is created. + + Returns + ------- + viz : :class:`sklearn.metrics.PrecisionRecallDisplay` + Object that stores computed values. + """ + check_matplotlib_support("plot_precision_recall_curve") + + if response_method not in ("predict_proba", "decision_function", "auto"): + raise ValueError("response_method must be 'predict_proba', " + "'decision_function' or 'auto'") + + if response_method != "auto": + prediction_method = getattr(estimator, response_method, None) + if prediction_method is None: + raise ValueError( + "response method {} is not defined".format(response_method)) + else: + predict_proba = getattr(estimator, 'predict_proba', None) + decision_function = getattr(estimator, 'decision_function', None) + prediction_method = predict_proba or decision_function + + if prediction_method is None: + raise ValueError('response methods not defined') + + y_pred = prediction_method(X) + + if y_pred.ndim != 1: + if y_pred.shape[1] > 2: + raise ValueError("Estimator should solve a " + "binary classification problem") + y_pred = y_pred[:, 1] + + precision, recall, _ = precision_recall_curve(y, y_pred, + pos_label=pos_label, + sample_weight=sample_weight) + average_precision = average_precision_score(y, y_pred, + sample_weight=sample_weight) + viz = PrecisionRecallDisplay(precision, recall, average_precision, + estimator.__class__.__name__) + return viz.plot(ax=ax, name=name, **kwargs) diff --git a/sklearn/metrics/_plot/tests/test_plot_precision_recall.py b/sklearn/metrics/_plot/tests/test_plot_precision_recall.py new file mode 100644 index 0000000000000..6aa77c37eed7c --- /dev/null +++ b/sklearn/metrics/_plot/tests/test_plot_precision_recall.py @@ -0,0 +1,93 @@ +import pytest +import numpy as np +from numpy.testing import assert_allclose + +from sklearn.metrics import plot_precision_recall_curve +from sklearn.metrics import average_precision_score +from sklearn.metrics import precision_recall_curve +from sklearn.datasets import load_breast_cancer +from sklearn.datasets import load_iris +from sklearn.tree import DecisionTreeClassifier +from sklearn.linear_model import LogisticRegression + + +@pytest.fixture(scope="module") +def data_binary(): + return load_breast_cancer(return_X_y=True) + + +def test_error_non_binary(pyplot): + X, y = load_iris(return_X_y=True) + clf = DecisionTreeClassifier() + clf.fit(X, y) + + msg = "Estimator should solve a binary classification problem" + with pytest.raises(ValueError, match=msg): + plot_precision_recall_curve(clf, X, y) + + +@pytest.mark.parametrize( + "response_method, msg", + [("predict_proba", "response method predict_proba is not defined"), + ("decision_function", "response method decision_function is not defined"), + ("auto", "response methods not defined"), + ("bad_method", "response_method must be 'predict_proba', " + "'decision_function' or 'auto'")]) +def test_error_no_response(pyplot, data_binary, response_method, msg): + X, y = data_binary + + class MyClassifier: + pass + + clf = MyClassifier() + + with pytest.raises(ValueError, match=msg): + plot_precision_recall_curve(clf, X, y, response_method=response_method) + + +@pytest.mark.parametrize("response_method", + ["predict_proba", "decision_function"]) +@pytest.mark.parametrize("with_sample_weight", [True, False]) +def test_plot_precision_recall(pyplot, response_method, data_binary, + with_sample_weight): + X, y = data_binary + + lr = LogisticRegression() + lr.fit(X, y) + + if with_sample_weight: + rng = np.random.RandomState(42) + sample_weight = rng.randint(0, 4, size=X.shape[0]) + else: + sample_weight = None + + viz = plot_precision_recall_curve(lr, X, y, alpha=0.8, + sample_weight=sample_weight) + + y_score = getattr(lr, response_method)(X) + if y_score.ndim == 2: + y_score = y_score[:, 1] + + prec, recall, _ = precision_recall_curve(y, y_score, + sample_weight=sample_weight) + avg_prec = average_precision_score(y, y_score, sample_weight=sample_weight) + + assert_allclose(viz.precision, prec) + assert_allclose(viz.recall, recall) + assert_allclose(viz.average_precision, avg_prec) + + assert viz.estimator_name == "LogisticRegression" + + # cannot fail thanks to pyplot fixture + import matplotlib as mpl # noqal + assert isinstance(viz.line_, mpl.lines.Line2D) + assert viz.line_.get_alpha() == 0.8 + assert isinstance(viz.ax_, mpl.axes.Axes) + assert isinstance(viz.figure_, mpl.figure.Figure) + + expected_label = "LogisticRegression (AP = {:0.2f})".format(avg_prec) + assert viz.line_.get_label() == expected_label + assert viz.ax_.get_xlabel() == "Recall" + assert viz.ax_.get_ylabel() == "Precision" + assert_allclose(viz.ax_.get_xlim(), [0.0, 1.0]) + assert_allclose(viz.ax_.get_ylim(), [0.0, 1.05]) From 8ac4469774a3386b8ca0679c94ac3cb9245233c4 Mon Sep 17 00:00:00 2001 From: Thomas Fan Date: Thu, 22 Aug 2019 08:54:53 -0400 Subject: [PATCH 02/22] DOC Uses plot_precision_recall in example --- .../model_selection/plot_precision_recall.py | 26 ++++--------------- 1 file changed, 5 insertions(+), 21 deletions(-) diff --git a/examples/model_selection/plot_precision_recall.py b/examples/model_selection/plot_precision_recall.py index 203757e0136fc..d6401b205a571 100644 --- a/examples/model_selection/plot_precision_recall.py +++ b/examples/model_selection/plot_precision_recall.py @@ -134,25 +134,12 @@ # Plot the Precision-Recall curve # ................................ from sklearn.metrics import precision_recall_curve +from sklearn.metrics import plot_precision_recall_curve import matplotlib.pyplot as plt -from inspect import signature -precision, recall, _ = precision_recall_curve(y_test, y_score) - -# In matplotlib < 1.5, plt.fill_between does not have a 'step' argument -step_kwargs = ({'step': 'post'} - if 'step' in signature(plt.fill_between).parameters - else {}) -plt.step(recall, precision, color='b', alpha=0.2, - where='post') -plt.fill_between(recall, precision, alpha=0.2, color='b', **step_kwargs) - -plt.xlabel('Recall') -plt.ylabel('Precision') -plt.ylim([0.0, 1.05]) -plt.xlim([0.0, 1.0]) -plt.title('2-class Precision-Recall curve: AP={0:0.2f}'.format( - average_precision)) +disp = plot_precision_recall_curve(classifier, X_test, y_test, color='b') +disp.ax_.set_title('2-class Precision-Recall curve: ' + 'AP={0:0.2f}'.format(average_precision)) ############################################################################### # In multi-label settings @@ -212,10 +199,7 @@ # plt.figure() -plt.step(recall['micro'], precision['micro'], color='b', alpha=0.2, - where='post') -plt.fill_between(recall["micro"], precision["micro"], alpha=0.2, color='b', - **step_kwargs) +plt.step(recall['micro'], precision['micro'], color='b', where='post') plt.xlabel('Recall') plt.ylabel('Precision') From 0b8138369a6f0c7eadf46e162914ed7f7c67bd2a Mon Sep 17 00:00:00 2001 From: Thomas Fan Date: Thu, 22 Aug 2019 08:56:39 -0400 Subject: [PATCH 03/22] DOC Adds to userguide --- doc/modules/classes.rst | 2 ++ doc/visualizations.rst | 2 ++ 2 files changed, 4 insertions(+) diff --git a/doc/modules/classes.rst b/doc/modules/classes.rst index f6bd3d995c099..dd5ec45983216 100644 --- a/doc/modules/classes.rst +++ b/doc/modules/classes.rst @@ -1029,12 +1029,14 @@ See the :ref:`visualizations` section of the user guide for further details. :toctree: generated/ :template: function.rst + metrics.plot_precision_recall_curve metrics.plot_roc_curve .. autosummary:: :toctree: generated/ :template: class.rst + metrics.PrecisionRecallDisplay metrics.RocCurveDisplay diff --git a/doc/visualizations.rst b/doc/visualizations.rst index d21b90d0b4171..c8248902b3149 100644 --- a/doc/visualizations.rst +++ b/doc/visualizations.rst @@ -70,6 +70,7 @@ Functions .. autosummary:: + metrics.plot_precision_recall_curve metrics.plot_roc_curve @@ -80,4 +81,5 @@ Display Objects .. autosummary:: + metrics.PrecisionRecallDisplay metrics.RocCurveDisplay From e9c8131b2f89df96a3cda2c53655f9561439bb57 Mon Sep 17 00:00:00 2001 From: Thomas Fan Date: Tue, 3 Sep 2019 11:05:26 -0400 Subject: [PATCH 04/22] DOC style --- sklearn/metrics/_plot/precision_recall.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sklearn/metrics/_plot/precision_recall.py b/sklearn/metrics/_plot/precision_recall.py index b14ec294925b9..0836b18a0e089 100644 --- a/sklearn/metrics/_plot/precision_recall.py +++ b/sklearn/metrics/_plot/precision_recall.py @@ -97,10 +97,10 @@ def plot_precision_recall_curve(estimator, X, y, pos_label=None, estimator : estimator instance Trained classifier. - X : {array-like, sparse matrix}, shape (n_samples, n_features) + X : {array-like, sparse matrix} of shape (n_samples, n_features) Input values. - y : array-like, shape (n_samples, ) + y : array-like of shape (n_samples, ) Target values. pos_label : int or str, default=None @@ -108,7 +108,7 @@ def plot_precision_recall_curve(estimator, X, y, pos_label=None, When `pos_label=None`, if y_true is in {-1, 1} or {0, 1}, `pos_label` is set to 1, otherwise an error will be raised. - sample_weight : array-like, shape (n_samples, ), default=None + sample_weight : array-like of shape (n_samples, ), default=None Sample weights. response_method : {'predict_proba', 'decision_function', 'auto'} \ From 3d8686709c35d8e2a9f1ddd8e751a3c1952127c9 Mon Sep 17 00:00:00 2001 From: Thomas Fan Date: Thu, 5 Sep 2019 15:54:32 -0400 Subject: [PATCH 05/22] DOC Better docs --- sklearn/metrics/_plot/precision_recall.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/sklearn/metrics/_plot/precision_recall.py b/sklearn/metrics/_plot/precision_recall.py index 0836b18a0e089..4be80c7a93fbe 100644 --- a/sklearn/metrics/_plot/precision_recall.py +++ b/sklearn/metrics/_plot/precision_recall.py @@ -58,6 +58,11 @@ def plot(self, ax=None, name=None, **kwargs): name : str, default=None Name of precision recall curve for labeling. If `None`, use the name of the estimator. + + Returns + ------- + display : :class:`~sklearn.metrics.PrecisionRecallDisplay` + Object that stores computed values. """ check_matplotlib_support("PrecisionRecallDisplay.plot") import matplotlib.pyplot as plt @@ -100,7 +105,7 @@ def plot_precision_recall_curve(estimator, X, y, pos_label=None, X : {array-like, sparse matrix} of shape (n_samples, n_features) Input values. - y : array-like of shape (n_samples, ) + y : array-like of shape (n_samples,) Target values. pos_label : int or str, default=None @@ -108,14 +113,15 @@ def plot_precision_recall_curve(estimator, X, y, pos_label=None, When `pos_label=None`, if y_true is in {-1, 1} or {0, 1}, `pos_label` is set to 1, otherwise an error will be raised. - sample_weight : array-like of shape (n_samples, ), default=None + sample_weight : array-like of shape (n_samples,), default=None Sample weights. response_method : {'predict_proba', 'decision_function', 'auto'} \ default='auto' - Specifies whether to use `predict_proba` or `decision_function` as the - target response. If set to 'auto', `predict_proba` is tried first - and if it does not exist `decision_function` is tried next. + Specifies whether to use :term:`predict_proba` or + :term:`decision_function` as the target response. If set to 'auto', + :term:`predict_proba` is tried first and if it does not exist + :term:`decision_function` is tried next. name : str, default=None Name for labeling curve. If `None`, the name of the @@ -126,7 +132,7 @@ def plot_precision_recall_curve(estimator, X, y, pos_label=None, Returns ------- - viz : :class:`sklearn.metrics.PrecisionRecallDisplay` + display : :class:`~sklearn.metrics.PrecisionRecallDisplay` Object that stores computed values. """ check_matplotlib_support("plot_precision_recall_curve") From 66deaac142524c41bc3357db5debb168262be683 Mon Sep 17 00:00:00 2001 From: Thomas Fan Date: Mon, 9 Sep 2019 10:24:10 -0400 Subject: [PATCH 06/22] CLN --- sklearn/metrics/_plot/tests/test_plot_precision_recall.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/metrics/_plot/tests/test_plot_precision_recall.py b/sklearn/metrics/_plot/tests/test_plot_precision_recall.py index 6aa77c37eed7c..cc1e1a57b1d51 100644 --- a/sklearn/metrics/_plot/tests/test_plot_precision_recall.py +++ b/sklearn/metrics/_plot/tests/test_plot_precision_recall.py @@ -79,7 +79,7 @@ def test_plot_precision_recall(pyplot, response_method, data_binary, assert viz.estimator_name == "LogisticRegression" # cannot fail thanks to pyplot fixture - import matplotlib as mpl # noqal + import matplotlib as mpl # noqa assert isinstance(viz.line_, mpl.lines.Line2D) assert viz.line_.get_alpha() == 0.8 assert isinstance(viz.ax_, mpl.axes.Axes) From 10dc97ef594f6d0d0860663d847878da622f6874 Mon Sep 17 00:00:00 2001 From: Thomas Fan Date: Thu, 19 Sep 2019 12:46:20 -0400 Subject: [PATCH 07/22] CLN Address @glemaitre comments --- .../model_selection/plot_precision_recall.py | 4 +- sklearn/metrics/_plot/precision_recall.py | 46 +++++++------ .../_plot/tests/test_plot_precision_recall.py | 65 ++++++++++++------- 3 files changed, 71 insertions(+), 44 deletions(-) diff --git a/examples/model_selection/plot_precision_recall.py b/examples/model_selection/plot_precision_recall.py index d6401b205a571..9b71b85e9b37a 100644 --- a/examples/model_selection/plot_precision_recall.py +++ b/examples/model_selection/plot_precision_recall.py @@ -137,7 +137,7 @@ from sklearn.metrics import plot_precision_recall_curve import matplotlib.pyplot as plt -disp = plot_precision_recall_curve(classifier, X_test, y_test, color='b') +disp = plot_precision_recall_curve(classifier, X_test, y_test) disp.ax_.set_title('2-class Precision-Recall curve: ' 'AP={0:0.2f}'.format(average_precision)) @@ -199,7 +199,7 @@ # plt.figure() -plt.step(recall['micro'], precision['micro'], color='b', where='post') +plt.step(recall['micro'], precision['micro'], where='post') plt.xlabel('Recall') plt.ylabel('Precision') diff --git a/sklearn/metrics/_plot/precision_recall.py b/sklearn/metrics/_plot/precision_recall.py index 4be80c7a93fbe..b1c29ace865bf 100644 --- a/sklearn/metrics/_plot/precision_recall.py +++ b/sklearn/metrics/_plot/precision_recall.py @@ -2,6 +2,7 @@ from .. import precision_recall_curve from ...utils import check_matplotlib_support +from ...utils.validation import check_is_fitted class PrecisionRecallDisplay: @@ -14,10 +15,10 @@ class PrecisionRecallDisplay: Parameters ----------- - precision : ndarray + precision : ndarray of shape (n_thresholds + 1, ) Precision values. - recall : ndarray + recall : ndarray of shape (n_thresholds + 1,) Recall values. average_precision : float @@ -44,8 +45,8 @@ def __init__(self, precision, recall, average_precision, estimator_name): self.average_precision = average_precision self.estimator_name = estimator_name - def plot(self, ax=None, name=None, **kwargs): - """Plot visualization + def plot(self, ax=None, label_name=None, **kwargs): + """Plot visualization. Extra keyword arguments will be passed to matplotlib's ``plot``. @@ -55,7 +56,7 @@ def plot(self, ax=None, name=None, **kwargs): Axes object to plot on. If `None`, a new figure and axes is created. - name : str, default=None + label_name : str, default=None Name of precision recall curve for labeling. If `None`, use the name of the estimator. @@ -70,15 +71,16 @@ def plot(self, ax=None, name=None, **kwargs): if ax is None: fig, ax = plt.subplots() - name = self.estimator_name if name is None else name + label_name = self.estimator_name if label_name is None else label_name line_kwargs = { - "label": "{} (AP = {:0.2f})".format(name, self.average_precision), + "label": "{} (AP = {:0.2f})".format(label_name, + self.average_precision), "drawstyle": "steps-post" } line_kwargs.update(**kwargs) - self.line_ = ax.plot(self.recall, self.precision, **line_kwargs)[0] + self.line_, = ax.plot(self.recall, self.precision, **line_kwargs) ax.set(xlabel="Recall", ylabel="Precision", ylim=[0.0, 1.05], xlim=[0.0, 1.0]) ax.legend(loc='lower left') @@ -90,8 +92,8 @@ def plot(self, ax=None, name=None, **kwargs): def plot_precision_recall_curve(estimator, X, y, pos_label=None, sample_weight=None, response_method="auto", - name=None, ax=None, **kwargs): - """Plot Precision Recall Curve. + label_name=None, ax=None, **kwargs): + """Plot Precision Recall Curve for binary classifers. Extra keyword arguments will be passed to matplotlib's ``plot``. @@ -106,7 +108,7 @@ def plot_precision_recall_curve(estimator, X, y, pos_label=None, Input values. y : array-like of shape (n_samples,) - Target values. + Binary target values. pos_label : int or str, default=None The label of the positive class. @@ -117,13 +119,13 @@ def plot_precision_recall_curve(estimator, X, y, pos_label=None, Sample weights. response_method : {'predict_proba', 'decision_function', 'auto'} \ - default='auto' + default='auto' Specifies whether to use :term:`predict_proba` or :term:`decision_function` as the target response. If set to 'auto', :term:`predict_proba` is tried first and if it does not exist :term:`decision_function` is tried next. - name : str, default=None + label_name : str, default=None Name for labeling curve. If `None`, the name of the estimator is used. @@ -136,27 +138,33 @@ def plot_precision_recall_curve(estimator, X, y, pos_label=None, Object that stores computed values. """ check_matplotlib_support("plot_precision_recall_curve") + check_is_fitted(estimator) if response_method not in ("predict_proba", "decision_function", "auto"): raise ValueError("response_method must be 'predict_proba', " "'decision_function' or 'auto'") + error_msg = "response method {} not defined for estimator {}" + if response_method != "auto": prediction_method = getattr(estimator, response_method, None) if prediction_method is None: - raise ValueError( - "response method {} is not defined".format(response_method)) + raise ValueError(error_msg.format(response_method, + estimator.__class__.__name__)) + is_predict_proba = response_method == 'predict_proba' else: predict_proba = getattr(estimator, 'predict_proba', None) decision_function = getattr(estimator, 'decision_function', None) prediction_method = predict_proba or decision_function - if prediction_method is None: - raise ValueError('response methods not defined') + raise ValueError(error_msg.format( + "decision_function or predict_proba", + estimator.__class__.__name__)) + is_predict_proba = prediction_method == predict_proba y_pred = prediction_method(X) - if y_pred.ndim != 1: + if is_predict_proba and y_pred.ndim != 1: if y_pred.shape[1] > 2: raise ValueError("Estimator should solve a " "binary classification problem") @@ -169,4 +177,4 @@ def plot_precision_recall_curve(estimator, X, y, pos_label=None, sample_weight=sample_weight) viz = PrecisionRecallDisplay(precision, recall, average_precision, estimator.__class__.__name__) - return viz.plot(ax=ax, name=name, **kwargs) + return viz.plot(ax=ax, label_name=label_name, **kwargs) diff --git a/sklearn/metrics/_plot/tests/test_plot_precision_recall.py b/sklearn/metrics/_plot/tests/test_plot_precision_recall.py index cc1e1a57b1d51..6555b22ce400f 100644 --- a/sklearn/metrics/_plot/tests/test_plot_precision_recall.py +++ b/sklearn/metrics/_plot/tests/test_plot_precision_recall.py @@ -2,6 +2,7 @@ import numpy as np from numpy.testing import assert_allclose +from sklearn.base import BaseEstimator from sklearn.metrics import plot_precision_recall_curve from sklearn.metrics import average_precision_score from sklearn.metrics import precision_recall_curve @@ -9,6 +10,7 @@ from sklearn.datasets import load_iris from sklearn.tree import DecisionTreeClassifier from sklearn.linear_model import LogisticRegression +from sklearn.exceptions import NotFittedError @pytest.fixture(scope="module") @@ -26,20 +28,32 @@ def test_error_non_binary(pyplot): plot_precision_recall_curve(clf, X, y) +def test_unfitted_classifier(pyplot, data_binary): + X, y = data_binary + clf = DecisionTreeClassifier() + with pytest.raises(NotFittedError): + plot_precision_recall_curve(clf, X, y) + + @pytest.mark.parametrize( "response_method, msg", - [("predict_proba", "response method predict_proba is not defined"), - ("decision_function", "response method decision_function is not defined"), - ("auto", "response methods not defined"), + [("predict_proba", "response method predict_proba not defined for " + "estimator MyClassifier"), + ("decision_function", "response method decision_function not defined for " + "estimator MyClassifier"), + ("auto", "response method decision_function or predict_proba not defined " + "for estimator MyClassifier"), ("bad_method", "response_method must be 'predict_proba', " "'decision_function' or 'auto'")]) def test_error_no_response(pyplot, data_binary, response_method, msg): X, y = data_binary - class MyClassifier: - pass + class MyClassifier(BaseEstimator): + def fit(self, X, y): + self.fitted_ = True + return self - clf = MyClassifier() + clf = MyClassifier().fit(X, y) with pytest.raises(ValueError, match=msg): plot_precision_recall_curve(clf, X, y, response_method=response_method) @@ -52,8 +66,7 @@ def test_plot_precision_recall(pyplot, response_method, data_binary, with_sample_weight): X, y = data_binary - lr = LogisticRegression() - lr.fit(X, y) + lr = LogisticRegression().fit(X, y) if with_sample_weight: rng = np.random.RandomState(42) @@ -61,8 +74,9 @@ def test_plot_precision_recall(pyplot, response_method, data_binary, else: sample_weight = None - viz = plot_precision_recall_curve(lr, X, y, alpha=0.8, - sample_weight=sample_weight) + disp = plot_precision_recall_curve(lr, X, y, alpha=0.8, + response_method=response_method, + sample_weight=sample_weight) y_score = getattr(lr, response_method)(X) if y_score.ndim == 2: @@ -72,22 +86,27 @@ def test_plot_precision_recall(pyplot, response_method, data_binary, sample_weight=sample_weight) avg_prec = average_precision_score(y, y_score, sample_weight=sample_weight) - assert_allclose(viz.precision, prec) - assert_allclose(viz.recall, recall) - assert_allclose(viz.average_precision, avg_prec) + assert_allclose(disp.precision, prec) + assert_allclose(disp.recall, recall) + assert disp.average_precision == pytest.approx(avg_prec) - assert viz.estimator_name == "LogisticRegression" + assert disp.estimator_name == "LogisticRegression" # cannot fail thanks to pyplot fixture import matplotlib as mpl # noqa - assert isinstance(viz.line_, mpl.lines.Line2D) - assert viz.line_.get_alpha() == 0.8 - assert isinstance(viz.ax_, mpl.axes.Axes) - assert isinstance(viz.figure_, mpl.figure.Figure) + assert isinstance(disp.line_, mpl.lines.Line2D) + assert disp.line_.get_alpha() == 0.8 + assert isinstance(disp.ax_, mpl.axes.Axes) + assert isinstance(disp.figure_, mpl.figure.Figure) expected_label = "LogisticRegression (AP = {:0.2f})".format(avg_prec) - assert viz.line_.get_label() == expected_label - assert viz.ax_.get_xlabel() == "Recall" - assert viz.ax_.get_ylabel() == "Precision" - assert_allclose(viz.ax_.get_xlim(), [0.0, 1.0]) - assert_allclose(viz.ax_.get_ylim(), [0.0, 1.05]) + assert disp.line_.get_label() == expected_label + assert disp.ax_.get_xlabel() == "Recall" + assert disp.ax_.get_ylabel() == "Precision" + assert_allclose(disp.ax_.get_xlim(), [0.0, 1.0]) + assert_allclose(disp.ax_.get_ylim(), [0.0, 1.05]) + + # draw again with another label + disp.plot(label_name="MySpecialEstimator") + expected_label = "MySpecialEstimator (AP = {:0.2f})".format(avg_prec) + assert disp.line_.get_label() == expected_label From affec162c79b88951141cc3818c671efac777df9 Mon Sep 17 00:00:00 2001 From: Thomas Fan Date: Fri, 20 Sep 2019 14:53:25 -0400 Subject: [PATCH 08/22] CLN Address @glemaitre comments --- doc/whats_new/v0.22.rst | 3 +++ sklearn/metrics/_plot/precision_recall.py | 18 ++++++++++++++---- .../_plot/tests/test_plot_precision_recall.py | 7 ++++++- 3 files changed, 23 insertions(+), 5 deletions(-) diff --git a/doc/whats_new/v0.22.rst b/doc/whats_new/v0.22.rst index b99c9b0e3f334..4dd02d417ce10 100644 --- a/doc/whats_new/v0.22.rst +++ b/doc/whats_new/v0.22.rst @@ -286,6 +286,9 @@ Changelog curves. This function introduces the visualization API described in the :ref:`User Guide `. :pr:`14357` by `Thomas Fan`_. +- |Feature| :func:`metrics.plot_precision_recall_curve` has been added to plot + precision recall curves. :pr:`14936` by `Thomas Fan`_. + - |Feature| Added multiclass support to :func:`metrics.roc_auc_score`. :issue:`12789` by :user:`Kathy Chen `, :user:`Mohamed Maskani `, and :user:`Thomas Fan `. diff --git a/sklearn/metrics/_plot/precision_recall.py b/sklearn/metrics/_plot/precision_recall.py index b1c29ace865bf..6321df620d40e 100644 --- a/sklearn/metrics/_plot/precision_recall.py +++ b/sklearn/metrics/_plot/precision_recall.py @@ -2,6 +2,7 @@ from .. import precision_recall_curve from ...utils import check_matplotlib_support +from ...utils.multiclass import type_of_target from ...utils.validation import check_is_fitted @@ -15,7 +16,7 @@ class PrecisionRecallDisplay: Parameters ----------- - precision : ndarray of shape (n_thresholds + 1, ) + precision : ndarray of shape (n_thresholds + 1,) Precision values. recall : ndarray of shape (n_thresholds + 1,) @@ -48,7 +49,7 @@ def __init__(self, precision, recall, average_precision, estimator_name): def plot(self, ax=None, label_name=None, **kwargs): """Plot visualization. - Extra keyword arguments will be passed to matplotlib's ``plot``. + Extra keyword arguments will be passed to matplotlib's `plot`. Parameters ---------- @@ -60,6 +61,9 @@ def plot(self, ax=None, label_name=None, **kwargs): Name of precision recall curve for labeling. If `None`, use the name of the estimator. + **kwargs : dict + Keyword arguments to be passed to matplotlib's `plot`. + Returns ------- display : :class:`~sklearn.metrics.PrecisionRecallDisplay` @@ -95,7 +99,7 @@ def plot_precision_recall_curve(estimator, X, y, pos_label=None, label_name=None, ax=None, **kwargs): """Plot Precision Recall Curve for binary classifers. - Extra keyword arguments will be passed to matplotlib's ``plot``. + Extra keyword arguments will be passed to matplotlib's `plot`. Read more in the :ref:`User Guide `. @@ -132,6 +136,9 @@ def plot_precision_recall_curve(estimator, X, y, pos_label=None, ax : matplotlib axes, default=None Axes object to plot on. If `None`, a new figure and axes is created. + **kwargs : dict + Keyword arguments to be passed to matplotlib's `plot`. + Returns ------- display : :class:`~sklearn.metrics.PrecisionRecallDisplay` @@ -144,8 +151,11 @@ def plot_precision_recall_curve(estimator, X, y, pos_label=None, raise ValueError("response_method must be 'predict_proba', " "'decision_function' or 'auto'") - error_msg = "response method {} not defined for estimator {}" + type_y = type_of_target(y) + if type_y != 'binary': + raise ValueError("{} format is not supported".format(type_y)) + error_msg = "response method {} not defined for estimator {}" if response_method != "auto": prediction_method = getattr(estimator, response_method, None) if prediction_method is None: diff --git a/sklearn/metrics/_plot/tests/test_plot_precision_recall.py b/sklearn/metrics/_plot/tests/test_plot_precision_recall.py index 6555b22ce400f..e6d199952d5a6 100644 --- a/sklearn/metrics/_plot/tests/test_plot_precision_recall.py +++ b/sklearn/metrics/_plot/tests/test_plot_precision_recall.py @@ -23,10 +23,15 @@ def test_error_non_binary(pyplot): clf = DecisionTreeClassifier() clf.fit(X, y) - msg = "Estimator should solve a binary classification problem" + msg = "multiclass format is not supported" with pytest.raises(ValueError, match=msg): plot_precision_recall_curve(clf, X, y) + msg = "Estimator should solve a binary classification problem" + y_binary = y == 1 + with pytest.raises(ValueError, match=msg): + plot_precision_recall_curve(clf, X, y_binary) + def test_unfitted_classifier(pyplot, data_binary): X, y = data_binary From 179602041e3b37c573cad01842d98f7c9f92bbaa Mon Sep 17 00:00:00 2001 From: Thomas Fan Date: Tue, 24 Sep 2019 11:24:29 -0400 Subject: [PATCH 09/22] DOC Remove whatsnew --- doc/whats_new/v0.22.rst | 4 ---- 1 file changed, 4 deletions(-) diff --git a/doc/whats_new/v0.22.rst b/doc/whats_new/v0.22.rst index 8ff4bf9c70ebd..e1a06085ad0a7 100644 --- a/doc/whats_new/v0.22.rst +++ b/doc/whats_new/v0.22.rst @@ -379,10 +379,6 @@ Changelog Gain and Normalized Discounted Cumulative Gain. :pr:`9951` by :user:`Jérôme Dockès `. -- |MajorFeature| :func:`metrics.plot_roc_curve` has been added to plot roc - curves. This function introduces the visualization API described in - the :ref:`User Guide `. :pr:`14357` by `Thomas Fan`_. - - |Feature| :func:`metrics.plot_precision_recall_curve` has been added to plot precision recall curves. :pr:`14936` by `Thomas Fan`_. From d7d448fe7d0846f0970bba170a1db3feed65bf30 Mon Sep 17 00:00:00 2001 From: Thomas Fan Date: Tue, 24 Sep 2019 11:35:06 -0400 Subject: [PATCH 10/22] DOC Style --- sklearn/metrics/_plot/precision_recall.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/metrics/_plot/precision_recall.py b/sklearn/metrics/_plot/precision_recall.py index 6321df620d40e..f69bccb68ee7f 100644 --- a/sklearn/metrics/_plot/precision_recall.py +++ b/sklearn/metrics/_plot/precision_recall.py @@ -122,7 +122,7 @@ def plot_precision_recall_curve(estimator, X, y, pos_label=None, sample_weight : array-like of shape (n_samples,), default=None Sample weights. - response_method : {'predict_proba', 'decision_function', 'auto'} \ + response_method : {'predict_proba', 'decision_function', 'auto'}, \ default='auto' Specifies whether to use :term:`predict_proba` or :term:`decision_function` as the target response. If set to 'auto', From 294c29a938d6968efd53907edd5788859c8c4356 Mon Sep 17 00:00:00 2001 From: Thomas J Fan Date: Wed, 25 Sep 2019 15:35:16 -0400 Subject: [PATCH 11/22] CLN Addresses @amuller comments --- sklearn/metrics/_plot/precision_recall.py | 2 +- sklearn/metrics/_plot/tests/test_plot_precision_recall.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/metrics/_plot/precision_recall.py b/sklearn/metrics/_plot/precision_recall.py index f69bccb68ee7f..5ebe824cc1148 100644 --- a/sklearn/metrics/_plot/precision_recall.py +++ b/sklearn/metrics/_plot/precision_recall.py @@ -174,7 +174,7 @@ def plot_precision_recall_curve(estimator, X, y, pos_label=None, y_pred = prediction_method(X) - if is_predict_proba and y_pred.ndim != 1: + if is_predict_proba: if y_pred.shape[1] > 2: raise ValueError("Estimator should solve a " "binary classification problem") diff --git a/sklearn/metrics/_plot/tests/test_plot_precision_recall.py b/sklearn/metrics/_plot/tests/test_plot_precision_recall.py index e6d199952d5a6..19e88041bdfcb 100644 --- a/sklearn/metrics/_plot/tests/test_plot_precision_recall.py +++ b/sklearn/metrics/_plot/tests/test_plot_precision_recall.py @@ -28,7 +28,7 @@ def test_error_non_binary(pyplot): plot_precision_recall_curve(clf, X, y) msg = "Estimator should solve a binary classification problem" - y_binary = y == 1 + y_binary = y >= 1 with pytest.raises(ValueError, match=msg): plot_precision_recall_curve(clf, X, y_binary) From dbe9a3a6d9424626c44d2a039842ee20477bc15a Mon Sep 17 00:00:00 2001 From: Thomas J Fan Date: Wed, 25 Sep 2019 16:56:07 -0400 Subject: [PATCH 12/22] CLN Addresses @amuller comments --- sklearn/metrics/_plot/precision_recall.py | 3 ++- .../_plot/tests/test_plot_precision_recall.py | 19 ++++++++++++------- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/sklearn/metrics/_plot/precision_recall.py b/sklearn/metrics/_plot/precision_recall.py index 5ebe824cc1148..c1764dd30c453 100644 --- a/sklearn/metrics/_plot/precision_recall.py +++ b/sklearn/metrics/_plot/precision_recall.py @@ -153,7 +153,8 @@ def plot_precision_recall_curve(estimator, X, y, pos_label=None, type_y = type_of_target(y) if type_y != 'binary': - raise ValueError("{} format is not supported".format(type_y)) + raise ValueError( + "only binary format is not supported, got {}".format(type_y)) error_msg = "response method {} not defined for estimator {}" if response_method != "auto": diff --git a/sklearn/metrics/_plot/tests/test_plot_precision_recall.py b/sklearn/metrics/_plot/tests/test_plot_precision_recall.py index 19e88041bdfcb..220efcbe943f1 100644 --- a/sklearn/metrics/_plot/tests/test_plot_precision_recall.py +++ b/sklearn/metrics/_plot/tests/test_plot_precision_recall.py @@ -8,6 +8,7 @@ from sklearn.metrics import precision_recall_curve from sklearn.datasets import load_breast_cancer from sklearn.datasets import load_iris +from sklearn.datasets import make_classification from sklearn.tree import DecisionTreeClassifier from sklearn.linear_model import LogisticRegression from sklearn.exceptions import NotFittedError @@ -23,22 +24,26 @@ def test_error_non_binary(pyplot): clf = DecisionTreeClassifier() clf.fit(X, y) - msg = "multiclass format is not supported" - with pytest.raises(ValueError, match=msg): - plot_precision_recall_curve(clf, X, y) - msg = "Estimator should solve a binary classification problem" - y_binary = y >= 1 with pytest.raises(ValueError, match=msg): - plot_precision_recall_curve(clf, X, y_binary) + plot_precision_recall_curve(clf, X, y) -def test_unfitted_classifier(pyplot, data_binary): +def test_error_binary(pyplot, data_binary): X, y = data_binary clf = DecisionTreeClassifier() with pytest.raises(NotFittedError): plot_precision_recall_curve(clf, X, y) + n_samples = X.shape[0] + _, y_multiclass = make_classification(n_samples=n_samples, + n_informative=3, + n_classes=3) + clf.fit(X, y) + msg = "only binary format is not supported, got multiclass" + with pytest.raises(ValueError, match=msg): + plot_precision_recall_curve(clf, X, y_multiclass) + @pytest.mark.parametrize( "response_method, msg", From fd1cc4429be1e1e0853aabd4248d173ea36aacc1 Mon Sep 17 00:00:00 2001 From: Thomas J Fan Date: Wed, 2 Oct 2019 09:50:31 -0400 Subject: [PATCH 13/22] TST Clearier error messages --- sklearn/metrics/_plot/precision_recall.py | 2 +- .../_plot/tests/test_plot_precision_recall.py | 50 ++++++++----------- 2 files changed, 22 insertions(+), 30 deletions(-) diff --git a/sklearn/metrics/_plot/precision_recall.py b/sklearn/metrics/_plot/precision_recall.py index c1764dd30c453..7ff809c629a9f 100644 --- a/sklearn/metrics/_plot/precision_recall.py +++ b/sklearn/metrics/_plot/precision_recall.py @@ -154,7 +154,7 @@ def plot_precision_recall_curve(estimator, X, y, pos_label=None, type_y = type_of_target(y) if type_y != 'binary': raise ValueError( - "only binary format is not supported, got {}".format(type_y)) + "Only binary format is supported, got {}".format(type_y)) error_msg = "response method {} not defined for estimator {}" if response_method != "auto": diff --git a/sklearn/metrics/_plot/tests/test_plot_precision_recall.py b/sklearn/metrics/_plot/tests/test_plot_precision_recall.py index 220efcbe943f1..0afec5e4a1991 100644 --- a/sklearn/metrics/_plot/tests/test_plot_precision_recall.py +++ b/sklearn/metrics/_plot/tests/test_plot_precision_recall.py @@ -6,43 +6,36 @@ from sklearn.metrics import plot_precision_recall_curve from sklearn.metrics import average_precision_score from sklearn.metrics import precision_recall_curve -from sklearn.datasets import load_breast_cancer -from sklearn.datasets import load_iris from sklearn.datasets import make_classification from sklearn.tree import DecisionTreeClassifier from sklearn.linear_model import LogisticRegression from sklearn.exceptions import NotFittedError -@pytest.fixture(scope="module") -def data_binary(): - return load_breast_cancer(return_X_y=True) +def test_errors(pyplot): + X, y_binary = make_classification(n_classes=2, n_samples=50, + random_state=0) + # Unfitted classifer + binary_clf = DecisionTreeClassifier() + with pytest.raises(NotFittedError): + plot_precision_recall_curve(binary_clf, X, y_binary) + binary_clf.fit(X, y_binary) -def test_error_non_binary(pyplot): - X, y = load_iris(return_X_y=True) - clf = DecisionTreeClassifier() - clf.fit(X, y) + _, y_multiclass = make_classification(n_samples=X.shape[0], + n_informative=3, + n_classes=3) + multi_clf = DecisionTreeClassifier().fit(X, y_multiclass) + # Fitted multiclass classifier with binary data msg = "Estimator should solve a binary classification problem" with pytest.raises(ValueError, match=msg): - plot_precision_recall_curve(clf, X, y) - + plot_precision_recall_curve(multi_clf, X, y_binary) -def test_error_binary(pyplot, data_binary): - X, y = data_binary - clf = DecisionTreeClassifier() - with pytest.raises(NotFittedError): - plot_precision_recall_curve(clf, X, y) - - n_samples = X.shape[0] - _, y_multiclass = make_classification(n_samples=n_samples, - n_informative=3, - n_classes=3) - clf.fit(X, y) - msg = "only binary format is not supported, got multiclass" + # Fitted binary classifier with multiclass data + msg = "Only binary format is supported, got multiclass" with pytest.raises(ValueError, match=msg): - plot_precision_recall_curve(clf, X, y_multiclass) + plot_precision_recall_curve(binary_clf, X, y_multiclass) @pytest.mark.parametrize( @@ -55,8 +48,8 @@ def test_error_binary(pyplot, data_binary): "for estimator MyClassifier"), ("bad_method", "response_method must be 'predict_proba', " "'decision_function' or 'auto'")]) -def test_error_no_response(pyplot, data_binary, response_method, msg): - X, y = data_binary +def test_error_no_response(pyplot, response_method, msg): + X, y = make_classification(n_classes=2, n_samples=50, random_state=0) class MyClassifier(BaseEstimator): def fit(self, X, y): @@ -72,9 +65,8 @@ def fit(self, X, y): @pytest.mark.parametrize("response_method", ["predict_proba", "decision_function"]) @pytest.mark.parametrize("with_sample_weight", [True, False]) -def test_plot_precision_recall(pyplot, response_method, data_binary, - with_sample_weight): - X, y = data_binary +def test_plot_precision_recall(pyplot, response_method, with_sample_weight): + X, y = make_classification(n_classes=2, n_samples=50, random_state=0) lr = LogisticRegression().fit(X, y) From fdb60aecc4e114e01a8d189bd05b78926beae9b1 Mon Sep 17 00:00:00 2001 From: Thomas J Fan Date: Wed, 2 Oct 2019 09:56:21 -0400 Subject: [PATCH 14/22] TST Modify test name --- sklearn/metrics/_plot/tests/test_plot_precision_recall.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/metrics/_plot/tests/test_plot_precision_recall.py b/sklearn/metrics/_plot/tests/test_plot_precision_recall.py index 0afec5e4a1991..7ff11bd7b4e2a 100644 --- a/sklearn/metrics/_plot/tests/test_plot_precision_recall.py +++ b/sklearn/metrics/_plot/tests/test_plot_precision_recall.py @@ -48,7 +48,7 @@ def test_errors(pyplot): "for estimator MyClassifier"), ("bad_method", "response_method must be 'predict_proba', " "'decision_function' or 'auto'")]) -def test_error_no_response(pyplot, response_method, msg): +def test_error_bad_response(pyplot, response_method, msg): X, y = make_classification(n_classes=2, n_samples=50, random_state=0) class MyClassifier(BaseEstimator): From abbbb9d6ae3f32388075cf9f1aa7795f2c4197b8 Mon Sep 17 00:00:00 2001 From: Thomas J Fan Date: Wed, 6 Nov 2019 16:02:21 -0500 Subject: [PATCH 15/22] BUG Quick fix --- sklearn/metrics/_plot/precision_recall.py | 13 +++++++++---- .../_plot/tests/test_plot_precision_recall.py | 19 +++++++++++++++---- 2 files changed, 24 insertions(+), 8 deletions(-) diff --git a/sklearn/metrics/_plot/precision_recall.py b/sklearn/metrics/_plot/precision_recall.py index 2878da9d4c86d..3842627c3ac3a 100644 --- a/sklearn/metrics/_plot/precision_recall.py +++ b/sklearn/metrics/_plot/precision_recall.py @@ -2,7 +2,6 @@ from .. import precision_recall_curve from ...utils import check_matplotlib_support -from ...utils.multiclass import type_of_target from ...utils.validation import check_is_fitted from ...base import is_classifier @@ -143,9 +142,14 @@ def plot_precision_recall_curve(estimator, X, y, raise ValueError("response_method must be 'predict_proba', " "'decision_function' or 'auto'") - if not is_classifier(estimator): - raise ValueError("{} should solve a binary classification " - "problem".format(estimator.__class__.__name__)) + classificaiton_error = ("{} should solve a binary classification " + "problem".format(estimator.__class__.__name__)) + if is_classifier(estimator): + if len(estimator.classes_) != 2: + raise ValueError(classificaiton_error) + pos_label = estimator.classes_[1] + else: + raise ValueError(classificaiton_error) error_msg = "response method {} not defined for estimator {}" if response_method != "auto": @@ -170,6 +174,7 @@ def plot_precision_recall_curve(estimator, X, y, y_pred = y_pred[:, 1] precision, recall, _ = precision_recall_curve(y, y_pred, + pos_label=pos_label, sample_weight=sample_weight) average_precision = average_precision_score(y, y_pred, sample_weight=sample_weight) diff --git a/sklearn/metrics/_plot/tests/test_plot_precision_recall.py b/sklearn/metrics/_plot/tests/test_plot_precision_recall.py index 4f08214ab34c4..5eee755af8472 100644 --- a/sklearn/metrics/_plot/tests/test_plot_precision_recall.py +++ b/sklearn/metrics/_plot/tests/test_plot_precision_recall.py @@ -2,12 +2,12 @@ import numpy as np from numpy.testing import assert_allclose -from sklearn.base import BaseEstimator +from sklearn.base import BaseEstimator, ClassifierMixin from sklearn.metrics import plot_precision_recall_curve from sklearn.metrics import average_precision_score from sklearn.metrics import precision_recall_curve from sklearn.datasets import make_classification -from sklearn.tree import DecisionTreeClassifier +from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor from sklearn.linear_model import LogisticRegression from sklearn.exceptions import NotFittedError @@ -27,10 +27,20 @@ def test_errors(pyplot): multi_clf = DecisionTreeClassifier().fit(X, y_multiclass) # Fitted multiclass classifier with binary data - msg = "Estimator should solve a binary classification problem" + msg = "DecisionTreeClassifier should solve a binary classification problem" with pytest.raises(ValueError, match=msg): plot_precision_recall_curve(multi_clf, X, y_binary) + # Fitted binary classifier with multiclass data + msg = "DecisionTreeClassifier should solve a binary classification problem" + with pytest.raises(ValueError, match=msg): + plot_precision_recall_curve(binary_clf, X, y_multiclass) + + reg = DecisionTreeRegressor().fit(X, y_multiclass) + msg = "DecisionTreeRegressor should solve a binary classification problem" + with pytest.raises(ValueError, match=msg): + plot_precision_recall_curve(reg, X, y_binary) + @pytest.mark.parametrize( "response_method, msg", @@ -45,9 +55,10 @@ def test_errors(pyplot): def test_error_bad_response(pyplot, response_method, msg): X, y = make_classification(n_classes=2, n_samples=50, random_state=0) - class MyClassifier(BaseEstimator): + class MyClassifier(BaseEstimator, ClassifierMixin): def fit(self, X, y): self.fitted_ = True + self.classes_ = [0, 1] return self clf = MyClassifier().fit(X, y) From a589f92deaccf433207ce3f7afb3422cc8b2070e Mon Sep 17 00:00:00 2001 From: Thomas J Fan Date: Wed, 6 Nov 2019 16:07:23 -0500 Subject: [PATCH 16/22] BUG Fix test --- sklearn/metrics/_plot/tests/test_plot_precision_recall.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/sklearn/metrics/_plot/tests/test_plot_precision_recall.py b/sklearn/metrics/_plot/tests/test_plot_precision_recall.py index 5eee755af8472..c4cd897805375 100644 --- a/sklearn/metrics/_plot/tests/test_plot_precision_recall.py +++ b/sklearn/metrics/_plot/tests/test_plot_precision_recall.py @@ -31,11 +31,6 @@ def test_errors(pyplot): with pytest.raises(ValueError, match=msg): plot_precision_recall_curve(multi_clf, X, y_binary) - # Fitted binary classifier with multiclass data - msg = "DecisionTreeClassifier should solve a binary classification problem" - with pytest.raises(ValueError, match=msg): - plot_precision_recall_curve(binary_clf, X, y_multiclass) - reg = DecisionTreeRegressor().fit(X, y_multiclass) msg = "DecisionTreeRegressor should solve a binary classification problem" with pytest.raises(ValueError, match=msg): @@ -86,7 +81,7 @@ def test_plot_precision_recall(pyplot, response_method, with_sample_weight): sample_weight=sample_weight) y_score = getattr(lr, response_method)(X) - if y_score.ndim == 2: + if response_method == 'predict_proba': y_score = y_score[:, 1] prec, recall, _ = precision_recall_curve(y, y_score, From c9b3d602c29d2a64acffefdc52c81061ba6c8f5e Mon Sep 17 00:00:00 2001 From: Thomas J Fan Date: Wed, 6 Nov 2019 18:30:52 -0500 Subject: [PATCH 17/22] ENH Better error message --- sklearn/metrics/_plot/precision_recall.py | 4 ++-- sklearn/metrics/_plot/tests/test_plot_precision_recall.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/sklearn/metrics/_plot/precision_recall.py b/sklearn/metrics/_plot/precision_recall.py index 3842627c3ac3a..56841dbc81484 100644 --- a/sklearn/metrics/_plot/precision_recall.py +++ b/sklearn/metrics/_plot/precision_recall.py @@ -142,8 +142,8 @@ def plot_precision_recall_curve(estimator, X, y, raise ValueError("response_method must be 'predict_proba', " "'decision_function' or 'auto'") - classificaiton_error = ("{} should solve a binary classification " - "problem".format(estimator.__class__.__name__)) + classificaiton_error = ("{} should be a binary classifer".format( + estimator.__class__.__name__)) if is_classifier(estimator): if len(estimator.classes_) != 2: raise ValueError(classificaiton_error) diff --git a/sklearn/metrics/_plot/tests/test_plot_precision_recall.py b/sklearn/metrics/_plot/tests/test_plot_precision_recall.py index c4cd897805375..47b2468567d98 100644 --- a/sklearn/metrics/_plot/tests/test_plot_precision_recall.py +++ b/sklearn/metrics/_plot/tests/test_plot_precision_recall.py @@ -27,12 +27,12 @@ def test_errors(pyplot): multi_clf = DecisionTreeClassifier().fit(X, y_multiclass) # Fitted multiclass classifier with binary data - msg = "DecisionTreeClassifier should solve a binary classification problem" + msg = "DecisionTreeClassifier should be a binary classifer" with pytest.raises(ValueError, match=msg): plot_precision_recall_curve(multi_clf, X, y_binary) reg = DecisionTreeRegressor().fit(X, y_multiclass) - msg = "DecisionTreeRegressor should solve a binary classification problem" + msg = "DecisionTreeRegressor should be a binary classifer" with pytest.raises(ValueError, match=msg): plot_precision_recall_curve(reg, X, y_binary) From bfd5634147cda1371c062c189f0d7dee3295db96 Mon Sep 17 00:00:00 2001 From: Thomas J Fan Date: Thu, 7 Nov 2019 16:42:31 -0500 Subject: [PATCH 18/22] CLN Address comments --- doc/modules/model_evaluation.rst | 10 ++++++-- doc/visualizations.rst | 4 +-- sklearn/metrics/__init__.py | 4 +-- ...on_recall.py => precision_recall_curve.py} | 25 +++++++++++-------- .../_plot/tests/test_plot_precision_recall.py | 2 +- 5 files changed, 27 insertions(+), 18 deletions(-) rename sklearn/metrics/_plot/{precision_recall.py => precision_recall_curve.py} (90%) diff --git a/doc/modules/model_evaluation.rst b/doc/modules/model_evaluation.rst index b80549db933ef..3f5999346401a 100644 --- a/doc/modules/model_evaluation.rst +++ b/doc/modules/model_evaluation.rst @@ -744,8 +744,14 @@ score: Note that the :func:`precision_recall_curve` function is restricted to the binary case. The :func:`average_precision_score` function works only in -binary classification and multilabel indicator format. - +binary classification and multilabel indicator format. The +:func:`plot_precision_recall_curve` function plots the precision recall as +follows. + +.. image:: ../auto_examples/model_selection/images/sphx_glr_plot_precision_recall_001.png + :target: ../auto_examples/model_selection/plot_precision_recall.html#plot-the-precision-recall-curve + :scale: 75 + :align: center .. topic:: Examples: diff --git a/doc/visualizations.rst b/doc/visualizations.rst index 8fe7fc5b6b29f..4b6f7ea34febb 100644 --- a/doc/visualizations.rst +++ b/doc/visualizations.rst @@ -71,8 +71,8 @@ Functions .. autosummary:: - metrics.plot_precision_recall_curve inspection.plot_partial_dependence + metrics.plot_precision_recall_curve metrics.plot_roc_curve @@ -83,6 +83,6 @@ Display Objects .. autosummary:: - metrics.PrecisionRecallDisplay inspection.PartialDependenceDisplay + metrics.PrecisionRecallDisplay metrics.RocCurveDisplay diff --git a/sklearn/metrics/__init__.py b/sklearn/metrics/__init__.py index 6ae1d41fd45fd..ac6162b924a90 100644 --- a/sklearn/metrics/__init__.py +++ b/sklearn/metrics/__init__.py @@ -79,8 +79,8 @@ from ._plot.roc_curve import plot_roc_curve from ._plot.roc_curve import RocCurveDisplay -from ._plot.precision_recall import plot_precision_recall_curve -from ._plot.precision_recall import PrecisionRecallDisplay +from ._plot.precision_recall_curve import plot_precision_recall_curve +from ._plot.precision_recall_curve import PrecisionRecallDisplay __all__ = [ diff --git a/sklearn/metrics/_plot/precision_recall.py b/sklearn/metrics/_plot/precision_recall_curve.py similarity index 90% rename from sklearn/metrics/_plot/precision_recall.py rename to sklearn/metrics/_plot/precision_recall_curve.py index 56841dbc81484..733f748d767df 100644 --- a/sklearn/metrics/_plot/precision_recall.py +++ b/sklearn/metrics/_plot/precision_recall_curve.py @@ -12,12 +12,14 @@ class PrecisionRecallDisplay: It is recommend to use :func:`~sklearn.metrics.plot_precision_recall_curve` to create a visualizer. All parameters are stored as attributes. + Read more in the :ref:`User Guide `. + Parameters ----------- - precision : ndarray of shape (n_thresholds + 1,) + precision : ndarray Precision values. - recall : ndarray of shape (n_thresholds + 1,) + recall : ndarray Recall values. average_precision : float @@ -44,7 +46,7 @@ def __init__(self, precision, recall, average_precision, estimator_name): self.average_precision = average_precision self.estimator_name = estimator_name - def plot(self, ax=None, label_name=None, **kwargs): + def plot(self, ax=None, name=None, **kwargs): """Plot visualization. Extra keyword arguments will be passed to matplotlib's `plot`. @@ -55,7 +57,7 @@ def plot(self, ax=None, label_name=None, **kwargs): Axes object to plot on. If `None`, a new figure and axes is created. - label_name : str, default=None + name : str, default=None Name of precision recall curve for labeling. If `None`, use the name of the estimator. @@ -73,18 +75,17 @@ def plot(self, ax=None, label_name=None, **kwargs): if ax is None: fig, ax = plt.subplots() - label_name = self.estimator_name if label_name is None else label_name + name = self.estimator_name if name is None else name line_kwargs = { - "label": "{} (AP = {:0.2f})".format(label_name, + "label": "{} (AP = {:0.2f})".format(name, self.average_precision), "drawstyle": "steps-post" } line_kwargs.update(**kwargs) self.line_, = ax.plot(self.recall, self.precision, **line_kwargs) - ax.set(xlabel="Recall", ylabel="Precision", ylim=[0.0, 1.05], - xlim=[0.0, 1.0]) + ax.set(xlabel="Recall", ylabel="Precision") ax.legend(loc='lower left') self.ax_ = ax @@ -94,11 +95,13 @@ def plot(self, ax=None, label_name=None, **kwargs): def plot_precision_recall_curve(estimator, X, y, sample_weight=None, response_method="auto", - label_name=None, ax=None, **kwargs): + name=None, ax=None, **kwargs): """Plot Precision Recall Curve for binary classifers. Extra keyword arguments will be passed to matplotlib's `plot`. + Read more in the :ref:`User Guide `. + Parameters ---------- estimator : estimator instance @@ -120,7 +123,7 @@ def plot_precision_recall_curve(estimator, X, y, :term:`predict_proba` is tried first and if it does not exist :term:`decision_function` is tried next. - label_name : str, default=None + name : str, default=None Name for labeling curve. If `None`, the name of the estimator is used. @@ -180,4 +183,4 @@ def plot_precision_recall_curve(estimator, X, y, sample_weight=sample_weight) viz = PrecisionRecallDisplay(precision, recall, average_precision, estimator.__class__.__name__) - return viz.plot(ax=ax, label_name=label_name, **kwargs) + return viz.plot(ax=ax, name=name, **kwargs) diff --git a/sklearn/metrics/_plot/tests/test_plot_precision_recall.py b/sklearn/metrics/_plot/tests/test_plot_precision_recall.py index 47b2468567d98..ce6b490b154b3 100644 --- a/sklearn/metrics/_plot/tests/test_plot_precision_recall.py +++ b/sklearn/metrics/_plot/tests/test_plot_precision_recall.py @@ -109,6 +109,6 @@ def test_plot_precision_recall(pyplot, response_method, with_sample_weight): assert_allclose(disp.ax_.get_ylim(), [0.0, 1.05]) # draw again with another label - disp.plot(label_name="MySpecialEstimator") + disp.plot(name="MySpecialEstimator") expected_label = "MySpecialEstimator (AP = {:0.2f})".format(avg_prec) assert disp.line_.get_label() == expected_label From 2c3d78d8f9d6644bd49073d2696f3085c642517d Mon Sep 17 00:00:00 2001 From: Thomas J Fan Date: Fri, 8 Nov 2019 14:10:30 -0500 Subject: [PATCH 19/22] CLN Address comments --- sklearn/metrics/_plot/__init__.py | 40 +++++++++++++++++++ .../metrics/_plot/precision_recall_curve.py | 26 ++---------- .../_plot/tests/test_plot_precision_recall.py | 14 +++---- 3 files changed, 50 insertions(+), 30 deletions(-) diff --git a/sklearn/metrics/_plot/__init__.py b/sklearn/metrics/_plot/__init__.py index e69de29bb2d1d..f0518be38f264 100644 --- a/sklearn/metrics/_plot/__init__.py +++ b/sklearn/metrics/_plot/__init__.py @@ -0,0 +1,40 @@ +def _check_classifer_response_method(estimator, response_method): + """Return prediction method from the response_method + + Parameters + ---------- + estimator: object + Classifier to check + + response_method: {'auto', 'predict_proba', 'decision_function'} + Specifies whether to use :term:`predict_proba` or + :term:`decision_function` as the target response. If set to 'auto', + :term:`predict_proba` is tried first and if it does not exist + :term:`decision_function` is tried next. + + Returns + ------- + prediction_method: callable + prediction method of estimator + """ + + if response_method not in ("predict_proba", "decision_function", "auto"): + raise ValueError("response_method must be 'predict_proba', " + "'decision_function' or 'auto'") + + error_msg = "response method {} is not defined in {}" + if response_method != "auto": + prediction_method = getattr(estimator, response_method, None) + if prediction_method is None: + raise ValueError(error_msg.format(response_method, + estimator.__class__.__name__)) + else: + predict_proba = getattr(estimator, 'predict_proba', None) + decision_function = getattr(estimator, 'decision_function', None) + prediction_method = predict_proba or decision_function + if prediction_method is None: + raise ValueError(error_msg.format( + "decision_function or predict_proba", + estimator.__class__.__name__)) + + return prediction_method diff --git a/sklearn/metrics/_plot/precision_recall_curve.py b/sklearn/metrics/_plot/precision_recall_curve.py index 733f748d767df..94c6232e566e9 100644 --- a/sklearn/metrics/_plot/precision_recall_curve.py +++ b/sklearn/metrics/_plot/precision_recall_curve.py @@ -1,5 +1,6 @@ from .. import average_precision_score from .. import precision_recall_curve +from . import _check_classifer_response_method from ...utils import check_matplotlib_support from ...utils.validation import check_is_fitted @@ -141,10 +142,6 @@ def plot_precision_recall_curve(estimator, X, y, check_matplotlib_support("plot_precision_recall_curve") check_is_fitted(estimator) - if response_method not in ("predict_proba", "decision_function", "auto"): - raise ValueError("response_method must be 'predict_proba', " - "'decision_function' or 'auto'") - classificaiton_error = ("{} should be a binary classifer".format( estimator.__class__.__name__)) if is_classifier(estimator): @@ -154,26 +151,11 @@ def plot_precision_recall_curve(estimator, X, y, else: raise ValueError(classificaiton_error) - error_msg = "response method {} not defined for estimator {}" - if response_method != "auto": - prediction_method = getattr(estimator, response_method, None) - if prediction_method is None: - raise ValueError(error_msg.format(response_method, - estimator.__class__.__name__)) - is_predict_proba = response_method == 'predict_proba' - else: - predict_proba = getattr(estimator, 'predict_proba', None) - decision_function = getattr(estimator, 'decision_function', None) - prediction_method = predict_proba or decision_function - if prediction_method is None: - raise ValueError(error_msg.format( - "decision_function or predict_proba", - estimator.__class__.__name__)) - is_predict_proba = prediction_method == predict_proba - + prediction_method = _check_classifer_response_method(estimator, + response_method) y_pred = prediction_method(X) - if is_predict_proba: + if y_pred.ndim != 1: y_pred = y_pred[:, 1] precision, recall, _ = precision_recall_curve(y, y_pred, diff --git a/sklearn/metrics/_plot/tests/test_plot_precision_recall.py b/sklearn/metrics/_plot/tests/test_plot_precision_recall.py index ce6b490b154b3..36a948766d6e2 100644 --- a/sklearn/metrics/_plot/tests/test_plot_precision_recall.py +++ b/sklearn/metrics/_plot/tests/test_plot_precision_recall.py @@ -39,12 +39,12 @@ def test_errors(pyplot): @pytest.mark.parametrize( "response_method, msg", - [("predict_proba", "response method predict_proba not defined for " - "estimator MyClassifier"), - ("decision_function", "response method decision_function not defined for " - "estimator MyClassifier"), - ("auto", "response method decision_function or predict_proba not defined " - "for estimator MyClassifier"), + [("predict_proba", "response method predict_proba is not defined in " + "MyClassifier"), + ("decision_function", "response method decision_function is not defined " + "in MyClassifier"), + ("auto", "response method decision_function or predict_proba is not " + "defined in MyClassifier"), ("bad_method", "response_method must be 'predict_proba', " "'decision_function' or 'auto'")]) def test_error_bad_response(pyplot, response_method, msg): @@ -105,8 +105,6 @@ def test_plot_precision_recall(pyplot, response_method, with_sample_weight): assert disp.line_.get_label() == expected_label assert disp.ax_.get_xlabel() == "Recall" assert disp.ax_.get_ylabel() == "Precision" - assert_allclose(disp.ax_.get_xlim(), [0.0, 1.0]) - assert_allclose(disp.ax_.get_ylim(), [0.0, 1.05]) # draw again with another label disp.plot(name="MySpecialEstimator") From 7736f77aca8ac62a2c356c46348bb150026d2ac7 Mon Sep 17 00:00:00 2001 From: Thomas J Fan Date: Sat, 9 Nov 2019 20:43:31 -0800 Subject: [PATCH 20/22] CLN Move to base --- sklearn/metrics/_plot/__init__.py | 40 ------------------- sklearn/metrics/_plot/base.py | 40 +++++++++++++++++++ .../metrics/_plot/precision_recall_curve.py | 3 +- 3 files changed, 42 insertions(+), 41 deletions(-) create mode 100644 sklearn/metrics/_plot/base.py diff --git a/sklearn/metrics/_plot/__init__.py b/sklearn/metrics/_plot/__init__.py index f0518be38f264..e69de29bb2d1d 100644 --- a/sklearn/metrics/_plot/__init__.py +++ b/sklearn/metrics/_plot/__init__.py @@ -1,40 +0,0 @@ -def _check_classifer_response_method(estimator, response_method): - """Return prediction method from the response_method - - Parameters - ---------- - estimator: object - Classifier to check - - response_method: {'auto', 'predict_proba', 'decision_function'} - Specifies whether to use :term:`predict_proba` or - :term:`decision_function` as the target response. If set to 'auto', - :term:`predict_proba` is tried first and if it does not exist - :term:`decision_function` is tried next. - - Returns - ------- - prediction_method: callable - prediction method of estimator - """ - - if response_method not in ("predict_proba", "decision_function", "auto"): - raise ValueError("response_method must be 'predict_proba', " - "'decision_function' or 'auto'") - - error_msg = "response method {} is not defined in {}" - if response_method != "auto": - prediction_method = getattr(estimator, response_method, None) - if prediction_method is None: - raise ValueError(error_msg.format(response_method, - estimator.__class__.__name__)) - else: - predict_proba = getattr(estimator, 'predict_proba', None) - decision_function = getattr(estimator, 'decision_function', None) - prediction_method = predict_proba or decision_function - if prediction_method is None: - raise ValueError(error_msg.format( - "decision_function or predict_proba", - estimator.__class__.__name__)) - - return prediction_method diff --git a/sklearn/metrics/_plot/base.py b/sklearn/metrics/_plot/base.py new file mode 100644 index 0000000000000..f0518be38f264 --- /dev/null +++ b/sklearn/metrics/_plot/base.py @@ -0,0 +1,40 @@ +def _check_classifer_response_method(estimator, response_method): + """Return prediction method from the response_method + + Parameters + ---------- + estimator: object + Classifier to check + + response_method: {'auto', 'predict_proba', 'decision_function'} + Specifies whether to use :term:`predict_proba` or + :term:`decision_function` as the target response. If set to 'auto', + :term:`predict_proba` is tried first and if it does not exist + :term:`decision_function` is tried next. + + Returns + ------- + prediction_method: callable + prediction method of estimator + """ + + if response_method not in ("predict_proba", "decision_function", "auto"): + raise ValueError("response_method must be 'predict_proba', " + "'decision_function' or 'auto'") + + error_msg = "response method {} is not defined in {}" + if response_method != "auto": + prediction_method = getattr(estimator, response_method, None) + if prediction_method is None: + raise ValueError(error_msg.format(response_method, + estimator.__class__.__name__)) + else: + predict_proba = getattr(estimator, 'predict_proba', None) + decision_function = getattr(estimator, 'decision_function', None) + prediction_method = predict_proba or decision_function + if prediction_method is None: + raise ValueError(error_msg.format( + "decision_function or predict_proba", + estimator.__class__.__name__)) + + return prediction_method diff --git a/sklearn/metrics/_plot/precision_recall_curve.py b/sklearn/metrics/_plot/precision_recall_curve.py index 94c6232e566e9..17364481d30ce 100644 --- a/sklearn/metrics/_plot/precision_recall_curve.py +++ b/sklearn/metrics/_plot/precision_recall_curve.py @@ -1,6 +1,7 @@ +from .base import _check_classifer_response_method + from .. import average_precision_score from .. import precision_recall_curve -from . import _check_classifer_response_method from ...utils import check_matplotlib_support from ...utils.validation import check_is_fitted From 91f0d059518c3100278826f917237f1844536bae Mon Sep 17 00:00:00 2001 From: Thomas J Fan Date: Sat, 9 Nov 2019 20:45:06 -0800 Subject: [PATCH 21/22] CLN Unify response detection --- sklearn/metrics/_plot/roc_curve.py | 15 +++------------ .../metrics/_plot/tests/test_plot_roc_curve.py | 9 ++++++--- 2 files changed, 9 insertions(+), 15 deletions(-) diff --git a/sklearn/metrics/_plot/roc_curve.py b/sklearn/metrics/_plot/roc_curve.py index e247da068a8c1..ad34b3e6f6da3 100644 --- a/sklearn/metrics/_plot/roc_curve.py +++ b/sklearn/metrics/_plot/roc_curve.py @@ -1,6 +1,7 @@ from .. import auc from .. import roc_curve +from .base import _check_classifer_response_method from ...utils import check_matplotlib_support from ...base import is_classifier from ...utils.validation import check_is_fitted @@ -180,18 +181,8 @@ def plot_roc_curve(estimator, X, y, sample_weight=None, else: raise ValueError(classification_error) - if response_method != "auto": - prediction_method = getattr(estimator, response_method, None) - if prediction_method is None: - raise ValueError( - "response method {} is not defined".format(response_method)) - else: - predict_proba = getattr(estimator, 'predict_proba', None) - decision_function = getattr(estimator, 'decision_function', None) - prediction_method = predict_proba or decision_function - - if prediction_method is None: - raise ValueError('response methods not defined') + prediction_method = _check_classifer_response_method(estimator, + response_method) y_pred = prediction_method(X) diff --git a/sklearn/metrics/_plot/tests/test_plot_roc_curve.py b/sklearn/metrics/_plot/tests/test_plot_roc_curve.py index ad054e512eec6..a7535522cf738 100644 --- a/sklearn/metrics/_plot/tests/test_plot_roc_curve.py +++ b/sklearn/metrics/_plot/tests/test_plot_roc_curve.py @@ -33,9 +33,12 @@ def test_plot_roc_curve_error_non_binary(pyplot, data): @pytest.mark.parametrize( "response_method, msg", - [("predict_proba", "response method predict_proba is not defined"), - ("decision_function", "response method decision_function is not defined"), - ("auto", "response methods not defined"), + [("predict_proba", "response method predict_proba is not defined in " + "MyClassifier"), + ("decision_function", "response method decision_function is not defined " + "in MyClassifier"), + ("auto", "response method decision_function or predict_proba is not " + "defined in MyClassifier"), ("bad_method", "response_method must be 'predict_proba', " "'decision_function' or 'auto'")]) def test_plot_roc_curve_error_no_response(pyplot, data_binary, response_method, From a559342951f9ebcd96f4e009d1e98df590758aab Mon Sep 17 00:00:00 2001 From: Thomas J Fan Date: Sun, 10 Nov 2019 08:48:31 -0800 Subject: [PATCH 22/22] CLN Removes unneeded check --- sklearn/metrics/_plot/roc_curve.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/sklearn/metrics/_plot/roc_curve.py b/sklearn/metrics/_plot/roc_curve.py index ad34b3e6f6da3..e3f6b918858ec 100644 --- a/sklearn/metrics/_plot/roc_curve.py +++ b/sklearn/metrics/_plot/roc_curve.py @@ -167,10 +167,6 @@ def plot_roc_curve(estimator, X, y, sample_weight=None, check_matplotlib_support('plot_roc_curve') check_is_fitted(estimator) - if response_method not in ("predict_proba", "decision_function", "auto"): - raise ValueError("response_method must be 'predict_proba', " - "'decision_function' or 'auto'") - classification_error = ("{} should be a binary classifer".format( estimator.__class__.__name__))