diff --git a/doc/modules/model_evaluation.rst b/doc/modules/model_evaluation.rst index b1ef50dafbaa9..e4df915d3eb33 100644 --- a/doc/modules/model_evaluation.rst +++ b/doc/modules/model_evaluation.rst @@ -796,9 +796,10 @@ score: Note that the :func:`precision_recall_curve` function is restricted to the binary case. The :func:`average_precision_score` function works only in -binary classification and multilabel indicator format. The -:func:`plot_precision_recall_curve` function plots the precision recall as -follows. +binary classification and multilabel indicator format. +The :func:`PredictionRecallDisplay.from_estimator` and +:func:`PredictionRecallDisplay.from_predictions` functions will plot the +precision-recall curve as follows. .. image:: ../auto_examples/model_selection/images/sphx_glr_plot_precision_recall_001.png :target: ../auto_examples/model_selection/plot_precision_recall.html#plot-the-precision-recall-curve diff --git a/doc/whats_new/v1.0.rst b/doc/whats_new/v1.0.rst index 0097b1816213e..87d5586f45791 100644 --- a/doc/whats_new/v1.0.rst +++ b/doc/whats_new/v1.0.rst @@ -552,6 +552,14 @@ Changelog class methods and will be removed in 1.2. :pr:`18543` by `Guillaume Lemaitre`_. +- |API| :class:`metrics.PrecisionRecallDisplay` exposes two class methods + :func:`~metrics.PrecisionRecallDisplay.from_estimator` and + :func:`~metrics.PrecisionRecallDisplay.from_predictions` allowing to create + a precision-recall curve using an estimator or the predictions. + :func:`metrics.plot_precision_recall_curve` is deprecated in favor of these + two class methods and will be removed in 1.2. + :pr:`20552` by `Guillaume Lemaitre`_. + - |API| :class:`metrics.DetCurveDisplay` exposes two class methods :func:`~metrics.DetCurveDisplay.from_estimator` and :func:`~metrics.DetCurveDisplay.from_predictions` allowing to create diff --git a/examples/model_selection/plot_precision_recall.py b/examples/model_selection/plot_precision_recall.py index 83493c44c7847..c0f0a97dd44ce 100644 --- a/examples/model_selection/plot_precision_recall.py +++ b/examples/model_selection/plot_precision_recall.py @@ -92,64 +92,80 @@ """ # %% # In binary classification settings -# -------------------------------------------------------- +# --------------------------------- # -# Create simple data -# .................. +# Dataset and model +# ................. # -# Try to differentiate the two first classes of the iris data -from sklearn import svm, datasets -from sklearn.model_selection import train_test_split +# We will use a Linear SVC classifier to differentiate two types of irises. import numpy as np +from sklearn.datasets import load_iris +from sklearn.model_selection import train_test_split -iris = datasets.load_iris() -X = iris.data -y = iris.target +X, y = load_iris(return_X_y=True) # Add noisy features random_state = np.random.RandomState(0) n_samples, n_features = X.shape -X = np.c_[X, random_state.randn(n_samples, 200 * n_features)] +X = np.concatenate([X, random_state.randn(n_samples, 200 * n_features)], axis=1) # Limit to the two first classes, and split into training and test -X_train, X_test, y_train, y_test = train_test_split(X[y < 2], y[y < 2], - test_size=.5, - random_state=random_state) +X_train, X_test, y_train, y_test = train_test_split( + X[y < 2], y[y < 2], test_size=0.5, random_state=random_state +) -# Create a simple classifier -classifier = svm.LinearSVC(random_state=random_state) +# %% +# Linear SVC will expect each feature to have a similar range of values. Thus, +# we will first scale the data using a +# :class:`~sklearn.preprocessing.StandardScaler`. +from sklearn.pipeline import make_pipeline +from sklearn.preprocessing import StandardScaler +from sklearn.svm import LinearSVC + +classifier = make_pipeline(StandardScaler(), LinearSVC(random_state=random_state)) classifier.fit(X_train, y_train) -y_score = classifier.decision_function(X_test) # %% -# Compute the average precision score -# ................................... -from sklearn.metrics import average_precision_score -average_precision = average_precision_score(y_test, y_score) +# Plot the Precision-Recall curve +# ............................... +# +# To plot the precision-recall curve, you should use +# :class:`~sklearn.metrics.PrecisionRecallDisplay`. Indeed, there is two +# methods available depending if you already computed the predictions of the +# classifier or not. +# +# Let's first plot the precision-recall curve without the classifier +# predictions. We use +# :func:`~sklearn.metrics.PrecisionRecallDisplay.from_estimator` that +# computes the predictions for us before plotting the curve. +from sklearn.metrics import PrecisionRecallDisplay -print('Average precision-recall score: {0:0.2f}'.format( - average_precision)) +display = PrecisionRecallDisplay.from_estimator( + classifier, X_test, y_test, name="LinearSVC" +) +_ = display.ax_.set_title("2-class Precision-Recall curve") # %% -# Plot the Precision-Recall curve -# ................................ -from sklearn.metrics import precision_recall_curve -from sklearn.metrics import plot_precision_recall_curve -import matplotlib.pyplot as plt +# If we already got the estimated probabilities or scores for +# our model, then we can use +# :func:`~sklearn.metrics.PrecisionRecallDisplay.from_predictions`. +y_score = classifier.decision_function(X_test) -disp = plot_precision_recall_curve(classifier, X_test, y_test) -disp.ax_.set_title('2-class Precision-Recall curve: ' - 'AP={0:0.2f}'.format(average_precision)) +display = PrecisionRecallDisplay.from_predictions(y_test, y_score, name="LinearSVC") +_ = display.ax_.set_title("2-class Precision-Recall curve") # %% # In multi-label settings -# ------------------------ +# ----------------------- +# +# The precision-recall curve does not support the multilabel setting. However, +# one can decide how to handle this case. We show such an example below. # # Create multi-label data, fit, and predict -# ........................................... +# ......................................... # # We create a multi-label dataset, to illustrate the precision-recall in -# multi-label settings +# multi-label settings. from sklearn.preprocessing import label_binarize @@ -158,21 +174,26 @@ n_classes = Y.shape[1] # Split into training and test -X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=.5, - random_state=random_state) +X_train, X_test, Y_train, Y_test = train_test_split( + X, Y, test_size=0.5, random_state=random_state +) -# We use OneVsRestClassifier for multi-label prediction +# %% +# We use :class:`~sklearn.multiclass.OneVsRestClassifier` for multi-label +# prediction. from sklearn.multiclass import OneVsRestClassifier -# Run classifier -classifier = OneVsRestClassifier(svm.LinearSVC(random_state=random_state)) +classifier = OneVsRestClassifier( + make_pipeline(StandardScaler(), LinearSVC(random_state=random_state)) +) classifier.fit(X_train, Y_train) y_score = classifier.decision_function(X_test) # %% # The average precision score in multi-label settings -# .................................................... +# ................................................... +from sklearn.metrics import precision_recall_curve from sklearn.metrics import average_precision_score # For each class @@ -180,73 +201,68 @@ recall = dict() average_precision = dict() for i in range(n_classes): - precision[i], recall[i], _ = precision_recall_curve(Y_test[:, i], - y_score[:, i]) + precision[i], recall[i], _ = precision_recall_curve(Y_test[:, i], y_score[:, i]) average_precision[i] = average_precision_score(Y_test[:, i], y_score[:, i]) # A "micro-average": quantifying score on all classes jointly -precision["micro"], recall["micro"], _ = precision_recall_curve(Y_test.ravel(), - y_score.ravel()) -average_precision["micro"] = average_precision_score(Y_test, y_score, - average="micro") -print('Average precision score, micro-averaged over all classes: {0:0.2f}' - .format(average_precision["micro"])) +precision["micro"], recall["micro"], _ = precision_recall_curve( + Y_test.ravel(), y_score.ravel() +) +average_precision["micro"] = average_precision_score(Y_test, y_score, average="micro") # %% # Plot the micro-averaged Precision-Recall curve -# ............................................... -# - -plt.figure() -plt.step(recall['micro'], precision['micro'], where='post') - -plt.xlabel('Recall') -plt.ylabel('Precision') -plt.ylim([0.0, 1.05]) -plt.xlim([0.0, 1.0]) -plt.title( - 'Average precision score, micro-averaged over all classes: AP={0:0.2f}' - .format(average_precision["micro"])) +# .............................................. +display = PrecisionRecallDisplay( + recall=recall["micro"], + precision=precision["micro"], + average_precision=average_precision["micro"], +) +display.plot() +_ = display.ax_.set_title("Micro-averaged over all classes") # %% # Plot Precision-Recall curve for each class and iso-f1 curves -# ............................................................. -# +# ............................................................ +import matplotlib.pyplot as plt from itertools import cycle + # setup plot details -colors = cycle(['navy', 'turquoise', 'darkorange', 'cornflowerblue', 'teal']) +colors = cycle(["navy", "turquoise", "darkorange", "cornflowerblue", "teal"]) + +_, ax = plt.subplots(figsize=(7, 8)) -plt.figure(figsize=(7, 8)) f_scores = np.linspace(0.2, 0.8, num=4) -lines = [] -labels = [] +lines, labels = [], [] for f_score in f_scores: x = np.linspace(0.01, 1) y = f_score * x / (2 * x - f_score) - l, = plt.plot(x[y >= 0], y[y >= 0], color='gray', alpha=0.2) - plt.annotate('f1={0:0.1f}'.format(f_score), xy=(0.9, y[45] + 0.02)) + (l,) = plt.plot(x[y >= 0], y[y >= 0], color="gray", alpha=0.2) + plt.annotate("f1={0:0.1f}".format(f_score), xy=(0.9, y[45] + 0.02)) -lines.append(l) -labels.append('iso-f1 curves') -l, = plt.plot(recall["micro"], precision["micro"], color='gold', lw=2) -lines.append(l) -labels.append('micro-average Precision-recall (area = {0:0.2f})' - ''.format(average_precision["micro"])) +display = PrecisionRecallDisplay( + recall=recall["micro"], + precision=precision["micro"], + average_precision=average_precision["micro"], +) +display.plot(ax=ax, name="Micro-average precision-recall", color="gold") for i, color in zip(range(n_classes), colors): - l, = plt.plot(recall[i], precision[i], color=color, lw=2) - lines.append(l) - labels.append('Precision-recall for class {0} (area = {1:0.2f})' - ''.format(i, average_precision[i])) - -fig = plt.gcf() -fig.subplots_adjust(bottom=0.25) -plt.xlim([0.0, 1.0]) -plt.ylim([0.0, 1.05]) -plt.xlabel('Recall') -plt.ylabel('Precision') -plt.title('Extension of Precision-Recall curve to multi-class') -plt.legend(lines, labels, loc=(0, -.38), prop=dict(size=14)) - + display = PrecisionRecallDisplay( + recall=recall[i], + precision=precision[i], + average_precision=average_precision[i], + ) + display.plot(ax=ax, name=f"Precision-recall for class {i}", color=color) + +# add the legend for the iso-f1 curves +handles, labels = display.ax_.get_legend_handles_labels() +handles.extend([l]) +labels.extend(["iso-f1 curves"]) +# set the legend and the axes +ax.set_xlim([0.0, 1.0]) +ax.set_ylim([0.0, 1.05]) +ax.legend(handles=handles, labels=labels, loc="best") +ax.set_title("Extension of Precision-Recall curve to multi-class") plt.show() diff --git a/sklearn/metrics/_plot/precision_recall_curve.py b/sklearn/metrics/_plot/precision_recall_curve.py index 00937950a40e9..c8f45b10fa343 100644 --- a/sklearn/metrics/_plot/precision_recall_curve.py +++ b/sklearn/metrics/_plot/precision_recall_curve.py @@ -1,16 +1,22 @@ +from sklearn.base import is_classifier from .base import _get_response from .. import average_precision_score from .. import precision_recall_curve +from .._base import _check_pos_label_consistency +from .._classification import check_consistent_length -from ...utils import check_matplotlib_support +from ...utils import check_matplotlib_support, deprecated class PrecisionRecallDisplay: """Precision Recall visualization. - It is recommend to use :func:`~sklearn.metrics.plot_precision_recall_curve` - to create a visualizer. All parameters are stored as attributes. + It is recommend to use + :func:`~sklearn.metrics.PrecisionRecallDisplay.from_estimator` or + :func:`~sklearn.metrics.PrecisionRecallDisplay.from_predictions` to create + a :class:`~sklearn.metrics.PredictionRecallDisplay`. All parameters are + stored as attributes. Read more in the :ref:`User Guide `. @@ -49,8 +55,10 @@ class PrecisionRecallDisplay: -------- precision_recall_curve : Compute precision-recall pairs for different probability thresholds. - plot_precision_recall_curve : Plot Precision Recall Curve for binary - classifiers. + PrecisionRecallDisplay.from_estimator : Plot Precision Recall Curve given + a binary classifier. + PrecisionRecallDisplay.from_predictions : Plot Precision Recall Curve + using predictions from a binary classifier. Examples -------- @@ -144,7 +152,206 @@ def plot(self, ax=None, *, name=None, **kwargs): self.figure_ = ax.figure return self + @classmethod + def from_estimator( + cls, + estimator, + X, + y, + *, + sample_weight=None, + pos_label=None, + response_method="auto", + name=None, + ax=None, + **kwargs, + ): + """Plot precision-recall curve given an estimator and some data. + + Parameters + ---------- + estimator : estimator instance + Fitted classifier or a fitted :class:`~sklearn.pipeline.Pipeline` + in which the last estimator is a classifier. + + X : {array-like, sparse matrix} of shape (n_samples, n_features) + Input values. + + y : array-like of shape (n_samples,) + Target values. + + sample_weight : array-like of shape (n_samples,), default=None + Sample weights. + + pos_label : str or int, default=None + The class considered as the positive class when computing the + precision and recall metrics. By default, `estimators.classes_[1]` + is considered as the positive class. + + response_method : {'predict_proba', 'decision_function', 'auto'}, \ + default='auto' + Specifies whether to use :term:`predict_proba` or + :term:`decision_function` as the target response. If set to 'auto', + :term:`predict_proba` is tried first and if it does not exist + :term:`decision_function` is tried next. + + name : str, default=None + Name for labeling curve. If `None`, no name is used. + + ax : matplotlib axes, default=None + Axes object to plot on. If `None`, a new figure and axes is created. + + **kwargs : dict + Keyword arguments to be passed to matplotlib's `plot`. + + Returns + ------- + display : :class:`~sklearn.metrics.PrecisionRecallDisplay` + See Also + -------- + PrecisionRecallDisplay.from_predictions : Plot precision-recall curve + using estimated probabilities or output of decision function. + + Examples + -------- + >>> import matplotlib.pyplot as plt + >>> from sklearn.datasets import make_classification + >>> from sklearn.metrics import PrecisionRecallDisplay + >>> from sklearn.model_selection import train_test_split + >>> from sklearn.linear_model import LogisticRegression + >>> X, y = make_classification(random_state=0) + >>> X_train, X_test, y_train, y_test = train_test_split( + ... X, y, random_state=0) + >>> clf = LogisticRegression() + >>> clf.fit(X_train, y_train) + LogisticRegression() + >>> PrecisionRecallDisplay.from_estimator( + ... clf, X_test, y_test) + <...> + >>> plt.show() + """ + method_name = f"{cls.__name__}.from_estimator" + check_matplotlib_support(method_name) + if not is_classifier(estimator): + raise ValueError(f"{method_name} only supports classifiers") + y_pred, pos_label = _get_response( + X, + estimator, + response_method, + pos_label=pos_label, + ) + + name = name if name is not None else estimator.__class__.__name__ + + return cls.from_predictions( + y, + y_pred, + sample_weight=sample_weight, + name=name, + pos_label=pos_label, + ax=ax, + **kwargs, + ) + + @classmethod + def from_predictions( + cls, + y_true, + y_pred, + *, + sample_weight=None, + pos_label=None, + name=None, + ax=None, + **kwargs, + ): + """Plot precision-recall curve given binary class predictions. + + Parameters + ---------- + y_true : array-like of shape (n_samples,) + True binary labels. + + y_pred : array-like of shape (n_samples,) + Estimated probabilities or output of decision function. + + sample_weight : array-like of shape (n_samples,), default=None + Sample weights. + + pos_label : str or int, default=None + The class considered as the positive class when computing the + precision and recall metrics. + + name : str, default=None + Name for labeling curve. If `None`, name will be set to + `"Classifier"`. + + ax : matplotlib axes, default=None + Axes object to plot on. If `None`, a new figure and axes is created. + + **kwargs : dict + Keyword arguments to be passed to matplotlib's `plot`. + + Returns + ------- + display : :class:`~sklearn.metrics.PrecisionRecallDisplay` + + See Also + -------- + PrecisionRecallDisplay.from_estimator : Plot precision-recall curve + using an estimator. + + Examples + -------- + >>> import matplotlib.pyplot as plt + >>> from sklearn.datasets import make_classification + >>> from sklearn.metrics import PrecisionRecallDisplay + >>> from sklearn.model_selection import train_test_split + >>> from sklearn.linear_model import LogisticRegression + >>> X, y = make_classification(random_state=0) + >>> X_train, X_test, y_train, y_test = train_test_split( + ... X, y, random_state=0) + >>> clf = LogisticRegression() + >>> clf.fit(X_train, y_train) + LogisticRegression() + >>> y_pred = clf.predict_proba(X_test)[:, 1] + >>> PrecisionRecallDisplay.from_predictions( + ... y_test, y_pred) + <...> + >>> plt.show() + """ + check_matplotlib_support(f"{cls.__name__}.from_predictions") + + check_consistent_length(y_true, y_pred, sample_weight) + pos_label = _check_pos_label_consistency(pos_label, y_true) + + precision, recall, _ = precision_recall_curve( + y_true, y_pred, pos_label=pos_label, sample_weight=sample_weight + ) + average_precision = average_precision_score( + y_true, y_pred, pos_label=pos_label, sample_weight=sample_weight + ) + + name = name if name is not None else "Classifier" + + viz = PrecisionRecallDisplay( + precision=precision, + recall=recall, + average_precision=average_precision, + estimator_name=name, + pos_label=pos_label, + ) + + return viz.plot(ax=ax, name=name, **kwargs) + + +@deprecated( + "Function `plot_precision_recall_curve` is deprecated in 1.0 and will be " + "removed in 1.2. Use one of the class methods: " + "PrecisionRecallDisplay.from_predictions or " + "PrecisionRecallDisplay.from_estimator." +) def plot_precision_recall_curve( estimator, X, @@ -163,6 +370,12 @@ def plot_precision_recall_curve( Read more in the :ref:`User Guide `. + .. deprecated:: 1.0 + `plot_precision_recall_curve` is deprecated in 1.0 and will be removed in + 1.2. Use one of the following class methods: + :func:`~sklearn.metrics.PrecisionRecallDisplay.from_predictions` or + :func:`~sklearn.metrics.PrecisionRecallDisplay.from_estimator`. + Parameters ---------- estimator : estimator instance diff --git a/sklearn/metrics/_plot/tests/test_common_curve_display.py b/sklearn/metrics/_plot/tests/test_common_curve_display.py index 5b3e5541fb4b2..483dc0710e82e 100644 --- a/sklearn/metrics/_plot/tests/test_common_curve_display.py +++ b/sklearn/metrics/_plot/tests/test_common_curve_display.py @@ -9,7 +9,10 @@ from sklearn.preprocessing import StandardScaler from sklearn.tree import DecisionTreeClassifier -from sklearn.metrics import DetCurveDisplay +from sklearn.metrics import ( + DetCurveDisplay, + PrecisionRecallDisplay, +) @pytest.fixture(scope="module") @@ -23,7 +26,7 @@ def data_binary(data): return X[y < 2], y[y < 2] -@pytest.mark.parametrize("Display", [DetCurveDisplay]) +@pytest.mark.parametrize("Display", [DetCurveDisplay, PrecisionRecallDisplay]) def test_display_curve_error_non_binary(pyplot, data, Display): """Check that a proper error is raised when only binary classification is supported.""" @@ -59,7 +62,7 @@ def test_display_curve_error_non_binary(pyplot, data, Display): ), ], ) -@pytest.mark.parametrize("Display", [DetCurveDisplay]) +@pytest.mark.parametrize("Display", [DetCurveDisplay, PrecisionRecallDisplay]) def test_display_curve_error_no_response( pyplot, data_binary, @@ -82,7 +85,7 @@ def fit(self, X, y): Display.from_estimator(clf, X, y, response_method=response_method) -@pytest.mark.parametrize("Display", [DetCurveDisplay]) +@pytest.mark.parametrize("Display", [DetCurveDisplay, PrecisionRecallDisplay]) @pytest.mark.parametrize("constructor_name", ["from_estimator", "from_predictions"]) def test_display_curve_estimator_name_multiple_calls( pyplot, @@ -124,7 +127,7 @@ def test_display_curve_estimator_name_multiple_calls( ), ], ) -@pytest.mark.parametrize("Display", [DetCurveDisplay]) +@pytest.mark.parametrize("Display", [DetCurveDisplay, PrecisionRecallDisplay]) def test_display_curve_not_fitted_errors(pyplot, data_binary, clf, Display): """Check that a proper error is raised when the classifier is not fitted.""" diff --git a/sklearn/metrics/_plot/tests/test_confusion_matrix_display.py b/sklearn/metrics/_plot/tests/test_confusion_matrix_display.py index 43d4171b42a05..8db971fb26971 100644 --- a/sklearn/metrics/_plot/tests/test_confusion_matrix_display.py +++ b/sklearn/metrics/_plot/tests/test_confusion_matrix_display.py @@ -11,8 +11,7 @@ from sklearn.linear_model import LogisticRegression from sklearn.pipeline import make_pipeline from sklearn.preprocessing import StandardScaler -from sklearn.svm import SVC -from sklearn.svm import SVR +from sklearn.svm import SVC, SVR from sklearn.metrics import ConfusionMatrixDisplay from sklearn.metrics import confusion_matrix @@ -31,6 +30,9 @@ def test_confusion_matrix_display_validation(pyplot): n_samples=100, n_informative=5, n_classes=5, random_state=0 ) + with pytest.raises(NotFittedError): + ConfusionMatrixDisplay.from_estimator(SVC(), X, y) + regressor = SVR().fit(X, y) y_pred_regressor = regressor.predict(X) y_pred_classifier = SVC().fit(X, y).predict(X) diff --git a/sklearn/metrics/_plot/tests/test_plot_precision_recall.py b/sklearn/metrics/_plot/tests/test_plot_precision_recall.py index 6569605e0226d..c77e45177cc15 100644 --- a/sklearn/metrics/_plot/tests/test_plot_precision_recall.py +++ b/sklearn/metrics/_plot/tests/test_plot_precision_recall.py @@ -4,7 +4,6 @@ from sklearn.base import BaseEstimator, ClassifierMixin from sklearn.metrics import plot_precision_recall_curve -from sklearn.metrics import PrecisionRecallDisplay from sklearn.metrics import average_precision_score from sklearn.metrics import precision_recall_curve from sklearn.datasets import make_classification @@ -18,10 +17,12 @@ from sklearn.utils import shuffle from sklearn.compose import make_column_transformer -# TODO: Remove when https://github.com/numpy/numpy/issues/14397 is resolved pytestmark = pytest.mark.filterwarnings( + # TODO: Remove when https://github.com/numpy/numpy/issues/14397 is resolved "ignore:In future, it will be an error for 'np.bool_':DeprecationWarning:" - "matplotlib.*" + "matplotlib.*", + # TODO: Remove in 1.2 (as well as all the tests below) + "ignore:Function plot_precision_recall_curve is deprecated", ) @@ -199,24 +200,6 @@ def test_plot_precision_recall_curve_estimator_name_multiple_calls(pyplot): assert clf_name in disp.line_.get_label() -@pytest.mark.parametrize( - "average_precision, estimator_name, expected_label", - [ - (0.9, None, "AP = 0.90"), - (None, "my_est", "my_est"), - (0.8, "my_est2", "my_est2 (AP = 0.80)"), - ], -) -def test_default_labels(pyplot, average_precision, estimator_name, expected_label): - prec = np.array([1, 0.5, 0]) - recall = np.array([0, 0.5, 1]) - disp = PrecisionRecallDisplay( - prec, recall, average_precision=average_precision, estimator_name=estimator_name - ) - disp.plot() - assert disp.line_.get_label() == expected_label - - @pytest.mark.parametrize("response_method", ["predict_proba", "decision_function"]) def test_plot_precision_recall_pos_label(pyplot, response_method): # check that we can provide the positive label and display the proper diff --git a/sklearn/metrics/_plot/tests/test_precision_recall_display.py b/sklearn/metrics/_plot/tests/test_precision_recall_display.py new file mode 100644 index 0000000000000..165e2b75df36e --- /dev/null +++ b/sklearn/metrics/_plot/tests/test_precision_recall_display.py @@ -0,0 +1,304 @@ +import numpy as np +import pytest + +from sklearn.compose import make_column_transformer +from sklearn.datasets import load_breast_cancer, make_classification +from sklearn.exceptions import NotFittedError +from sklearn.linear_model import LogisticRegression +from sklearn.metrics import average_precision_score, precision_recall_curve +from sklearn.model_selection import train_test_split +from sklearn.pipeline import make_pipeline +from sklearn.preprocessing import StandardScaler +from sklearn.svm import SVC, SVR +from sklearn.utils import shuffle + +from sklearn.metrics import PrecisionRecallDisplay, plot_precision_recall_curve + +# TODO: Remove when https://github.com/numpy/numpy/issues/14397 is resolved +pytestmark = pytest.mark.filterwarnings( + "ignore:In future, it will be an error for 'np.bool_':DeprecationWarning:" + "matplotlib.*" +) + + +def test_precision_recall_display_validation(pyplot): + """Check that we raise the proper error when validating parameters.""" + X, y = make_classification( + n_samples=100, n_informative=5, n_classes=5, random_state=0 + ) + + with pytest.raises(NotFittedError): + PrecisionRecallDisplay.from_estimator(SVC(), X, y) + + regressor = SVR().fit(X, y) + y_pred_regressor = regressor.predict(X) + classifier = SVC(probability=True).fit(X, y) + y_pred_classifier = classifier.predict_proba(X)[:, -1] + + err_msg = "PrecisionRecallDisplay.from_estimator only supports classifiers" + with pytest.raises(ValueError, match=err_msg): + PrecisionRecallDisplay.from_estimator(regressor, X, y) + + err_msg = "Expected 'estimator' to be a binary classifier, but got SVC" + with pytest.raises(ValueError, match=err_msg): + PrecisionRecallDisplay.from_estimator(classifier, X, y) + + err_msg = "{} format is not supported" + with pytest.raises(ValueError, match=err_msg.format("continuous")): + # Force `y_true` to be seen as a regression problem + PrecisionRecallDisplay.from_predictions(y + 0.5, y_pred_classifier, pos_label=1) + with pytest.raises(ValueError, match=err_msg.format("multiclass")): + PrecisionRecallDisplay.from_predictions(y, y_pred_regressor, pos_label=1) + + err_msg = "Found input variables with inconsistent numbers of samples" + with pytest.raises(ValueError, match=err_msg): + PrecisionRecallDisplay.from_predictions(y, y_pred_classifier[::2]) + + X, y = make_classification(n_classes=2, n_samples=50, random_state=0) + y += 10 + classifier.fit(X, y) + y_pred_classifier = classifier.predict_proba(X)[:, -1] + err_msg = r"y_true takes value in {10, 11} and pos_label is not specified" + with pytest.raises(ValueError, match=err_msg): + PrecisionRecallDisplay.from_predictions(y, y_pred_classifier) + + +# FIXME: Remove in 1.2 +def test_plot_precision_recall_curve_deprecation(pyplot): + """Check that we raise a FutureWarning when calling + `plot_precision_recall_curve`.""" + + X, y = make_classification(random_state=0) + clf = LogisticRegression().fit(X, y) + deprecation_warning = "Function plot_precision_recall_curve is deprecated" + with pytest.warns(FutureWarning, match=deprecation_warning): + plot_precision_recall_curve(clf, X, y) + + +@pytest.mark.parametrize("constructor_name", ["from_estimator", "from_predictions"]) +@pytest.mark.parametrize("response_method", ["predict_proba", "decision_function"]) +def test_precision_recall_display_plotting(pyplot, constructor_name, response_method): + """Check the overall plotting rendering.""" + X, y = make_classification(n_classes=2, n_samples=50, random_state=0) + pos_label = 1 + + classifier = LogisticRegression().fit(X, y) + classifier.fit(X, y) + + y_pred = getattr(classifier, response_method)(X) + y_pred = y_pred if y_pred.ndim == 1 else y_pred[:, pos_label] + + # safe guard for the binary if/else construction + assert constructor_name in ("from_estimator", "from_predictions") + + if constructor_name == "from_estimator": + display = PrecisionRecallDisplay.from_estimator( + classifier, X, y, response_method=response_method + ) + else: + display = PrecisionRecallDisplay.from_predictions( + y, y_pred, pos_label=pos_label + ) + + precision, recall, _ = precision_recall_curve(y, y_pred, pos_label=pos_label) + average_precision = average_precision_score(y, y_pred, pos_label=pos_label) + + np.testing.assert_allclose(display.precision, precision) + np.testing.assert_allclose(display.recall, recall) + assert display.average_precision == pytest.approx(average_precision) + + import matplotlib as mpl + + assert isinstance(display.line_, mpl.lines.Line2D) + assert isinstance(display.ax_, mpl.axes.Axes) + assert isinstance(display.figure_, mpl.figure.Figure) + + assert display.ax_.get_xlabel() == "Recall (Positive label: 1)" + assert display.ax_.get_ylabel() == "Precision (Positive label: 1)" + + # plotting passing some new parameters + display.plot(alpha=0.8, name="MySpecialEstimator") + expected_label = f"MySpecialEstimator (AP = {average_precision:0.2f})" + assert display.line_.get_label() == expected_label + assert display.line_.get_alpha() == pytest.approx(0.8) + + +@pytest.mark.parametrize( + "constructor_name, default_label", + [ + ("from_estimator", "LogisticRegression (AP = {:.2f})"), + ("from_predictions", "Classifier (AP = {:.2f})"), + ], +) +def test_precision_recall_display_name(pyplot, constructor_name, default_label): + """Check the behaviour of the name parameters""" + X, y = make_classification(n_classes=2, n_samples=100, random_state=0) + pos_label = 1 + + classifier = LogisticRegression().fit(X, y) + classifier.fit(X, y) + + y_pred = classifier.predict_proba(X)[:, pos_label] + + # safe guard for the binary if/else construction + assert constructor_name in ("from_estimator", "from_predictions") + + if constructor_name == "from_estimator": + display = PrecisionRecallDisplay.from_estimator(classifier, X, y) + else: + display = PrecisionRecallDisplay.from_predictions( + y, y_pred, pos_label=pos_label + ) + + average_precision = average_precision_score(y, y_pred, pos_label=pos_label) + + # check that the default name is used + assert display.line_.get_label() == default_label.format(average_precision) + + # check that the name can be set + display.plot(name="MySpecialEstimator") + assert ( + display.line_.get_label() + == f"MySpecialEstimator (AP = {average_precision:.2f})" + ) + + +@pytest.mark.parametrize( + "clf", + [ + make_pipeline(StandardScaler(), LogisticRegression()), + make_pipeline( + make_column_transformer((StandardScaler(), [0, 1])), LogisticRegression() + ), + ], +) +def test_precision_recall_display_pipeline(pyplot, clf): + X, y = make_classification(n_classes=2, n_samples=50, random_state=0) + with pytest.raises(NotFittedError): + PrecisionRecallDisplay.from_estimator(clf, X, y) + clf.fit(X, y) + display = PrecisionRecallDisplay.from_estimator(clf, X, y) + assert display.estimator_name == clf.__class__.__name__ + + +def test_precision_recall_display_string_labels(pyplot): + # regression test #15738 + cancer = load_breast_cancer() + X, y = cancer.data, cancer.target_names[cancer.target] + + lr = make_pipeline(StandardScaler(), LogisticRegression()) + lr.fit(X, y) + for klass in cancer.target_names: + assert klass in lr.classes_ + display = PrecisionRecallDisplay.from_estimator(lr, X, y) + + y_pred = lr.predict_proba(X)[:, 1] + avg_prec = average_precision_score(y, y_pred, pos_label=lr.classes_[1]) + + assert display.average_precision == pytest.approx(avg_prec) + assert display.estimator_name == lr.__class__.__name__ + + err_msg = r"y_true takes value in {'benign', 'malignant'}" + with pytest.raises(ValueError, match=err_msg): + PrecisionRecallDisplay.from_predictions(y, y_pred) + + display = PrecisionRecallDisplay.from_predictions( + y, y_pred, pos_label=lr.classes_[1] + ) + assert display.average_precision == pytest.approx(avg_prec) + + +@pytest.mark.parametrize( + "average_precision, estimator_name, expected_label", + [ + (0.9, None, "AP = 0.90"), + (None, "my_est", "my_est"), + (0.8, "my_est2", "my_est2 (AP = 0.80)"), + ], +) +def test_default_labels(pyplot, average_precision, estimator_name, expected_label): + """Check the default labels used in the display.""" + precision = np.array([1, 0.5, 0]) + recall = np.array([0, 0.5, 1]) + display = PrecisionRecallDisplay( + precision, + recall, + average_precision=average_precision, + estimator_name=estimator_name, + ) + display.plot() + assert display.line_.get_label() == expected_label + + +@pytest.mark.parametrize("constructor_name", ["from_estimator", "from_predictions"]) +@pytest.mark.parametrize("response_method", ["predict_proba", "decision_function"]) +def test_plot_precision_recall_pos_label(pyplot, constructor_name, response_method): + # check that we can provide the positive label and display the proper + # statistics + X, y = load_breast_cancer(return_X_y=True) + # create an highly imbalanced version of the breast cancer dataset + idx_positive = np.flatnonzero(y == 1) + idx_negative = np.flatnonzero(y == 0) + idx_selected = np.hstack([idx_negative, idx_positive[:25]]) + X, y = X[idx_selected], y[idx_selected] + X, y = shuffle(X, y, random_state=42) + # only use 2 features to make the problem even harder + X = X[:, :2] + y = np.array(["cancer" if c == 1 else "not cancer" for c in y], dtype=object) + X_train, X_test, y_train, y_test = train_test_split( + X, + y, + stratify=y, + random_state=0, + ) + + classifier = LogisticRegression() + classifier.fit(X_train, y_train) + + # sanity check to be sure the positive class is classes_[0] and that we + # are betrayed by the class imbalance + assert classifier.classes_.tolist() == ["cancer", "not cancer"] + + y_pred = getattr(classifier, response_method)(X_test) + # we select the correcponding probability columns or reverse the decision + # function otherwise + y_pred_cancer = -1 * y_pred if y_pred.ndim == 1 else y_pred[:, 0] + y_pred_not_cancer = y_pred if y_pred.ndim == 1 else y_pred[:, 1] + + if constructor_name == "from_estimator": + display = PrecisionRecallDisplay.from_estimator( + classifier, + X_test, + y_test, + pos_label="cancer", + response_method=response_method, + ) + else: + display = PrecisionRecallDisplay.from_predictions( + y_test, + y_pred_cancer, + pos_label="cancer", + ) + # we should obtain the statistics of the "cancer" class + avg_prec_limit = 0.65 + assert display.average_precision < avg_prec_limit + assert -np.trapz(display.precision, display.recall) < avg_prec_limit + + # otherwise we should obtain the statistics of the "not cancer" class + if constructor_name == "from_estimator": + display = PrecisionRecallDisplay.from_estimator( + classifier, + X_test, + y_test, + response_method=response_method, + pos_label="not cancer", + ) + else: + display = PrecisionRecallDisplay.from_predictions( + y_test, + y_pred_not_cancer, + pos_label="not cancer", + ) + avg_prec_limit = 0.95 + assert display.average_precision > avg_prec_limit + assert -np.trapz(display.precision, display.recall) > avg_prec_limit diff --git a/sklearn/metrics/_ranking.py b/sklearn/metrics/_ranking.py index 2e021acbcc331..2c61c61ee10d1 100644 --- a/sklearn/metrics/_ranking.py +++ b/sklearn/metrics/_ranking.py @@ -827,9 +827,10 @@ def precision_recall_curve(y_true, probas_pred, *, pos_label=None, sample_weight See Also -------- - plot_precision_recall_curve : Plot Precision Recall Curve for binary - classifiers. - PrecisionRecallDisplay : Precision Recall visualization. + PrecisionRecallDisplay.from_estimator : Plot Precision Recall Curve given + a binary classifier. + PrecisionRecallDisplay.from_predictions : Plot Precision Recall Curve + using predictions from a binary classifier. average_precision_score : Compute average precision from prediction scores. det_curve: Compute error rates for different probability thresholds. roc_curve : Compute Receiver operating characteristic (ROC) curve.