From d098d1223e09777da8ce379647152653cedf574c Mon Sep 17 00:00:00 2001 From: Bagus Tris Atmaja Date: Tue, 17 Sep 2024 12:10:46 +0900 Subject: [PATCH 01/14] REFAC update ROC curve plotting to use y_score instead of y_pred; deprecate y_pred --- sklearn/metrics/_plot/roc_curve.py | 38 +++++++++++++++++++++++------- sklearn/metrics/_ranking.py | 8 +++---- 2 files changed, 34 insertions(+), 12 deletions(-) diff --git a/sklearn/metrics/_plot/roc_curve.py b/sklearn/metrics/_plot/roc_curve.py index e9d4ca5d5672d..8ff745115bb5a 100644 --- a/sklearn/metrics/_plot/roc_curve.py +++ b/sklearn/metrics/_plot/roc_curve.py @@ -1,6 +1,7 @@ # Authors: The scikit-learn developers # SPDX-License-Identifier: BSD-3-Clause +import warnings from ...utils._plotting import _BinaryClassifierCurveDisplayMixin from .._ranking import auc, roc_curve @@ -275,7 +276,7 @@ def from_estimator( <...> >>> plt.show() """ - y_pred, pos_label, name = cls._validate_and_get_response_values( + y_score, pos_label, name = cls._validate_and_get_response_values( estimator, X, y, @@ -286,7 +287,7 @@ def from_estimator( return cls.from_predictions( y_true=y, - y_pred=y_pred, + y_score=y_score, sample_weight=sample_weight, drop_intermediate=drop_intermediate, name=name, @@ -301,7 +302,7 @@ def from_estimator( def from_predictions( cls, y_true, - y_pred, + y_score, *, sample_weight=None, drop_intermediate=True, @@ -310,6 +311,7 @@ def from_predictions( ax=None, plot_chance_level=False, chance_level_kw=None, + y_pred="deprecated", **kwargs, ): """Plot ROC curve given the true and predicted values. @@ -323,11 +325,15 @@ def from_predictions( y_true : array-like of shape (n_samples,) True labels. - y_pred : array-like of shape (n_samples,) + y_score : array-like of shape (n_samples,) Target scores, can either be probability estimates of the positive class, confidence values, or non-thresholded measure of decisions (as returned by “decision_function” on some classifiers). + .. deprecated:: 1.6 + `y_pred` is deprecated and will be removed in 1.8. Use + `y_score` instead. + sample_weight : array-like of shape (n_samples,), default=None Sample weights. @@ -386,19 +392,35 @@ def from_predictions( >>> X_train, X_test, y_train, y_test = train_test_split( ... X, y, random_state=0) >>> clf = SVC(random_state=0).fit(X_train, y_train) - >>> y_pred = clf.decision_function(X_test) + >>> y_score = clf.decision_function(X_test) >>> RocCurveDisplay.from_predictions( - ... y_test, y_pred) + ... y_test, y_score) <...> >>> plt.show() """ + # TODO(1.8): Remove y_pred support and use y_score only + if y_score is not None and not isinstance(y_pred, str): + raise ValueError( + "`y_pred` and `y_score` cannot be both specified. Please use `y_score`" + " only as `y_pred` is deprecated in v1.6 and will be removed in v1.8." + ) + if y_score is None: + warnings.warn( + ( + "y_pred was deprecated in version 1.6 and will be removed in 1.8. " + "Please use ``y_score`` instead." + ), + FutureWarning, + ) + y_score = y_pred + pos_label_validated, name = cls._validate_from_predictions_params( - y_true, y_pred, sample_weight=sample_weight, pos_label=pos_label, name=name + y_true, y_score, sample_weight=sample_weight, pos_label=pos_label, name=name ) fpr, tpr, _ = roc_curve( y_true, - y_pred, + y_score, pos_label=pos_label, sample_weight=sample_weight, drop_intermediate=drop_intermediate, diff --git a/sklearn/metrics/_ranking.py b/sklearn/metrics/_ranking.py index fd5c30805a0c0..5b2f707bcc153 100644 --- a/sklearn/metrics/_ranking.py +++ b/sklearn/metrics/_ranking.py @@ -73,9 +73,9 @@ def auc(x, y): -------- >>> import numpy as np >>> from sklearn import metrics - >>> y = np.array([1, 1, 2, 2]) - >>> pred = np.array([0.1, 0.4, 0.35, 0.8]) - >>> fpr, tpr, thresholds = metrics.roc_curve(y, pred, pos_label=2) + >>> y_true = np.array([1, 1, 2, 2]) + >>> y_score = np.array([0.1, 0.4, 0.35, 0.8]) + >>> fpr, tpr, thresholds = metrics.roc_curve(y_true, y_score, pos_label=2) >>> metrics.auc(fpr, tpr) np.float64(0.75) """ @@ -985,7 +985,7 @@ def precision_recall_curve( warnings.warn( ( "probas_pred was deprecated in version 1.5 and will be removed in 1.7." - "Please use ``y_score`` instead." + "Please use `y_score` instead." ), FutureWarning, ) From cb2836690481592ca86c3b096bf078da5e5b4c82 Mon Sep 17 00:00:00 2001 From: Bagus Tris Atmaja Date: Fri, 20 Sep 2024 11:18:00 +0900 Subject: [PATCH 02/14] change y_pred to y_score in roc_curve.py --- doc/whats_new/v1.6.rst | 2 ++ sklearn/metrics/_plot/roc_curve.py | 10 +++++----- sklearn/metrics/_ranking.py | 6 +++--- 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/doc/whats_new/v1.6.rst b/doc/whats_new/v1.6.rst index 22a0d7acfd24e..b8792d7080571 100644 --- a/doc/whats_new/v1.6.rst +++ b/doc/whats_new/v1.6.rst @@ -277,6 +277,8 @@ Changelog :mod:`sklearn.metrics` ...................... +- |API| In :func:`sklearn.metrics.roc_curve`, the argument `y_pred` has been renamed to `y_score` to better reflect its purpose. `y_pred` will be removed in 1.8. :pr:`29865` by :user:`Bagus Tris Atmaja `. + - |Enhancement| :func:`sklearn.metrics.check_scoring` now accepts `raise_exc` to specify whether to raise an exception if a subset of the scorers in multimetric scoring fails or to return an error code. :pr:`28992` by :user:`Stefanie Senger `. diff --git a/sklearn/metrics/_plot/roc_curve.py b/sklearn/metrics/_plot/roc_curve.py index 8ff745115bb5a..342cde1b66432 100644 --- a/sklearn/metrics/_plot/roc_curve.py +++ b/sklearn/metrics/_plot/roc_curve.py @@ -68,9 +68,9 @@ class RocCurveDisplay(_BinaryClassifierCurveDisplayMixin): >>> import matplotlib.pyplot as plt >>> import numpy as np >>> from sklearn import metrics - >>> y = np.array([0, 0, 1, 1]) - >>> pred = np.array([0.1, 0.4, 0.35, 0.8]) - >>> fpr, tpr, thresholds = metrics.roc_curve(y, pred) + >>> y_true = np.array([0, 0, 1, 1]) + >>> y_score = np.array([0.1, 0.4, 0.35, 0.8]) + >>> fpr, tpr, thresholds = metrics.roc_curve(y_true, y_score) >>> roc_auc = metrics.auc(fpr, tpr) >>> display = metrics.RocCurveDisplay(fpr=fpr, tpr=tpr, roc_auc=roc_auc, ... estimator_name='example estimator') @@ -333,7 +333,7 @@ def from_predictions( .. deprecated:: 1.6 `y_pred` is deprecated and will be removed in 1.8. Use `y_score` instead. - + sample_weight : array-like of shape (n_samples,), default=None Sample weights. @@ -413,7 +413,7 @@ def from_predictions( FutureWarning, ) y_score = y_pred - + pos_label_validated, name = cls._validate_from_predictions_params( y_true, y_score, sample_weight=sample_weight, pos_label=pos_label, name=name ) diff --git a/sklearn/metrics/_ranking.py b/sklearn/metrics/_ranking.py index 5b2f707bcc153..dc8e0a4d2cd2e 100644 --- a/sklearn/metrics/_ranking.py +++ b/sklearn/metrics/_ranking.py @@ -595,10 +595,10 @@ class scores must correspond to the order of ``labels``, >>> clf = MultiOutputClassifier(clf).fit(X, y) >>> # get a list of n_output containing probability arrays of shape >>> # (n_samples, n_classes) - >>> y_pred = clf.predict_proba(X) + >>> y_score = clf.predict_proba(X) >>> # extract the positive columns for each output - >>> y_pred = np.transpose([pred[:, 1] for pred in y_pred]) - >>> roc_auc_score(y, y_pred, average=None) + >>> y_score = np.transpose([score[:, 1] for score in y_score]) + >>> roc_auc_score(y, y_score, average=None) array([0.82..., 0.86..., 0.94..., 0.85... , 0.94...]) >>> from sklearn.linear_model import RidgeClassifierCV >>> clf = RidgeClassifierCV().fit(X, y) From 67b15b7f7cc26dd9c9f0caaa03f59419ae540af2 Mon Sep 17 00:00:00 2001 From: Bagus Tris Atmaja Date: Fri, 20 Sep 2024 11:23:49 +0900 Subject: [PATCH 03/14] modified: sklearn/metrics/_plot/roc_curve.py --- sklearn/metrics/_plot/roc_curve.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sklearn/metrics/_plot/roc_curve.py b/sklearn/metrics/_plot/roc_curve.py index 342cde1b66432..4f90518d6db87 100644 --- a/sklearn/metrics/_plot/roc_curve.py +++ b/sklearn/metrics/_plot/roc_curve.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: BSD-3-Clause import warnings + from ...utils._plotting import _BinaryClassifierCurveDisplayMixin from .._ranking import auc, roc_curve From 55e172031a2055a18273009203b79e67a848706f Mon Sep 17 00:00:00 2001 From: Bagus Tris Atmaja Date: Tue, 24 Sep 2024 17:03:03 +0900 Subject: [PATCH 04/14] move what's new, revert warn, add condition to raise warning --- doc/whats_new/v1.6.rst | 5 +++-- sklearn/metrics/_plot/roc_curve.py | 15 ++++++++++----- sklearn/metrics/_ranking.py | 2 +- 3 files changed, 14 insertions(+), 8 deletions(-) diff --git a/doc/whats_new/v1.6.rst b/doc/whats_new/v1.6.rst index b8792d7080571..f9e189f649da1 100644 --- a/doc/whats_new/v1.6.rst +++ b/doc/whats_new/v1.6.rst @@ -277,8 +277,6 @@ Changelog :mod:`sklearn.metrics` ...................... -- |API| In :func:`sklearn.metrics.roc_curve`, the argument `y_pred` has been renamed to `y_score` to better reflect its purpose. `y_pred` will be removed in 1.8. :pr:`29865` by :user:`Bagus Tris Atmaja `. - - |Enhancement| :func:`sklearn.metrics.check_scoring` now accepts `raise_exc` to specify whether to raise an exception if a subset of the scorers in multimetric scoring fails or to return an error code. :pr:`28992` by :user:`Stefanie Senger `. @@ -301,6 +299,9 @@ Changelog is renamed into `ensure_all_finite`. `force_all_finite` will be removed in 1.8. :pr:`29404` by :user:`Jérémie du Boisberranger `. +- |API| In :func:`sklearn.metrics.RocCurveDisplay.from_predictions`, the argument `y_pred` has been renamed to `y_score` to better reflect its purpose. `y_pred` will be removed in 1.8. + :pr:`29865` by :user:`Bagus Tris Atmaja `. + :mod:`sklearn.model_selection` .............................. diff --git a/sklearn/metrics/_plot/roc_curve.py b/sklearn/metrics/_plot/roc_curve.py index 4f90518d6db87..f7443ae5be34b 100644 --- a/sklearn/metrics/_plot/roc_curve.py +++ b/sklearn/metrics/_plot/roc_curve.py @@ -331,10 +331,6 @@ def from_predictions( class, confidence values, or non-thresholded measure of decisions (as returned by “decision_function” on some classifiers). - .. deprecated:: 1.6 - `y_pred` is deprecated and will be removed in 1.8. Use - `y_score` instead. - sample_weight : array-like of shape (n_samples,), default=None Sample weights. @@ -367,6 +363,15 @@ def from_predictions( .. versionadded:: 1.3 + y_pred : array-like of shape (n_samples,) + Target scores, can either be probability estimates of the positive + class, confidence values, or non-thresholded measure of decisions + (as returned by “decision_function” on some classifiers). + + .. deprecated:: 1.6 + `y_pred` is deprecated and will be removed in 1.8. Use + `y_score` instead. + **kwargs : dict Additional keywords arguments passed to matplotlib `plot` function. @@ -405,7 +410,7 @@ def from_predictions( "`y_pred` and `y_score` cannot be both specified. Please use `y_score`" " only as `y_pred` is deprecated in v1.6 and will be removed in v1.8." ) - if y_score is None: + if y_pred is not None and y_pred != "deprecated": warnings.warn( ( "y_pred was deprecated in version 1.6 and will be removed in 1.8. " diff --git a/sklearn/metrics/_ranking.py b/sklearn/metrics/_ranking.py index dc8e0a4d2cd2e..344cd9b077de0 100644 --- a/sklearn/metrics/_ranking.py +++ b/sklearn/metrics/_ranking.py @@ -985,7 +985,7 @@ def precision_recall_curve( warnings.warn( ( "probas_pred was deprecated in version 1.5 and will be removed in 1.7." - "Please use `y_score` instead." + "Please use ``y_score`` instead." ), FutureWarning, ) From 9f154218890dc7e5953e6b6af22e83c9fade2a6c Mon Sep 17 00:00:00 2001 From: Bagus Tris Atmaja Date: Fri, 11 Oct 2024 10:06:51 +0900 Subject: [PATCH 05/14] add unit test --- .../_plot/tests/test_roc_curve_display.py | 40 +++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/sklearn/metrics/_plot/tests/test_roc_curve_display.py b/sklearn/metrics/_plot/tests/test_roc_curve_display.py index a4f4d81fb9ded..1f8055dfc0553 100644 --- a/sklearn/metrics/_plot/tests/test_roc_curve_display.py +++ b/sklearn/metrics/_plot/tests/test_roc_curve_display.py @@ -13,6 +13,8 @@ from sklearn.preprocessing import StandardScaler from sklearn.utils import shuffle +import warnings + @pytest.fixture(scope="module") def data(): @@ -313,3 +315,41 @@ def test_plot_roc_curve_pos_label(pyplot, response_method, constructor_name): assert display.roc_auc == pytest.approx(roc_auc_limit) assert trapezoid(display.tpr, display.fpr) == pytest.approx(roc_auc_limit) + + +# TODO(1.8): remove +def test_y_pred_and_y_score_both_specified(): + with pytest.raises( + ValueError, match="`y_pred` and `y_score` cannot be both specified" + ): + RocCurveDisplay.from_predictions( + y_true=np.array([0, 1, 1, 0]), + y_pred=np.array([0.1, 0.4, 0.35, 0.8]), + y_score=np.array([0.2, 0.3, 0.5, 0.1]), + ) + + +def test_y_pred_deprecation_warning(): + with pytest.warns(FutureWarning, match="y_pred was deprecated in version 1.6"): + RocCurveDisplay.from_predictions( + y_true=np.array([0, 1, 1, 0]), y_pred=np.array([0.1, 0.4, 0.35, 0.8]) + ) + + +def test_y_pred_deprecated_value(): + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + display = RocCurveDisplay.from_predictions( + y_true=np.array([0, 1, 1, 0]), + y_pred="deprecated", + y_score=np.array([0.1, 0.4, 0.35, 0.8]), + ) + assert np.array_equal(display.y_score, np.array([0.1, 0.4, 0.35, 0.8])) + + +def test_y_score_used_when_y_pred_not_specified(): + y_score = np.array([0.1, 0.4, 0.35, 0.8]) + display = RocCurveDisplay.from_predictions( + y_true=np.array([0, 1, 1, 0]), y_score=y_score + ) + assert np.array_equal(display.y_score, y_score) From eaeece68bd9a83da6226d5491890ddb5c26f686b Mon Sep 17 00:00:00 2001 From: Bagus Tris Atmaja Date: Fri, 11 Oct 2024 11:16:50 +0900 Subject: [PATCH 06/14] update and fix unit test --- .../_plot/tests/test_roc_curve_display.py | 69 ++++++++----------- 1 file changed, 27 insertions(+), 42 deletions(-) diff --git a/sklearn/metrics/_plot/tests/test_roc_curve_display.py b/sklearn/metrics/_plot/tests/test_roc_curve_display.py index 1f8055dfc0553..a45b9e3e3591e 100644 --- a/sklearn/metrics/_plot/tests/test_roc_curve_display.py +++ b/sklearn/metrics/_plot/tests/test_roc_curve_display.py @@ -13,8 +13,6 @@ from sklearn.preprocessing import StandardScaler from sklearn.utils import shuffle -import warnings - @pytest.fixture(scope="module") def data(): @@ -65,8 +63,8 @@ def test_roc_curve_display_plotting( lr = LogisticRegression() lr.fit(X, y) - y_pred = getattr(lr, response_method)(X) - y_pred = y_pred if y_pred.ndim == 1 else y_pred[:, 1] + y_score = getattr(lr, response_method)(X) + y_score = y_score if y_score.ndim == 1 else y_score[:, 1] if constructor_name == "from_estimator": display = RocCurveDisplay.from_estimator( @@ -81,7 +79,7 @@ def test_roc_curve_display_plotting( else: display = RocCurveDisplay.from_predictions( y, - y_pred, + y_score, sample_weight=sample_weight, drop_intermediate=drop_intermediate, pos_label=pos_label, @@ -90,7 +88,7 @@ def test_roc_curve_display_plotting( fpr, tpr, _ = roc_curve( y, - y_pred, + y_score, sample_weight=sample_weight, drop_intermediate=drop_intermediate, pos_label=pos_label, @@ -145,8 +143,8 @@ def test_roc_curve_chance_level_line( lr = LogisticRegression() lr.fit(X, y) - y_pred = getattr(lr, "predict_proba")(X) - y_pred = y_pred if y_pred.ndim == 1 else y_pred[:, 1] + y_score = getattr(lr, "predict_proba")(X) + y_score = y_score if y_score.ndim == 1 else y_score[:, 1] if constructor_name == "from_estimator": display = RocCurveDisplay.from_estimator( @@ -160,7 +158,7 @@ def test_roc_curve_chance_level_line( else: display = RocCurveDisplay.from_predictions( y, - y_pred, + y_score, alpha=0.8, plot_chance_level=plot_chance_level, chance_level_kw=chance_level_kw, @@ -272,11 +270,11 @@ def test_plot_roc_curve_pos_label(pyplot, response_method, constructor_name): # are betrayed by the class imbalance assert classifier.classes_.tolist() == ["cancer", "not cancer"] - y_pred = getattr(classifier, response_method)(X_test) + y_score = getattr(classifier, response_method)(X_test) # we select the corresponding probability columns or reverse the decision # function otherwise - y_pred_cancer = -1 * y_pred if y_pred.ndim == 1 else y_pred[:, 0] - y_pred_not_cancer = y_pred if y_pred.ndim == 1 else y_pred[:, 1] + y_score_cancer = -1 * y_score if y_score.ndim == 1 else y_score[:, 0] + y_score_not_cancer = y_score if y_score.ndim == 1 else y_score[:, 1] if constructor_name == "from_estimator": display = RocCurveDisplay.from_estimator( @@ -289,7 +287,7 @@ def test_plot_roc_curve_pos_label(pyplot, response_method, constructor_name): else: display = RocCurveDisplay.from_predictions( y_test, - y_pred_cancer, + y_score_cancer, pos_label="cancer", ) @@ -309,7 +307,7 @@ def test_plot_roc_curve_pos_label(pyplot, response_method, constructor_name): else: display = RocCurveDisplay.from_predictions( y_test, - y_pred_not_cancer, + y_score_not_cancer, pos_label="not cancer", ) @@ -318,38 +316,25 @@ def test_plot_roc_curve_pos_label(pyplot, response_method, constructor_name): # TODO(1.8): remove -def test_y_pred_and_y_score_both_specified(): +def test_y_score_and_y_pred_deprecation(): + y_true = np.array([0, 1, 1, 0]) + y_score = np.array([0.1, 0.4, 0.35, 0.8]) + y_pred = np.array([0.2, 0.3, 0.5, 0.1]) + + # Test that an error is raised when both y_score and y_pred are specified with pytest.raises( ValueError, match="`y_pred` and `y_score` cannot be both specified" ): - RocCurveDisplay.from_predictions( - y_true=np.array([0, 1, 1, 0]), - y_pred=np.array([0.1, 0.4, 0.35, 0.8]), - y_score=np.array([0.2, 0.3, 0.5, 0.1]), - ) - - -def test_y_pred_deprecation_warning(): - with pytest.warns(FutureWarning, match="y_pred was deprecated in version 1.6"): - RocCurveDisplay.from_predictions( - y_true=np.array([0, 1, 1, 0]), y_pred=np.array([0.1, 0.4, 0.35, 0.8]) - ) - - -def test_y_pred_deprecated_value(): - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - display = RocCurveDisplay.from_predictions( - y_true=np.array([0, 1, 1, 0]), - y_pred="deprecated", - y_score=np.array([0.1, 0.4, 0.35, 0.8]), - ) - assert np.array_equal(display.y_score, np.array([0.1, 0.4, 0.35, 0.8])) + RocCurveDisplay.from_predictions(y_true, y_score=y_score, y_pred=y_pred) + # Test that y_score is used when y_pred is not specified + display = RocCurveDisplay.from_predictions(y_true, y_score=y_score) + assert_allclose(display.fpr, [0, 0.5, 0.5, 1]) + assert_allclose(display.tpr, [0, 0, 1, 1]) -def test_y_score_used_when_y_pred_not_specified(): - y_score = np.array([0.1, 0.4, 0.35, 0.8]) + # Test that y_score is used when y_pred is "deprecated" display = RocCurveDisplay.from_predictions( - y_true=np.array([0, 1, 1, 0]), y_score=y_score + y_true, y_score=y_score, y_pred="deprecated" ) - assert np.array_equal(display.y_score, y_score) + assert_allclose(display.fpr, [0, 0.5, 0.5, 1]) + assert_allclose(display.tpr, [0, 0, 1, 1]) From fd381db4d5826960a12d11a969e7eb4badeace6d Mon Sep 17 00:00:00 2001 From: Bagus Tris Atmaja Date: Fri, 11 Oct 2024 13:29:20 +0900 Subject: [PATCH 07/14] add test for y_pred depreciation warning --- .../_plot/tests/test_roc_curve_display.py | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/sklearn/metrics/_plot/tests/test_roc_curve_display.py b/sklearn/metrics/_plot/tests/test_roc_curve_display.py index a45b9e3e3591e..2e43a07869c8f 100644 --- a/sklearn/metrics/_plot/tests/test_roc_curve_display.py +++ b/sklearn/metrics/_plot/tests/test_roc_curve_display.py @@ -338,3 +338,23 @@ def test_y_score_and_y_pred_deprecation(): ) assert_allclose(display.fpr, [0, 0.5, 0.5, 1]) assert_allclose(display.tpr, [0, 0, 1, 1]) + + +def test_y_pred_deprecation_warning(): + y_true = np.array([0, 1, 1, 0]) + y_pred = np.array([0.1, 0.4, 0.35, 0.8]) + + with pytest.warns(FutureWarning, match="y_pred was deprecated in version 1.6"): + display = RocCurveDisplay.from_predictions(y_true, y_pred=y_pred) + + assert_allclose(display.fpr, [0, 0.5, 0.5, 1]) + assert_allclose(display.tpr, [0, 0, 1, 1]) + + # Test that y_score is used when y_pred is "deprecated" + y_score = np.array([0.2, 0.3, 0.5, 0.1]) + display = RocCurveDisplay.from_predictions( + y_true, y_score=y_score, y_pred="deprecated" + ) + assert_allclose(display.fpr, [0, 0.5, 0.5, 1]) + assert_allclose(display.tpr, [0, 0, 1, 1]) + assert np.array_equal(display.y_score, y_score) From 1b402c53f90734e7713571cea54277353e2209e3 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 11 Oct 2024 17:53:14 +0200 Subject: [PATCH 08/14] remove trailing spaces --- doc/whats_new/v1.6.rst | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/doc/whats_new/v1.6.rst b/doc/whats_new/v1.6.rst index 2349ca464503b..53bed57a01503 100644 --- a/doc/whats_new/v1.6.rst +++ b/doc/whats_new/v1.6.rst @@ -316,9 +316,9 @@ Changelog is renamed into `ensure_all_finite`. `force_all_finite` will be removed in 1.8. :pr:`29404` by :user:`Jérémie du Boisberranger `. -- |API| In :func:`sklearn.metrics.RocCurveDisplay.from_predictions`, - the argument `y_pred` has been renamed to `y_score` to better reflect its purpose. - `y_pred` will be removed in 1.8. +- |API| In :func:`sklearn.metrics.RocCurveDisplay.from_predictions`, + the argument `y_pred` has been renamed to `y_score` to better reflect its purpose. + `y_pred` will be removed in 1.8. :pr:`29865` by :user:`Bagus Tris Atmaja `. - |Fix| the functions :func:`metrics.mean_squared_log_error` and @@ -333,8 +333,6 @@ Changelog will be returned when `multioutput=uniform_average`. :pr:`29709` by :user:`Virgil Chan `. - - :mod:`sklearn.model_selection` .............................. From 389aa8ee275e81b105e490866713392eef4375bb Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Thu, 2 Jan 2025 13:43:30 +0100 Subject: [PATCH 09/14] fix --- sklearn/metrics/_plot/roc_curve.py | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/sklearn/metrics/_plot/roc_curve.py b/sklearn/metrics/_plot/roc_curve.py index f7443ae5be34b..6042639f8c5fe 100644 --- a/sklearn/metrics/_plot/roc_curve.py +++ b/sklearn/metrics/_plot/roc_curve.py @@ -303,7 +303,7 @@ def from_estimator( def from_predictions( cls, y_true, - y_score, + y_score=None, *, sample_weight=None, drop_intermediate=True, @@ -331,6 +331,9 @@ def from_predictions( class, confidence values, or non-thresholded measure of decisions (as returned by “decision_function” on some classifiers). + .. versionadded:: 1.7 + `y_score` is now used instead of `y_pred`. + sample_weight : array-like of shape (n_samples,), default=None Sample weights. @@ -368,8 +371,8 @@ def from_predictions( class, confidence values, or non-thresholded measure of decisions (as returned by “decision_function” on some classifiers). - .. deprecated:: 1.6 - `y_pred` is deprecated and will be removed in 1.8. Use + .. deprecated:: 1.7 + `y_pred` is deprecated and will be removed in 1.9. Use `y_score` instead. **kwargs : dict @@ -399,22 +402,23 @@ def from_predictions( ... X, y, random_state=0) >>> clf = SVC(random_state=0).fit(X_train, y_train) >>> y_score = clf.decision_function(X_test) - >>> RocCurveDisplay.from_predictions( - ... y_test, y_score) + >>> RocCurveDisplay.from_predictions(y_test, y_score) <...> >>> plt.show() """ - # TODO(1.8): Remove y_pred support and use y_score only - if y_score is not None and not isinstance(y_pred, str): + # TODO(1.8): remove after the end of the deprecation period of `y_pred` + if y_score is not None and ( + not isinstance(y_pred, str) or y_pred != "deprecated" + ): raise ValueError( "`y_pred` and `y_score` cannot be both specified. Please use `y_score`" - " only as `y_pred` is deprecated in v1.6 and will be removed in v1.8." + " only as `y_pred` is deprecated in 1.7 and will be removed in 1.9." ) - if y_pred is not None and y_pred != "deprecated": + if not isinstance(y_pred, str) or y_pred != "deprecated": warnings.warn( ( - "y_pred was deprecated in version 1.6 and will be removed in 1.8. " - "Please use ``y_score`` instead." + "y_pred is deprecated in 1.7 and will be removed in 1.9. " + "Please use `y_score` instead." ), FutureWarning, ) From b0b7b70300abcc9cae4e4b15d90462b3c2b4a05b Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Thu, 2 Jan 2025 14:28:34 +0100 Subject: [PATCH 10/14] fix test --- .../_plot/tests/test_roc_curve_display.py | 41 ++++++------------- 1 file changed, 13 insertions(+), 28 deletions(-) diff --git a/sklearn/metrics/_plot/tests/test_roc_curve_display.py b/sklearn/metrics/_plot/tests/test_roc_curve_display.py index 7d3f54494e0f2..8f6e12b54f5f8 100644 --- a/sklearn/metrics/_plot/tests/test_roc_curve_display.py +++ b/sklearn/metrics/_plot/tests/test_roc_curve_display.py @@ -336,49 +336,34 @@ def test_plot_roc_curve_pos_label(pyplot, response_method, constructor_name): assert trapezoid(display.tpr, display.fpr) == pytest.approx(roc_auc_limit) -# TODO(1.8): remove -def test_y_score_and_y_pred_deprecation(): +# TODO(1.9): remove +def test_y_score_and_y_pred_specified_error(): + """Check that an error is raised when both y_score and y_pred are specified.""" y_true = np.array([0, 1, 1, 0]) y_score = np.array([0.1, 0.4, 0.35, 0.8]) y_pred = np.array([0.2, 0.3, 0.5, 0.1]) - # Test that an error is raised when both y_score and y_pred are specified with pytest.raises( ValueError, match="`y_pred` and `y_score` cannot be both specified" ): RocCurveDisplay.from_predictions(y_true, y_score=y_score, y_pred=y_pred) - # Test that y_score is used when y_pred is not specified - display = RocCurveDisplay.from_predictions(y_true, y_score=y_score) - assert_allclose(display.fpr, [0, 0.5, 0.5, 1]) - assert_allclose(display.tpr, [0, 0, 1, 1]) - - # Test that y_score is used when y_pred is "deprecated" - display = RocCurveDisplay.from_predictions( - y_true, y_score=y_score, y_pred="deprecated" - ) - assert_allclose(display.fpr, [0, 0.5, 0.5, 1]) - assert_allclose(display.tpr, [0, 0, 1, 1]) - +# TODO(1.9): remove def test_y_pred_deprecation_warning(): + """Check that a warning is raised when y_pred is specified.""" y_true = np.array([0, 1, 1, 0]) - y_pred = np.array([0.1, 0.4, 0.35, 0.8]) + y_score = np.array([0.1, 0.4, 0.35, 0.8]) - with pytest.warns(FutureWarning, match="y_pred was deprecated in version 1.6"): - display = RocCurveDisplay.from_predictions(y_true, y_pred=y_pred) + with pytest.warns(FutureWarning, match="y_pred is deprecated in 1.7"): + display_y_pred = RocCurveDisplay.from_predictions(y_true, y_pred=y_score) - assert_allclose(display.fpr, [0, 0.5, 0.5, 1]) - assert_allclose(display.tpr, [0, 0, 1, 1]) + assert_allclose(display_y_pred.fpr, [0, 0.5, 0.5, 1]) + assert_allclose(display_y_pred.tpr, [0, 0, 1, 1]) - # Test that y_score is used when y_pred is "deprecated" - y_score = np.array([0.2, 0.3, 0.5, 0.1]) - display = RocCurveDisplay.from_predictions( - y_true, y_score=y_score, y_pred="deprecated" - ) - assert_allclose(display.fpr, [0, 0.5, 0.5, 1]) - assert_allclose(display.tpr, [0, 0, 1, 1]) - assert np.array_equal(display.y_score, y_score) + display_y_score = RocCurveDisplay.from_predictions(y_true, y_score) + assert_allclose(display_y_score.fpr, [0, 0.5, 0.5, 1]) + assert_allclose(display_y_score.tpr, [0, 0, 1, 1]) @pytest.mark.parametrize("despine", [True, False]) From f809813859fe1f7be29001ca66f456b862e60486 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Thu, 2 Jan 2025 14:31:16 +0100 Subject: [PATCH 11/14] remove tag changelog --- doc/whats_new/upcoming_changes/sklearn.metrics/29865.api.rst | 2 +- doc/whats_new/v1.6.rst | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/doc/whats_new/upcoming_changes/sklearn.metrics/29865.api.rst b/doc/whats_new/upcoming_changes/sklearn.metrics/29865.api.rst index 8e171a1167b3a..60ea7d83de71f 100644 --- a/doc/whats_new/upcoming_changes/sklearn.metrics/29865.api.rst +++ b/doc/whats_new/upcoming_changes/sklearn.metrics/29865.api.rst @@ -1,4 +1,4 @@ -- |API| In :meth:`sklearn.metrics.RocCurveDisplay.from_predictions`, +- In :meth:`sklearn.metrics.RocCurveDisplay.from_predictions`, the argument `y_pred` has been renamed to `y_score` to better reflect its purpose. `y_pred` will be removed in 1.9. By :user:`Bagus Tris Atmaja ` in diff --git a/doc/whats_new/v1.6.rst b/doc/whats_new/v1.6.rst index 4a1c9e11a9892..47f8260820987 100644 --- a/doc/whats_new/v1.6.rst +++ b/doc/whats_new/v1.6.rst @@ -724,4 +724,3 @@ Søren Bredlund Caspersen, Stefanie Senger, Steffen Schneider, Štěpán Sršeň, Sylvain Combettes, Tamara, Thomas, Thomas Gessey-Jones, Thomas J. Fan, Thomas Li, Tialo, Tim Head, Tuhin Sharma, Tushar Parimi, vedpawar2254, Victoria Shevchenko, viktor765, Vince Carey, Virgil Chan, Wang Jiayi, Xiao Yuan, Xuefeng -Xu, Yao Xiao, yareyaredesuyo, Zachary Vealey, Ziad Amerr From 7721bb10c2a5ee11fb0012bfc5e2066fc3776d4b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=A9mie=20du=20Boisberranger?= Date: Sat, 12 Apr 2025 15:24:34 +0200 Subject: [PATCH 12/14] replace in examples + nitpicks --- .../plot_outlier_detection_bench.py | 30 +++++++++---------- sklearn/metrics/_plot/roc_curve.py | 10 +++---- 2 files changed, 20 insertions(+), 20 deletions(-) diff --git a/examples/miscellaneous/plot_outlier_detection_bench.py b/examples/miscellaneous/plot_outlier_detection_bench.py index cef58e20d75eb..600eceb1a06b3 100644 --- a/examples/miscellaneous/plot_outlier_detection_bench.py +++ b/examples/miscellaneous/plot_outlier_detection_bench.py @@ -88,12 +88,12 @@ def fit_predict(estimator, X): tic = perf_counter() if estimator[-1].__class__.__name__ == "LocalOutlierFactor": estimator.fit(X) - y_pred = estimator[-1].negative_outlier_factor_ + y_score = estimator[-1].negative_outlier_factor_ else: # "IsolationForest" - y_pred = estimator.fit(X).decision_function(X) + y_score = estimator.fit(X).decision_function(X) toc = perf_counter() print(f"Duration for {model_name}: {toc - tic:.2f} s") - return y_pred + return y_score # %% @@ -138,7 +138,7 @@ def fit_predict(estimator, X): # %% y_true = {} -y_pred = {"LOF": {}, "IForest": {}} +y_score = {"LOF": {}, "IForest": {}} model_names = ["LOF", "IForest"] cat_columns = ["protocol_type", "service", "flag"] @@ -150,7 +150,7 @@ def fit_predict(estimator, X): lof_kw={"n_neighbors": int(n_samples * anomaly_frac)}, iforest_kw={"random_state": 42}, ) - y_pred[model_name]["KDDCup99 - SA"] = fit_predict(model, X) + y_score[model_name]["KDDCup99 - SA"] = fit_predict(model, X) # %% # Forest covertypes dataset @@ -185,7 +185,7 @@ def fit_predict(estimator, X): lof_kw={"n_neighbors": int(n_samples * anomaly_frac)}, iforest_kw={"random_state": 42}, ) - y_pred[model_name]["forestcover"] = fit_predict(model, X) + y_score[model_name]["forestcover"] = fit_predict(model, X) # %% # Ames Housing dataset @@ -242,7 +242,7 @@ def fit_predict(estimator, X): lof_kw={"n_neighbors": int(n_samples * anomaly_frac)}, iforest_kw={"random_state": 42}, ) - y_pred[model_name]["ames_housing"] = fit_predict(model, X) + y_score[model_name]["ames_housing"] = fit_predict(model, X) # %% # Cardiotocography dataset @@ -271,7 +271,7 @@ def fit_predict(estimator, X): lof_kw={"n_neighbors": int(n_samples * anomaly_frac)}, iforest_kw={"random_state": 42}, ) - y_pred[model_name]["cardiotocography"] = fit_predict(model, X) + y_score[model_name]["cardiotocography"] = fit_predict(model, X) # %% # Plot and interpret results @@ -299,7 +299,7 @@ def fit_predict(estimator, X): for model_idx, model_name in enumerate(model_names): display = RocCurveDisplay.from_predictions( y_true[dataset_name], - y_pred[model_name][dataset_name], + y_score[model_name][dataset_name], pos_label=pos_label, name=model_name, ax=ax, @@ -346,10 +346,10 @@ def fit_predict(estimator, X): for model_idx, (linestyle, n_neighbors) in enumerate(zip(linestyles, n_neighbors_list)): model.set_params(localoutlierfactor__n_neighbors=n_neighbors) model.fit(X) - y_pred = model[-1].negative_outlier_factor_ + y_score = model[-1].negative_outlier_factor_ display = RocCurveDisplay.from_predictions( y, - y_pred, + y_score, pos_label=pos_label, name=f"n_neighbors = {n_neighbors}", ax=ax, @@ -386,10 +386,10 @@ def fit_predict(estimator, X): ): model = make_pipeline(preprocessor, lof) model.fit(X) - y_pred = model[-1].negative_outlier_factor_ + y_score = model[-1].negative_outlier_factor_ display = RocCurveDisplay.from_predictions( y, - y_pred, + y_score, pos_label=pos_label, name=str(preprocessor).split("(")[0], ax=ax, @@ -438,10 +438,10 @@ def fit_predict(estimator, X): ): model = make_pipeline(preprocessor, lof) model.fit(X) - y_pred = model[-1].negative_outlier_factor_ + y_score = model[-1].negative_outlier_factor_ display = RocCurveDisplay.from_predictions( y, - y_pred, + y_score, pos_label=pos_label, name=str(preprocessor).split("(")[0], ax=ax, diff --git a/sklearn/metrics/_plot/roc_curve.py b/sklearn/metrics/_plot/roc_curve.py index ee3b5964f8b01..cc467296cfed1 100644 --- a/sklearn/metrics/_plot/roc_curve.py +++ b/sklearn/metrics/_plot/roc_curve.py @@ -358,7 +358,7 @@ def from_predictions( (as returned by “decision_function” on some classifiers). .. versionadded:: 1.7 - `y_score` is now used instead of `y_pred`. + `y_pred` has been renamed to `y_score`. sample_weight : array-like of shape (n_samples,), default=None Sample weights. @@ -437,15 +437,15 @@ def from_predictions( <...> >>> plt.show() """ - # TODO(1.8): remove after the end of the deprecation period of `y_pred` - if y_score is not None and ( - not isinstance(y_pred, str) or y_pred != "deprecated" + # TODO(1.9): remove after the end of the deprecation period of `y_pred` + if y_score is not None and not ( + isinstance(y_pred, str) and y_pred == "deprecated" ): raise ValueError( "`y_pred` and `y_score` cannot be both specified. Please use `y_score`" " only as `y_pred` is deprecated in 1.7 and will be removed in 1.9." ) - if not isinstance(y_pred, str) or y_pred != "deprecated": + if not (isinstance(y_pred, str) and y_pred == "deprecated"): warnings.warn( ( "y_pred is deprecated in 1.7 and will be removed in 1.9. " From fc8df122c316cff481ed46c40fa75cbc08f0583e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=A9mie=20du=20Boisberranger?= Date: Sat, 12 Apr 2025 15:25:17 +0200 Subject: [PATCH 13/14] [doc build] From 4435e091ad4f8420997b0adf58d77a8cdedc13e8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=A9mie=20du=20Boisberranger?= Date: Sat, 12 Apr 2025 16:01:23 +0200 Subject: [PATCH 14/14] iter [doc build] [azure parallel] --- sklearn/metrics/_plot/tests/test_roc_curve_display.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/metrics/_plot/tests/test_roc_curve_display.py b/sklearn/metrics/_plot/tests/test_roc_curve_display.py index 93ef019a51602..c2e6c865fa9a9 100644 --- a/sklearn/metrics/_plot/tests/test_roc_curve_display.py +++ b/sklearn/metrics/_plot/tests/test_roc_curve_display.py @@ -365,7 +365,7 @@ def test_y_score_and_y_pred_specified_error(): # TODO(1.9): remove -def test_y_pred_deprecation_warning(): +def test_y_pred_deprecation_warning(pyplot): """Check that a warning is raised when y_pred is specified.""" y_true = np.array([0, 1, 1, 0]) y_score = np.array([0.1, 0.4, 0.35, 0.8])