diff --git a/doc/whats_new/upcoming_changes/sklearn.metrics/30399.feature.rst b/doc/whats_new/upcoming_changes/sklearn.metrics/30399.feature.rst new file mode 100644 index 0000000000000..c3b6d77c5aefb --- /dev/null +++ b/doc/whats_new/upcoming_changes/sklearn.metrics/30399.feature.rst @@ -0,0 +1,4 @@ +- Add class method `from_cv_results` to :class:`metrics.RocCurveDisplay`, which allows + easy plotting of multiple ROC curves from :func:`model_selection.cross_validate` + results. + By :user:`Lucy Liu ` diff --git a/sklearn/metrics/_plot/roc_curve.py b/sklearn/metrics/_plot/roc_curve.py index cc467296cfed1..586366dfbf2f4 100644 --- a/sklearn/metrics/_plot/roc_curve.py +++ b/sklearn/metrics/_plot/roc_curve.py @@ -1,13 +1,21 @@ # Authors: The scikit-learn developers # SPDX-License-Identifier: BSD-3-Clause + import warnings +import numpy as np + +from ...utils import _safe_indexing from ...utils._plotting import ( _BinaryClassifierCurveDisplayMixin, + _check_param_lengths, + _convert_to_list_leaving_none, + _deprecate_estimator_name, _despine, _validate_style_kwargs, ) +from ...utils._response import _get_response_values_binary from .._ranking import auc, roc_curve @@ -16,25 +24,50 @@ class RocCurveDisplay(_BinaryClassifierCurveDisplayMixin): It is recommend to use :func:`~sklearn.metrics.RocCurveDisplay.from_estimator` or - :func:`~sklearn.metrics.RocCurveDisplay.from_predictions` to create + :func:`~sklearn.metrics.RocCurveDisplay.from_predictions` or + :func:`~sklearn.metrics.RocCurveDisplay.from_cv_results` to create a :class:`~sklearn.metrics.RocCurveDisplay`. All parameters are stored as attributes. - Read more in the :ref:`User Guide `. + For more about the ROC metric, see :ref:`roc_metrics`. + For more about scikit-learn visualization classes, see :ref:`visualizations`. Parameters ---------- - fpr : ndarray - False positive rate. + fpr : ndarray or list of ndarrays + False positive rates. Each ndarray should contain values for a single curve. + If plotting multiple curves, list should be of same length as `tpr`. - tpr : ndarray - True positive rate. + .. versionchanged:: 1.7 + Now accepts a list for plotting multiple curves. - roc_auc : float, default=None - Area under ROC curve. If None, the roc_auc score is not shown. + tpr : ndarray or list of ndarrays + True positive rates. Each ndarray should contain values for a single curve. + If plotting multiple curves, list should be of same length as `fpr`. - estimator_name : str, default=None - Name of estimator. If None, the estimator name is not shown. + .. versionchanged:: 1.7 + Now accepts a list for plotting multiple curves. + + roc_auc : float or list of floats, default=None + Area under ROC curve, used for labeling each curve in the legend. + If plotting multiple curves, should be a list of the same length as `fpr` + and `tpr`. If `None`, ROC AUC scores are not shown in the legend. + + .. versionchanged:: 1.7 + Now accepts a list for plotting multiple curves. + + name : str or list of str, default=None + Name for labeling legend entries. The number of legend entries + is determined by the `curve_kwargs` passed to `plot`. + To label each curve, provide a list of strings. To avoid labeling + individual curves that have the same appearance, this cannot be used in + conjunction with `curve_kwargs` being a dictionary or None. If a + string is provided, it will be used to either label the single legend entry + or if there are multiple legend entries, label each individual curve with + the same name. If `None`, set to `name` provided at `RocCurveDisplay` + initialization. If still `None`, no name is shown in the legend. + + .. versionadded:: 1.7 pos_label : int, float, bool or str, default=None The class considered as the positive class when computing the roc auc @@ -43,10 +76,21 @@ class RocCurveDisplay(_BinaryClassifierCurveDisplayMixin): .. versionadded:: 0.24 + estimator_name : str, default=None + Name of estimator. If None, the estimator name is not shown. + + .. deprecated:: 1.7 + `estimator_name` is deprecated and will be removed in 1.9. Use `name` + instead. + Attributes ---------- - line_ : matplotlib Artist - ROC Curve. + line_ : matplotlib Artist or list of matplotlib Artists + ROC Curves. + + .. versionchanged:: 1.7 + This attribute can now be a list of Artists, for when multiple curves are + plotted. chance_level_ : matplotlib Artist or None The chance level line. It is `None` if the chance level is not plotted. @@ -78,24 +122,52 @@ class RocCurveDisplay(_BinaryClassifierCurveDisplayMixin): >>> fpr, tpr, thresholds = metrics.roc_curve(y_true, y_score) >>> roc_auc = metrics.auc(fpr, tpr) >>> display = metrics.RocCurveDisplay(fpr=fpr, tpr=tpr, roc_auc=roc_auc, - ... estimator_name='example estimator') + ... name='example estimator') >>> display.plot() <...> >>> plt.show() """ - def __init__(self, *, fpr, tpr, roc_auc=None, estimator_name=None, pos_label=None): - self.estimator_name = estimator_name + def __init__( + self, + *, + fpr, + tpr, + roc_auc=None, + name=None, + pos_label=None, + estimator_name="deprecated", + ): self.fpr = fpr self.tpr = tpr self.roc_auc = roc_auc + self.name = _deprecate_estimator_name(estimator_name, name, "1.7") self.pos_label = pos_label + def _validate_plot_params(self, *, ax, name): + self.ax_, self.figure_, name = super()._validate_plot_params(ax=ax, name=name) + + fpr = _convert_to_list_leaving_none(self.fpr) + tpr = _convert_to_list_leaving_none(self.tpr) + roc_auc = _convert_to_list_leaving_none(self.roc_auc) + name = _convert_to_list_leaving_none(name) + + optional = {"self.roc_auc": roc_auc} + if isinstance(name, list) and len(name) != 1: + optional.update({"'name' (or self.name)": name}) + _check_param_lengths( + required={"self.fpr": fpr, "self.tpr": tpr}, + optional=optional, + class_name="RocCurveDisplay", + ) + return fpr, tpr, roc_auc, name + def plot( self, ax=None, *, name=None, + curve_kwargs=None, plot_chance_level=False, chance_level_kw=None, despine=False, @@ -103,17 +175,36 @@ def plot( ): """Plot visualization. - Extra keyword arguments will be passed to matplotlib's ``plot``. - Parameters ---------- ax : matplotlib axes, default=None Axes object to plot on. If `None`, a new figure and axes is created. - name : str, default=None - Name of ROC Curve for labeling. If `None`, use `estimator_name` if - not `None`, otherwise no labeling is shown. + name : str or list of str, default=None + Name for labeling legend entries. The number of legend entries + is determined by `curve_kwargs`. + To label each curve, provide a list of strings. To avoid labeling + individual curves that have the same appearance, this cannot be used in + conjunction with `curve_kwargs` being a dictionary or None. If a + string is provided, it will be used to either label the single legend entry + or if there are multiple legend entries, label each individual curve with + the same name. If `None`, set to `name` provided at `RocCurveDisplay` + initialization. If still `None`, no name is shown in the legend. + + .. versionadded:: 1.7 + + curve_kwargs : dict or list of dict, default=None + Keywords arguments to be passed to matplotlib's `plot` function + to draw individual ROC curves. For single curve plotting, should be + a dictionary. For multi-curve plotting, if a list is provided the + parameters are applied to the ROC curves of each CV fold + sequentially and a legend entry is added for each curve. + If a single dictionary is provided, the same parameters are applied + to all ROC curves and a single legend entry for all curves is added, + labeled with the mean ROC AUC score. + + .. versionadded:: 1.7 plot_chance_level : bool, default=False Whether to plot the chance level. @@ -134,22 +225,34 @@ def plot( **kwargs : dict Keyword arguments to be passed to matplotlib's `plot`. + .. deprecated:: 1.7 + kwargs is deprecated and will be removed in 1.9. Pass matplotlib + arguments to `curve_kwargs` as a dictionary instead. + Returns ------- display : :class:`~sklearn.metrics.RocCurveDisplay` Object that stores computed values. """ - self.ax_, self.figure_, name = self._validate_plot_params(ax=ax, name=name) - - default_line_kwargs = {} - if self.roc_auc is not None and name is not None: - default_line_kwargs["label"] = f"{name} (AUC = {self.roc_auc:0.2f})" - elif self.roc_auc is not None: - default_line_kwargs["label"] = f"AUC = {self.roc_auc:0.2f}" - elif name is not None: - default_line_kwargs["label"] = name - - line_kwargs = _validate_style_kwargs(default_line_kwargs, kwargs) + fpr, tpr, roc_auc, name = self._validate_plot_params(ax=ax, name=name) + n_curves = len(fpr) + if not isinstance(curve_kwargs, list) and n_curves > 1: + if roc_auc: + legend_metric = {"mean": np.mean(roc_auc), "std": np.std(roc_auc)} + else: + legend_metric = {"mean": None, "std": None} + else: + roc_auc = roc_auc if roc_auc is not None else [None] * n_curves + legend_metric = {"metric": roc_auc} + + curve_kwargs = self._validate_curve_kwargs( + n_curves, + name, + legend_metric, + "AUC", + curve_kwargs=curve_kwargs, + **kwargs, + ) default_chance_level_line_kw = { "label": "Chance level (AUC = 0.5)", @@ -164,7 +267,13 @@ def plot( default_chance_level_line_kw, chance_level_kw ) - (self.line_,) = self.ax_.plot(self.fpr, self.tpr, **line_kwargs) + self.line_ = [] + for fpr, tpr, line_kw in zip(fpr, tpr, curve_kwargs): + self.line_.extend(self.ax_.plot(fpr, tpr, **line_kw)) + # Return single artist if only one curve is plotted + if len(self.line_) == 1: + self.line_ = self.line_[0] + info_pos_label = ( f" (Positive label: {self.pos_label})" if self.pos_label is not None else "" ) @@ -187,9 +296,8 @@ def plot( if despine: _despine(self.ax_) - if ( - line_kwargs.get("label") is not None - or chance_level_kw.get("label") is not None + if curve_kwargs[0].get("label") is not None or ( + plot_chance_level and chance_level_kw.get("label") is not None ): self.ax_.legend(loc="lower right") @@ -208,6 +316,7 @@ def from_estimator( pos_label=None, name=None, ax=None, + curve_kwargs=None, plot_chance_level=False, chance_level_kw=None, despine=False, @@ -243,8 +352,8 @@ def from_estimator( :term:`decision_function` is tried next. pos_label : int, float, bool or str, default=None - The class considered as the positive class when computing the roc auc - metrics. By default, `estimators.classes_[1]` is considered + The class considered as the positive class when computing the ROC AUC. + By default, `estimators.classes_[1]` is considered as the positive class. name : str, default=None @@ -254,6 +363,11 @@ def from_estimator( ax : matplotlib axes, default=None Axes object to plot on. If `None`, a new figure and axes is created. + curve_kwargs : dict, default=None + Keywords arguments to be passed to matplotlib's `plot` function. + + .. versionadded:: 1.7 + plot_chance_level : bool, default=False Whether to plot the chance level. @@ -273,6 +387,10 @@ def from_estimator( **kwargs : dict Keyword arguments to be passed to matplotlib's `plot`. + .. deprecated:: 1.7 + kwargs is deprecated and will be removed in 1.9. Pass matplotlib + arguments to `curve_kwargs` as a dictionary instead. + Returns ------- display : :class:`~sklearn.metrics.RocCurveDisplay` @@ -318,6 +436,7 @@ def from_estimator( name=name, ax=ax, pos_label=pos_label, + curve_kwargs=curve_kwargs, plot_chance_level=plot_chance_level, chance_level_kw=chance_level_kw, despine=despine, @@ -335,6 +454,7 @@ def from_predictions( pos_label=None, name=None, ax=None, + curve_kwargs=None, plot_chance_level=False, chance_level_kw=None, despine=False, @@ -369,18 +489,23 @@ def from_predictions( ROC curves. pos_label : int, float, bool or str, default=None - The label of the positive class. When `pos_label=None`, if `y_true` - is in {-1, 1} or {0, 1}, `pos_label` is set to 1, otherwise an - error will be raised. + The label of the positive class when computing the ROC AUC. + When `pos_label=None`, if `y_true` is in {-1, 1} or {0, 1}, `pos_label` + is set to 1, otherwise an error will be raised. name : str, default=None - Name of ROC curve for labeling. If `None`, name will be set to + Name of ROC curve for legend labeling. If `None`, name will be set to `"Classifier"`. ax : matplotlib axes, default=None Axes object to plot on. If `None`, a new figure and axes is created. + curve_kwargs : dict, default=None + Keywords arguments to be passed to matplotlib's `plot` function. + + .. versionadded:: 1.7 + plot_chance_level : bool, default=False Whether to plot the chance level. @@ -409,6 +534,10 @@ def from_predictions( **kwargs : dict Additional keywords arguments passed to matplotlib `plot` function. + .. deprecated:: 1.7 + kwargs is deprecated and will be removed in 1.9. Pass matplotlib + arguments to `curve_kwargs` as a dictionary instead. + Returns ------- display : :class:`~sklearn.metrics.RocCurveDisplay` @@ -472,15 +601,184 @@ def from_predictions( fpr=fpr, tpr=tpr, roc_auc=roc_auc, - estimator_name=name, + name=name, pos_label=pos_label_validated, ) return viz.plot( ax=ax, - name=name, + curve_kwargs=curve_kwargs, plot_chance_level=plot_chance_level, chance_level_kw=chance_level_kw, despine=despine, **kwargs, ) + + @classmethod + def from_cv_results( + cls, + cv_results, + X, + y, + *, + sample_weight=None, + drop_intermediate=True, + response_method="auto", + pos_label=None, + ax=None, + name=None, + curve_kwargs=None, + plot_chance_level=False, + chance_level_kwargs=None, + despine=False, + ): + """Create a multi-fold ROC curve display given cross-validation results. + + .. versionadded:: 1.7 + + Parameters + ---------- + cv_results : dict + Dictionary as returned by :func:`~sklearn.model_selection.cross_validate` + using `return_estimator=True` and `return_indices=True` (i.e., dictionary + should contain the keys "estimator" and "indices"). + + X : {array-like, sparse matrix} of shape (n_samples, n_features) + Input values. + + y : array-like of shape (n_samples,) + Target values. + + sample_weight : array-like of shape (n_samples,), default=None + Sample weights. + + drop_intermediate : bool, default=True + Whether to drop some suboptimal thresholds which would not appear + on a plotted ROC curve. This is useful in order to create lighter + ROC curves. + + response_method : {'predict_proba', 'decision_function', 'auto'} \ + default='auto' + Specifies whether to use :term:`predict_proba` or + :term:`decision_function` as the target response. If set to 'auto', + :term:`predict_proba` is tried first and if it does not exist + :term:`decision_function` is tried next. + + pos_label : int, float, bool or str, default=None + The class considered as the positive class when computing the ROC AUC + metrics. By default, `estimators.classes_[1]` is considered + as the positive class. + + ax : matplotlib axes, default=None + Axes object to plot on. If `None`, a new figure and axes is + created. + + name : str or list of str, default=None + Name for labeling legend entries. The number of legend entries + is determined by `curve_kwargs`. + To label each curve, provide a list of strings. To avoid labeling + individual curves that have the same appearance, this cannot be used in + conjunction with `curve_kwargs` being a dictionary or None. If a + string is provided, it will be used to either label the single legend entry + or if there are multiple legend entries, label each individual curve with + the same name. If `None`, no name is shown in the legend. + + curve_kwargs : dict or list of dict, default=None + Keywords arguments to be passed to matplotlib's `plot` function + to draw individual ROC curves. If a list is provided the + parameters are applied to the ROC curves of each CV fold + sequentially and a legend entry is added for each curve. + If a single dictionary is provided, the same parameters are applied + to all ROC curves and a single legend entry for all curves is added, + labeled with the mean ROC AUC score. + + plot_chance_level : bool, default=False + Whether to plot the chance level. + + chance_level_kwargs : dict, default=None + Keyword arguments to be passed to matplotlib's `plot` for rendering + the chance level line. + + despine : bool, default=False + Whether to remove the top and right spines from the plot. + + Returns + ------- + display : :class:`~sklearn.metrics.RocCurveDisplay` + The multi-fold ROC curve display. + + See Also + -------- + roc_curve : Compute Receiver operating characteristic (ROC) curve. + RocCurveDisplay.from_estimator : ROC Curve visualization given an + estimator and some data. + RocCurveDisplay.from_predictions : ROC Curve visualization given the + probabilities of scores of a classifier. + roc_auc_score : Compute the area under the ROC curve. + + Examples + -------- + >>> import matplotlib.pyplot as plt + >>> from sklearn.datasets import make_classification + >>> from sklearn.metrics import RocCurveDisplay + >>> from sklearn.model_selection import cross_validate + >>> from sklearn.svm import SVC + >>> X, y = make_classification(random_state=0) + >>> clf = SVC(random_state=0) + >>> cv_results = cross_validate( + ... clf, X, y, cv=3, return_estimator=True, return_indices=True) + >>> RocCurveDisplay.from_cv_results(cv_results, X, y) + <...> + >>> plt.show() + """ + pos_label_ = cls._validate_from_cv_results_params( + cv_results, + X, + y, + sample_weight=sample_weight, + pos_label=pos_label, + ) + + fpr_folds, tpr_folds, auc_folds = [], [], [] + for estimator, test_indices in zip( + cv_results["estimator"], cv_results["indices"]["test"] + ): + y_true = _safe_indexing(y, test_indices) + y_pred, _ = _get_response_values_binary( + estimator, + _safe_indexing(X, test_indices), + response_method=response_method, + pos_label=pos_label_, + ) + sample_weight_fold = ( + None + if sample_weight is None + else _safe_indexing(sample_weight, test_indices) + ) + fpr, tpr, _ = roc_curve( + y_true, + y_pred, + pos_label=pos_label_, + sample_weight=sample_weight_fold, + drop_intermediate=drop_intermediate, + ) + roc_auc = auc(fpr, tpr) + + fpr_folds.append(fpr) + tpr_folds.append(tpr) + auc_folds.append(roc_auc) + + viz = cls( + fpr=fpr_folds, + tpr=tpr_folds, + name=name, + roc_auc=auc_folds, + pos_label=pos_label_, + ) + return viz.plot( + ax=ax, + curve_kwargs=curve_kwargs, + plot_chance_level=plot_chance_level, + chance_level_kw=chance_level_kwargs, + despine=despine, + ) diff --git a/sklearn/metrics/_plot/tests/test_common_curve_display.py b/sklearn/metrics/_plot/tests/test_common_curve_display.py index 0014a73055e41..2dde6cc76be97 100644 --- a/sklearn/metrics/_plot/tests/test_common_curve_display.py +++ b/sklearn/metrics/_plot/tests/test_common_curve_display.py @@ -132,9 +132,7 @@ def fit(self, X, y): Display.from_estimator(clf, X, y, response_method=response_method) -@pytest.mark.parametrize( - "Display", [DetCurveDisplay, PrecisionRecallDisplay, RocCurveDisplay] -) +@pytest.mark.parametrize("Display", [DetCurveDisplay, PrecisionRecallDisplay]) @pytest.mark.parametrize("constructor_name", ["from_estimator", "from_predictions"]) def test_display_curve_estimator_name_multiple_calls( pyplot, @@ -166,6 +164,8 @@ def test_display_curve_estimator_name_multiple_calls( assert clf_name in disp.line_.get_label() +# TODO: remove this test once classes moved to using `name_` instead of +# `estimator_name` @pytest.mark.parametrize( "clf", [ @@ -176,10 +176,8 @@ def test_display_curve_estimator_name_multiple_calls( ), ], ) -@pytest.mark.parametrize( - "Display", [DetCurveDisplay, PrecisionRecallDisplay, RocCurveDisplay] -) -def test_display_curve_not_fitted_errors(pyplot, data_binary, clf, Display): +@pytest.mark.parametrize("Display", [DetCurveDisplay, PrecisionRecallDisplay]) +def test_display_curve_not_fitted_errors_old_name(pyplot, data_binary, clf, Display): """Check that a proper error is raised when the classifier is not fitted.""" X, y = data_binary @@ -194,6 +192,31 @@ def test_display_curve_not_fitted_errors(pyplot, data_binary, clf, Display): assert disp.estimator_name == model.__class__.__name__ +@pytest.mark.parametrize( + "clf", + [ + LogisticRegression(), + make_pipeline(StandardScaler(), LogisticRegression()), + make_pipeline( + make_column_transformer((StandardScaler(), [0, 1])), LogisticRegression() + ), + ], +) +@pytest.mark.parametrize("Display", [RocCurveDisplay]) +def test_display_curve_not_fitted_errors(pyplot, data_binary, clf, Display): + """Check that a proper error is raised when the classifier not fitted.""" + X, y = data_binary + # clone since we parametrize the test and the classifier will be fitted + # when testing the second and subsequent plotting function + model = clone(clf) + with pytest.raises(NotFittedError): + Display.from_estimator(model, X, y) + model.fit(X, y) + disp = Display.from_estimator(model, X, y) + assert model.__class__.__name__ in disp.line_.get_label() + assert disp.name == model.__class__.__name__ + + @pytest.mark.parametrize( "Display", [DetCurveDisplay, PrecisionRecallDisplay, RocCurveDisplay] ) diff --git a/sklearn/metrics/_plot/tests/test_roc_curve_display.py b/sklearn/metrics/_plot/tests/test_roc_curve_display.py index c2e6c865fa9a9..3f788009a21a9 100644 --- a/sklearn/metrics/_plot/tests/test_roc_curve_display.py +++ b/sklearn/metrics/_plot/tests/test_roc_curve_display.py @@ -1,3 +1,5 @@ +from collections.abc import Mapping + import numpy as np import pytest from numpy.testing import assert_allclose @@ -9,10 +11,11 @@ from sklearn.exceptions import NotFittedError from sklearn.linear_model import LogisticRegression from sklearn.metrics import RocCurveDisplay, auc, roc_curve -from sklearn.model_selection import train_test_split +from sklearn.model_selection import cross_validate, train_test_split from sklearn.pipeline import make_pipeline from sklearn.preprocessing import StandardScaler -from sklearn.utils import shuffle +from sklearn.utils import _safe_indexing, shuffle +from sklearn.utils._response import _get_response_values_binary @pytest.fixture(scope="module") @@ -30,6 +33,24 @@ def data_binary(data): return X[y < 2], y[y < 2] +def _check_figure_axes_and_labels(display, pos_label): + """Check mpl axes and figure defaults are correct.""" + import matplotlib as mpl + + assert isinstance(display.ax_, mpl.axes.Axes) + assert isinstance(display.figure_, mpl.figure.Figure) + assert display.ax_.get_adjustable() == "box" + assert display.ax_.get_aspect() in ("equal", 1.0) + assert display.ax_.get_xlim() == display.ax_.get_ylim() == (-0.01, 1.01) + + expected_pos_label = 1 if pos_label is None else pos_label + expected_ylabel = f"True Positive Rate (Positive label: {expected_pos_label})" + expected_xlabel = f"False Positive Rate (Positive label: {expected_pos_label})" + + assert display.ax_.get_ylabel() == expected_ylabel + assert display.ax_.get_xlabel() == expected_xlabel + + @pytest.mark.parametrize("response_method", ["predict_proba", "decision_function"]) @pytest.mark.parametrize("with_sample_weight", [True, False]) @pytest.mark.parametrize("drop_intermediate", [True, False]) @@ -51,7 +72,7 @@ def test_roc_curve_display_plotting( constructor_name, default_name, ): - """Check the overall plotting behaviour.""" + """Check the overall plotting behaviour for single curve.""" X, y = data_binary pos_label = None @@ -79,7 +100,7 @@ def test_roc_curve_display_plotting( sample_weight=sample_weight, drop_intermediate=drop_intermediate, pos_label=pos_label, - alpha=0.8, + curve_kwargs={"alpha": 0.8}, ) else: display = RocCurveDisplay.from_predictions( @@ -88,7 +109,7 @@ def test_roc_curve_display_plotting( sample_weight=sample_weight, drop_intermediate=drop_intermediate, pos_label=pos_label, - alpha=0.8, + curve_kwargs={"alpha": 0.8}, ) fpr, tpr, _ = roc_curve( @@ -103,27 +124,504 @@ def test_roc_curve_display_plotting( assert_allclose(display.fpr, fpr) assert_allclose(display.tpr, tpr) - assert display.estimator_name == default_name + assert display.name == default_name import matplotlib as mpl + _check_figure_axes_and_labels(display, pos_label) assert isinstance(display.line_, mpl.lines.Line2D) assert display.line_.get_alpha() == 0.8 - assert isinstance(display.ax_, mpl.axes.Axes) - assert isinstance(display.figure_, mpl.figure.Figure) - assert display.ax_.get_adjustable() == "box" - assert display.ax_.get_aspect() in ("equal", 1.0) - assert display.ax_.get_xlim() == display.ax_.get_ylim() == (-0.01, 1.01) expected_label = f"{default_name} (AUC = {display.roc_auc:.2f})" assert display.line_.get_label() == expected_label - expected_pos_label = 1 if pos_label is None else pos_label - expected_ylabel = f"True Positive Rate (Positive label: {expected_pos_label})" - expected_xlabel = f"False Positive Rate (Positive label: {expected_pos_label})" - assert display.ax_.get_ylabel() == expected_ylabel - assert display.ax_.get_xlabel() == expected_xlabel +@pytest.mark.parametrize( + "params, err_msg", + [ + ( + { + "fpr": [np.array([0, 0.5, 1]), np.array([0, 0.5, 1])], + "tpr": [np.array([0, 0.5, 1])], + "roc_auc": None, + "name": None, + }, + "self.fpr and self.tpr from `RocCurveDisplay` initialization,", + ), + ( + { + "fpr": [np.array([0, 0.5, 1])], + "tpr": [np.array([0, 0.5, 1]), np.array([0, 0.5, 1])], + "roc_auc": [0.8, 0.9], + "name": None, + }, + "self.fpr, self.tpr and self.roc_auc from `RocCurveDisplay`", + ), + ( + { + "fpr": [np.array([0, 0.5, 1]), np.array([0, 0.5, 1])], + "tpr": [np.array([0, 0.5, 1]), np.array([0, 0.5, 1])], + "roc_auc": [0.8], + "name": None, + }, + "Got: self.fpr: 2, self.tpr: 2, self.roc_auc: 1", + ), + ( + { + "fpr": [np.array([0, 0.5, 1]), np.array([0, 0.5, 1])], + "tpr": [np.array([0, 0.5, 1]), np.array([0, 0.5, 1])], + "roc_auc": [0.8, 0.9], + "name": ["curve1", "curve2", "curve3"], + }, + r"self.fpr, self.tpr, self.roc_auc and 'name' \(or self.name\)", + ), + ( + { + "fpr": [np.array([0, 0.5, 1]), np.array([0, 0.5, 1])], + "tpr": [np.array([0, 0.5, 1]), np.array([0, 0.5, 1])], + "roc_auc": [0.8, 0.9], + # List of length 1 is always allowed + "name": ["curve1"], + }, + None, + ), + ], +) +def test_roc_curve_plot_parameter_length_validation(pyplot, params, err_msg): + """Check `plot` parameter length validation performed correctly.""" + display = RocCurveDisplay(**params) + if err_msg: + with pytest.raises(ValueError, match=err_msg): + display.plot() + else: + # No error should be raised + display.plot() + + +def test_validate_plot_params(pyplot): + """Check `_validate_plot_params` returns the correct variables.""" + fpr = np.array([0, 0.5, 1]) + tpr = [np.array([0, 0.5, 1])] + roc_auc = None + name = "test_curve" + + # Initialize display with test inputs + display = RocCurveDisplay( + fpr=fpr, + tpr=tpr, + roc_auc=roc_auc, + name=name, + pos_label=None, + ) + fpr_out, tpr_out, roc_auc_out, name_out = display._validate_plot_params( + ax=None, name=None + ) + + assert isinstance(fpr_out, list) + assert isinstance(tpr_out, list) + assert len(fpr_out) == 1 + assert len(tpr_out) == 1 + assert roc_auc_out is None + assert name_out == ["test_curve"] + + +def test_roc_curve_from_cv_results_param_validation(pyplot, data_binary, data): + """Check parameter validation is correct.""" + X, y = data_binary + + # `cv_results` missing key + cv_results_no_est = cross_validate( + LogisticRegression(), X, y, cv=3, return_estimator=True, return_indices=False + ) + cv_results_no_indices = cross_validate( + LogisticRegression(), X, y, cv=3, return_estimator=True, return_indices=False + ) + for cv_results in (cv_results_no_est, cv_results_no_indices): + with pytest.raises( + ValueError, + match="`cv_results` does not contain one of the following required", + ): + RocCurveDisplay.from_cv_results(cv_results, X, y) + + cv_results = cross_validate( + LogisticRegression(), X, y, cv=3, return_estimator=True, return_indices=True + ) + + # `X` wrong length + with pytest.raises(ValueError, match="`X` does not contain the correct"): + RocCurveDisplay.from_cv_results(cv_results, X[:10, :], y) + + # `y` not binary + X_mutli, y_multi = data + with pytest.raises(ValueError, match="The target `y` is not binary."): + RocCurveDisplay.from_cv_results(cv_results, X, y_multi) + + # input inconsistent length + with pytest.raises(ValueError, match="Found input variables with inconsistent"): + RocCurveDisplay.from_cv_results(cv_results, X, y[:10]) + with pytest.raises(ValueError, match="Found input variables with inconsistent"): + RocCurveDisplay.from_cv_results(cv_results, X, y, sample_weight=[1, 2]) + + # `pos_label` inconsistency + X_bad_pos_label, y_bad_pos_label = X_mutli[y_multi > 0], y_multi[y_multi > 0] + with pytest.raises(ValueError, match=r"y takes value in \{1, 2\}"): + RocCurveDisplay.from_cv_results(cv_results, X_bad_pos_label, y_bad_pos_label) + + # `name` is list while `curve_kwargs` is None or dict + for curve_kwargs in (None, {"alpha": 0.2}): + with pytest.raises(ValueError, match="To avoid labeling individual curves"): + RocCurveDisplay.from_cv_results( + cv_results, + X, + y, + name=["one", "two", "three"], + curve_kwargs=curve_kwargs, + ) + + # `curve_kwargs` incorrect length + with pytest.raises(ValueError, match="`curve_kwargs` must be None, a dictionary"): + RocCurveDisplay.from_cv_results(cv_results, X, y, curve_kwargs=[{"alpha": 1}]) + + # `curve_kwargs` both alias provided + with pytest.raises(TypeError, match="Got both c and"): + RocCurveDisplay.from_cv_results( + cv_results, X, y, curve_kwargs={"c": "blue", "color": "red"} + ) + + +@pytest.mark.parametrize( + "curve_kwargs", + [None, {"alpha": 0.2}, [{"alpha": 0.2}, {"alpha": 0.3}, {"alpha": 0.4}]], +) +def test_roc_curve_display_from_cv_results_curve_kwargs( + pyplot, data_binary, curve_kwargs +): + """Check `curve_kwargs` correctly passed.""" + X, y = data_binary + n_cv = 3 + cv_results = cross_validate( + LogisticRegression(), X, y, cv=n_cv, return_estimator=True, return_indices=True + ) + display = RocCurveDisplay.from_cv_results( + cv_results, + X, + y, + curve_kwargs=curve_kwargs, + ) + if curve_kwargs is None: + # Default `alpha` used + assert all(line.get_alpha() == 0.5 for line in display.line_) + elif isinstance(curve_kwargs, Mapping): + # `alpha` from dict used for all curves + assert all(line.get_alpha() == 0.2 for line in display.line_) + else: + # Different `alpha` used for each curve + assert all( + line.get_alpha() == curve_kwargs[i]["alpha"] + for i, line in enumerate(display.line_) + ) + + +# TODO(1.9): Remove in 1.9 +def test_roc_curve_display_estimator_name_deprecation(pyplot): + """Check deprecation of `estimator_name`.""" + fpr = np.array([0, 0.5, 1]) + tpr = np.array([0, 0.5, 1]) + with pytest.warns(FutureWarning, match="`estimator_name` is deprecated in"): + RocCurveDisplay(fpr=fpr, tpr=tpr, estimator_name="test") + + +# TODO(1.9): Remove in 1.9 +@pytest.mark.parametrize( + "constructor_name", ["from_estimator", "from_predictions", "plot"] +) +def test_roc_curve_display_kwargs_deprecation(pyplot, data_binary, constructor_name): + """Check **kwargs deprecated correctly in favour of `curve_kwargs`.""" + X, y = data_binary + lr = LogisticRegression() + lr.fit(X, y) + fpr = np.array([0, 0.5, 1]) + tpr = np.array([0, 0.5, 1]) + + # Error when both `curve_kwargs` and `**kwargs` provided + with pytest.raises(ValueError, match="Cannot provide both `curve_kwargs`"): + if constructor_name == "from_estimator": + RocCurveDisplay.from_estimator( + lr, X, y, curve_kwargs={"alpha": 1}, label="test" + ) + elif constructor_name == "from_predictions": + RocCurveDisplay.from_predictions( + y, y, curve_kwargs={"alpha": 1}, label="test" + ) + else: + RocCurveDisplay(fpr=fpr, tpr=tpr).plot( + curve_kwargs={"alpha": 1}, label="test" + ) + + # Warning when `**kwargs`` provided + with pytest.warns(FutureWarning, match=r"`\*\*kwargs` is deprecated and will be"): + if constructor_name == "from_estimator": + RocCurveDisplay.from_estimator(lr, X, y, label="test") + elif constructor_name == "from_predictions": + RocCurveDisplay.from_predictions(y, y, label="test") + else: + RocCurveDisplay(fpr=fpr, tpr=tpr).plot(label="test") + + +@pytest.mark.parametrize( + "curve_kwargs", + [ + None, + {"color": "blue"}, + [{"color": "blue"}, {"color": "green"}, {"color": "red"}], + ], +) +@pytest.mark.parametrize("drop_intermediate", [True, False]) +@pytest.mark.parametrize("response_method", ["predict_proba", "decision_function"]) +@pytest.mark.parametrize("with_sample_weight", [True, False]) +@pytest.mark.parametrize("with_strings", [True, False]) +def test_roc_curve_display_plotting_from_cv_results( + pyplot, + data_binary, + with_strings, + with_sample_weight, + response_method, + drop_intermediate, + curve_kwargs, +): + """Check overall plotting of `from_cv_results`.""" + X, y = data_binary + + pos_label = None + if with_strings: + y = np.array(["c", "b"])[y] + pos_label = "c" + + if with_sample_weight: + rng = np.random.RandomState(42) + sample_weight = rng.randint(1, 4, size=(X.shape[0])) + else: + sample_weight = None + + cv_results = cross_validate( + LogisticRegression(), X, y, cv=3, return_estimator=True, return_indices=True + ) + display = RocCurveDisplay.from_cv_results( + cv_results, + X, + y, + sample_weight=sample_weight, + drop_intermediate=drop_intermediate, + response_method=response_method, + pos_label=pos_label, + curve_kwargs=curve_kwargs, + ) + + for idx, (estimator, test_indices) in enumerate( + zip(cv_results["estimator"], cv_results["indices"]["test"]) + ): + y_true = _safe_indexing(y, test_indices) + y_pred = _get_response_values_binary( + estimator, + _safe_indexing(X, test_indices), + response_method=response_method, + pos_label=pos_label, + )[0] + sample_weight_fold = ( + None + if sample_weight is None + else _safe_indexing(sample_weight, test_indices) + ) + fpr, tpr, _ = roc_curve( + y_true, + y_pred, + sample_weight=sample_weight_fold, + drop_intermediate=drop_intermediate, + pos_label=pos_label, + ) + assert_allclose(display.roc_auc[idx], auc(fpr, tpr)) + assert_allclose(display.fpr[idx], fpr) + assert_allclose(display.tpr[idx], tpr) + + assert display.name is None + + import matplotlib as mpl + + _check_figure_axes_and_labels(display, pos_label) + aggregate_expected_labels = ["AUC = 1.00 +/- 0.00", "_child1", "_child2"] + for idx, line in enumerate(display.line_): + assert isinstance(line, mpl.lines.Line2D) + # Default alpha for `from_cv_results` + line.get_alpha() == 0.5 + if isinstance(curve_kwargs, list): + # Each individual curve labelled + assert line.get_label() == f"AUC = {display.roc_auc[idx]:.2f}" + else: + # Single aggregate label + assert line.get_label() == aggregate_expected_labels[idx] + + +@pytest.mark.parametrize("roc_auc", [[1.0, 1.0, 1.0], None]) +@pytest.mark.parametrize( + "curve_kwargs", + [None, {"color": "red"}, [{"c": "red"}, {"c": "green"}, {"c": "yellow"}]], +) +@pytest.mark.parametrize("name", [None, "single", ["one", "two", "three"]]) +def test_roc_curve_plot_legend_label(pyplot, data_binary, name, curve_kwargs, roc_auc): + """Check legend label correct with all `curve_kwargs`, `name` combinations.""" + fpr = [np.array([0, 0.5, 1]), np.array([0, 0.5, 1]), np.array([0, 0.5, 1])] + tpr = [np.array([0, 0.5, 1]), np.array([0, 0.5, 1]), np.array([0, 0.5, 1])] + if not isinstance(curve_kwargs, list) and isinstance(name, list): + with pytest.raises(ValueError, match="To avoid labeling individual curves"): + RocCurveDisplay(fpr=fpr, tpr=tpr, roc_auc=roc_auc).plot( + name=name, curve_kwargs=curve_kwargs + ) + + else: + display = RocCurveDisplay(fpr=fpr, tpr=tpr, roc_auc=roc_auc).plot( + name=name, curve_kwargs=curve_kwargs + ) + legend = display.ax_.get_legend() + if legend is None: + # No legend is created, exit test early + assert name is None + assert roc_auc is None + return + else: + legend_labels = [text.get_text() for text in legend.get_texts()] + + if isinstance(curve_kwargs, list): + # Multiple labels in legend + assert len(legend_labels) == 3 + for idx, label in enumerate(legend_labels): + if name is None: + expected_label = "AUC = 1.00" if roc_auc else None + assert label == expected_label + elif isinstance(name, str): + expected_label = "single (AUC = 1.00)" if roc_auc else "single" + assert label == expected_label + else: + # `name` is a list of different strings + expected_label = ( + f"{name[idx]} (AUC = 1.00)" if roc_auc else f"{name[idx]}" + ) + assert label == expected_label + else: + # Single label in legend + assert len(legend_labels) == 1 + if name is None: + expected_label = "AUC = 1.00 +/- 0.00" if roc_auc else None + assert legend_labels[0] == expected_label + else: + # name is single string + expected_label = "single (AUC = 1.00 +/- 0.00)" if roc_auc else "single" + assert legend_labels[0] == expected_label + + +@pytest.mark.parametrize( + "curve_kwargs", + [None, {"color": "red"}, [{"c": "red"}, {"c": "green"}, {"c": "yellow"}]], +) +@pytest.mark.parametrize("name", [None, "single", ["one", "two", "three"]]) +def test_roc_curve_from_cv_results_legend_label( + pyplot, data_binary, name, curve_kwargs +): + """Check legend label correct with all `curve_kwargs`, `name` combinations.""" + X, y = data_binary + n_cv = 3 + cv_results = cross_validate( + LogisticRegression(), X, y, cv=n_cv, return_estimator=True, return_indices=True + ) + + if not isinstance(curve_kwargs, list) and isinstance(name, list): + with pytest.raises(ValueError, match="To avoid labeling individual curves"): + RocCurveDisplay.from_cv_results( + cv_results, X, y, name=name, curve_kwargs=curve_kwargs + ) + else: + display = RocCurveDisplay.from_cv_results( + cv_results, X, y, name=name, curve_kwargs=curve_kwargs + ) + + legend = display.ax_.get_legend() + legend_labels = [text.get_text() for text in legend.get_texts()] + if isinstance(curve_kwargs, list): + # Multiple labels in legend + assert len(legend_labels) == 3 + for idx, label in enumerate(legend_labels): + if name is None: + assert label == "AUC = 1.00" + elif isinstance(name, str): + assert label == "single (AUC = 1.00)" + else: + # `name` is a list of different strings + assert label == f"{name[idx]} (AUC = 1.00)" + else: + # Single label in legend + assert len(legend_labels) == 1 + if name is None: + assert legend_labels[0] == "AUC = 1.00 +/- 0.00" + else: + # name is single string + assert legend_labels[0] == "single (AUC = 1.00 +/- 0.00)" + + +@pytest.mark.parametrize( + "curve_kwargs", + [None, {"color": "red"}, [{"c": "red"}, {"c": "green"}, {"c": "yellow"}]], +) +def test_roc_curve_from_cv_results_curve_kwargs(pyplot, data_binary, curve_kwargs): + """Check line kwargs passed correctly in `from_cv_results`.""" + + X, y = data_binary + cv_results = cross_validate( + LogisticRegression(), X, y, cv=3, return_estimator=True, return_indices=True + ) + display = RocCurveDisplay.from_cv_results( + cv_results, X, y, curve_kwargs=curve_kwargs + ) + + for idx, line in enumerate(display.line_): + color = line.get_color() + if curve_kwargs is None: + # Default color + assert color == "blue" + elif isinstance(curve_kwargs, Mapping): + # All curves "red" + assert color == "red" + else: + assert color == curve_kwargs[idx]["c"] + + +def _check_chance_level(plot_chance_level, chance_level_kw, display): + """Check chance level line and line styles correct.""" + import matplotlib as mpl + + if plot_chance_level: + assert isinstance(display.chance_level_, mpl.lines.Line2D) + assert tuple(display.chance_level_.get_xdata()) == (0, 1) + assert tuple(display.chance_level_.get_ydata()) == (0, 1) + else: + assert display.chance_level_ is None + + # Checking for chance level line styles + if plot_chance_level and chance_level_kw is None: + assert display.chance_level_.get_color() == "k" + assert display.chance_level_.get_linestyle() == "--" + assert display.chance_level_.get_label() == "Chance level (AUC = 0.5)" + elif plot_chance_level: + if "c" in chance_level_kw: + assert display.chance_level_.get_color() == chance_level_kw["c"] + else: + assert display.chance_level_.get_color() == chance_level_kw["color"] + if "lw" in chance_level_kw: + assert display.chance_level_.get_linewidth() == chance_level_kw["lw"] + else: + assert display.chance_level_.get_linewidth() == chance_level_kw["linewidth"] + if "ls" in chance_level_kw: + assert display.chance_level_.get_linestyle() == chance_level_kw["ls"] + else: + assert display.chance_level_.get_linestyle() == chance_level_kw["linestyle"] @pytest.mark.parametrize("plot_chance_level", [True, False]) @@ -137,10 +635,7 @@ def test_roc_curve_display_plotting( {"lw": 1, "color": "blue", "ls": "-", "label": None}, ], ) -@pytest.mark.parametrize( - "constructor_name", - ["from_estimator", "from_predictions"], -) +@pytest.mark.parametrize("constructor_name", ["from_estimator", "from_predictions"]) def test_roc_curve_chance_level_line( pyplot, data_binary, @@ -149,7 +644,7 @@ def test_roc_curve_chance_level_line( label, constructor_name, ): - """Check the chance level line plotting behaviour.""" + """Check chance level plotting behavior of `from_predictions`, `from_estimator`.""" X, y = data_binary lr = LogisticRegression() @@ -163,8 +658,7 @@ def test_roc_curve_chance_level_line( lr, X, y, - label=label, - alpha=0.8, + curve_kwargs={"alpha": 0.8, "label": label}, plot_chance_level=plot_chance_level, chance_level_kw=chance_level_kw, ) @@ -172,8 +666,7 @@ def test_roc_curve_chance_level_line( display = RocCurveDisplay.from_predictions( y, y_score, - label=label, - alpha=0.8, + curve_kwargs={"alpha": 0.8, "label": label}, plot_chance_level=plot_chance_level, chance_level_kw=chance_level_kw, ) @@ -185,32 +678,10 @@ def test_roc_curve_chance_level_line( assert isinstance(display.ax_, mpl.axes.Axes) assert isinstance(display.figure_, mpl.figure.Figure) - if plot_chance_level: - assert isinstance(display.chance_level_, mpl.lines.Line2D) - assert tuple(display.chance_level_.get_xdata()) == (0, 1) - assert tuple(display.chance_level_.get_ydata()) == (0, 1) - else: - assert display.chance_level_ is None + _check_chance_level(plot_chance_level, chance_level_kw, display) - # Checking for chance level line styles - if plot_chance_level and chance_level_kw is None: - assert display.chance_level_.get_color() == "k" - assert display.chance_level_.get_linestyle() == "--" - assert display.chance_level_.get_label() == "Chance level (AUC = 0.5)" - elif plot_chance_level: - if "c" in chance_level_kw: - assert display.chance_level_.get_color() == chance_level_kw["c"] - else: - assert display.chance_level_.get_color() == chance_level_kw["color"] - if "lw" in chance_level_kw: - assert display.chance_level_.get_linewidth() == chance_level_kw["lw"] - else: - assert display.chance_level_.get_linewidth() == chance_level_kw["linewidth"] - if "ls" in chance_level_kw: - assert display.chance_level_.get_linestyle() == chance_level_kw["ls"] - else: - assert display.chance_level_.get_linestyle() == chance_level_kw["linestyle"] - # Checking for legend behaviour + # Checking for legend behaviour + if plot_chance_level and chance_level_kw is not None: if label is not None or chance_level_kw.get("label") is not None: legend = display.ax_.get_legend() assert legend is not None # Legend should be present if any label is set @@ -223,6 +694,62 @@ def test_roc_curve_chance_level_line( assert display.ax_.get_legend() is None +@pytest.mark.parametrize("plot_chance_level", [True, False]) +@pytest.mark.parametrize( + "chance_level_kw", + [ + None, + {"linewidth": 1, "color": "red", "linestyle": "-", "label": "DummyEstimator"}, + {"lw": 1, "c": "red", "ls": "-", "label": "DummyEstimator"}, + {"lw": 1, "color": "blue", "ls": "-", "label": None}, + ], +) +@pytest.mark.parametrize("curve_kwargs", [None, {"alpha": 0.8}]) +def test_roc_curve_chance_level_line_from_cv_results( + pyplot, + data_binary, + plot_chance_level, + chance_level_kw, + curve_kwargs, +): + """Check chance level plotting behavior with `from_cv_results`.""" + X, y = data_binary + n_cv = 3 + cv_results = cross_validate( + LogisticRegression(), X, y, cv=n_cv, return_estimator=True, return_indices=True + ) + + display = RocCurveDisplay.from_cv_results( + cv_results, + X, + y, + plot_chance_level=plot_chance_level, + chance_level_kwargs=chance_level_kw, + curve_kwargs=curve_kwargs, + ) + + import matplotlib as mpl + + assert all(isinstance(line, mpl.lines.Line2D) for line in display.line_) + # Ensure both curve line kwargs passed correctly as well + if curve_kwargs: + assert all(line.get_alpha() == 0.8 for line in display.line_) + assert isinstance(display.ax_, mpl.axes.Axes) + assert isinstance(display.figure_, mpl.figure.Figure) + + _check_chance_level(plot_chance_level, chance_level_kw, display) + + legend = display.ax_.get_legend() + # There is always a legend, to indicate each 'Fold' curve + assert legend is not None + legend_labels = [text.get_text() for text in legend.get_texts()] + if plot_chance_level and chance_level_kw is not None: + if chance_level_kw.get("label") is not None: + assert chance_level_kw["label"] in legend_labels + else: + assert len(legend_labels) == 1 + + @pytest.mark.parametrize( "clf", [ @@ -254,31 +781,52 @@ def test_roc_curve_display_complex_pipeline(pyplot, data_binary, clf, constructo name = "Classifier" assert name in display.line_.get_label() - assert display.estimator_name == name + assert display.name == name @pytest.mark.parametrize( - "roc_auc, estimator_name, expected_label", + "roc_auc, name, curve_kwargs, expected_labels", [ - (0.9, None, "AUC = 0.90"), - (None, "my_est", "my_est"), - (0.8, "my_est2", "my_est2 (AUC = 0.80)"), + ([0.9, 0.8], None, None, ["AUC = 0.85 +/- 0.05", "_child1"]), + ([0.9, 0.8], "Est name", None, ["Est name (AUC = 0.85 +/- 0.05)", "_child1"]), + ( + [0.8, 0.7], + ["fold1", "fold2"], + [{"c": "blue"}, {"c": "red"}], + ["fold1 (AUC = 0.80)", "fold2 (AUC = 0.70)"], + ), + (None, ["fold1", "fold2"], [{"c": "blue"}, {"c": "red"}], ["fold1", "fold2"]), ], ) def test_roc_curve_display_default_labels( - pyplot, roc_auc, estimator_name, expected_label + pyplot, roc_auc, name, curve_kwargs, expected_labels ): """Check the default labels used in the display.""" - fpr = np.array([0, 0.5, 1]) - tpr = np.array([0, 0.5, 1]) - disp = RocCurveDisplay( - fpr=fpr, tpr=tpr, roc_auc=roc_auc, estimator_name=estimator_name - ).plot() - assert disp.line_.get_label() == expected_label + fpr = [np.array([0, 0.5, 1]), np.array([0, 0.3, 1])] + tpr = [np.array([0, 0.5, 1]), np.array([0, 0.3, 1])] + disp = RocCurveDisplay(fpr=fpr, tpr=tpr, roc_auc=roc_auc, name=name).plot( + curve_kwargs=curve_kwargs + ) + for idx, expected_label in enumerate(expected_labels): + assert disp.line_[idx].get_label() == expected_label + + +def _check_auc(display, constructor_name): + roc_auc_limit = 0.95679 + roc_auc_limit_multi = [0.97007, 0.985915, 0.980952] + + if constructor_name == "from_cv_results": + for idx, roc_auc in enumerate(display.roc_auc): + assert roc_auc == pytest.approx(roc_auc_limit_multi[idx]) + else: + assert display.roc_auc == pytest.approx(roc_auc_limit) + assert trapezoid(display.tpr, display.fpr) == pytest.approx(roc_auc_limit) @pytest.mark.parametrize("response_method", ["predict_proba", "decision_function"]) -@pytest.mark.parametrize("constructor_name", ["from_estimator", "from_predictions"]) +@pytest.mark.parametrize( + "constructor_name", ["from_estimator", "from_predictions", "from_cv_results"] +) def test_plot_roc_curve_pos_label(pyplot, response_method, constructor_name): # check that we can provide the positive label and display the proper # statistics @@ -301,9 +849,13 @@ def test_plot_roc_curve_pos_label(pyplot, response_method, constructor_name): classifier = LogisticRegression() classifier.fit(X_train, y_train) + cv_results = cross_validate( + LogisticRegression(), X, y, cv=3, return_estimator=True, return_indices=True + ) - # sanity check to be sure the positive class is classes_[0] and that we - # are betrayed by the class imbalance + # Sanity check to be sure the positive class is `classes_[0]` + # Class imbalance ensures a large difference in prediction values between classes, + # allowing us to catch errors when we switch `pos_label` assert classifier.classes_.tolist() == ["cancer", "not cancer"] y_score = getattr(classifier, response_method)(X_test) @@ -312,43 +864,59 @@ def test_plot_roc_curve_pos_label(pyplot, response_method, constructor_name): y_score_cancer = -1 * y_score if y_score.ndim == 1 else y_score[:, 0] y_score_not_cancer = y_score if y_score.ndim == 1 else y_score[:, 1] + pos_label = "cancer" + y_score = y_score_cancer if constructor_name == "from_estimator": display = RocCurveDisplay.from_estimator( classifier, X_test, y_test, - pos_label="cancer", + pos_label=pos_label, response_method=response_method, ) - else: + elif constructor_name == "from_predictions": display = RocCurveDisplay.from_predictions( y_test, - y_score_cancer, - pos_label="cancer", + y_score, + pos_label=pos_label, + ) + else: + display = RocCurveDisplay.from_cv_results( + cv_results, + X, + y, + response_method=response_method, + pos_label=pos_label, ) - roc_auc_limit = 0.95679 - - assert display.roc_auc == pytest.approx(roc_auc_limit) - assert trapezoid(display.tpr, display.fpr) == pytest.approx(roc_auc_limit) + _check_auc(display, constructor_name) + pos_label = "not cancer" + y_score = y_score_not_cancer if constructor_name == "from_estimator": display = RocCurveDisplay.from_estimator( classifier, X_test, y_test, response_method=response_method, - pos_label="not cancer", + pos_label=pos_label, ) - else: + elif constructor_name == "from_predictions": display = RocCurveDisplay.from_predictions( y_test, - y_score_not_cancer, - pos_label="not cancer", + y_score, + pos_label=pos_label, + ) + else: + display = RocCurveDisplay.from_cv_results( + cv_results, + X, + y, + response_method=response_method, + pos_label=pos_label, ) - assert display.roc_auc == pytest.approx(roc_auc_limit) - assert trapezoid(display.tpr, display.fpr) == pytest.approx(roc_auc_limit) + _check_auc(display, constructor_name) # TODO(1.9): remove @@ -382,23 +950,30 @@ def test_y_pred_deprecation_warning(pyplot): @pytest.mark.parametrize("despine", [True, False]) -@pytest.mark.parametrize("constructor_name", ["from_estimator", "from_predictions"]) +@pytest.mark.parametrize( + "constructor_name", ["from_estimator", "from_predictions", "from_cv_results"] +) def test_plot_roc_curve_despine(pyplot, data_binary, despine, constructor_name): # Check that the despine keyword is working correctly X, y = data_binary lr = LogisticRegression().fit(X, y) lr.fit(X, y) + cv_results = cross_validate( + LogisticRegression(), X, y, cv=3, return_estimator=True, return_indices=True + ) y_pred = lr.decision_function(X) - # safe guard for the binary if/else construction - assert constructor_name in ("from_estimator", "from_predictions") + # safe guard for the if/else construction + assert constructor_name in ("from_estimator", "from_predictions", "from_cv_results") if constructor_name == "from_estimator": display = RocCurveDisplay.from_estimator(lr, X, y, despine=despine) - else: + elif constructor_name == "from_predictions": display = RocCurveDisplay.from_predictions(y, y_pred, despine=despine) + else: + display = RocCurveDisplay.from_cv_results(cv_results, X, y, despine=despine) for s in ["top", "right"]: assert display.ax_.spines[s].get_visible() is not despine diff --git a/sklearn/utils/_plotting.py b/sklearn/utils/_plotting.py index 946c95186374b..ac893282ea6cf 100644 --- a/sklearn/utils/_plotting.py +++ b/sklearn/utils/_plotting.py @@ -1,13 +1,16 @@ # Authors: The scikit-learn developers # SPDX-License-Identifier: BSD-3-Clause +import warnings +from collections.abc import Mapping import numpy as np from . import check_consistent_length from ._optional_dependencies import check_matplotlib_support from ._response import _get_response_values_binary +from .fixes import parse_version from .multiclass import type_of_target -from .validation import _check_pos_label_consistency +from .validation import _check_pos_label_consistency, _num_samples class _BinaryClassifierCurveDisplayMixin: @@ -24,7 +27,10 @@ def _validate_plot_params(self, *, ax=None, name=None): if ax is None: _, ax = plt.subplots() - name = self.estimator_name if name is None else name + # Display classes are in process of changing from `estimator_name` to `name`. + # Try old attr name: `estimator_name` first. + if name is None: + name = getattr(self, "estimator_name", getattr(self, "name", None)) return ax, ax.figure, name @classmethod @@ -63,6 +69,187 @@ def _validate_from_predictions_params( return pos_label, name + @classmethod + def _validate_from_cv_results_params( + cls, + cv_results, + X, + y, + *, + sample_weight, + pos_label, + ): + check_matplotlib_support(f"{cls.__name__}.from_cv_results") + + required_keys = {"estimator", "indices"} + if not all(key in cv_results for key in required_keys): + raise ValueError( + "`cv_results` does not contain one of the following required keys: " + f"{required_keys}. Set explicitly the parameters " + "`return_estimator=True` and `return_indices=True` to the function" + "`cross_validate`." + ) + + train_size, test_size = ( + len(cv_results["indices"]["train"][0]), + len(cv_results["indices"]["test"][0]), + ) + + if _num_samples(X) != train_size + test_size: + raise ValueError( + "`X` does not contain the correct number of samples. " + f"Expected {train_size + test_size}, got {_num_samples(X)}." + ) + + if type_of_target(y) != "binary": + raise ValueError( + f"The target `y` is not binary. Got {type_of_target(y)} type of target." + ) + check_consistent_length(X, y, sample_weight) + + try: + pos_label = _check_pos_label_consistency(pos_label, y) + except ValueError as e: + # Adapt error message + raise ValueError(str(e).replace("y_true", "y")) + + return pos_label + + @staticmethod + def _get_legend_label(curve_legend_metric, curve_name, legend_metric_name): + """Helper to get legend label using `name` and `legend_metric`""" + if curve_legend_metric is not None and curve_name is not None: + label = f"{curve_name} ({legend_metric_name} = {curve_legend_metric:0.2f})" + elif curve_legend_metric is not None: + label = f"{legend_metric_name} = {curve_legend_metric:0.2f}" + elif curve_name is not None: + label = curve_name + else: + label = None + return label + + @staticmethod + def _validate_curve_kwargs( + n_curves, + name, + legend_metric, + legend_metric_name, + curve_kwargs, + **kwargs, + ): + """Get validated line kwargs for each curve. + + Parameters + ---------- + n_curves : int + Number of curves. + + name : list of str or None + Name for labeling legend entries. + + legend_metric : dict + Dictionary with "mean" and "std" keys, or "metric" key of metric + values for each curve. If None, "label" will not contain metric values. + + legend_metric_name : str + Name of the summary value provided in `legend_metrics`. + + curve_kwargs : dict or list of dict or None + Dictionary with keywords passed to the matplotlib's `plot` function + to draw the individual curves. If a list is provided, the + parameters are applied to the curves sequentially. If a single + dictionary is provided, the same parameters are applied to all + curves. + + **kwargs : dict + Deprecated. Keyword arguments to be passed to matplotlib's `plot`. + """ + # TODO(1.9): Remove + # Deprecate **kwargs + if curve_kwargs and kwargs: + raise ValueError( + "Cannot provide both `curve_kwargs` and `kwargs`. `**kwargs` is " + "deprecated in 1.7 and will be removed in 1.9. Pass all matplotlib " + "arguments to `curve_kwargs` as a dictionary." + ) + if kwargs: + warnings.warn( + "`**kwargs` is deprecated and will be removed in 1.9. Pass all " + "matplotlib arguments to `curve_kwargs` as a dictionary instead.", + FutureWarning, + ) + curve_kwargs = kwargs + + if isinstance(curve_kwargs, list) and len(curve_kwargs) != n_curves: + raise ValueError( + f"`curve_kwargs` must be None, a dictionary or a list of length " + f"{n_curves}. Got: {curve_kwargs}." + ) + + # Ensure valid `name` and `curve_kwargs` combination. + if ( + isinstance(name, list) + and len(name) != 1 + and not isinstance(curve_kwargs, list) + ): + raise ValueError( + "To avoid labeling individual curves that have the same appearance, " + f"`curve_kwargs` should be a list of {n_curves} dictionaries. " + "Alternatively, set `name` to `None` or a single string to label " + "a single legend entry with mean ROC AUC score of all curves." + ) + + # Ensure `name` is of the correct length + if isinstance(name, str): + name = [name] + if isinstance(name, list) and len(name) == 1: + name = name * n_curves + name = [None] * n_curves if name is None else name + + # Ensure `curve_kwargs` is of correct length + if isinstance(curve_kwargs, Mapping): + curve_kwargs = [curve_kwargs] * n_curves + + default_multi_curve_kwargs = {"alpha": 0.5, "linestyle": "--", "color": "blue"} + if curve_kwargs is None: + if n_curves > 1: + curve_kwargs = [default_multi_curve_kwargs] * n_curves + else: + curve_kwargs = [{}] + + labels = [] + if "mean" in legend_metric: + label_aggregate = _BinaryClassifierCurveDisplayMixin._get_legend_label( + legend_metric["mean"], name[0], legend_metric_name + ) + # Note: "std" always `None` when "mean" is `None` - no metric value added + # to label in this case + if legend_metric["std"] is not None: + # Add the "+/- std" to the end (in brackets if name provided) + if name[0] is not None: + label_aggregate = ( + label_aggregate[:-1] + f" +/- {legend_metric['std']:0.2f})" + ) + else: + label_aggregate = ( + label_aggregate + f" +/- {legend_metric['std']:0.2f}" + ) + # Add `label` for first curve only, set to `None` for remaining curves + labels.extend([label_aggregate] + [None] * (n_curves - 1)) + else: + for curve_legend_metric, curve_name in zip(legend_metric["metric"], name): + labels.append( + _BinaryClassifierCurveDisplayMixin._get_legend_label( + curve_legend_metric, curve_name, legend_metric_name + ) + ) + + curve_kwargs_ = [ + _validate_style_kwargs({"label": label}, curve_kwargs[fold_idx]) + for fold_idx, label in enumerate(labels) + ] + return curve_kwargs_ + def _validate_score_name(score_name, scoring, negate_score): """Validate the `score_name` parameter. @@ -177,3 +364,57 @@ def _despine(ax): ax.spines[s].set_visible(False) for s in ["bottom", "left"]: ax.spines[s].set_bounds(0, 1) + + +def _deprecate_estimator_name(estimator_name, name, version): + """Deprecate `estimator_name` in favour of `name`.""" + version = parse_version(version) + version_remove = f"{version.major}.{version.minor + 2}" + if estimator_name != "deprecated": + if name: + raise ValueError( + "Cannot provide both `estimator_name` and `name`. `estimator_name` " + f"is deprecated in {version} and will be removed in {version_remove}. " + "Use `name` only." + ) + warnings.warn( + f"`estimator_name` is deprecated in {version} and will be removed in " + f"{version_remove}. Use `name` instead.", + FutureWarning, + ) + return estimator_name + return name + + +def _convert_to_list_leaving_none(param): + """Convert parameters to a list, leaving `None` as is.""" + if param is None: + return None + if isinstance(param, list): + return param + return [param] + + +def _check_param_lengths(required, optional, class_name): + """Check required and optional parameters are of the same length.""" + optional_provided = {} + for name, param in optional.items(): + if isinstance(param, list): + optional_provided[name] = param + + all_params = {**required, **optional_provided} + if len({len(param) for param in all_params.values()}) > 1: + param_keys = [key for key in all_params.keys()] + # Note: below code requires `len(param_keys) >= 2`, which is the case for all + # display classes + params_formatted = " and ".join([", ".join(param_keys[:-1]), param_keys[-1]]) + or_plot = "" + if "'name' (or self.name)" in param_keys: + or_plot = " (or `plot`)" + lengths_formatted = ", ".join( + f"{key}: {len(value)}" for key, value in all_params.items() + ) + raise ValueError( + f"{params_formatted} from `{class_name}` initialization{or_plot}, " + f"should all be lists of the same length. Got: {lengths_formatted}" + ) diff --git a/sklearn/utils/tests/test_plotting.py b/sklearn/utils/tests/test_plotting.py index 1f0c675577bca..c0cff3265c621 100644 --- a/sklearn/utils/tests/test_plotting.py +++ b/sklearn/utils/tests/test_plotting.py @@ -2,6 +2,8 @@ import pytest from sklearn.utils._plotting import ( + _BinaryClassifierCurveDisplayMixin, + _deprecate_estimator_name, _despine, _interval_max_min_ratio, _validate_score_name, @@ -9,6 +11,273 @@ ) +@pytest.mark.parametrize( + "params, err_msg", + [ + ( + { + # Missing "indices" key + "cv_results": {"estimator": "dummy"}, + "X": np.array([[1, 2], [3, 4]]), + "y": np.array([0, 1]), + "sample_weight": None, + "pos_label": None, + }, + "`cv_results` does not contain one of the following", + ), + ( + { + "cv_results": { + "estimator": "dummy", + "indices": {"test": [[1, 2], [1, 2]], "train": [[3, 4], [3, 4]]}, + }, + # `X` wrong length + "X": np.array([[1, 2]]), + "y": np.array([0, 1]), + "sample_weight": None, + "pos_label": None, + }, + "`X` does not contain the correct number of", + ), + ( + { + "cv_results": { + "estimator": "dummy", + "indices": {"test": [[1, 2], [1, 2]], "train": [[3, 4], [3, 4]]}, + }, + "X": np.array([1, 2, 3, 4]), + # `y` not binary + "y": np.array([0, 2, 1, 3]), + "sample_weight": None, + "pos_label": None, + }, + "The target `y` is not binary", + ), + ( + { + "cv_results": { + "estimator": "dummy", + "indices": {"test": [[1, 2], [1, 2]], "train": [[3, 4], [3, 4]]}, + }, + "X": np.array([1, 2, 3, 4]), + "y": np.array([0, 1, 0, 1]), + # `sample_weight` wrong length + "sample_weight": np.array([0.5]), + "pos_label": None, + }, + "Found input variables with inconsistent", + ), + ( + { + "cv_results": { + "estimator": "dummy", + "indices": {"test": [[1, 2], [1, 2]], "train": [[3, 4], [3, 4]]}, + }, + "X": np.array([1, 2, 3, 4]), + "y": np.array([2, 3, 2, 3]), + "sample_weight": None, + # Not specified when `y` not in {0, 1} or {-1, 1} + "pos_label": None, + }, + "y takes value in {2, 3} and pos_label is not specified", + ), + ], +) +def test_validate_from_cv_results_params(pyplot, params, err_msg): + """Check parameter validation is performed correctly.""" + with pytest.raises(ValueError, match=err_msg): + _BinaryClassifierCurveDisplayMixin()._validate_from_cv_results_params(**params) + + +@pytest.mark.parametrize( + "curve_legend_metric, curve_name, expected_label", + [ + (0.85, None, "AUC = 0.85"), + (None, "Model A", "Model A"), + (0.95, "Random Forest", "Random Forest (AUC = 0.95)"), + (None, None, None), + ], +) +def test_get_legend_label(curve_legend_metric, curve_name, expected_label): + """Check `_get_legend_label` returns the correct label.""" + legend_metric_name = "AUC" + label = _BinaryClassifierCurveDisplayMixin._get_legend_label( + curve_legend_metric, curve_name, legend_metric_name + ) + assert label == expected_label + + +# TODO(1.9) : Remove +@pytest.mark.parametrize("curve_kwargs", [{"alpha": 1.0}, None]) +@pytest.mark.parametrize("kwargs", [{}, {"alpha": 1.0}]) +def test_validate_curve_kwargs_deprecate_kwargs(curve_kwargs, kwargs): + """Check `_validate_curve_kwargs` deprecates kwargs correctly.""" + n_curves = 1 + name = None + legend_metric = {"mean": 0.8, "std": 0.1} + legend_metric_name = "AUC" + + if curve_kwargs and kwargs: + with pytest.raises(ValueError, match="Cannot provide both `curve_kwargs`"): + _BinaryClassifierCurveDisplayMixin._validate_curve_kwargs( + n_curves, + name, + legend_metric, + legend_metric_name, + curve_kwargs, + **kwargs, + ) + elif kwargs: + with pytest.warns(FutureWarning, match=r"`\*\*kwargs` is deprecated and"): + _BinaryClassifierCurveDisplayMixin._validate_curve_kwargs( + n_curves, + name, + legend_metric, + legend_metric_name, + curve_kwargs, + **kwargs, + ) + else: + # No warning or error should be raised + _BinaryClassifierCurveDisplayMixin._validate_curve_kwargs( + n_curves, name, legend_metric, legend_metric_name, curve_kwargs, **kwargs + ) + + +def test_validate_curve_kwargs_error(): + """Check `_validate_curve_kwargs` performs parameter validation correctly.""" + n_curves = 3 + legend_metric = {"mean": 0.8, "std": 0.1} + legend_metric_name = "AUC" + with pytest.raises(ValueError, match="`curve_kwargs` must be None"): + _BinaryClassifierCurveDisplayMixin._validate_curve_kwargs( + n_curves=n_curves, + name=None, + legend_metric=legend_metric, + legend_metric_name=legend_metric_name, + curve_kwargs=[{"alpha": 1.0}], + ) + with pytest.raises(ValueError, match="To avoid labeling individual curves"): + name = ["one", "two", "three"] + _BinaryClassifierCurveDisplayMixin._validate_curve_kwargs( + n_curves=n_curves, + name=name, + legend_metric=legend_metric, + legend_metric_name=legend_metric_name, + curve_kwargs=None, + ) + _BinaryClassifierCurveDisplayMixin._validate_curve_kwargs( + n_curves=n_curves, + name=name, + legend_metric=legend_metric, + legend_metric_name=legend_metric_name, + curve_kwargs={"alpha": 1.0}, + ) + + +@pytest.mark.parametrize("name", [None, "curve_name", ["curve_name"]]) +@pytest.mark.parametrize( + "legend_metric", + [ + {"mean": 0.8, "std": 0.2}, + {"mean": None, "std": None}, + ], +) +@pytest.mark.parametrize("legend_metric_name", ["AUC", "AP"]) +@pytest.mark.parametrize( + "curve_kwargs", + [ + None, + {"color": "red"}, + ], +) +def test_validate_curve_kwargs_single_legend( + name, legend_metric, legend_metric_name, curve_kwargs +): + """Check `_validate_curve_kwargs` returns correct kwargs for single legend entry.""" + n_curves = 3 + curve_kwargs_out = _BinaryClassifierCurveDisplayMixin._validate_curve_kwargs( + n_curves=n_curves, + name=name, + legend_metric=legend_metric, + legend_metric_name=legend_metric_name, + curve_kwargs=curve_kwargs, + ) + + assert isinstance(curve_kwargs_out, list) + assert len(curve_kwargs_out) == n_curves + + expected_label = None + if isinstance(name, list): + name = name[0] + if name is not None: + expected_label = name + if legend_metric["mean"] is not None: + expected_label = expected_label + f" ({legend_metric_name} = 0.80 +/- 0.20)" + # `name` is None + elif legend_metric["mean"] is not None: + expected_label = f"{legend_metric_name} = 0.80 +/- 0.20" + + assert curve_kwargs_out[0]["label"] == expected_label + # All remaining curves should have None as "label" + assert curve_kwargs_out[1]["label"] is None + assert curve_kwargs_out[2]["label"] is None + + # Default multi-curve kwargs + if curve_kwargs is None: + assert all(len(kwargs) == 4 for kwargs in curve_kwargs_out) + assert all(kwargs["alpha"] == 0.5 for kwargs in curve_kwargs_out) + assert all(kwargs["linestyle"] == "--" for kwargs in curve_kwargs_out) + assert all(kwargs["color"] == "blue" for kwargs in curve_kwargs_out) + else: + assert all(len(kwargs) == 2 for kwargs in curve_kwargs_out) + assert all(kwargs["color"] == "red" for kwargs in curve_kwargs_out) + + +@pytest.mark.parametrize("name", [None, "curve_name", ["one", "two", "three"]]) +@pytest.mark.parametrize( + "legend_metric", [{"metric": [1.0, 1.0, 1.0]}, {"metric": [None, None, None]}] +) +@pytest.mark.parametrize("legend_metric_name", ["AUC", "AP"]) +def test_validate_curve_kwargs_multi_legend(name, legend_metric, legend_metric_name): + """Check `_validate_curve_kwargs` returns correct kwargs for multi legend entry.""" + n_curves = 3 + curve_kwargs = [{"color": "red"}, {"color": "yellow"}, {"color": "blue"}] + curve_kwargs_out = _BinaryClassifierCurveDisplayMixin._validate_curve_kwargs( + n_curves=n_curves, + name=name, + legend_metric=legend_metric, + legend_metric_name=legend_metric_name, + curve_kwargs=curve_kwargs, + ) + + assert isinstance(curve_kwargs_out, list) + assert len(curve_kwargs_out) == n_curves + + expected_labels = [None, None, None] + if isinstance(name, str): + expected_labels = "curve_name" + if legend_metric["metric"][0] is not None: + expected_labels = expected_labels + f" ({legend_metric_name} = 1.00)" + expected_labels = [expected_labels] * n_curves + elif isinstance(name, list) and legend_metric["metric"][0] is None: + expected_labels = name + elif isinstance(name, list) and legend_metric["metric"][0] is not None: + expected_labels = [ + f"{name_single} ({legend_metric_name} = 1.00)" for name_single in name + ] + # `name` is None + elif legend_metric["metric"][0] is not None: + expected_labels = [f"{legend_metric_name} = 1.00"] * n_curves + + for idx, expected_label in enumerate(expected_labels): + assert curve_kwargs_out[idx]["label"] == expected_label + + assert all(len(kwargs) == 2 for kwargs in curve_kwargs_out) + for curve_kwarg, curve_kwarg_out in zip(curve_kwargs, curve_kwargs_out): + assert curve_kwarg_out["color"] == curve_kwarg["color"] + + def metric(): pass # pragma: no cover @@ -138,3 +407,31 @@ def test_despine(pyplot): assert ax.spines["right"].get_visible() is False assert ax.spines["bottom"].get_bounds() == (0, 1) assert ax.spines["left"].get_bounds() == (0, 1) + + +@pytest.mark.parametrize("estimator_name", ["my_est_name", "deprecated"]) +@pytest.mark.parametrize("name", [None, "my_name"]) +def test_deprecate_estimator_name(estimator_name, name): + """Check `_deprecate_estimator_name` behaves correctly""" + version = "1.7" + version_remove = "1.9" + + if estimator_name == "deprecated": + name_out = _deprecate_estimator_name(estimator_name, name, version) + assert name_out == name + # `estimator_name` is provided and `name` is: + elif name is None: + warning_message = ( + f"`estimator_name` is deprecated in {version} and will be removed in " + f"{version_remove}. Use `name` instead." + ) + with pytest.warns(FutureWarning, match=warning_message): + result = _deprecate_estimator_name(estimator_name, name, version) + assert result == estimator_name + elif name is not None: + error_message = ( + f"Cannot provide both `estimator_name` and `name`. `estimator_name` " + f"is deprecated in {version} and will be removed in {version_remove}. " + ) + with pytest.raises(ValueError, match=error_message): + _deprecate_estimator_name(estimator_name, name, version)