From a1442e5dd6444dd0b57beb352804b228d37b6996 Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Tue, 3 Dec 2024 22:19:51 +1100 Subject: [PATCH 01/63] first commit --- sklearn/metrics/_plot/roc_curve.py | 320 ++++++++++++++++++++++++++--- sklearn/utils/_plotting.py | 8 +- 2 files changed, 298 insertions(+), 30 deletions(-) diff --git a/sklearn/metrics/_plot/roc_curve.py b/sklearn/metrics/_plot/roc_curve.py index 058b3612baa61..8a9a5b76b2f1f 100644 --- a/sklearn/metrics/_plot/roc_curve.py +++ b/sklearn/metrics/_plot/roc_curve.py @@ -1,11 +1,15 @@ # Authors: The scikit-learn developers # SPDX-License-Identifier: BSD-3-Clause +from collections.abc import Mapping +from ...utils import _safe_indexing from ...utils._plotting import ( _BinaryClassifierCurveDisplayMixin, _despine, _validate_style_kwargs, ) +from ...utils._response import _get_response_values_binary +from ...utils.validation import _num_samples from .._ranking import auc, roc_curve @@ -14,7 +18,8 @@ class RocCurveDisplay(_BinaryClassifierCurveDisplayMixin): It is recommend to use :func:`~sklearn.metrics.RocCurveDisplay.from_estimator` or - :func:`~sklearn.metrics.RocCurveDisplay.from_predictions` to create + :func:`~sklearn.metrics.RocCurveDisplay.from_predictions` or + :func:`~sklearn.metrics.RocCurveDisplay.from_cv_results` to create a :class:`~sklearn.metrics.RocCurveDisplay`. All parameters are stored as attributes. @@ -22,17 +27,23 @@ class RocCurveDisplay(_BinaryClassifierCurveDisplayMixin): Parameters ---------- - fpr : ndarray - False positive rate. + fpr : ndarray or list of ndarray + False positive rate. When plotting multiple ROC curves, `fpr` and `tpr` should + lists of the same length. - tpr : ndarray - True positive rate. + tpr : ndarray or list of ndarray + True positive rate. When plotting multiple ROC curves, `fpr` and `tpr` should + lists of the same length. - roc_auc : float, default=None - Area under ROC curve. If None, the roc_auc score is not shown. + roc_auc : float or list of floats, default=None + Area under ROC curve. When plotting multiple ROC curves, can be a list + of the same length as `fpr` and `tpr`. + If None, no roc_auc score is shown. - estimator_name : str, default=None - Name of estimator. If None, the estimator name is not shown. + curve_name : str or list of str, default=None + Label for the ROC curve. For multiple ROC curves, `name` can be a list + of the same length as `tpr` and `fpr`. + If None, no name is shown. pos_label : int, float, bool or str, default=None The class considered as the positive class when computing the roc auc @@ -43,7 +54,7 @@ class RocCurveDisplay(_BinaryClassifierCurveDisplayMixin): Attributes ---------- - line_ : matplotlib Artist + line_ : matplotlib Artist or list of Artists ROC Curve. chance_level_ : matplotlib Artist or None @@ -76,19 +87,29 @@ class RocCurveDisplay(_BinaryClassifierCurveDisplayMixin): >>> fpr, tpr, thresholds = metrics.roc_curve(y, pred) >>> roc_auc = metrics.auc(fpr, tpr) >>> display = metrics.RocCurveDisplay(fpr=fpr, tpr=tpr, roc_auc=roc_auc, - ... estimator_name='example estimator') + ... name='example estimator') >>> display.plot() <...> >>> plt.show() """ - def __init__(self, *, fpr, tpr, roc_auc=None, estimator_name=None, pos_label=None): - self.estimator_name = estimator_name + def __init__(self, *, fpr, tpr, roc_auc=None, name=None, pos_label=None): self.fpr = fpr self.tpr = tpr self.roc_auc = roc_auc + self.name = name self.pos_label = pos_label + def _get_default_line_kwargs(self, name): + default_line_kwargs = {} + if self.roc_auc is not None and name is not None: + default_line_kwargs["label"] = f"{name} (AUC = {self.roc_auc:0.2f})" + elif self.roc_auc is not None: + default_line_kwargs["label"] = f"AUC = {self.roc_auc:0.2f}" + elif name is not None: + default_line_kwargs["label"] = name + return default_line_kwargs + def plot( self, ax=None, @@ -97,6 +118,7 @@ def plot( plot_chance_level=False, chance_level_kw=None, despine=False, + fold_line_kw=None, **kwargs, ): """Plot visualization. @@ -109,9 +131,12 @@ def plot( Axes object to plot on. If `None`, a new figure and axes is created. - name : str, default=None - Name of ROC Curve for labeling. If `None`, use `estimator_name` if - not `None`, otherwise no labeling is shown. + name : str or list of str, default=None + Name of ROC Curve(s) for labeling. If `None`: + * for single curve, use `self.name` if not `None`, otherwise + no labeling is shown + * for multiple curves (`self.fpr` and `self.fpr` are both lists), + use 'ROC fold {cv_index}' plot_chance_level : bool, default=False Whether to plot the chance level. @@ -129,25 +154,66 @@ def plot( .. versionadded:: 1.6 + fold_line_kw : dict or list of dict, default=None + Dictionary with keywords passed to the matplotlib's `plot` function + to draw the individual ROC curves. If a list is provided, the + parameters are applied to the ROC curves of each fold + sequentially. If a single dictionary is provided, the same + parameters are applied to all ROC curves. Ignored for single curve + plots (when self.fpr and self.tpr are not lists). + **kwargs : dict - Keyword arguments to be passed to matplotlib's `plot`. + For a single curve plots only, keyword arguments to be passed to + matplotlib's `plot`. Ignored for multi-curve plots. Returns ------- display : :class:`~sklearn.metrics.RocCurveDisplay` Object that stores computed values. """ - self.ax_, self.figure_, name = self._validate_plot_params(ax=ax, name=name) + # If multi-curve, ensure all args are of the right length + multi_params = [self.fpr, self.tpr, self.roc_auc, self.name] + req_multi = [input for input in multi_params[:2] if isinstance(input, list)] + optional_multi = [input for input in multi_params[2:] if isinstance(input, list)] + if req_multi and (len(req_multi) != 2): + raise ValueError( + "When plotting multiple ROC curves, `self.fpr`, `self.tpr`, " + "should both be lists." + ) + if len({len(arg) for arg in req_multi + optional_multi}) > 1: + raise ValueError( + "When plotting multiple ROC curves, `self.fpr`, `self.tpr`, and " + "if provided, `self.roc_auc` and `self.name`, should all be " + "lists of the same length." + ) + + n_multi = len(self.fpr) if req_multi else None + self.ax_, self.figure_, name = self._validate_plot_params( + ax=ax, name=name, n_multi=n_multi, curve_type="ROC", + ) - default_line_kwargs = {} - if self.roc_auc is not None and name is not None: - default_line_kwargs["label"] = f"{name} (AUC = {self.roc_auc:0.2f})" - elif self.roc_auc is not None: - default_line_kwargs["label"] = f"AUC = {self.roc_auc:0.2f}" - elif name is not None: - default_line_kwargs["label"] = name + if n_multi: + if fold_line_kw is None: + fold_line_kw = [ + {"alpha": 0.5, "color": "tab:blue", "linestyle": "--"} + ] * n_multi + elif isinstance(fold_line_kw, Mapping): + fold_line_kw = [fold_line_kw] * n_multi + elif len(fold_line_kw) != n_multi: + raise ValueError( + "When `fold_line_kw` is a list, it must have the same length as " + "the number of ROC curves to be plotted." + ) + line_kwargs = [] + for name_idx, curve_name in enumerate(name): + default_line_kwargs = self._get_default_line_kwargs(curve_name) + line_kwargs.append(_validate_style_kwargs( + default_line_kwargs, fold_line_kw[name_idx] + )) + else: + default_line_kwargs = self._get_default_line_kwargs(name) + line_kwargs = _validate_style_kwargs(default_line_kwargs, kwargs) - line_kwargs = _validate_style_kwargs(default_line_kwargs, kwargs) default_chance_level_line_kw = { "label": "Chance level (AUC = 0.5)", @@ -162,7 +228,13 @@ def plot( default_chance_level_line_kw, chance_level_kw ) - (self.line_,) = self.ax_.plot(self.fpr, self.tpr, **line_kwargs) + if n_multi: + self.line_ = [] + for fpr, tpr, line_kw in zip(self.fpr, self.tpr, line_kwargs): + self.line_.extend(self.ax_.plot(fpr, tpr, **line_kw)) + else: + (self.line_,) = self.ax_.plot(self.fpr, self.tpr, **line_kwargs) + info_pos_label = ( f" (Positive label: {self.pos_label})" if self.pos_label is not None else "" ) @@ -437,7 +509,7 @@ def from_predictions( fpr=fpr, tpr=tpr, roc_auc=roc_auc, - estimator_name=name, + name=name, pos_label=pos_label_validated, ) @@ -449,3 +521,195 @@ def from_predictions( despine=despine, **kwargs, ) + + @classmethod + def from_cv_results( + cls, + cv_results, + X, + y, + *, + sample_weight=None, + drop_intermediate=True, + response_method="auto", + pos_label=None, + ax=None, + fold_name=None, + fold_line_kw=None, + plot_chance_level=False, + chance_level_kw=None, + ): + """Create a multi-fold ROC curve display given cross-validation results. + + .. versionadded:: 1.7 + + Parameters + ---------- + cv_results : dict + Dictionary as returned by :func:`~sklearn.model_selection.cross_validate` + using `return_estimator=True` and `return_indices=True`. + + X : {array-like, sparse matrix} of shape (n_samples, n_features) + Input values. + + y : array-like of shape (n_samples,) + Target values. + + sample_weight : array-like of shape (n_samples,), default=None + Sample weights. + + drop_intermediate : bool, default=True + Whether to drop some suboptimal thresholds which would not appear + on a plotted ROC curve. This is useful in order to create lighter + ROC curves. + + response_method : {'predict_proba', 'decision_function', 'auto'} \ + default='auto' + Specifies whether to use :term:`predict_proba` or + :term:`decision_function` as the target response. If set to 'auto', + :term:`predict_proba` is tried first and if it does not exist + :term:`decision_function` is tried next. + + pos_label : str or int, default=None + The class considered as the positive class when computing the roc auc + metrics. By default, `estimators.classes_[1]` is considered + as the positive class. + + ax : matplotlib axes, default=None + Axes object to plot on. If `None`, a new figure and axes is + created. + + fold_name : list of str, default=None + Name used in the legend for each individual ROC curve. If `None`, + the name will be set to "ROC fold #N" where N is the index of the + CV fold. + + fold_line_kw : dict or list of dict, default=None + Dictionary with keywords passed to the matplotlib's `plot` function + to draw the individual ROC curves. If a list is provided, the + parameters are applied to the ROC curves of each CV fold + sequentially. If a single dictionary is provided, the same + parameters are applied to all ROC curves. + + plot_chance_level : bool, default=False + Whether to plot the chance level. + + chance_level_kw : dict, default=None + Keyword arguments to be passed to matplotlib's `plot` for rendering + the chance level line. + + Returns + ------- + display : :class:`~sklearn.metrics.MultiRocCurveDisplay` + The multi-fold ROC curve display. + + See Also + -------- + roc_curve : Compute Receiver operating characteristic (ROC) curve. + RocCurveDisplay.from_estimator : ROC Curve visualization given an + estimator and some data. + RocCurveDisplay.from_predictions : ROC Curve visualization given the + probabilities of scores of a classifier. + roc_auc_score : Compute the area under the ROC curve. + + Examples + -------- + >>> import matplotlib.pyplot as plt + >>> from sklearn.datasets import make_classification + >>> from sklearn.metrics import RocCurveDisplay + >>> from sklearn.model_selection import cross_validate + >>> from sklearn.svm import SVC + >>> X, y = make_classification(random_state=0) + >>> clf = SVC(random_state=0) + >>> cv_results = cross_validate( + ... clf, X, y, cv=3, return_estimator=True, return_indices=True) + >>> RocCurveDisplay.from_cv_results(cv_results, X, y, kind="both") + <...> + >>> plt.show() + """ + required_keys = {"estimator", "indices"} + if not all(key in cv_results for key in required_keys): + raise ValueError( + "cv_results does not contain one of the following required keys: " + f"{required_keys}. Set explicitly the parameters return_estimator=True " + "and return_indices=True to the function cross_validate." + ) + + train_size, test_size = ( + len(cv_results["indices"]["train"][0]), + len(cv_results["indices"]["test"][0]), + ) + + if _num_samples(X) != train_size + test_size: + raise ValueError( + "X does not contain the correct number of samples. " + f"Expected {train_size + test_size}, got {_num_samples(X)}." + ) + + if fold_name is None: + # create an iterable of the same length as the number of ROC curves + fold_name_ = [None] * len(cv_results["estimator"]) + elif fold_name is not None and len(fold_name) != len(cv_results["estimator"]): + raise ValueError( + "When `fold_name` is provided, it must have the same length as " + f"the number of ROC curves to be plotted. Got {len(fold_name)} names " + f"instead of {len(cv_results['estimator'])}." + ) + else: + fold_name_ = fold_name + + if fold_line_kw is None: + fold_line_kw = [ + {"alpha": 0.5, "color": "tab:blue", "linestyle": "--"} + ] * len(cv_results["estimator"]) + elif isinstance(fold_line_kw, Mapping): + fold_line_kw = [fold_line_kw] * len(cv_results["estimator"]) + elif len(fold_line_kw) != len(cv_results["estimator"]): + raise ValueError( + "When `fold_line_kw` is a list, it must have the same length as " + "the number of ROC curves to be plotted." + ) + + fpr_all = [] + tpr_all = [] + auc_all = [] + for estimator, test_indices, name in zip( + cv_results["estimator"], cv_results["indices"]["test"], fold_name_ + ): + y_true = _safe_indexing(y, test_indices) + y_pred = _get_response_values_binary( + estimator, + _safe_indexing(X, test_indices), + response_method=response_method, + pos_label=pos_label, + )[0] + # Should we use `_validate_from_predictions_params` here? + # The check would technically only be needed once though + fpr, tpr, _ = roc_curve( + y_true, + y_pred, + pos_label=pos_label, + sample_weight=sample_weight, + drop_intermediate=drop_intermediate, + ) + roc_auc = auc(fpr, tpr) + # Append all + fpr_all.append(fpr) + tpr_all.append(tpr) + auc_all.append(roc_auc) + + + viz = cls( + fpr=fpr_all, + tpr=tpr_all, + roc_auc=auc_all, + name=name, + pos_label=pos_label, + ) + return viz.plot( + ax=ax, + fold_name=fold_name_, + fold_line_kw=fold_line_kw, + plot_chance_level=plot_chance_level, + chance_level_kw=chance_level_kw, + ) \ No newline at end of file diff --git a/sklearn/utils/_plotting.py b/sklearn/utils/_plotting.py index 946c95186374b..c931f6150b616 100644 --- a/sklearn/utils/_plotting.py +++ b/sklearn/utils/_plotting.py @@ -17,14 +17,18 @@ class _BinaryClassifierCurveDisplayMixin: the target and gather the response of the estimator. """ - def _validate_plot_params(self, *, ax=None, name=None): + def _validate_plot_params(self, *, ax=None, name=None, n_multi=None, curve_type=None): check_matplotlib_support(f"{self.__class__.__name__}.plot") import matplotlib.pyplot as plt if ax is None: _, ax = plt.subplots() - name = self.estimator_name if name is None else name + # Not 100% sure on this change + if n_multi is None: + name = self.estimator_name if name is None else name + else: + name = [f"{curve_type} fold {curve_idx}:" for curve_idx in range(n_multi)] return ax, ax.figure, name @classmethod From f4f9a98ae2209d031cb227500ccf83039e9aa6e4 Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Tue, 3 Dec 2024 22:29:14 +1100 Subject: [PATCH 02/63] lint --- sklearn/metrics/_plot/roc_curve.py | 19 +++++++++++-------- sklearn/utils/_plotting.py | 4 +++- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/sklearn/metrics/_plot/roc_curve.py b/sklearn/metrics/_plot/roc_curve.py index 8a9a5b76b2f1f..31d06c069d861 100644 --- a/sklearn/metrics/_plot/roc_curve.py +++ b/sklearn/metrics/_plot/roc_curve.py @@ -174,7 +174,9 @@ def plot( # If multi-curve, ensure all args are of the right length multi_params = [self.fpr, self.tpr, self.roc_auc, self.name] req_multi = [input for input in multi_params[:2] if isinstance(input, list)] - optional_multi = [input for input in multi_params[2:] if isinstance(input, list)] + optional_multi = [ + input for input in multi_params[2:] if isinstance(input, list) + ] if req_multi and (len(req_multi) != 2): raise ValueError( "When plotting multiple ROC curves, `self.fpr`, `self.tpr`, " @@ -189,7 +191,10 @@ def plot( n_multi = len(self.fpr) if req_multi else None self.ax_, self.figure_, name = self._validate_plot_params( - ax=ax, name=name, n_multi=n_multi, curve_type="ROC", + ax=ax, + name=name, + n_multi=n_multi, + curve_type="ROC", ) if n_multi: @@ -207,14 +212,13 @@ def plot( line_kwargs = [] for name_idx, curve_name in enumerate(name): default_line_kwargs = self._get_default_line_kwargs(curve_name) - line_kwargs.append(_validate_style_kwargs( - default_line_kwargs, fold_line_kw[name_idx] - )) + line_kwargs.append( + _validate_style_kwargs(default_line_kwargs, fold_line_kw[name_idx]) + ) else: default_line_kwargs = self._get_default_line_kwargs(name) line_kwargs = _validate_style_kwargs(default_line_kwargs, kwargs) - default_chance_level_line_kw = { "label": "Chance level (AUC = 0.5)", "color": "k", @@ -698,7 +702,6 @@ def from_cv_results( tpr_all.append(tpr) auc_all.append(roc_auc) - viz = cls( fpr=fpr_all, tpr=tpr_all, @@ -712,4 +715,4 @@ def from_cv_results( fold_line_kw=fold_line_kw, plot_chance_level=plot_chance_level, chance_level_kw=chance_level_kw, - ) \ No newline at end of file + ) diff --git a/sklearn/utils/_plotting.py b/sklearn/utils/_plotting.py index c931f6150b616..1057108f2b767 100644 --- a/sklearn/utils/_plotting.py +++ b/sklearn/utils/_plotting.py @@ -17,7 +17,9 @@ class _BinaryClassifierCurveDisplayMixin: the target and gather the response of the estimator. """ - def _validate_plot_params(self, *, ax=None, name=None, n_multi=None, curve_type=None): + def _validate_plot_params( + self, *, ax=None, name=None, n_multi=None, curve_type=None + ): check_matplotlib_support(f"{self.__class__.__name__}.plot") import matplotlib.pyplot as plt From 9e87e13ddb0883400bafdc728bd82ef27b7cce03 Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Wed, 18 Dec 2024 22:18:13 +1100 Subject: [PATCH 03/63] fixes --- sklearn/metrics/_plot/roc_curve.py | 44 ++++++++++++++++++------------ sklearn/utils/_plotting.py | 2 +- 2 files changed, 27 insertions(+), 19 deletions(-) diff --git a/sklearn/metrics/_plot/roc_curve.py b/sklearn/metrics/_plot/roc_curve.py index 31d06c069d861..7b470f982e99e 100644 --- a/sklearn/metrics/_plot/roc_curve.py +++ b/sklearn/metrics/_plot/roc_curve.py @@ -40,7 +40,7 @@ class RocCurveDisplay(_BinaryClassifierCurveDisplayMixin): of the same length as `fpr` and `tpr`. If None, no roc_auc score is shown. - curve_name : str or list of str, default=None + name : str or list of str, default=None Label for the ROC curve. For multiple ROC curves, `name` can be a list of the same length as `tpr` and `fpr`. If None, no name is shown. @@ -100,12 +100,12 @@ def __init__(self, *, fpr, tpr, roc_auc=None, name=None, pos_label=None): self.name = name self.pos_label = pos_label - def _get_default_line_kwargs(self, name): + def _get_default_line_kwargs(self, roc_auc, name): default_line_kwargs = {} - if self.roc_auc is not None and name is not None: - default_line_kwargs["label"] = f"{name} (AUC = {self.roc_auc:0.2f})" - elif self.roc_auc is not None: - default_line_kwargs["label"] = f"AUC = {self.roc_auc:0.2f}" + if roc_auc is not None and name is not None: + default_line_kwargs["label"] = f"{name} (AUC = {roc_auc:0.2f})" + elif roc_auc is not None: + default_line_kwargs["label"] = f"AUC = {roc_auc:0.2f}" elif name is not None: default_line_kwargs["label"] = name return default_line_kwargs @@ -132,11 +132,10 @@ def plot( created. name : str or list of str, default=None - Name of ROC Curve(s) for labeling. If `None`: - * for single curve, use `self.name` if not `None`, otherwise - no labeling is shown - * for multiple curves (`self.fpr` and `self.fpr` are both lists), - use 'ROC fold {cv_index}' + Name of ROC Curve(s) for labeling. If `None`; + * try to use `self.name`, + * if `self.name` also `None`, no labeling is shown for single curves. + For multiple curves use 'ROC fold {cv_index}' plot_chance_level : bool, default=False Whether to plot the chance level. @@ -171,8 +170,9 @@ def plot( display : :class:`~sklearn.metrics.RocCurveDisplay` Object that stores computed values. """ + name_ = self.name if name is None else name # If multi-curve, ensure all args are of the right length - multi_params = [self.fpr, self.tpr, self.roc_auc, self.name] + multi_params = [self.fpr, self.tpr, self.roc_auc, name_] req_multi = [input for input in multi_params[:2] if isinstance(input, list)] optional_multi = [ input for input in multi_params[2:] if isinstance(input, list) @@ -185,8 +185,8 @@ def plot( if len({len(arg) for arg in req_multi + optional_multi}) > 1: raise ValueError( "When plotting multiple ROC curves, `self.fpr`, `self.tpr`, and " - "if provided, `self.roc_auc` and `self.name`, should all be " - "lists of the same length." + "if provided, `self.roc_auc` and `name` (or `self.name`), should all " + "be lists of the same length." ) n_multi = len(self.fpr) if req_multi else None @@ -209,11 +209,19 @@ def plot( "When `fold_line_kw` is a list, it must have the same length as " "the number of ROC curves to be plotted." ) + name_ = [name_] * n_multi if name_ is None else name_ + roc_auc_ = ( + [self.roc_auc] * n_multi if self.roc_auc is None else self.roc_auc + ) line_kwargs = [] - for name_idx, curve_name in enumerate(name): - default_line_kwargs = self._get_default_line_kwargs(curve_name) + for fold_idx, (curve_name, curve_roc_auc) in enumerate( + zip(name_, roc_auc_) + ): + default_line_kwargs = self._get__default_line_kwargs( + curve_name, curve_roc_auc + ) line_kwargs.append( - _validate_style_kwargs(default_line_kwargs, fold_line_kw[name_idx]) + _validate_style_kwargs(default_line_kwargs, fold_line_kw[fold_idx]) ) else: default_line_kwargs = self._get_default_line_kwargs(name) @@ -604,7 +612,7 @@ def from_cv_results( Returns ------- - display : :class:`~sklearn.metrics.MultiRocCurveDisplay` + display : :class:`~sklearn.metrics.RocCurveDisplay` The multi-fold ROC curve display. See Also diff --git a/sklearn/utils/_plotting.py b/sklearn/utils/_plotting.py index 1057108f2b767..bc7a2ce680e18 100644 --- a/sklearn/utils/_plotting.py +++ b/sklearn/utils/_plotting.py @@ -28,7 +28,7 @@ def _validate_plot_params( # Not 100% sure on this change if n_multi is None: - name = self.estimator_name if name is None else name + name = self.curve_name if name is None else name else: name = [f"{curve_type} fold {curve_idx}:" for curve_idx in range(n_multi)] return ax, ax.figure, name From f0908e1e7d14d5bf1c557b3b438c7302bed7ca9e Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Tue, 31 Dec 2024 15:10:22 +1100 Subject: [PATCH 04/63] factorize --- sklearn/metrics/_plot/roc_curve.py | 237 +++++++++++++++-------------- sklearn/utils/_plotting.py | 143 ++++++++++++++++- 2 files changed, 254 insertions(+), 126 deletions(-) diff --git a/sklearn/metrics/_plot/roc_curve.py b/sklearn/metrics/_plot/roc_curve.py index 7b470f982e99e..c02016cce2da9 100644 --- a/sklearn/metrics/_plot/roc_curve.py +++ b/sklearn/metrics/_plot/roc_curve.py @@ -1,11 +1,13 @@ # Authors: The scikit-learn developers # SPDX-License-Identifier: BSD-3-Clause -from collections.abc import Mapping from ...utils import _safe_indexing from ...utils._plotting import ( _BinaryClassifierCurveDisplayMixin, + _check_param_lengths, + _deprecate_singular, _despine, + _process_fold_names_line_kwargs, _validate_style_kwargs, ) from ...utils._response import _get_response_values_binary @@ -27,30 +29,64 @@ class RocCurveDisplay(_BinaryClassifierCurveDisplayMixin): Parameters ---------- + fprs : list of ndarray + False positive rates. Each ndarray should contain values for a single curve. + If plotting multiple curves, list should be of same length as + and `tprs`. + + tprs : list of ndarray + True positive rates. Each ndarray should contain values for a single curve. + If plotting multiple curves, list should be of same length as + and `fprs`. + + roc_aucs : list of floats, default=None + Area under ROC curve. Should be list of the same length as `fprs` and + `tprs` or None, in which case no area under ROC curve score is shown. + + names : str or list of str, default=None + Label for the ROC curve. Should be list of the same length as + `fprs` and `tprs` or None, in which case no name is shown. + + pos_label : int, float, bool or str, default=None + The class considered as the positive class when computing the roc auc + metrics. By default, `estimators.classes_[1]` is considered + as the positive class. + + .. versionadded:: 0.24 + fpr : ndarray or list of ndarray False positive rate. When plotting multiple ROC curves, `fpr` and `tpr` should lists of the same length. + .. deprecated:: 1.7 + `fpr` is deprecated in 1.7 and will be removed in 1.9. + Use `fprs` instead. + tpr : ndarray or list of ndarray True positive rate. When plotting multiple ROC curves, `fpr` and `tpr` should lists of the same length. + .. deprecated:: 1.7 + `tpr` is deprecated in 1.7 and will be removed in 1.9. + Use `tprs` instead. + roc_auc : float or list of floats, default=None Area under ROC curve. When plotting multiple ROC curves, can be a list of the same length as `fpr` and `tpr`. If None, no roc_auc score is shown. + .. deprecated:: 1.7 + `roc_auc` is deprecated in 1.7 and will be removed in 1.9. + Use `roc_aucs` instead. + name : str or list of str, default=None Label for the ROC curve. For multiple ROC curves, `name` can be a list of the same length as `tpr` and `fpr`. If None, no name is shown. - pos_label : int, float, bool or str, default=None - The class considered as the positive class when computing the roc auc - metrics. By default, `estimators.classes_[1]` is considered - as the positive class. - - .. versionadded:: 0.24 + .. deprecated:: 1.7 + `name` is deprecated in 1.7 and will be removed in 1.9. + Use `names` instead. Attributes ---------- @@ -93,11 +129,23 @@ class RocCurveDisplay(_BinaryClassifierCurveDisplayMixin): >>> plt.show() """ - def __init__(self, *, fpr, tpr, roc_auc=None, name=None, pos_label=None): - self.fpr = fpr - self.tpr = tpr - self.roc_auc = roc_auc - self.name = name + def __init__( + self, + *, + fprs, + tprs, + roc_aucs=None, + names=None, + pos_label=None, + fpr="deprecated", + tpr="deprecated", + roc_auc="deprecated", + name="deprecated", + ): + self.fprs = _deprecate_singular(fpr, fprs, "fpr") + self.tprs = _deprecate_singular(tpr, tprs, "tpr") + self.roc_aucs = _deprecate_singular(roc_auc, roc_aucs, "roc_auc") + self.names = _deprecate_singular(name, names, "name") self.pos_label = pos_label def _get_default_line_kwargs(self, roc_auc, name): @@ -114,16 +162,18 @@ def plot( self, ax=None, *, - name=None, + names=None, plot_chance_level=False, chance_level_kw=None, despine=False, - fold_line_kw=None, + fold_line_kws=None, + name="deprecated", **kwargs, ): """Plot visualization. - Extra keyword arguments will be passed to matplotlib's ``plot``. + For single curve plots, extra keyword arguments will be passed to + matplotlib's ``plot``. Parameters ---------- @@ -131,11 +181,12 @@ def plot( Axes object to plot on. If `None`, a new figure and axes is created. - name : str or list of str, default=None - Name of ROC Curve(s) for labeling. If `None`; - * try to use `self.name`, - * if `self.name` also `None`, no labeling is shown for single curves. - For multiple curves use 'ROC fold {cv_index}' + names : list of str, default=None + Names of each ROC curve for labeling. If `None`, use + name provided at `RocCurveDisplay` initialization. If not + provided at initialization, no labeling is shown. + + .. versionadded:: 1.7 plot_chance_level : bool, default=False Whether to plot the chance level. @@ -153,13 +204,23 @@ def plot( .. versionadded:: 1.6 - fold_line_kw : dict or list of dict, default=None + fold_line_kws : dict or list of dict, default=None Dictionary with keywords passed to the matplotlib's `plot` function to draw the individual ROC curves. If a list is provided, the parameters are applied to the ROC curves of each fold sequentially. If a single dictionary is provided, the same parameters are applied to all ROC curves. Ignored for single curve - plots (when self.fpr and self.tpr are not lists). + plots. + + .. versionadded:: 1.7 + + name : str, default=None + Name of ROC Curve for labeling. If `None`, use `estimator_name` if + not `None`, otherwise no labeling is shown. + + .. deprecated:: 1.7 + `name` is deprecated in 1.7 and will be removed in 1.9. + Use `names` instead. **kwargs : dict For a single curve plots only, keyword arguments to be passed to @@ -170,62 +231,22 @@ def plot( display : :class:`~sklearn.metrics.RocCurveDisplay` Object that stores computed values. """ - name_ = self.name if name is None else name - # If multi-curve, ensure all args are of the right length - multi_params = [self.fpr, self.tpr, self.roc_auc, name_] - req_multi = [input for input in multi_params[:2] if isinstance(input, list)] - optional_multi = [ - input for input in multi_params[2:] if isinstance(input, list) - ] - if req_multi and (len(req_multi) != 2): - raise ValueError( - "When plotting multiple ROC curves, `self.fpr`, `self.tpr`, " - "should both be lists." - ) - if len({len(arg) for arg in req_multi + optional_multi}) > 1: - raise ValueError( - "When plotting multiple ROC curves, `self.fpr`, `self.tpr`, and " - "if provided, `self.roc_auc` and `name` (or `self.name`), should all " - "be lists of the same length." - ) + names = _deprecate_singular(name, names, "name") + names_ = self.names if (names[0] is None) else names + _check_param_lengths( + {"self.fprs": self.fprs, "self.tprs": self.tprs}, + {"roc_aucs": self.roc_aucs, "self.names (or names from `plot`)": names_}, + "RocCurveDisplay", + ) - n_multi = len(self.fpr) if req_multi else None - self.ax_, self.figure_, name = self._validate_plot_params( + n_curves = len(self.fprs) + self.ax_, self.figure_, _ = self._validate_plot_params( ax=ax, - name=name, - n_multi=n_multi, - curve_type="ROC", ) - if n_multi: - if fold_line_kw is None: - fold_line_kw = [ - {"alpha": 0.5, "color": "tab:blue", "linestyle": "--"} - ] * n_multi - elif isinstance(fold_line_kw, Mapping): - fold_line_kw = [fold_line_kw] * n_multi - elif len(fold_line_kw) != n_multi: - raise ValueError( - "When `fold_line_kw` is a list, it must have the same length as " - "the number of ROC curves to be plotted." - ) - name_ = [name_] * n_multi if name_ is None else name_ - roc_auc_ = ( - [self.roc_auc] * n_multi if self.roc_auc is None else self.roc_auc - ) - line_kwargs = [] - for fold_idx, (curve_name, curve_roc_auc) in enumerate( - zip(name_, roc_auc_) - ): - default_line_kwargs = self._get__default_line_kwargs( - curve_name, curve_roc_auc - ) - line_kwargs.append( - _validate_style_kwargs(default_line_kwargs, fold_line_kw[fold_idx]) - ) - else: - default_line_kwargs = self._get_default_line_kwargs(name) - line_kwargs = _validate_style_kwargs(default_line_kwargs, kwargs) + line_kwargs = self._get_line_kwargs( + n_curves, names_, self.roc_aucs, fold_line_kws, **kwargs + ) default_chance_level_line_kw = { "label": "Chance level (AUC = 0.5)", @@ -240,12 +261,12 @@ def plot( default_chance_level_line_kw, chance_level_kw ) - if n_multi: - self.line_ = [] - for fpr, tpr, line_kw in zip(self.fpr, self.tpr, line_kwargs): - self.line_.extend(self.ax_.plot(fpr, tpr, **line_kw)) - else: - (self.line_,) = self.ax_.plot(self.fpr, self.tpr, **line_kwargs) + self.line_ = [] + for fpr, tpr, line_kw in zip(self.fprs, self.tprs, line_kwargs): + self.line_.extend(self.ax_.plot(fpr, tpr, **line_kw)) + # Should we do this to be backwards compatible or have `line_` always be list? + if len(self.line_) == 1: + self.line_ = self.line_[0] info_pos_label = ( f" (Positive label: {self.pos_label})" if self.pos_label is not None else "" @@ -269,7 +290,8 @@ def plot( if despine: _despine(self.ax_) - if "label" in line_kwargs or "label" in chance_level_kw: + # Note: if 'label' present in one `line_kwargs`, it should be present in all + if "label" in line_kwargs[0] or chance_level_kw: self.ax_.legend(loc="lower right") return self @@ -546,8 +568,8 @@ def from_cv_results( response_method="auto", pos_label=None, ax=None, - fold_name=None, - fold_line_kw=None, + fold_names=None, + fold_line_kwargs=None, plot_chance_level=False, chance_level_kw=None, ): @@ -591,12 +613,12 @@ def from_cv_results( Axes object to plot on. If `None`, a new figure and axes is created. - fold_name : list of str, default=None - Name used in the legend for each individual ROC curve. If `None`, - the name will be set to "ROC fold #N" where N is the index of the + fold_names : list of str, default=None + Names used in the legend for each individual ROC curve. If `None`, + the name will be set to "ROC fold " where N is the index of the CV fold. - fold_line_kw : dict or list of dict, default=None + fold_line_kwargs : dict or list of dict, default=None Dictionary with keywords passed to the matplotlib's `plot` function to draw the individual ROC curves. If a list is provided, the parameters are applied to the ROC curves of each CV fold @@ -658,35 +680,15 @@ def from_cv_results( f"Expected {train_size + test_size}, got {_num_samples(X)}." ) - if fold_name is None: - # create an iterable of the same length as the number of ROC curves - fold_name_ = [None] * len(cv_results["estimator"]) - elif fold_name is not None and len(fold_name) != len(cv_results["estimator"]): - raise ValueError( - "When `fold_name` is provided, it must have the same length as " - f"the number of ROC curves to be plotted. Got {len(fold_name)} names " - f"instead of {len(cv_results['estimator'])}." - ) - else: - fold_name_ = fold_name - - if fold_line_kw is None: - fold_line_kw = [ - {"alpha": 0.5, "color": "tab:blue", "linestyle": "--"} - ] * len(cv_results["estimator"]) - elif isinstance(fold_line_kw, Mapping): - fold_line_kw = [fold_line_kw] * len(cv_results["estimator"]) - elif len(fold_line_kw) != len(cv_results["estimator"]): - raise ValueError( - "When `fold_line_kw` is a list, it must have the same length as " - "the number of ROC curves to be plotted." - ) + fold_names_, fold_line_kws_ = _process_fold_names_line_kwargs( + len(cv_results["estimator"]), fold_names, fold_line_kwargs + ) fpr_all = [] tpr_all = [] auc_all = [] - for estimator, test_indices, name in zip( - cv_results["estimator"], cv_results["indices"]["test"], fold_name_ + for estimator, test_indices in zip( + cv_results["estimator"], cv_results["indices"]["test"] ): y_true = _safe_indexing(y, test_indices) y_pred = _get_response_values_binary( @@ -711,16 +713,15 @@ def from_cv_results( auc_all.append(roc_auc) viz = cls( - fpr=fpr_all, - tpr=tpr_all, - roc_auc=auc_all, - name=name, + fprs=fpr_all, + tprs=tpr_all, + roc_aucs=auc_all, pos_label=pos_label, ) return viz.plot( ax=ax, - fold_name=fold_name_, - fold_line_kw=fold_line_kw, + names=fold_names_, + fold_line_kws=fold_line_kws_, plot_chance_level=plot_chance_level, chance_level_kw=chance_level_kw, ) diff --git a/sklearn/utils/_plotting.py b/sklearn/utils/_plotting.py index bc7a2ce680e18..d5725462074f8 100644 --- a/sklearn/utils/_plotting.py +++ b/sklearn/utils/_plotting.py @@ -1,5 +1,7 @@ # Authors: The scikit-learn developers # SPDX-License-Identifier: BSD-3-Clause +import warnings +from collections.abc import Mapping import numpy as np @@ -17,21 +19,69 @@ class _BinaryClassifierCurveDisplayMixin: the target and gather the response of the estimator. """ - def _validate_plot_params( - self, *, ax=None, name=None, n_multi=None, curve_type=None - ): + def _validate_plot_params(self, *, ax=None, names=None): check_matplotlib_support(f"{self.__class__.__name__}.plot") import matplotlib.pyplot as plt if ax is None: _, ax = plt.subplots() - # Not 100% sure on this change - if n_multi is None: - name = self.curve_name if name is None else name + if names is not None: + names = self.names if names[0] is None else names + return ax, ax.figure, names + + @classmethod + def _get_line_kwargs( + cls, + n_curves, + names, + summary_values, + fold_line_kws, + default_line_kwargs={}, + **kwargs, + ): + """Get validated line kwargs for each curve.""" + # Ensure parameters are of the correct length + names_ = [None] * n_curves if names is None else names + summary_values_ = ( + [None] * n_curves if summary_values is None else summary_values + ) + # `fold_line_kws` ignored for single curve plots + # `kwargs` ignored for multi-curve plots + if n_curves == 1: + fold_line_kws = [kwargs] else: - name = [f"{curve_type} fold {curve_idx}:" for curve_idx in range(n_multi)] - return ax, ax.figure, name + if fold_line_kws is None: + # We should not set color to be the same, otherwise legend is + # meaningless + fold_line_kws = [ + {"alpha": 0.5, "color": "tab:blue", "linestyle": "--"} + ] * n_curves + elif isinstance(fold_line_kws, Mapping): + fold_line_kws = [fold_line_kws] * n_curves + elif len(fold_line_kws) != n_curves: + raise ValueError( + "When `fold_line_kws` is a list, it must have the same length as " + "the number of curves to be plotted." + ) + + line_kwargs = [] + for fold_idx, (curve_summary_value, curve_name) in enumerate( + zip(summary_values_, names_) + ): + if curve_summary_value is not None and curve_name is not None: + default_line_kwargs["label"] = ( + f"{curve_name} (AP = {curve_summary_value:0.2f})" + ) + elif curve_summary_value is not None: + default_line_kwargs["label"] = f"AP = {curve_summary_value:0.2f}" + elif curve_name is not None: + default_line_kwargs["label"] = curve_name + + line_kwargs.append( + _validate_style_kwargs(default_line_kwargs, fold_line_kws[fold_idx]) + ) + return line_kwargs @classmethod def _validate_and_get_response_values( @@ -183,3 +233,80 @@ def _despine(ax): ax.spines[s].set_visible(False) for s in ["bottom", "left"]: ax.spines[s].set_bounds(0, 1) + + +# TODO(1.9): remove +# Should this be a parent class method? +def _deprecate_singular(singular, plural, name): + """Deprecate the singular version of Display parameters. + + If only `singular` parameter passed, it will be returned as a list with a warning. + """ + if singular != "deprecated": + warnings.warn( + f"`{name}` was passed to `{name}s` in a list because `{name}` is " + f"deprecated in 1.7 and will be removed in 1.9. Use " + f"`{name}s` instead.", + FutureWarning, + ) + if plural: + raise ValueError( + f"Cannot use both `{name}` and `{name}s`. Use only `{name}s` as " + f"`{name}` is deprecated." + ) + return [singular] + return plural + + +# Should this be a parent class/mixin method? +def _check_param_lengths(required, optional, class_name): + """Check required and optional parameters are of the same length.""" + optional_provided = {} + for name, param in optional: + if isinstance(param, list): + optional_provided[name] = param + + all_params = {**required, **optional_provided} + if len({len(param) for param in all_params.values()}) > 1: + required_formatted = ", ".join(f"'{key}'" for key in required.keys()) + optional_formatted = ", ".join(f"'{key}'" for key in optional_provided.keys()) + lengths_formatted = ", ".join( + f"{key}: {len(value)}" for key, value in all_params.items() + ) + raise ValueError( + f"{required_formatted}, and optional parameters {optional_formatted} " + f"from `{class_name}` initialization, should all be lists of the same " + f"length. Got: {lengths_formatted}" + ) + + +def _process_fold_names_line_kwargs(n_curves, fold_names, fold_line_kwargs): + """Ensure that `fold_names` and `fold_line_kwargs` are of correct length.""" + msg = ( + "When `{param}` is provided, it must have the same length as " + "the number of curves to be plotted. Got {len_param} " + "instead of {n_curves}." + ) + + if fold_names is None: + fold_names_ = [f"ROC fold: {idx}" for idx in range(n_curves)] + elif fold_names is not None and len(fold_names) != n_curves: + raise ValueError( + msg.format(param="fold_names", len_param=len(fold_names), n_curves=n_curves) + ) + else: + fold_names_ = fold_names + + if isinstance(fold_line_kwargs, Mapping): + fold_line_kws_ = [fold_line_kwargs] * n_curves + elif len(fold_line_kwargs) != n_curves: + raise ValueError( + msg.format( + param="fold_line_kwargs", + len_param=len(fold_line_kwargs), + n_curves=n_curves, + ) + ) + else: + fold_line_kws_ = fold_line_kwargs + return fold_names_, fold_line_kws_ From 7e77d4c72307e6ea709624606ad84867af0a4652 Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Tue, 31 Dec 2024 15:12:44 +1100 Subject: [PATCH 05/63] fix --- sklearn/utils/_plotting.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/sklearn/utils/_plotting.py b/sklearn/utils/_plotting.py index d5725462074f8..800cd244a1934 100644 --- a/sklearn/utils/_plotting.py +++ b/sklearn/utils/_plotting.py @@ -262,7 +262,7 @@ def _deprecate_singular(singular, plural, name): def _check_param_lengths(required, optional, class_name): """Check required and optional parameters are of the same length.""" optional_provided = {} - for name, param in optional: + for name, param in optional.items(): if isinstance(param, list): optional_provided[name] = param @@ -290,7 +290,7 @@ def _process_fold_names_line_kwargs(n_curves, fold_names, fold_line_kwargs): if fold_names is None: fold_names_ = [f"ROC fold: {idx}" for idx in range(n_curves)] - elif fold_names is not None and len(fold_names) != n_curves: + elif len(fold_names) != n_curves: raise ValueError( msg.format(param="fold_names", len_param=len(fold_names), n_curves=n_curves) ) @@ -299,7 +299,7 @@ def _process_fold_names_line_kwargs(n_curves, fold_names, fold_line_kwargs): if isinstance(fold_line_kwargs, Mapping): fold_line_kws_ = [fold_line_kwargs] * n_curves - elif len(fold_line_kwargs) != n_curves: + elif fold_names is not None and len(fold_line_kwargs) != n_curves: raise ValueError( msg.format( param="fold_line_kwargs", @@ -309,4 +309,5 @@ def _process_fold_names_line_kwargs(n_curves, fold_names, fold_line_kwargs): ) else: fold_line_kws_ = fold_line_kwargs + return fold_names_, fold_line_kws_ From 1a4a2f315e98747be435d083d53b957fbf14aafc Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Wed, 15 Jan 2025 11:01:45 +1100 Subject: [PATCH 06/63] review changes --- sklearn/metrics/_plot/roc_curve.py | 63 ++++++++++++++++-------------- sklearn/utils/_plotting.py | 3 +- 2 files changed, 35 insertions(+), 31 deletions(-) diff --git a/sklearn/metrics/_plot/roc_curve.py b/sklearn/metrics/_plot/roc_curve.py index c02016cce2da9..dbac1cf78da67 100644 --- a/sklearn/metrics/_plot/roc_curve.py +++ b/sklearn/metrics/_plot/roc_curve.py @@ -1,7 +1,7 @@ # Authors: The scikit-learn developers # SPDX-License-Identifier: BSD-3-Clause -from ...utils import _safe_indexing +from ...utils import deprecated, _safe_indexing from ...utils._plotting import ( _BinaryClassifierCurveDisplayMixin, _check_param_lengths, @@ -90,8 +90,8 @@ class RocCurveDisplay(_BinaryClassifierCurveDisplayMixin): Attributes ---------- - line_ : matplotlib Artist or list of Artists - ROC Curve. + lines_ : list of matplotlib Artists + ROC Curves. chance_level_ : matplotlib Artist or None The chance level line. It is `None` if the chance level is not plotted. @@ -104,6 +104,13 @@ class RocCurveDisplay(_BinaryClassifierCurveDisplayMixin): figure_ : matplotlib Figure Figure containing the curve. + line_ : matplotlib Artist + ROC Curve. + + .. deprecated:: 1.7 + `line_` is deprecated in 1.7 and will be removed in 1.9. Use `lines_` + instead. + See Also -------- roc_curve : Compute Receiver operating characteristic (ROC) curve. @@ -148,16 +155,6 @@ def __init__( self.names = _deprecate_singular(name, names, "name") self.pos_label = pos_label - def _get_default_line_kwargs(self, roc_auc, name): - default_line_kwargs = {} - if roc_auc is not None and name is not None: - default_line_kwargs["label"] = f"{name} (AUC = {roc_auc:0.2f})" - elif roc_auc is not None: - default_line_kwargs["label"] = f"AUC = {roc_auc:0.2f}" - elif name is not None: - default_line_kwargs["label"] = name - return default_line_kwargs - def plot( self, ax=None, @@ -166,7 +163,7 @@ def plot( plot_chance_level=False, chance_level_kw=None, despine=False, - fold_line_kws=None, + fold_line_kwargs=None, name="deprecated", **kwargs, ): @@ -182,9 +179,9 @@ def plot( created. names : list of str, default=None - Names of each ROC curve for labeling. If `None`, use - name provided at `RocCurveDisplay` initialization. If not - provided at initialization, no labeling is shown. + Names of each ROC curve for labeling each curve in the legend. + If `None`, use name provided at `RocCurveDisplay` initialization. If none + provided at initialization, no legend is added. .. versionadded:: 1.7 @@ -204,13 +201,12 @@ def plot( .. versionadded:: 1.6 - fold_line_kws : dict or list of dict, default=None + fold_line_kwargs : dict or list of dict, default=None Dictionary with keywords passed to the matplotlib's `plot` function to draw the individual ROC curves. If a list is provided, the - parameters are applied to the ROC curves of each fold - sequentially. If a single dictionary is provided, the same - parameters are applied to all ROC curves. Ignored for single curve - plots. + parameters are applied to the ROC curves sequentially. If a single + dictionary is provided, the same parameters are applied to all ROC + curves. Ignored for single curve plots. .. versionadded:: 1.7 @@ -225,6 +221,7 @@ def plot( **kwargs : dict For a single curve plots only, keyword arguments to be passed to matplotlib's `plot`. Ignored for multi-curve plots. + (Note req for backwards compat, maybe not ideal?) Returns ------- @@ -245,7 +242,7 @@ def plot( ) line_kwargs = self._get_line_kwargs( - n_curves, names_, self.roc_aucs, fold_line_kws, **kwargs + n_curves, names_, self.roc_aucs, fold_line_kwargs, **kwargs ) default_chance_level_line_kw = { @@ -261,12 +258,9 @@ def plot( default_chance_level_line_kw, chance_level_kw ) - self.line_ = [] + self.lines_ = [] for fpr, tpr, line_kw in zip(self.fprs, self.tprs, line_kwargs): self.line_.extend(self.ax_.plot(fpr, tpr, **line_kw)) - # Should we do this to be backwards compatible or have `line_` always be list? - if len(self.line_) == 1: - self.line_ = self.line_[0] info_pos_label = ( f" (Positive label: {self.pos_label})" if self.pos_label is not None else "" @@ -296,6 +290,15 @@ def plot( return self + #TODO(1.9): Remove + @deprecated( + "Attribute `line_` is deprecated in 1.7 and will be removed in " + "1.9. Use `lines_` instead." + ) + @property + def line_(self): + return self.lines_[0] + @classmethod def from_estimator( cls, @@ -680,7 +683,7 @@ def from_cv_results( f"Expected {train_size + test_size}, got {_num_samples(X)}." ) - fold_names_, fold_line_kws_ = _process_fold_names_line_kwargs( + fold_names_, fold_line_kwargs_ = _process_fold_names_line_kwargs( len(cv_results["estimator"]), fold_names, fold_line_kwargs ) @@ -715,13 +718,13 @@ def from_cv_results( viz = cls( fprs=fpr_all, tprs=tpr_all, + names=fold_names_, roc_aucs=auc_all, pos_label=pos_label, ) return viz.plot( ax=ax, - names=fold_names_, - fold_line_kws=fold_line_kws_, + fold_line_kwargs=fold_line_kwargs_, plot_chance_level=plot_chance_level, chance_level_kw=chance_level_kw, ) diff --git a/sklearn/utils/_plotting.py b/sklearn/utils/_plotting.py index 800cd244a1934..1c2657a3d2d0f 100644 --- a/sklearn/utils/_plotting.py +++ b/sklearn/utils/_plotting.py @@ -289,7 +289,8 @@ def _process_fold_names_line_kwargs(n_curves, fold_names, fold_line_kwargs): ) if fold_names is None: - fold_names_ = [f"ROC fold: {idx}" for idx in range(n_curves)] + # " fold ?" + fold_names_ = [f"Fold: {idx}" for idx in range(n_curves)] elif len(fold_names) != n_curves: raise ValueError( msg.format(param="fold_names", len_param=len(fold_names), n_curves=n_curves) From 042c0f669b83c0ec903717482560b0dbb410bc2c Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Wed, 15 Jan 2025 11:11:01 +1100 Subject: [PATCH 07/63] lint --- sklearn/metrics/_plot/roc_curve.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/metrics/_plot/roc_curve.py b/sklearn/metrics/_plot/roc_curve.py index bb74da68fdbbe..f342c8fccda84 100644 --- a/sklearn/metrics/_plot/roc_curve.py +++ b/sklearn/metrics/_plot/roc_curve.py @@ -1,7 +1,7 @@ # Authors: The scikit-learn developers # SPDX-License-Identifier: BSD-3-Clause -from ...utils import deprecated, _safe_indexing +from ...utils import _safe_indexing, deprecated from ...utils._plotting import ( _BinaryClassifierCurveDisplayMixin, _check_param_lengths, @@ -292,7 +292,7 @@ def plot( return self - #TODO(1.9): Remove + # TODO(1.9): Remove @deprecated( "Attribute `line_` is deprecated in 1.7 and will be removed in " "1.9. Use `lines_` instead." From e8a507336b648d53425fcd86a84865960301683a Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Wed, 15 Jan 2025 11:49:56 +1100 Subject: [PATCH 08/63] ignore mypy --- sklearn/metrics/_plot/roc_curve.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sklearn/metrics/_plot/roc_curve.py b/sklearn/metrics/_plot/roc_curve.py index f342c8fccda84..f190ef5679f5e 100644 --- a/sklearn/metrics/_plot/roc_curve.py +++ b/sklearn/metrics/_plot/roc_curve.py @@ -293,7 +293,9 @@ def plot( return self # TODO(1.9): Remove - @deprecated( + # Is it worth adding a global ignore for mypy error? + # mypy error: Decorated property not supported + @deprecated( # type: ignore "Attribute `line_` is deprecated in 1.7 and will be removed in " "1.9. Use `lines_` instead." ) From 4431c5527bf62a15972dad473323336d9f2f93bb Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Wed, 15 Jan 2025 12:35:45 +1100 Subject: [PATCH 09/63] fix validate plot param --- sklearn/metrics/_plot/roc_curve.py | 9 +++------ sklearn/utils/_plotting.py | 10 ++++++---- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/sklearn/metrics/_plot/roc_curve.py b/sklearn/metrics/_plot/roc_curve.py index f190ef5679f5e..87ab4ee5cc450 100644 --- a/sklearn/metrics/_plot/roc_curve.py +++ b/sklearn/metrics/_plot/roc_curve.py @@ -229,7 +229,8 @@ def plot( Object that stores computed values. """ names = _deprecate_singular(name, names, "name") - names_ = self.names if (names[0] is None) else names + # Not sure about this, as ideally we would check params are correct first?? + self.ax_, self.figure_, names_ = self._validate_plot_params(ax=ax, name=names) _check_param_lengths( {"self.fprs": self.fprs, "self.tprs": self.tprs}, {"roc_aucs": self.roc_aucs, "self.names (or names from `plot`)": names_}, @@ -237,10 +238,6 @@ def plot( ) n_curves = len(self.fprs) - self.ax_, self.figure_, _ = self._validate_plot_params( - ax=ax, - ) - line_kwargs = self._get_line_kwargs( n_curves, names_, self.roc_aucs, fold_line_kwargs, **kwargs ) @@ -664,7 +661,7 @@ def from_cv_results( >>> clf = SVC(random_state=0) >>> cv_results = cross_validate( ... clf, X, y, cv=3, return_estimator=True, return_indices=True) - >>> RocCurveDisplay.from_cv_results(cv_results, X, y, kind="both") + >>> RocCurveDisplay.from_cv_results(cv_results, X, y) <...> >>> plt.show() """ diff --git a/sklearn/utils/_plotting.py b/sklearn/utils/_plotting.py index 1c2657a3d2d0f..eb64d52842862 100644 --- a/sklearn/utils/_plotting.py +++ b/sklearn/utils/_plotting.py @@ -19,16 +19,18 @@ class _BinaryClassifierCurveDisplayMixin: the target and gather the response of the estimator. """ - def _validate_plot_params(self, *, ax=None, names=None): + def _validate_plot_params(self, *, ax=None, name=None): check_matplotlib_support(f"{self.__class__.__name__}.plot") import matplotlib.pyplot as plt if ax is None: _, ax = plt.subplots() - if names is not None: - names = self.names if names[0] is None else names - return ax, ax.figure, names + if name is None: + name = getattr(self, "estimator_name", None) + elif isinstance(name, list): + name = self.names if name[0] is None else name + return ax, ax.figure, name @classmethod def _get_line_kwargs( From 34d8051ba731d1054aab436dced086628e5d6813 Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Wed, 15 Jan 2025 13:44:54 +1100 Subject: [PATCH 10/63] fix from predictions --- sklearn/metrics/_plot/roc_curve.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/sklearn/metrics/_plot/roc_curve.py b/sklearn/metrics/_plot/roc_curve.py index 87ab4ee5cc450..aa34d54ea9eac 100644 --- a/sklearn/metrics/_plot/roc_curve.py +++ b/sklearn/metrics/_plot/roc_curve.py @@ -43,7 +43,7 @@ class RocCurveDisplay(_BinaryClassifierCurveDisplayMixin): Area under ROC curve. Should be list of the same length as `fprs` and `tprs` or None, in which case no area under ROC curve score is shown. - names : str or list of str, default=None + names : list of str, default=None Label for the ROC curve. Should be list of the same length as `fprs` and `tprs` or None, in which case no name is shown. @@ -79,7 +79,7 @@ class RocCurveDisplay(_BinaryClassifierCurveDisplayMixin): `roc_auc` is deprecated in 1.7 and will be removed in 1.9. Use `roc_aucs` instead. - name : str or list of str, default=None + name : str, default=None Label for the ROC curve. For multiple ROC curves, `name` can be a list of the same length as `tpr` and `fpr`. If None, no name is shown. @@ -180,8 +180,8 @@ def plot( names : list of str, default=None Names of each ROC curve for labeling each curve in the legend. - If `None`, use name provided at `RocCurveDisplay` initialization. If none - provided at initialization, no legend is added. + If `None`, use `names` provided at `RocCurveDisplay` initialization. If + also not provided at initialization, no legend is added. .. versionadded:: 1.7 @@ -257,7 +257,7 @@ def plot( self.lines_ = [] for fpr, tpr, line_kw in zip(self.fprs, self.tprs, line_kwargs): - self.line_.extend(self.ax_.plot(fpr, tpr, **line_kw)) + self.lines_.extend(self.ax_.plot(fpr, tpr, **line_kw)) info_pos_label = ( f" (Positive label: {self.pos_label})" if self.pos_label is not None else "" @@ -544,16 +544,17 @@ def from_predictions( roc_auc = auc(fpr, tpr) viz = cls( - fpr=fpr, - tpr=tpr, - roc_auc=roc_auc, - name=name, + fprs=[fpr], + tprs=[tpr], + roc_aucs=[roc_auc], + names=[name], pos_label=pos_label_validated, ) return viz.plot( ax=ax, - name=name, + # Should we provide `name` to both `cls` and `plot` or just `cls`? + names=[name], plot_chance_level=plot_chance_level, chance_level_kw=chance_level_kw, despine=despine, From dc6adceea499731239204736230ba919f06b5c54 Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Wed, 15 Jan 2025 14:04:31 +1100 Subject: [PATCH 11/63] fix example in docstring --- sklearn/metrics/_plot/roc_curve.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/metrics/_plot/roc_curve.py b/sklearn/metrics/_plot/roc_curve.py index aa34d54ea9eac..b9e253ea92fe5 100644 --- a/sklearn/metrics/_plot/roc_curve.py +++ b/sklearn/metrics/_plot/roc_curve.py @@ -129,7 +129,7 @@ class RocCurveDisplay(_BinaryClassifierCurveDisplayMixin): >>> pred = np.array([0.1, 0.4, 0.35, 0.8]) >>> fpr, tpr, thresholds = metrics.roc_curve(y, pred) >>> roc_auc = metrics.auc(fpr, tpr) - >>> display = metrics.RocCurveDisplay(fpr=fpr, tpr=tpr, roc_auc=roc_auc, + >>> display = metrics.RocCurveDisplay(fprs=[fpr], tprs=[tpr], roc_aucs=[roc_auc], ... name='example estimator') >>> display.plot() <...> From bdf2e43d452548a1db082190d5cbdc008eca18f6 Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Wed, 15 Jan 2025 14:41:19 +1100 Subject: [PATCH 12/63] fix tests --- sklearn/metrics/_plot/roc_curve.py | 13 +++++----- .../_plot/tests/test_common_curve_display.py | 4 +-- .../_plot/tests/test_roc_curve_display.py | 26 +++++++++---------- 3 files changed, 22 insertions(+), 21 deletions(-) diff --git a/sklearn/metrics/_plot/roc_curve.py b/sklearn/metrics/_plot/roc_curve.py index b9e253ea92fe5..762d1c3e35e21 100644 --- a/sklearn/metrics/_plot/roc_curve.py +++ b/sklearn/metrics/_plot/roc_curve.py @@ -44,8 +44,9 @@ class RocCurveDisplay(_BinaryClassifierCurveDisplayMixin): `tprs` or None, in which case no area under ROC curve score is shown. names : list of str, default=None - Label for the ROC curve. Should be list of the same length as - `fprs` and `tprs` or None, in which case no name is shown. + Names of each ROC curve, used for labeling curves in the legend. + Should be list of the same length as `fprs` and `tprs`, or None, in which + case no legend is added. pos_label : int, float, bool or str, default=None The class considered as the positive class when computing the roc auc @@ -179,7 +180,7 @@ def plot( created. names : list of str, default=None - Names of each ROC curve for labeling each curve in the legend. + Names of each ROC curve, used for labeling curves in the legend. If `None`, use `names` provided at `RocCurveDisplay` initialization. If also not provided at initialization, no legend is added. @@ -619,9 +620,9 @@ def from_cv_results( created. fold_names : list of str, default=None - Names used in the legend for each individual ROC curve. If `None`, - the name will be set to "ROC fold " where N is the index of the - CV fold. + Names of each ROC curve, used for labeling curves in the legend. + If `None`, the name will be set to "Fold " where N is the index of + the CV fold. fold_line_kwargs : dict or list of dict, default=None Dictionary with keywords passed to the matplotlib's `plot` function diff --git a/sklearn/metrics/_plot/tests/test_common_curve_display.py b/sklearn/metrics/_plot/tests/test_common_curve_display.py index 0014a73055e41..37bbd349f99b4 100644 --- a/sklearn/metrics/_plot/tests/test_common_curve_display.py +++ b/sklearn/metrics/_plot/tests/test_common_curve_display.py @@ -133,7 +133,7 @@ def fit(self, X, y): @pytest.mark.parametrize( - "Display", [DetCurveDisplay, PrecisionRecallDisplay, RocCurveDisplay] + "Display", [DetCurveDisplay, PrecisionRecallDisplay] ) @pytest.mark.parametrize("constructor_name", ["from_estimator", "from_predictions"]) def test_display_curve_estimator_name_multiple_calls( @@ -177,7 +177,7 @@ def test_display_curve_estimator_name_multiple_calls( ], ) @pytest.mark.parametrize( - "Display", [DetCurveDisplay, PrecisionRecallDisplay, RocCurveDisplay] + "Display", [DetCurveDisplay, PrecisionRecallDisplay] ) def test_display_curve_not_fitted_errors(pyplot, data_binary, clf, Display): """Check that a proper error is raised when the classifier is not diff --git a/sklearn/metrics/_plot/tests/test_roc_curve_display.py b/sklearn/metrics/_plot/tests/test_roc_curve_display.py index e7e2abd7bd5f5..dba947e6e04dc 100644 --- a/sklearn/metrics/_plot/tests/test_roc_curve_display.py +++ b/sklearn/metrics/_plot/tests/test_roc_curve_display.py @@ -99,11 +99,11 @@ def test_roc_curve_display_plotting( pos_label=pos_label, ) - assert_allclose(display.roc_auc, auc(fpr, tpr)) - assert_allclose(display.fpr, fpr) - assert_allclose(display.tpr, tpr) + assert_allclose(display.roc_aucs[0], auc(fpr, tpr)) + assert_allclose(display.fprs[0], fpr) + assert_allclose(display.tprs[0], tpr) - assert display.estimator_name == default_name + assert display.names[0] == default_name import matplotlib as mpl # noqal @@ -115,7 +115,7 @@ def test_roc_curve_display_plotting( assert display.ax_.get_aspect() in ("equal", 1.0) assert display.ax_.get_xlim() == display.ax_.get_ylim() == (-0.01, 1.01) - expected_label = f"{default_name} (AUC = {display.roc_auc:.2f})" + expected_label = f"{default_name} (AUC = {display.roc_auc[0]:.2f})" assert display.line_.get_label() == expected_label expected_pos_label = 1 if pos_label is None else pos_label @@ -254,11 +254,11 @@ def test_roc_curve_display_complex_pipeline(pyplot, data_binary, clf, constructo name = "Classifier" assert name in display.line_.get_label() - assert display.estimator_name == name + assert display.names[0] == name @pytest.mark.parametrize( - "roc_auc, estimator_name, expected_label", + "roc_aucs, names, expected_label", [ (0.9, None, "AUC = 0.90"), (None, "my_est", "my_est"), @@ -266,13 +266,13 @@ def test_roc_curve_display_complex_pipeline(pyplot, data_binary, clf, constructo ], ) def test_roc_curve_display_default_labels( - pyplot, roc_auc, estimator_name, expected_label + pyplot, roc_aucs, names, expected_label ): """Check the default labels used in the display.""" fpr = np.array([0, 0.5, 1]) tpr = np.array([0, 0.5, 1]) disp = RocCurveDisplay( - fpr=fpr, tpr=tpr, roc_auc=roc_auc, estimator_name=estimator_name + fprs=[fpr], tprs=[tpr], roc_aucs=[roc_aucs], names=names ).plot() assert disp.line_.get_label() == expected_label @@ -329,8 +329,8 @@ def test_plot_roc_curve_pos_label(pyplot, response_method, constructor_name): roc_auc_limit = 0.95679 - assert display.roc_auc == pytest.approx(roc_auc_limit) - assert trapezoid(display.tpr, display.fpr) == pytest.approx(roc_auc_limit) + assert display.roc_aucs[0] == pytest.approx(roc_auc_limit) + assert trapezoid(display.tprs[0], display.fprs[0]) == pytest.approx(roc_auc_limit) if constructor_name == "from_estimator": display = RocCurveDisplay.from_estimator( @@ -347,8 +347,8 @@ def test_plot_roc_curve_pos_label(pyplot, response_method, constructor_name): pos_label="not cancer", ) - assert display.roc_auc == pytest.approx(roc_auc_limit) - assert trapezoid(display.tpr, display.fpr) == pytest.approx(roc_auc_limit) + assert display.roc_aucs[0] == pytest.approx(roc_auc_limit) + assert trapezoid(display.tpr[0], display.fpr[0]) == pytest.approx(roc_auc_limit) @pytest.mark.parametrize("despine", [True, False]) From f5dbb1d1f41888ea0d976a0f349f4a6f722ec7bb Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Wed, 15 Jan 2025 14:43:50 +1100 Subject: [PATCH 13/63] fix docstring example --- sklearn/metrics/_plot/roc_curve.py | 2 +- sklearn/metrics/_plot/tests/test_common_curve_display.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/sklearn/metrics/_plot/roc_curve.py b/sklearn/metrics/_plot/roc_curve.py index 762d1c3e35e21..951e912885d00 100644 --- a/sklearn/metrics/_plot/roc_curve.py +++ b/sklearn/metrics/_plot/roc_curve.py @@ -131,7 +131,7 @@ class RocCurveDisplay(_BinaryClassifierCurveDisplayMixin): >>> fpr, tpr, thresholds = metrics.roc_curve(y, pred) >>> roc_auc = metrics.auc(fpr, tpr) >>> display = metrics.RocCurveDisplay(fprs=[fpr], tprs=[tpr], roc_aucs=[roc_auc], - ... name='example estimator') + ... names=['example estimator']) >>> display.plot() <...> >>> plt.show() diff --git a/sklearn/metrics/_plot/tests/test_common_curve_display.py b/sklearn/metrics/_plot/tests/test_common_curve_display.py index 37bbd349f99b4..f073b7022303d 100644 --- a/sklearn/metrics/_plot/tests/test_common_curve_display.py +++ b/sklearn/metrics/_plot/tests/test_common_curve_display.py @@ -176,6 +176,8 @@ def test_display_curve_estimator_name_multiple_calls( ), ], ) +# Add separate test for displays that have converted to names?, +# add note to remove this one in 1.9 @pytest.mark.parametrize( "Display", [DetCurveDisplay, PrecisionRecallDisplay] ) From 144fd1376376b4c0a84e641f870b7abefadbdf2c Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Wed, 15 Jan 2025 15:11:53 +1100 Subject: [PATCH 14/63] black --- sklearn/metrics/_plot/tests/test_roc_curve_display.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/sklearn/metrics/_plot/tests/test_roc_curve_display.py b/sklearn/metrics/_plot/tests/test_roc_curve_display.py index dba947e6e04dc..1d9869d2a0088 100644 --- a/sklearn/metrics/_plot/tests/test_roc_curve_display.py +++ b/sklearn/metrics/_plot/tests/test_roc_curve_display.py @@ -265,9 +265,7 @@ def test_roc_curve_display_complex_pipeline(pyplot, data_binary, clf, constructo (0.8, "my_est2", "my_est2 (AUC = 0.80)"), ], ) -def test_roc_curve_display_default_labels( - pyplot, roc_aucs, names, expected_label -): +def test_roc_curve_display_default_labels(pyplot, roc_aucs, names, expected_label): """Check the default labels used in the display.""" fpr = np.array([0, 0.5, 1]) tpr = np.array([0, 0.5, 1]) From 0aa775141d104b88aeed2325a7e1193f7ec05738 Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Wed, 15 Jan 2025 15:12:27 +1100 Subject: [PATCH 15/63] black --- sklearn/metrics/_plot/tests/test_common_curve_display.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/sklearn/metrics/_plot/tests/test_common_curve_display.py b/sklearn/metrics/_plot/tests/test_common_curve_display.py index f073b7022303d..865099b540c66 100644 --- a/sklearn/metrics/_plot/tests/test_common_curve_display.py +++ b/sklearn/metrics/_plot/tests/test_common_curve_display.py @@ -132,9 +132,7 @@ def fit(self, X, y): Display.from_estimator(clf, X, y, response_method=response_method) -@pytest.mark.parametrize( - "Display", [DetCurveDisplay, PrecisionRecallDisplay] -) +@pytest.mark.parametrize("Display", [DetCurveDisplay, PrecisionRecallDisplay]) @pytest.mark.parametrize("constructor_name", ["from_estimator", "from_predictions"]) def test_display_curve_estimator_name_multiple_calls( pyplot, @@ -178,9 +176,7 @@ def test_display_curve_estimator_name_multiple_calls( ) # Add separate test for displays that have converted to names?, # add note to remove this one in 1.9 -@pytest.mark.parametrize( - "Display", [DetCurveDisplay, PrecisionRecallDisplay] -) +@pytest.mark.parametrize("Display", [DetCurveDisplay, PrecisionRecallDisplay]) def test_display_curve_not_fitted_errors(pyplot, data_binary, clf, Display): """Check that a proper error is raised when the classifier is not fitted.""" From 70f0127902f0eaac5f2ffc941c6a52293430766e Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Wed, 15 Jan 2025 18:57:14 +1100 Subject: [PATCH 16/63] fix tests --- sklearn/metrics/_plot/tests/test_roc_curve_display.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sklearn/metrics/_plot/tests/test_roc_curve_display.py b/sklearn/metrics/_plot/tests/test_roc_curve_display.py index 1d9869d2a0088..cd62c92bbb4e6 100644 --- a/sklearn/metrics/_plot/tests/test_roc_curve_display.py +++ b/sklearn/metrics/_plot/tests/test_roc_curve_display.py @@ -115,7 +115,7 @@ def test_roc_curve_display_plotting( assert display.ax_.get_aspect() in ("equal", 1.0) assert display.ax_.get_xlim() == display.ax_.get_ylim() == (-0.01, 1.01) - expected_label = f"{default_name} (AUC = {display.roc_auc[0]:.2f})" + expected_label = f"{default_name} (AUC = {display.roc_aucs[0]:.2f})" assert display.line_.get_label() == expected_label expected_pos_label = 1 if pos_label is None else pos_label @@ -272,7 +272,7 @@ def test_roc_curve_display_default_labels(pyplot, roc_aucs, names, expected_labe disp = RocCurveDisplay( fprs=[fpr], tprs=[tpr], roc_aucs=[roc_aucs], names=names ).plot() - assert disp.line_.get_label() == expected_label + assert disp.lines_[0].get_label() == expected_label @pytest.mark.parametrize("response_method", ["predict_proba", "decision_function"]) @@ -346,7 +346,7 @@ def test_plot_roc_curve_pos_label(pyplot, response_method, constructor_name): ) assert display.roc_aucs[0] == pytest.approx(roc_auc_limit) - assert trapezoid(display.tpr[0], display.fpr[0]) == pytest.approx(roc_auc_limit) + assert trapezoid(display.tprs[0], display.fprs[0]) == pytest.approx(roc_auc_limit) @pytest.mark.parametrize("despine", [True, False]) From b9e1b0b210200a16571a0e085a2305d80e2d139c Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Wed, 15 Jan 2025 21:11:49 +1100 Subject: [PATCH 17/63] fix testst --- .../_plot/tests/test_roc_curve_display.py | 19 ++++++++++--------- sklearn/utils/_plotting.py | 14 +++++++++++--- 2 files changed, 21 insertions(+), 12 deletions(-) diff --git a/sklearn/metrics/_plot/tests/test_roc_curve_display.py b/sklearn/metrics/_plot/tests/test_roc_curve_display.py index cd62c92bbb4e6..ae79abf4b2a26 100644 --- a/sklearn/metrics/_plot/tests/test_roc_curve_display.py +++ b/sklearn/metrics/_plot/tests/test_roc_curve_display.py @@ -107,8 +107,8 @@ def test_roc_curve_display_plotting( import matplotlib as mpl # noqal - assert isinstance(display.line_, mpl.lines.Line2D) - assert display.line_.get_alpha() == 0.8 + assert isinstance(display.lines_[0], mpl.lines.Line2D) + assert display.lines_[0].get_alpha() == 0.8 assert isinstance(display.ax_, mpl.axes.Axes) assert isinstance(display.figure_, mpl.figure.Figure) assert display.ax_.get_adjustable() == "box" @@ -116,7 +116,7 @@ def test_roc_curve_display_plotting( assert display.ax_.get_xlim() == display.ax_.get_ylim() == (-0.01, 1.01) expected_label = f"{default_name} (AUC = {display.roc_aucs[0]:.2f})" - assert display.line_.get_label() == expected_label + assert display.lines_[0].get_label() == expected_label expected_pos_label = 1 if pos_label is None else pos_label expected_ylabel = f"True Positive Rate (Positive label: {expected_pos_label})" @@ -180,8 +180,8 @@ def test_roc_curve_chance_level_line( import matplotlib as mpl # noqa - assert isinstance(display.line_, mpl.lines.Line2D) - assert display.line_.get_alpha() == 0.8 + assert isinstance(display.lines_[0], mpl.lines.Line2D) + assert display.lines_[0].get_alpha() == 0.8 assert isinstance(display.ax_, mpl.axes.Axes) assert isinstance(display.figure_, mpl.figure.Figure) @@ -253,7 +253,7 @@ def test_roc_curve_display_complex_pipeline(pyplot, data_binary, clf, constructo display = RocCurveDisplay.from_predictions(y, y) name = "Classifier" - assert name in display.line_.get_label() + assert name in display.lines_[0].get_label() assert display.names[0] == name @@ -261,8 +261,8 @@ def test_roc_curve_display_complex_pipeline(pyplot, data_binary, clf, constructo "roc_aucs, names, expected_label", [ (0.9, None, "AUC = 0.90"), - (None, "my_est", "my_est"), - (0.8, "my_est2", "my_est2 (AUC = 0.80)"), + (None, ["my_est"], "my_est"), + (0.8, ["my_est2"], "my_est2 (AUC = 0.80)"), ], ) def test_roc_curve_display_default_labels(pyplot, roc_aucs, names, expected_label): @@ -272,7 +272,8 @@ def test_roc_curve_display_default_labels(pyplot, roc_aucs, names, expected_labe disp = RocCurveDisplay( fprs=[fpr], tprs=[tpr], roc_aucs=[roc_aucs], names=names ).plot() - assert disp.lines_[0].get_label() == expected_label + print(disp.lines_[0].get_label()) + # assert disp.lines_[0].get_label() == expected_label @pytest.mark.parametrize("response_method", ["predict_proba", "decision_function"]) diff --git a/sklearn/utils/_plotting.py b/sklearn/utils/_plotting.py index eb64d52842862..641995934c89a 100644 --- a/sklearn/utils/_plotting.py +++ b/sklearn/utils/_plotting.py @@ -26,10 +26,18 @@ def _validate_plot_params(self, *, ax=None, name=None): if ax is None: _, ax = plt.subplots() + # Displays will either have `estimator_name` or `names`, + # try one first, then the other. if name is None: - name = getattr(self, "estimator_name", None) - elif isinstance(name, list): - name = self.names if name[0] is None else name + for attr in ["estimator_name", "names"]: + name = getattr(self, attr, None) + if name is not None: + break + # One line shorter alternative, but looks funny: + # if name is None: + # name = getattr(self, "estimator_name", None) + # if name is None: + # name = getattr(self, "names", None) return ax, ax.figure, name @classmethod From 73984e3087c2dc94994f1a8d7c367dd7fdf68793 Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Wed, 15 Jan 2025 21:40:26 +1100 Subject: [PATCH 18/63] fix tests --- sklearn/metrics/_plot/roc_curve.py | 2 +- sklearn/utils/_plotting.py | 7 +++---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/sklearn/metrics/_plot/roc_curve.py b/sklearn/metrics/_plot/roc_curve.py index 951e912885d00..82ab2369d3b70 100644 --- a/sklearn/metrics/_plot/roc_curve.py +++ b/sklearn/metrics/_plot/roc_curve.py @@ -240,7 +240,7 @@ def plot( n_curves = len(self.fprs) line_kwargs = self._get_line_kwargs( - n_curves, names_, self.roc_aucs, fold_line_kwargs, **kwargs + n_curves, names_, self.roc_aucs, "AUC", fold_line_kwargs, **kwargs ) default_chance_level_line_kw = { diff --git a/sklearn/utils/_plotting.py b/sklearn/utils/_plotting.py index 641995934c89a..3d8fb71fdb0db 100644 --- a/sklearn/utils/_plotting.py +++ b/sklearn/utils/_plotting.py @@ -26,8 +26,6 @@ def _validate_plot_params(self, *, ax=None, name=None): if ax is None: _, ax = plt.subplots() - # Displays will either have `estimator_name` or `names`, - # try one first, then the other. if name is None: for attr in ["estimator_name", "names"]: name = getattr(self, attr, None) @@ -46,6 +44,7 @@ def _get_line_kwargs( n_curves, names, summary_values, + summary_value_name, fold_line_kws, default_line_kwargs={}, **kwargs, @@ -81,10 +80,10 @@ def _get_line_kwargs( ): if curve_summary_value is not None and curve_name is not None: default_line_kwargs["label"] = ( - f"{curve_name} (AP = {curve_summary_value:0.2f})" + f"{curve_name} ({summary_value_name} = {curve_summary_value:0.2f})" ) elif curve_summary_value is not None: - default_line_kwargs["label"] = f"AP = {curve_summary_value:0.2f}" + default_line_kwargs["label"] = f"{summary_value_name} = {curve_summary_value:0.2f}" elif curve_name is not None: default_line_kwargs["label"] = curve_name From d0926f6784c8a5fd958a511042ee96748446667c Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Thu, 16 Jan 2025 12:09:33 +1100 Subject: [PATCH 19/63] lint --- sklearn/utils/_plotting.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sklearn/utils/_plotting.py b/sklearn/utils/_plotting.py index 3d8fb71fdb0db..d28a3f154ecab 100644 --- a/sklearn/utils/_plotting.py +++ b/sklearn/utils/_plotting.py @@ -83,7 +83,9 @@ def _get_line_kwargs( f"{curve_name} ({summary_value_name} = {curve_summary_value:0.2f})" ) elif curve_summary_value is not None: - default_line_kwargs["label"] = f"{summary_value_name} = {curve_summary_value:0.2f}" + default_line_kwargs["label"] = ( + f"{summary_value_name} = {curve_summary_value:0.2f}" + ) elif curve_name is not None: default_line_kwargs["label"] = curve_name From 819d6fe26d75e5aedf0f4b1e9f9fce4fe80dc1f7 Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Thu, 16 Jan 2025 12:55:26 +1100 Subject: [PATCH 20/63] fix docstring --- sklearn/metrics/_plot/roc_curve.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/metrics/_plot/roc_curve.py b/sklearn/metrics/_plot/roc_curve.py index 82ab2369d3b70..8ada1de24342d 100644 --- a/sklearn/metrics/_plot/roc_curve.py +++ b/sklearn/metrics/_plot/roc_curve.py @@ -222,7 +222,7 @@ def plot( **kwargs : dict For a single curve plots only, keyword arguments to be passed to matplotlib's `plot`. Ignored for multi-curve plots. - (Note req for backwards compat, maybe not ideal?) + (Note req for backwards compat, maybe not ideal?). Returns ------- From 8b500d8f33702b72a75565457b2a038c684c0be3 Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Thu, 16 Jan 2025 14:50:49 +1100 Subject: [PATCH 21/63] add whats new --- .../upcoming_changes/sklearn.metrics/30399.enhancement.rst | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 doc/whats_new/upcoming_changes/sklearn.metrics/30399.enhancement.rst diff --git a/doc/whats_new/upcoming_changes/sklearn.metrics/30399.enhancement.rst b/doc/whats_new/upcoming_changes/sklearn.metrics/30399.enhancement.rst new file mode 100644 index 0000000000000..e9b75a20a5000 --- /dev/null +++ b/doc/whats_new/upcoming_changes/sklearn.metrics/30399.enhancement.rst @@ -0,0 +1,3 @@ +- Add class method `from_cv_results` to :class:`metrics.RocCurveDisplay`, which allows + easy plotting of multiple ROC curves using cross-validation results. + By :user:`Lucy Liu ` From e7244d66ba31ecfac7ce7520ae989ffd21bea0b6 Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Thu, 16 Jan 2025 16:39:22 +1100 Subject: [PATCH 22/63] amend current test --- .../_plot/tests/test_roc_curve_display.py | 25 ++++++++++--------- 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/sklearn/metrics/_plot/tests/test_roc_curve_display.py b/sklearn/metrics/_plot/tests/test_roc_curve_display.py index ae79abf4b2a26..92761555f534b 100644 --- a/sklearn/metrics/_plot/tests/test_roc_curve_display.py +++ b/sklearn/metrics/_plot/tests/test_roc_curve_display.py @@ -9,7 +9,7 @@ from sklearn.exceptions import NotFittedError from sklearn.linear_model import LogisticRegression from sklearn.metrics import RocCurveDisplay, auc, roc_curve -from sklearn.model_selection import train_test_split +from sklearn.model_selection import cross_validate, train_test_split from sklearn.pipeline import make_pipeline from sklearn.preprocessing import StandardScaler from sklearn.utils import shuffle @@ -258,22 +258,23 @@ def test_roc_curve_display_complex_pipeline(pyplot, data_binary, clf, constructo @pytest.mark.parametrize( - "roc_aucs, names, expected_label", + "roc_aucs, names, expected_labels", [ - (0.9, None, "AUC = 0.90"), - (None, ["my_est"], "my_est"), - (0.8, ["my_est2"], "my_est2 (AUC = 0.80)"), + ([0.9, 0.8], None, ["AUC = 0.90", "AUC = 0.80"]), + ([0.8, 0.7], [None, None], ["AUC = 0.80", "AUC = 0.70"]), + (None, ["fold1", "fold2"], ["fold1", "fold2"]), + ([0.8, 0.7], ["my_est2", "my_est2"], ["my_est2 (AUC = 0.80)", "my_est2 (AUC = 0.70)"]), ], ) -def test_roc_curve_display_default_labels(pyplot, roc_aucs, names, expected_label): +def test_roc_curve_display_default_labels(pyplot, roc_aucs, names, expected_labels): """Check the default labels used in the display.""" - fpr = np.array([0, 0.5, 1]) - tpr = np.array([0, 0.5, 1]) + fprs = [np.array([0, 0.5, 1]), np.array([0, 0.3, 1])] + tprs = [np.array([0, 0.5, 1]), np.array([0, 0.3, 1])] disp = RocCurveDisplay( - fprs=[fpr], tprs=[tpr], roc_aucs=[roc_aucs], names=names + fprs=fprs, tprs=tprs, roc_aucs=roc_aucs, names=names ).plot() - print(disp.lines_[0].get_label()) - # assert disp.lines_[0].get_label() == expected_label + for idx, expected_label in enumerate(expected_labels): + assert disp.lines_[idx].get_label() == expected_label @pytest.mark.parametrize("response_method", ["predict_proba", "decision_function"]) @@ -301,7 +302,7 @@ def test_plot_roc_curve_pos_label(pyplot, response_method, constructor_name): classifier = LogisticRegression() classifier.fit(X_train, y_train) - # sanity check to be sure the positive class is classes_[0] and that we + # sanity check to be sure the positive class is `classes_[0]` and that we # are betrayed by the class imbalance assert classifier.classes_.tolist() == ["cancer", "not cancer"] From bbfc1aab566d6a628cd707ce457d165e826776a8 Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Thu, 16 Jan 2025 16:41:31 +1100 Subject: [PATCH 23/63] lint --- sklearn/metrics/_plot/tests/test_roc_curve_display.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/sklearn/metrics/_plot/tests/test_roc_curve_display.py b/sklearn/metrics/_plot/tests/test_roc_curve_display.py index 92761555f534b..b800689af087e 100644 --- a/sklearn/metrics/_plot/tests/test_roc_curve_display.py +++ b/sklearn/metrics/_plot/tests/test_roc_curve_display.py @@ -263,16 +263,18 @@ def test_roc_curve_display_complex_pipeline(pyplot, data_binary, clf, constructo ([0.9, 0.8], None, ["AUC = 0.90", "AUC = 0.80"]), ([0.8, 0.7], [None, None], ["AUC = 0.80", "AUC = 0.70"]), (None, ["fold1", "fold2"], ["fold1", "fold2"]), - ([0.8, 0.7], ["my_est2", "my_est2"], ["my_est2 (AUC = 0.80)", "my_est2 (AUC = 0.70)"]), + ( + [0.8, 0.7], + ["my_est2", "my_est2"], + ["my_est2 (AUC = 0.80)", "my_est2 (AUC = 0.70)"], + ), ], ) def test_roc_curve_display_default_labels(pyplot, roc_aucs, names, expected_labels): """Check the default labels used in the display.""" fprs = [np.array([0, 0.5, 1]), np.array([0, 0.3, 1])] tprs = [np.array([0, 0.5, 1]), np.array([0, 0.3, 1])] - disp = RocCurveDisplay( - fprs=fprs, tprs=tprs, roc_aucs=roc_aucs, names=names - ).plot() + disp = RocCurveDisplay(fprs=fprs, tprs=tprs, roc_aucs=roc_aucs, names=names).plot() for idx, expected_label in enumerate(expected_labels): assert disp.lines_[idx].get_label() == expected_label From 65f3d21be3c8da4ad86f2d4dcd56164401252df4 Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Thu, 16 Jan 2025 16:44:28 +1100 Subject: [PATCH 24/63] lint --- sklearn/metrics/_plot/tests/test_roc_curve_display.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/metrics/_plot/tests/test_roc_curve_display.py b/sklearn/metrics/_plot/tests/test_roc_curve_display.py index b800689af087e..80fb0f6b6cdfa 100644 --- a/sklearn/metrics/_plot/tests/test_roc_curve_display.py +++ b/sklearn/metrics/_plot/tests/test_roc_curve_display.py @@ -9,7 +9,7 @@ from sklearn.exceptions import NotFittedError from sklearn.linear_model import LogisticRegression from sklearn.metrics import RocCurveDisplay, auc, roc_curve -from sklearn.model_selection import cross_validate, train_test_split +from sklearn.model_selection import train_test_split from sklearn.pipeline import make_pipeline from sklearn.preprocessing import StandardScaler from sklearn.utils import shuffle @@ -105,7 +105,7 @@ def test_roc_curve_display_plotting( assert display.names[0] == default_name - import matplotlib as mpl # noqal + import matplotlib as mpl # noqa assert isinstance(display.lines_[0], mpl.lines.Line2D) assert display.lines_[0].get_alpha() == 0.8 From 47f6bc715e277913d8f4b205e6c8d780a5765ca7 Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Tue, 21 Jan 2025 12:35:05 +1100 Subject: [PATCH 25/63] review --- sklearn/metrics/_plot/roc_curve.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sklearn/metrics/_plot/roc_curve.py b/sklearn/metrics/_plot/roc_curve.py index 8ada1de24342d..87391edffe5d2 100644 --- a/sklearn/metrics/_plot/roc_curve.py +++ b/sklearn/metrics/_plot/roc_curve.py @@ -55,7 +55,7 @@ class RocCurveDisplay(_BinaryClassifierCurveDisplayMixin): .. versionadded:: 0.24 - fpr : ndarray or list of ndarray + fpr : ndarray False positive rate. When plotting multiple ROC curves, `fpr` and `tpr` should lists of the same length. @@ -63,7 +63,7 @@ class RocCurveDisplay(_BinaryClassifierCurveDisplayMixin): `fpr` is deprecated in 1.7 and will be removed in 1.9. Use `fprs` instead. - tpr : ndarray or list of ndarray + tpr : ndarray True positive rate. When plotting multiple ROC curves, `fpr` and `tpr` should lists of the same length. @@ -71,7 +71,7 @@ class RocCurveDisplay(_BinaryClassifierCurveDisplayMixin): `tpr` is deprecated in 1.7 and will be removed in 1.9. Use `tprs` instead. - roc_auc : float or list of floats, default=None + roc_auc : float, default=None Area under ROC curve. When plotting multiple ROC curves, can be a list of the same length as `fpr` and `tpr`. If None, no roc_auc score is shown. From 231eb51562c5d3b77c82c7459ccf9f2bcb2207cb Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Thu, 23 Jan 2025 15:19:46 +1100 Subject: [PATCH 26/63] review --- sklearn/metrics/_plot/roc_curve.py | 7 ++++--- sklearn/utils/_plotting.py | 13 ++++--------- 2 files changed, 8 insertions(+), 12 deletions(-) diff --git a/sklearn/metrics/_plot/roc_curve.py b/sklearn/metrics/_plot/roc_curve.py index 87391edffe5d2..0407579b4942c 100644 --- a/sklearn/metrics/_plot/roc_curve.py +++ b/sklearn/metrics/_plot/roc_curve.py @@ -207,7 +207,8 @@ def plot( to draw the individual ROC curves. If a list is provided, the parameters are applied to the ROC curves sequentially. If a single dictionary is provided, the same parameters are applied to all ROC - curves. Ignored for single curve plots. + curves. Ignored for single curve plots - pass as `**kwargs` for + single curve plots. .. versionadded:: 1.7 @@ -221,8 +222,8 @@ def plot( **kwargs : dict For a single curve plots only, keyword arguments to be passed to - matplotlib's `plot`. Ignored for multi-curve plots. - (Note req for backwards compat, maybe not ideal?). + matplotlib's `plot`. Ignored for multi-curve plots - use `fold_line_kwargs` + for multi-curve plots. Returns ------- diff --git a/sklearn/utils/_plotting.py b/sklearn/utils/_plotting.py index d28a3f154ecab..4cae5f838a563 100644 --- a/sklearn/utils/_plotting.py +++ b/sklearn/utils/_plotting.py @@ -26,16 +26,11 @@ def _validate_plot_params(self, *, ax=None, name=None): if ax is None: _, ax = plt.subplots() + # write better comment if name is None: - for attr in ["estimator_name", "names"]: - name = getattr(self, attr, None) - if name is not None: - break - # One line shorter alternative, but looks funny: - # if name is None: - # name = getattr(self, "estimator_name", None) - # if name is None: - # name = getattr(self, "names", None) + name = getattr(self, "estimator_name", None) + if name is None: + name = getattr(self, "names", None) return ax, ax.figure, name @classmethod From 8fc20e86c10b732ba682614981c8f8cfffc11f14 Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Thu, 13 Feb 2025 15:28:50 +1100 Subject: [PATCH 27/63] revert to singular only --- sklearn/metrics/_plot/roc_curve.py | 196 ++++++++++++----------------- sklearn/utils/_plotting.py | 9 +- 2 files changed, 87 insertions(+), 118 deletions(-) diff --git a/sklearn/metrics/_plot/roc_curve.py b/sklearn/metrics/_plot/roc_curve.py index 0407579b4942c..6cf6548da598c 100644 --- a/sklearn/metrics/_plot/roc_curve.py +++ b/sklearn/metrics/_plot/roc_curve.py @@ -1,11 +1,10 @@ # Authors: The scikit-learn developers # SPDX-License-Identifier: BSD-3-Clause -from ...utils import _safe_indexing, deprecated +from ...utils import _safe_indexing from ...utils._plotting import ( _BinaryClassifierCurveDisplayMixin, _check_param_lengths, - _deprecate_singular, _despine, _process_fold_names_line_kwargs, _validate_style_kwargs, @@ -29,24 +28,27 @@ class RocCurveDisplay(_BinaryClassifierCurveDisplayMixin): Parameters ---------- - fprs : list of ndarray + fpr : ndarray or list of ndarrays False positive rates. Each ndarray should contain values for a single curve. If plotting multiple curves, list should be of same length as - and `tprs`. + and `tpr`. - tprs : list of ndarray + tpr : ndarray list of ndarrays True positive rates. Each ndarray should contain values for a single curve. If plotting multiple curves, list should be of same length as - and `fprs`. + and `fpr`. - roc_aucs : list of floats, default=None - Area under ROC curve. Should be list of the same length as `fprs` and - `tprs` or None, in which case no area under ROC curve score is shown. + roc_auc : float or list of floats, default=None + Area under ROC curve, used for labeling curves in the legend. + If plotting multiple curves, should be a list of the same length as `fpr` + and `tpr`. If `None`, no area under ROC curve score is shown. If `name` + is also `None` no legend is added. - names : list of str, default=None - Names of each ROC curve, used for labeling curves in the legend. - Should be list of the same length as `fprs` and `tprs`, or None, in which - case no legend is added. + name : str or list of str, default=None + Name of each ROC curve, used for labeling curves in the legend. + If plotting multiple curves, should be a list of the same length as `fpr` + and `tpr`. If `None`, no name is not shown in the legend. If `roc_auc` + is also `None` no legend is added. pos_label : int, float, bool or str, default=None The class considered as the positive class when computing the roc auc @@ -55,45 +57,15 @@ class RocCurveDisplay(_BinaryClassifierCurveDisplayMixin): .. versionadded:: 0.24 - fpr : ndarray - False positive rate. When plotting multiple ROC curves, `fpr` and `tpr` should - lists of the same length. - - .. deprecated:: 1.7 - `fpr` is deprecated in 1.7 and will be removed in 1.9. - Use `fprs` instead. - - tpr : ndarray - True positive rate. When plotting multiple ROC curves, `fpr` and `tpr` should - lists of the same length. - - .. deprecated:: 1.7 - `tpr` is deprecated in 1.7 and will be removed in 1.9. - Use `tprs` instead. - - roc_auc : float, default=None - Area under ROC curve. When plotting multiple ROC curves, can be a list - of the same length as `fpr` and `tpr`. - If None, no roc_auc score is shown. - - .. deprecated:: 1.7 - `roc_auc` is deprecated in 1.7 and will be removed in 1.9. - Use `roc_aucs` instead. - - name : str, default=None - Label for the ROC curve. For multiple ROC curves, `name` can be a list - of the same length as `tpr` and `fpr`. - If None, no name is shown. - - .. deprecated:: 1.7 - `name` is deprecated in 1.7 and will be removed in 1.9. - Use `names` instead. - Attributes ---------- - lines_ : list of matplotlib Artists + line_ : matplotlib Artist or list of matplotlib Artists ROC Curves. + .. versionchanged:: 1.7 + This attribute can now be a list of Artists, when multiple curves are + plotted. + chance_level_ : matplotlib Artist or None The chance level line. It is `None` if the chance level is not plotted. @@ -105,13 +77,6 @@ class RocCurveDisplay(_BinaryClassifierCurveDisplayMixin): figure_ : matplotlib Figure Figure containing the curve. - line_ : matplotlib Artist - ROC Curve. - - .. deprecated:: 1.7 - `line_` is deprecated in 1.7 and will be removed in 1.9. Use `lines_` - instead. - See Also -------- roc_curve : Compute Receiver operating characteristic (ROC) curve. @@ -130,8 +95,8 @@ class RocCurveDisplay(_BinaryClassifierCurveDisplayMixin): >>> pred = np.array([0.1, 0.4, 0.35, 0.8]) >>> fpr, tpr, thresholds = metrics.roc_curve(y, pred) >>> roc_auc = metrics.auc(fpr, tpr) - >>> display = metrics.RocCurveDisplay(fprs=[fpr], tprs=[tpr], roc_aucs=[roc_auc], - ... names=['example estimator']) + >>> display = metrics.RocCurveDisplay(fpr=fpr, tpr=tpr, roc_auc=roc_auc, + ... name='example estimator') >>> display.plot() <...> >>> plt.show() @@ -140,32 +105,51 @@ class RocCurveDisplay(_BinaryClassifierCurveDisplayMixin): def __init__( self, *, - fprs, - tprs, - roc_aucs=None, - names=None, + fpr, + tpr, + roc_auc=None, + name=None, pos_label=None, - fpr="deprecated", - tpr="deprecated", - roc_auc="deprecated", - name="deprecated", ): - self.fprs = _deprecate_singular(fpr, fprs, "fpr") - self.tprs = _deprecate_singular(tpr, tprs, "tpr") - self.roc_aucs = _deprecate_singular(roc_auc, roc_aucs, "roc_auc") - self.names = _deprecate_singular(name, names, "name") + self.fpr_ = ( + fpr + if isinstance(fpr, list) + else [ + fpr, + ] + ) + self.tpr_ = ( + tpr + if isinstance(tpr, list) + else [ + tpr, + ] + ) + self.roc_auc_ = ( + roc_auc + if isinstance(roc_auc, list) + else [ + roc_auc, + ] + ) + self.name_ = ( + name + if isinstance(name, list) + else [ + name, + ] + ) self.pos_label = pos_label def plot( self, ax=None, *, - names=None, + name=None, plot_chance_level=False, chance_level_kw=None, despine=False, fold_line_kwargs=None, - name="deprecated", **kwargs, ): """Plot visualization. @@ -179,10 +163,10 @@ def plot( Axes object to plot on. If `None`, a new figure and axes is created. - names : list of str, default=None - Names of each ROC curve, used for labeling curves in the legend. - If `None`, use `names` provided at `RocCurveDisplay` initialization. If - also not provided at initialization, no legend is added. + name : str or list of str, default=None + Name of each ROC curve, used for labeling curves in the legend. + If `None`, use `name` provided at `RocCurveDisplay` initialization. If + also not provided at initialization, no name is shown in the legend. .. versionadded:: 1.7 @@ -212,14 +196,6 @@ def plot( .. versionadded:: 1.7 - name : str, default=None - Name of ROC Curve for labeling. If `None`, use `estimator_name` if - not `None`, otherwise no labeling is shown. - - .. deprecated:: 1.7 - `name` is deprecated in 1.7 and will be removed in 1.9. - Use `names` instead. - **kwargs : dict For a single curve plots only, keyword arguments to be passed to matplotlib's `plot`. Ignored for multi-curve plots - use `fold_line_kwargs` @@ -230,18 +206,18 @@ def plot( display : :class:`~sklearn.metrics.RocCurveDisplay` Object that stores computed values. """ - names = _deprecate_singular(name, names, "name") - # Not sure about this, as ideally we would check params are correct first?? - self.ax_, self.figure_, names_ = self._validate_plot_params(ax=ax, name=names) + # TODO: Not sure about this, as ideally we would check params are correct + # first?? + self.ax_, self.figure_, name_ = self._validate_plot_params(ax=ax, name=name) _check_param_lengths( - {"self.fprs": self.fprs, "self.tprs": self.tprs}, - {"roc_aucs": self.roc_aucs, "self.names (or names from `plot`)": names_}, + {"self.fpr": self.fpr_, "self.tpr": self.tpr_}, + {"self.roc_auc": self.roc_auc_, "`name` from `plot` (or self.name)": name_}, "RocCurveDisplay", ) - n_curves = len(self.fprs) + n_curves = len(self.fpr_) line_kwargs = self._get_line_kwargs( - n_curves, names_, self.roc_aucs, "AUC", fold_line_kwargs, **kwargs + n_curves, name_, self.roc_auc_, "AUC", fold_line_kwargs, **kwargs ) default_chance_level_line_kw = { @@ -257,9 +233,12 @@ def plot( default_chance_level_line_kw, chance_level_kw ) - self.lines_ = [] - for fpr, tpr, line_kw in zip(self.fprs, self.tprs, line_kwargs): - self.lines_.extend(self.ax_.plot(fpr, tpr, **line_kw)) + self.line_ = [] + for fpr, tpr, line_kw in zip(self.fpr_, self.tpr_, line_kwargs): + self.line_.extend(self.ax_.plot(fpr, tpr, **line_kw)) + # Return single artist if only one curve is plotted + if len(self.line_) == 1: + self.line_ = self.line_[0] info_pos_label = ( f" (Positive label: {self.pos_label})" if self.pos_label is not None else "" @@ -291,17 +270,6 @@ def plot( return self - # TODO(1.9): Remove - # Is it worth adding a global ignore for mypy error? - # mypy error: Decorated property not supported - @deprecated( # type: ignore - "Attribute `line_` is deprecated in 1.7 and will be removed in " - "1.9. Use `lines_` instead." - ) - @property - def line_(self): - return self.lines_[0] - @classmethod def from_estimator( cls, @@ -477,7 +445,7 @@ def from_predictions( error will be raised. name : str, default=None - Name of ROC curve for labeling. If `None`, name will be set to + Name of ROC curve for legend labeling. If `None`, name will be set to `"Classifier"`. ax : matplotlib axes, default=None @@ -546,17 +514,17 @@ def from_predictions( roc_auc = auc(fpr, tpr) viz = cls( - fprs=[fpr], - tprs=[tpr], - roc_aucs=[roc_auc], - names=[name], + fpr=fpr, + tpr=tpr, + roc_auc=roc_auc, + name=name, pos_label=pos_label_validated, ) return viz.plot( ax=ax, # Should we provide `name` to both `cls` and `plot` or just `cls`? - names=[name], + name=name, plot_chance_level=plot_chance_level, chance_level_kw=chance_level_kw, despine=despine, @@ -720,10 +688,10 @@ def from_cv_results( auc_all.append(roc_auc) viz = cls( - fprs=fpr_all, - tprs=tpr_all, - names=fold_names_, - roc_aucs=auc_all, + fpr=fpr_all, + tpr=tpr_all, + name=fold_names_, + roc_auc=auc_all, pos_label=pos_label, ) return viz.plot( diff --git a/sklearn/utils/_plotting.py b/sklearn/utils/_plotting.py index 4cae5f838a563..93b1d6888d54b 100644 --- a/sklearn/utils/_plotting.py +++ b/sklearn/utils/_plotting.py @@ -26,11 +26,12 @@ def _validate_plot_params(self, *, ax=None, name=None): if ax is None: _, ax = plt.subplots() - # write better comment + # We are changing from `estimator_name` to `name`, Display objects will + # have one or the other. Try old attr name: `estimator_name` first. if name is None: name = getattr(self, "estimator_name", None) if name is None: - name = getattr(self, "names", None) + name = getattr(self, "name", None) return ax, ax.figure, name @classmethod @@ -281,8 +282,8 @@ def _check_param_lengths(required, optional, class_name): ) raise ValueError( f"{required_formatted}, and optional parameters {optional_formatted} " - f"from `{class_name}` initialization, should all be lists of the same " - f"length. Got: {lengths_formatted}" + f"from `{class_name}` initialization (or `plot`) should all be lists of " + f"the same length. Got: {lengths_formatted}" ) From 18b2b1ab8aa1584156a72c926585aa707e5fa0bb Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Thu, 13 Feb 2025 15:38:14 +1100 Subject: [PATCH 28/63] lint --- sklearn/metrics/_plot/roc_curve.py | 32 ++++-------------------------- 1 file changed, 4 insertions(+), 28 deletions(-) diff --git a/sklearn/metrics/_plot/roc_curve.py b/sklearn/metrics/_plot/roc_curve.py index 6cf6548da598c..e0bc069288773 100644 --- a/sklearn/metrics/_plot/roc_curve.py +++ b/sklearn/metrics/_plot/roc_curve.py @@ -111,34 +111,10 @@ def __init__( name=None, pos_label=None, ): - self.fpr_ = ( - fpr - if isinstance(fpr, list) - else [ - fpr, - ] - ) - self.tpr_ = ( - tpr - if isinstance(tpr, list) - else [ - tpr, - ] - ) - self.roc_auc_ = ( - roc_auc - if isinstance(roc_auc, list) - else [ - roc_auc, - ] - ) - self.name_ = ( - name - if isinstance(name, list) - else [ - name, - ] - ) + self.fpr_ = fpr if isinstance(fpr, list) else [fpr] + self.tpr_ = tpr if isinstance(tpr, list) else [tpr] + self.roc_auc_ = roc_auc if isinstance(roc_auc, list) else [roc_auc] + self.name_ = name if isinstance(name, list) else [name] self.pos_label = pos_label def plot( From a1824b9a46fd4e5b52b5149b798c193b855dab3e Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Thu, 13 Feb 2025 16:34:08 +1100 Subject: [PATCH 29/63] update tests, more fixes --- sklearn/metrics/_plot/roc_curve.py | 10 +++-- .../_plot/tests/test_roc_curve_display.py | 44 +++++++++---------- sklearn/utils/_plotting.py | 35 ++++----------- 3 files changed, 37 insertions(+), 52 deletions(-) diff --git a/sklearn/metrics/_plot/roc_curve.py b/sklearn/metrics/_plot/roc_curve.py index e0bc069288773..c2c12beea2f59 100644 --- a/sklearn/metrics/_plot/roc_curve.py +++ b/sklearn/metrics/_plot/roc_curve.py @@ -5,6 +5,7 @@ from ...utils._plotting import ( _BinaryClassifierCurveDisplayMixin, _check_param_lengths, + _convert_to_list_leaving_none, _despine, _process_fold_names_line_kwargs, _validate_style_kwargs, @@ -111,10 +112,10 @@ def __init__( name=None, pos_label=None, ): - self.fpr_ = fpr if isinstance(fpr, list) else [fpr] - self.tpr_ = tpr if isinstance(tpr, list) else [tpr] - self.roc_auc_ = roc_auc if isinstance(roc_auc, list) else [roc_auc] - self.name_ = name if isinstance(name, list) else [name] + self.fpr_ = _convert_to_list_leaving_none(fpr) + self.tpr_ = _convert_to_list_leaving_none(tpr) + self.roc_auc_ = _convert_to_list_leaving_none(roc_auc) + self.name_ = _convert_to_list_leaving_none(name) self.pos_label = pos_label def plot( @@ -185,6 +186,7 @@ def plot( # TODO: Not sure about this, as ideally we would check params are correct # first?? self.ax_, self.figure_, name_ = self._validate_plot_params(ax=ax, name=name) + name_ = _convert_to_list_leaving_none(name_) _check_param_lengths( {"self.fpr": self.fpr_, "self.tpr": self.tpr_}, {"self.roc_auc": self.roc_auc_, "`name` from `plot` (or self.name)": name_}, diff --git a/sklearn/metrics/_plot/tests/test_roc_curve_display.py b/sklearn/metrics/_plot/tests/test_roc_curve_display.py index 80fb0f6b6cdfa..251cd619bfbd6 100644 --- a/sklearn/metrics/_plot/tests/test_roc_curve_display.py +++ b/sklearn/metrics/_plot/tests/test_roc_curve_display.py @@ -99,24 +99,24 @@ def test_roc_curve_display_plotting( pos_label=pos_label, ) - assert_allclose(display.roc_aucs[0], auc(fpr, tpr)) - assert_allclose(display.fprs[0], fpr) - assert_allclose(display.tprs[0], tpr) + assert_allclose(display.roc_auc_[0], auc(fpr, tpr)) + assert_allclose(display.fpr_[0], fpr) + assert_allclose(display.tpr_[0], tpr) - assert display.names[0] == default_name + assert display.name_[0] == default_name import matplotlib as mpl # noqa - assert isinstance(display.lines_[0], mpl.lines.Line2D) - assert display.lines_[0].get_alpha() == 0.8 + assert isinstance(display.line_, mpl.lines.Line2D) + assert display.line_.get_alpha() == 0.8 assert isinstance(display.ax_, mpl.axes.Axes) assert isinstance(display.figure_, mpl.figure.Figure) assert display.ax_.get_adjustable() == "box" assert display.ax_.get_aspect() in ("equal", 1.0) assert display.ax_.get_xlim() == display.ax_.get_ylim() == (-0.01, 1.01) - expected_label = f"{default_name} (AUC = {display.roc_aucs[0]:.2f})" - assert display.lines_[0].get_label() == expected_label + expected_label = f"{default_name} (AUC = {display.roc_auc_[0]:.2f})" + assert display.line_.get_label() == expected_label expected_pos_label = 1 if pos_label is None else pos_label expected_ylabel = f"True Positive Rate (Positive label: {expected_pos_label})" @@ -180,8 +180,8 @@ def test_roc_curve_chance_level_line( import matplotlib as mpl # noqa - assert isinstance(display.lines_[0], mpl.lines.Line2D) - assert display.lines_[0].get_alpha() == 0.8 + assert isinstance(display.line_, mpl.lines.Line2D) + assert display.line_.get_alpha() == 0.8 assert isinstance(display.ax_, mpl.axes.Axes) assert isinstance(display.figure_, mpl.figure.Figure) @@ -253,12 +253,12 @@ def test_roc_curve_display_complex_pipeline(pyplot, data_binary, clf, constructo display = RocCurveDisplay.from_predictions(y, y) name = "Classifier" - assert name in display.lines_[0].get_label() - assert display.names[0] == name + assert name in display.line_.get_label() + assert display.name_[0] == name @pytest.mark.parametrize( - "roc_aucs, names, expected_labels", + "roc_auc, name, expected_labels", [ ([0.9, 0.8], None, ["AUC = 0.90", "AUC = 0.80"]), ([0.8, 0.7], [None, None], ["AUC = 0.80", "AUC = 0.70"]), @@ -270,13 +270,13 @@ def test_roc_curve_display_complex_pipeline(pyplot, data_binary, clf, constructo ), ], ) -def test_roc_curve_display_default_labels(pyplot, roc_aucs, names, expected_labels): +def test_roc_curve_display_default_labels(pyplot, roc_auc, name, expected_labels): """Check the default labels used in the display.""" - fprs = [np.array([0, 0.5, 1]), np.array([0, 0.3, 1])] - tprs = [np.array([0, 0.5, 1]), np.array([0, 0.3, 1])] - disp = RocCurveDisplay(fprs=fprs, tprs=tprs, roc_aucs=roc_aucs, names=names).plot() + fpr = [np.array([0, 0.5, 1]), np.array([0, 0.3, 1])] + tpr = [np.array([0, 0.5, 1]), np.array([0, 0.3, 1])] + disp = RocCurveDisplay(fpr=fpr, tpr=tpr, roc_auc=roc_auc, name=name).plot() for idx, expected_label in enumerate(expected_labels): - assert disp.lines_[idx].get_label() == expected_label + assert disp.line_[idx].get_label() == expected_label @pytest.mark.parametrize("response_method", ["predict_proba", "decision_function"]) @@ -331,8 +331,8 @@ def test_plot_roc_curve_pos_label(pyplot, response_method, constructor_name): roc_auc_limit = 0.95679 - assert display.roc_aucs[0] == pytest.approx(roc_auc_limit) - assert trapezoid(display.tprs[0], display.fprs[0]) == pytest.approx(roc_auc_limit) + assert display.roc_auc_[0] == pytest.approx(roc_auc_limit) + assert trapezoid(display.tpr_[0], display.fpr_[0]) == pytest.approx(roc_auc_limit) if constructor_name == "from_estimator": display = RocCurveDisplay.from_estimator( @@ -349,8 +349,8 @@ def test_plot_roc_curve_pos_label(pyplot, response_method, constructor_name): pos_label="not cancer", ) - assert display.roc_aucs[0] == pytest.approx(roc_auc_limit) - assert trapezoid(display.tprs[0], display.fprs[0]) == pytest.approx(roc_auc_limit) + assert display.roc_auc_[0] == pytest.approx(roc_auc_limit) + assert trapezoid(display.tpr_[0], display.fpr_[0]) == pytest.approx(roc_auc_limit) @pytest.mark.parametrize("despine", [True, False]) diff --git a/sklearn/utils/_plotting.py b/sklearn/utils/_plotting.py index 93b1d6888d54b..7484bb943956d 100644 --- a/sklearn/utils/_plotting.py +++ b/sklearn/utils/_plotting.py @@ -1,6 +1,5 @@ # Authors: The scikit-learn developers # SPDX-License-Identifier: BSD-3-Clause -import warnings from collections.abc import Mapping import numpy as np @@ -31,7 +30,7 @@ def _validate_plot_params(self, *, ax=None, name=None): if name is None: name = getattr(self, "estimator_name", None) if name is None: - name = getattr(self, "name", None) + name = getattr(self, "name_", None) return ax, ax.figure, name @classmethod @@ -57,11 +56,7 @@ def _get_line_kwargs( fold_line_kws = [kwargs] else: if fold_line_kws is None: - # We should not set color to be the same, otherwise legend is - # meaningless - fold_line_kws = [ - {"alpha": 0.5, "color": "tab:blue", "linestyle": "--"} - ] * n_curves + fold_line_kws = [{"alpha": 0.5, "linestyle": "--"}] * n_curves elif isinstance(fold_line_kws, Mapping): fold_line_kws = [fold_line_kws] * n_curves elif len(fold_line_kws) != n_curves: @@ -244,25 +239,13 @@ def _despine(ax): # TODO(1.9): remove # Should this be a parent class method? -def _deprecate_singular(singular, plural, name): - """Deprecate the singular version of Display parameters. - - If only `singular` parameter passed, it will be returned as a list with a warning. - """ - if singular != "deprecated": - warnings.warn( - f"`{name}` was passed to `{name}s` in a list because `{name}` is " - f"deprecated in 1.7 and will be removed in 1.9. Use " - f"`{name}s` instead.", - FutureWarning, - ) - if plural: - raise ValueError( - f"Cannot use both `{name}` and `{name}s`. Use only `{name}s` as " - f"`{name}` is deprecated." - ) - return [singular] - return plural +def _convert_to_list_leaving_none(param): + """Convert parameters to a list, leaving `None` as is.""" + if param is None: + return None + if isinstance(param, list): + return param + return [param] # Should this be a parent class/mixin method? From 2745405ea963407f65243a2a7401f5a3a5c7ecc4 Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Mon, 17 Feb 2025 15:54:21 +1100 Subject: [PATCH 30/63] process attrs in plot --- sklearn/metrics/_plot/roc_curve.py | 24 ++++++++++++------- .../_plot/tests/test_roc_curve_display.py | 18 +++++++++----- sklearn/utils/_plotting.py | 2 +- 3 files changed, 29 insertions(+), 15 deletions(-) diff --git a/sklearn/metrics/_plot/roc_curve.py b/sklearn/metrics/_plot/roc_curve.py index c2c12beea2f59..fee9d0ee7d6b5 100644 --- a/sklearn/metrics/_plot/roc_curve.py +++ b/sklearn/metrics/_plot/roc_curve.py @@ -112,10 +112,10 @@ def __init__( name=None, pos_label=None, ): - self.fpr_ = _convert_to_list_leaving_none(fpr) - self.tpr_ = _convert_to_list_leaving_none(tpr) - self.roc_auc_ = _convert_to_list_leaving_none(roc_auc) - self.name_ = _convert_to_list_leaving_none(name) + self.fpr = fpr + self.tpr = tpr + self.roc_auc = roc_auc + self.name = name self.pos_label = pos_label def plot( @@ -183,19 +183,27 @@ def plot( display : :class:`~sklearn.metrics.RocCurveDisplay` Object that stores computed values. """ + self.fpr_ = _convert_to_list_leaving_none(self.fpr) + self.tpr_ = _convert_to_list_leaving_none(self.tpr) + self.roc_auc_ = _convert_to_list_leaving_none(self.roc_auc) # TODO: Not sure about this, as ideally we would check params are correct # first?? - self.ax_, self.figure_, name_ = self._validate_plot_params(ax=ax, name=name) - name_ = _convert_to_list_leaving_none(name_) + self.ax_, self.figure_, self.name_ = self._validate_plot_params( + ax=ax, name=name + ) + self.name_ = _convert_to_list_leaving_none(self.name_) _check_param_lengths( {"self.fpr": self.fpr_, "self.tpr": self.tpr_}, - {"self.roc_auc": self.roc_auc_, "`name` from `plot` (or self.name)": name_}, + { + "self.roc_auc": self.roc_auc_, + "`name` from `plot` (or self.name)": self.name_, + }, "RocCurveDisplay", ) n_curves = len(self.fpr_) line_kwargs = self._get_line_kwargs( - n_curves, name_, self.roc_auc_, "AUC", fold_line_kwargs, **kwargs + n_curves, self.name_, self.roc_auc_, "AUC", fold_line_kwargs, **kwargs ) default_chance_level_line_kw = { diff --git a/sklearn/metrics/_plot/tests/test_roc_curve_display.py b/sklearn/metrics/_plot/tests/test_roc_curve_display.py index 251cd619bfbd6..e0fbc94e4a7a4 100644 --- a/sklearn/metrics/_plot/tests/test_roc_curve_display.py +++ b/sklearn/metrics/_plot/tests/test_roc_curve_display.py @@ -51,7 +51,7 @@ def test_roc_curve_display_plotting( constructor_name, default_name, ): - """Check the overall plotting behaviour.""" + """Check the overall plotting behaviour for single curve.""" X, y = data_binary pos_label = None @@ -99,11 +99,16 @@ def test_roc_curve_display_plotting( pos_label=pos_label, ) + # Both processed and unprocessed attributes should be the same for single curve assert_allclose(display.roc_auc_[0], auc(fpr, tpr)) + assert_allclose(display.roc_auc, auc(fpr, tpr)) assert_allclose(display.fpr_[0], fpr) + assert_allclose(display.fpr, fpr) assert_allclose(display.tpr_[0], tpr) + assert_allclose(display.tpr, tpr) assert display.name_[0] == default_name + assert display.name == default_name import matplotlib as mpl # noqa @@ -116,6 +121,7 @@ def test_roc_curve_display_plotting( assert display.ax_.get_xlim() == display.ax_.get_ylim() == (-0.01, 1.01) expected_label = f"{default_name} (AUC = {display.roc_auc_[0]:.2f})" + expected_label = f"{default_name} (AUC = {display.roc_auc:.2f})" assert display.line_.get_label() == expected_label expected_pos_label = 1 if pos_label is None else pos_label @@ -254,7 +260,7 @@ def test_roc_curve_display_complex_pipeline(pyplot, data_binary, clf, constructo name = "Classifier" assert name in display.line_.get_label() - assert display.name_[0] == name + assert display.name == name @pytest.mark.parametrize( @@ -331,8 +337,8 @@ def test_plot_roc_curve_pos_label(pyplot, response_method, constructor_name): roc_auc_limit = 0.95679 - assert display.roc_auc_[0] == pytest.approx(roc_auc_limit) - assert trapezoid(display.tpr_[0], display.fpr_[0]) == pytest.approx(roc_auc_limit) + assert display.roc_auc == pytest.approx(roc_auc_limit) + assert trapezoid(display.tpr, display.fpr) == pytest.approx(roc_auc_limit) if constructor_name == "from_estimator": display = RocCurveDisplay.from_estimator( @@ -349,8 +355,8 @@ def test_plot_roc_curve_pos_label(pyplot, response_method, constructor_name): pos_label="not cancer", ) - assert display.roc_auc_[0] == pytest.approx(roc_auc_limit) - assert trapezoid(display.tpr_[0], display.fpr_[0]) == pytest.approx(roc_auc_limit) + assert display.roc_auc == pytest.approx(roc_auc_limit) + assert trapezoid(display.tpr, display.fpr) == pytest.approx(roc_auc_limit) @pytest.mark.parametrize("despine", [True, False]) diff --git a/sklearn/utils/_plotting.py b/sklearn/utils/_plotting.py index 7484bb943956d..b41b46780f0bc 100644 --- a/sklearn/utils/_plotting.py +++ b/sklearn/utils/_plotting.py @@ -30,7 +30,7 @@ def _validate_plot_params(self, *, ax=None, name=None): if name is None: name = getattr(self, "estimator_name", None) if name is None: - name = getattr(self, "name_", None) + name = getattr(self, "name", None) return ax, ax.figure, name @classmethod From 62c3f326258083f0a9f33e1bc071017d9a440782 Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Tue, 18 Feb 2025 14:06:53 +1100 Subject: [PATCH 31/63] add ref --- sklearn/metrics/_plot/roc_curve.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sklearn/metrics/_plot/roc_curve.py b/sklearn/metrics/_plot/roc_curve.py index fee9d0ee7d6b5..4a554d61d4e0a 100644 --- a/sklearn/metrics/_plot/roc_curve.py +++ b/sklearn/metrics/_plot/roc_curve.py @@ -25,7 +25,8 @@ class RocCurveDisplay(_BinaryClassifierCurveDisplayMixin): a :class:`~sklearn.metrics.RocCurveDisplay`. All parameters are stored as attributes. - Read more in the :ref:`User Guide `. + For more about the ROC metric, see :ref:`roc_metrics`. + For more about scikit-learn visualization classes, see :ref:`visualizations`. Parameters ---------- From 25065a0ad19f9dca596b968924ea451047348d15 Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Fri, 7 Mar 2025 15:37:17 +1100 Subject: [PATCH 32/63] review and add tests --- sklearn/metrics/_plot/roc_curve.py | 112 ++--- .../_plot/tests/test_common_curve_display.py | 30 +- .../_plot/tests/test_roc_curve_display.py | 414 +++++++++++++++--- sklearn/utils/_plotting.py | 252 +++++++---- 4 files changed, 618 insertions(+), 190 deletions(-) diff --git a/sklearn/metrics/_plot/roc_curve.py b/sklearn/metrics/_plot/roc_curve.py index 4a554d61d4e0a..426d28821c848 100644 --- a/sklearn/metrics/_plot/roc_curve.py +++ b/sklearn/metrics/_plot/roc_curve.py @@ -6,12 +6,12 @@ _BinaryClassifierCurveDisplayMixin, _check_param_lengths, _convert_to_list_leaving_none, + _deprecate_estimator_name, _despine, - _process_fold_names_line_kwargs, + _validate_line_kwargs, _validate_style_kwargs, ) from ...utils._response import _get_response_values_binary -from ...utils.validation import _num_samples from .._ranking import auc, roc_curve @@ -59,13 +59,18 @@ class RocCurveDisplay(_BinaryClassifierCurveDisplayMixin): .. versionadded:: 0.24 + estimator_name : str, default=None + Name of estimator. If None, the estimator name is not shown. + + .. deprecated:: 1.7 + Attributes ---------- line_ : matplotlib Artist or list of matplotlib Artists ROC Curves. .. versionchanged:: 1.7 - This attribute can now be a list of Artists, when multiple curves are + This attribute can now be a list of Artists, for when multiple curves are plotted. chance_level_ : matplotlib Artist or None @@ -112,13 +117,31 @@ def __init__( roc_auc=None, name=None, pos_label=None, + estimator_name="deprecated", ): self.fpr = fpr self.tpr = tpr self.roc_auc = roc_auc - self.name = name + self.name = _deprecate_estimator_name(estimator_name, name, "1.7") self.pos_label = pos_label + def _validate_plot_params(self, *, ax=None, name=None): + self.ax_, self.figure_, name_ = super()._validate_plot_params(ax=ax, name=name) + + self.fpr_ = _convert_to_list_leaving_none(self.fpr) + self.tpr_ = _convert_to_list_leaving_none(self.tpr) + self.roc_auc_ = _convert_to_list_leaving_none(self.roc_auc) + self.name_ = _convert_to_list_leaving_none(name_) + + _check_param_lengths( + required={"self.fpr": self.fpr_, "self.tpr": self.tpr_}, + optional={ + "self.roc_auc": self.roc_auc_, + "`name` from `plot` (or self.name)": self.name_, + }, + class_name="RocCurveDisplay", + ) + def plot( self, ax=None, @@ -184,27 +207,16 @@ def plot( display : :class:`~sklearn.metrics.RocCurveDisplay` Object that stores computed values. """ - self.fpr_ = _convert_to_list_leaving_none(self.fpr) - self.tpr_ = _convert_to_list_leaving_none(self.tpr) - self.roc_auc_ = _convert_to_list_leaving_none(self.roc_auc) - # TODO: Not sure about this, as ideally we would check params are correct - # first?? - self.ax_, self.figure_, self.name_ = self._validate_plot_params( - ax=ax, name=name - ) - self.name_ = _convert_to_list_leaving_none(self.name_) - _check_param_lengths( - {"self.fpr": self.fpr_, "self.tpr": self.tpr_}, - { - "self.roc_auc": self.roc_auc_, - "`name` from `plot` (or self.name)": self.name_, - }, - "RocCurveDisplay", - ) + self._validate_plot_params(ax=ax, name=name) n_curves = len(self.fpr_) line_kwargs = self._get_line_kwargs( - n_curves, self.name_, self.roc_auc_, "AUC", fold_line_kwargs, **kwargs + n_curves, + self.name_, + self.roc_auc_, + "AUC", + fold_line_kwargs=fold_line_kwargs, + **kwargs, ) default_chance_level_line_kw = { @@ -533,7 +545,8 @@ def from_cv_results( fold_names=None, fold_line_kwargs=None, plot_chance_level=False, - chance_level_kw=None, + chance_level_kwargs=None, + despine=False, ): """Create a multi-fold ROC curve display given cross-validation results. @@ -590,10 +603,13 @@ def from_cv_results( plot_chance_level : bool, default=False Whether to plot the chance level. - chance_level_kw : dict, default=None + chance_level_kwargs : dict, default=None Keyword arguments to be passed to matplotlib's `plot` for rendering the chance level line. + despine : bool, default=False + Whether to remove the top and right spines from the plot. + Returns ------- display : :class:`~sklearn.metrics.RocCurveDisplay` @@ -623,27 +639,18 @@ def from_cv_results( <...> >>> plt.show() """ - required_keys = {"estimator", "indices"} - if not all(key in cv_results for key in required_keys): - raise ValueError( - "cv_results does not contain one of the following required keys: " - f"{required_keys}. Set explicitly the parameters return_estimator=True " - "and return_indices=True to the function cross_validate." - ) - - train_size, test_size = ( - len(cv_results["indices"]["train"][0]), - len(cv_results["indices"]["test"][0]), + pos_label, fold_names_ = cls._validate_from_cv_results_params( + cv_results, + X, + y, + sample_weight=sample_weight, + pos_label=pos_label, + fold_names=fold_names, ) - - if _num_samples(X) != train_size + test_size: - raise ValueError( - "X does not contain the correct number of samples. " - f"Expected {train_size + test_size}, got {_num_samples(X)}." - ) - - fold_names_, fold_line_kwargs_ = _process_fold_names_line_kwargs( - len(cv_results["estimator"]), fold_names, fold_line_kwargs + fold_line_kwargs_ = _validate_line_kwargs( + len(cv_results["estimator"]), + fold_line_kwargs, + default_line_kwargs={"alpha": 0.5, "linestyle": "--"}, ) fpr_all = [] @@ -659,17 +666,20 @@ def from_cv_results( response_method=response_method, pos_label=pos_label, )[0] - # Should we use `_validate_from_predictions_params` here? - # The check would technically only be needed once though + sample_weight_fold = ( + None + if sample_weight is None + else _safe_indexing(sample_weight, test_indices) + ) fpr, tpr, _ = roc_curve( y_true, y_pred, pos_label=pos_label, - sample_weight=sample_weight, + sample_weight=sample_weight_fold, drop_intermediate=drop_intermediate, ) roc_auc = auc(fpr, tpr) - # Append all + fpr_all.append(fpr) tpr_all.append(tpr) auc_all.append(roc_auc) @@ -677,13 +687,15 @@ def from_cv_results( viz = cls( fpr=fpr_all, tpr=tpr_all, + # Should we provide `name` to both `cls` and `plot` or just `cls`? name=fold_names_, roc_auc=auc_all, pos_label=pos_label, ) return viz.plot( ax=ax, - fold_line_kwargs=fold_line_kwargs_, plot_chance_level=plot_chance_level, - chance_level_kw=chance_level_kw, + chance_level_kw=chance_level_kwargs, + despine=despine, + fold_line_kwargs=fold_line_kwargs_, ) diff --git a/sklearn/metrics/_plot/tests/test_common_curve_display.py b/sklearn/metrics/_plot/tests/test_common_curve_display.py index 865099b540c66..0c1c27927e127 100644 --- a/sklearn/metrics/_plot/tests/test_common_curve_display.py +++ b/sklearn/metrics/_plot/tests/test_common_curve_display.py @@ -174,10 +174,8 @@ def test_display_curve_estimator_name_multiple_calls( ), ], ) -# Add separate test for displays that have converted to names?, -# add note to remove this one in 1.9 @pytest.mark.parametrize("Display", [DetCurveDisplay, PrecisionRecallDisplay]) -def test_display_curve_not_fitted_errors(pyplot, data_binary, clf, Display): +def test_display_curve_not_fitted_errors_old_name(pyplot, data_binary, clf, Display): """Check that a proper error is raised when the classifier is not fitted.""" X, y = data_binary @@ -192,6 +190,32 @@ def test_display_curve_not_fitted_errors(pyplot, data_binary, clf, Display): assert disp.estimator_name == model.__class__.__name__ +# This uses the new `name` para +@pytest.mark.parametrize( + "clf", + [ + LogisticRegression(), + make_pipeline(StandardScaler(), LogisticRegression()), + make_pipeline( + make_column_transformer((StandardScaler(), [0, 1])), LogisticRegression() + ), + ], +) +@pytest.mark.parametrize("Display", [RocCurveDisplay]) +def test_display_curve_not_fitted_errors(pyplot, data_binary, clf, Display): + """Check that a proper error is raised when the classifier not fitted.""" + X, y = data_binary + # clone since we parametrize the test and the classifier will be fitted + # when testing the second and subsequent plotting function + model = clone(clf) + with pytest.raises(NotFittedError): + Display.from_estimator(model, X, y) + model.fit(X, y) + disp = Display.from_estimator(model, X, y) + assert model.__class__.__name__ in disp.line_.get_label() + assert disp.name_[0] == model.__class__.__name__ + + @pytest.mark.parametrize( "Display", [DetCurveDisplay, PrecisionRecallDisplay, RocCurveDisplay] ) diff --git a/sklearn/metrics/_plot/tests/test_roc_curve_display.py b/sklearn/metrics/_plot/tests/test_roc_curve_display.py index e0fbc94e4a7a4..9f6510c041581 100644 --- a/sklearn/metrics/_plot/tests/test_roc_curve_display.py +++ b/sklearn/metrics/_plot/tests/test_roc_curve_display.py @@ -1,3 +1,5 @@ +from collections.abc import Mapping + import numpy as np import pytest from numpy.testing import assert_allclose @@ -9,10 +11,11 @@ from sklearn.exceptions import NotFittedError from sklearn.linear_model import LogisticRegression from sklearn.metrics import RocCurveDisplay, auc, roc_curve -from sklearn.model_selection import train_test_split +from sklearn.model_selection import cross_validate, train_test_split from sklearn.pipeline import make_pipeline from sklearn.preprocessing import StandardScaler -from sklearn.utils import shuffle +from sklearn.utils import _safe_indexing, shuffle +from sklearn.utils._response import _get_response_values_binary @pytest.fixture(scope="module") @@ -24,12 +27,32 @@ def data(): return X, y +# This data always (with and without `drop_intermediate`) +# results in an AUC of 1.0, should we consider changing the data used?? @pytest.fixture(scope="module") def data_binary(data): X, y = data return X[y < 2], y[y < 2] +def _check_figure_axes_and_labels(display, pos_label): + """Check mpl axes and figure defaults are correct.""" + import matplotlib as mpl # noqa + + assert isinstance(display.ax_, mpl.axes.Axes) + assert isinstance(display.figure_, mpl.figure.Figure) + assert display.ax_.get_adjustable() == "box" + assert display.ax_.get_aspect() in ("equal", 1.0) + assert display.ax_.get_xlim() == display.ax_.get_ylim() == (-0.01, 1.01) + + expected_pos_label = 1 if pos_label is None else pos_label + expected_ylabel = f"True Positive Rate (Positive label: {expected_pos_label})" + expected_xlabel = f"False Positive Rate (Positive label: {expected_pos_label})" + + assert display.ax_.get_ylabel() == expected_ylabel + assert display.ax_.get_xlabel() == expected_xlabel + + @pytest.mark.parametrize("response_method", ["predict_proba", "decision_function"]) @pytest.mark.parametrize("with_sample_weight", [True, False]) @pytest.mark.parametrize("drop_intermediate", [True, False]) @@ -112,24 +135,222 @@ def test_roc_curve_display_plotting( import matplotlib as mpl # noqa + _check_figure_axes_and_labels(display, pos_label) assert isinstance(display.line_, mpl.lines.Line2D) assert display.line_.get_alpha() == 0.8 - assert isinstance(display.ax_, mpl.axes.Axes) - assert isinstance(display.figure_, mpl.figure.Figure) - assert display.ax_.get_adjustable() == "box" - assert display.ax_.get_aspect() in ("equal", 1.0) - assert display.ax_.get_xlim() == display.ax_.get_ylim() == (-0.01, 1.01) expected_label = f"{default_name} (AUC = {display.roc_auc_[0]:.2f})" expected_label = f"{default_name} (AUC = {display.roc_auc:.2f})" assert display.line_.get_label() == expected_label - expected_pos_label = 1 if pos_label is None else pos_label - expected_ylabel = f"True Positive Rate (Positive label: {expected_pos_label})" - expected_xlabel = f"False Positive Rate (Positive label: {expected_pos_label})" - assert display.ax_.get_ylabel() == expected_ylabel - assert display.ax_.get_xlabel() == expected_xlabel +def test_roc_curve_from_cv_results_param_validation(pyplot, data_binary, data): + """Check parameter validation is correct.""" + X, y = data_binary + + # `cv_results` missing key + cv_results_no_est = cross_validate( + LogisticRegression(), X, y, cv=3, return_estimator=True, return_indices=False + ) + cv_results_no_indices = cross_validate( + LogisticRegression(), X, y, cv=3, return_estimator=True, return_indices=False + ) + for cv_results in (cv_results_no_est, cv_results_no_indices): + with pytest.raises( + ValueError, + match="'cv_results' does not contain one of the following required", + ): + RocCurveDisplay.from_cv_results(cv_results, X, y) + + cv_results = cross_validate( + LogisticRegression(), X, y, cv=3, return_estimator=True, return_indices=True + ) + + # `X` wrong length + with pytest.raises(ValueError, match="'X' does not contain the correct"): + RocCurveDisplay.from_cv_results(cv_results, X[:10, :], y) + + # `y` not binary + X_mutli, y_multi = data + with pytest.raises(ValueError, match="The target y is not binary."): + RocCurveDisplay.from_cv_results(cv_results, X, y_multi) + + # input inconsistent length + with pytest.raises(ValueError, match="Found input variables with inconsistent"): + RocCurveDisplay.from_cv_results(cv_results, X, y[:10]) + with pytest.raises(ValueError, match="Found input variables with inconsistent"): + RocCurveDisplay.from_cv_results(cv_results, X, y, sample_weight=[1, 2]) + + # `pos_label` inconsistency + X_bad_pos_label, y_bad_pos_label = X_mutli[y_multi > 0], y_multi[y_multi > 0] + with pytest.raises(ValueError, match=r"y takes value in \{1, 2\}"): + RocCurveDisplay.from_cv_results(cv_results, X_bad_pos_label, y_bad_pos_label) + + # `fold_names` incorrect length + with pytest.raises(ValueError, match="When 'fold_names' is provided, it must"): + RocCurveDisplay.from_cv_results(cv_results, X, y, fold_names=["fold"]) + # `fold_line_kwargs` incorrect length + with pytest.raises( + ValueError, match="When 'fold_line_kwargs' is provided, it must" + ): + RocCurveDisplay.from_cv_results( + cv_results, X, y, fold_line_kwargs=[{"alpha": 1}] + ) + + +@pytest.mark.parametrize("drop_intermediate", [True, False]) +@pytest.mark.parametrize("response_method", ["predict_proba", "decision_function"]) +@pytest.mark.parametrize("with_sample_weight", [True, False]) +@pytest.mark.parametrize("with_strings", [True, False]) +def test_test_roc_curve_display_plotting_from_cv_results( + pyplot, + data_binary, + with_strings, + with_sample_weight, + response_method, + drop_intermediate, +): + """Check overall plotting of `from_cv_results`.""" + X, y = data_binary + + pos_label = None + if with_strings: + y = np.array(["c", "b"])[y] + pos_label = "c" + + if with_sample_weight: + rng = np.random.RandomState(42) + sample_weight = rng.randint(1, 4, size=(X.shape[0])) + else: + sample_weight = None + + cv_results = cross_validate( + LogisticRegression(), X, y, cv=3, return_estimator=True, return_indices=True + ) + display = RocCurveDisplay.from_cv_results( + cv_results, + X, + y, + sample_weight=sample_weight, + drop_intermediate=drop_intermediate, + response_method=response_method, + pos_label=pos_label, + ) + + for idx, (estimator, test_indices) in enumerate( + zip(cv_results["estimator"], cv_results["indices"]["test"]) + ): + y_true = _safe_indexing(y, test_indices) + y_pred = _get_response_values_binary( + estimator, + _safe_indexing(X, test_indices), + response_method=response_method, + pos_label=pos_label, + )[0] + sample_weight_fold = ( + None + if sample_weight is None + else _safe_indexing(sample_weight, test_indices) + ) + fpr, tpr, _ = roc_curve( + y_true, + y_pred, + sample_weight=sample_weight_fold, + drop_intermediate=drop_intermediate, + pos_label=pos_label, + ) + assert_allclose(display.roc_auc_[idx], auc(fpr, tpr)) + assert_allclose(display.fpr_[idx], fpr) + assert_allclose(display.tpr_[idx], tpr) + + fold_names = ["Fold 0", "Fold 1", "Fold 2"] + assert display.name_ == fold_names + + import matplotlib as mpl # noqa + + _check_figure_axes_and_labels(display, pos_label) + for idx, line in enumerate(display.line_): + assert isinstance(line, mpl.lines.Line2D) + # Default alpha for `from_cv_results` + line.get_alpha() == 0.5 + expected_label = f"{fold_names[idx]} (AUC = {display.roc_auc_[idx]:.2f})" + assert display.line_.get_label() == expected_label + + +@pytest.mark.parametrize("fold_names", [None, ["one", "two", "three"]]) +def test_roc_curve_from_cv_results_fold_names(pyplot, data_binary, fold_names): + """Check fold names behaviour correct in `from_cv_results`.""" + X, y = data_binary + cv_results = cross_validate( + LogisticRegression(), X, y, cv=3, return_estimator=True, return_indices=True + ) + display = RocCurveDisplay.from_cv_results(cv_results, X, y, fold_names=fold_names) + legend = display.ax_.get_legend() + legend_labels = [text.get_text() for text in legend.get_texts()] + expected_names = ( + ["Fold 0", "Fold 1", "Fold 2"] if fold_names is None else fold_names + ) + assert display.name_ == expected_names + expected_labels = [name + " (AUC = 1.00)" for name in expected_names] + assert legend_labels == expected_labels + + +@pytest.mark.parametrize( + "fold_line_kwargs", + [None, {"color": "red"}, [{"c": "red"}, {"c": "green"}, {"c": "yellow"}]], +) +def test_roc_curve_from_cv_results_line_kwargs(pyplot, data_binary, fold_line_kwargs): + """Check line kwargs passed correctly in `from_cv_results`.""" + import matplotlib as mpl # noqa + + X, y = data_binary + cv_results = cross_validate( + LogisticRegression(), X, y, cv=3, return_estimator=True, return_indices=True + ) + display = RocCurveDisplay.from_cv_results( + cv_results, X, y, fold_line_kwargs=fold_line_kwargs + ) + + mpl_default_colors = mpl.rcParams["axes.prop_cycle"].by_key()["color"] + for idx, line in enumerate(display.line_): + color = line.get_color() + if fold_line_kwargs is None: + assert color == mpl_default_colors[idx] + elif isinstance(fold_line_kwargs, Mapping): + assert color == "red" + else: + assert color == fold_line_kwargs[idx]["c"] + + +def _check_chance_level(plot_chance_level, chance_level_kw, display): + """Check chance level line and line styles correct.""" + import matplotlib as mpl # noqa + + if plot_chance_level: + assert isinstance(display.chance_level_, mpl.lines.Line2D) + assert tuple(display.chance_level_.get_xdata()) == (0, 1) + assert tuple(display.chance_level_.get_ydata()) == (0, 1) + else: + assert display.chance_level_ is None + + # Checking for chance level line styles + if plot_chance_level and chance_level_kw is None: + assert display.chance_level_.get_color() == "k" + assert display.chance_level_.get_linestyle() == "--" + assert display.chance_level_.get_label() == "Chance level (AUC = 0.5)" + elif plot_chance_level: + if "c" in chance_level_kw: + assert display.chance_level_.get_color() == chance_level_kw["c"] + else: + assert display.chance_level_.get_color() == chance_level_kw["color"] + if "lw" in chance_level_kw: + assert display.chance_level_.get_linewidth() == chance_level_kw["lw"] + else: + assert display.chance_level_.get_linewidth() == chance_level_kw["linewidth"] + if "ls" in chance_level_kw: + assert display.chance_level_.get_linestyle() == chance_level_kw["ls"] + else: + assert display.chance_level_.get_linestyle() == chance_level_kw["linestyle"] @pytest.mark.parametrize("plot_chance_level", [True, False]) @@ -155,7 +376,7 @@ def test_roc_curve_chance_level_line( label, constructor_name, ): - """Check the chance level line plotting behaviour.""" + """Check chance level plotting behavior of `from_predictions`, `from_estimator`.""" X, y = data_binary lr = LogisticRegression() @@ -191,32 +412,10 @@ def test_roc_curve_chance_level_line( assert isinstance(display.ax_, mpl.axes.Axes) assert isinstance(display.figure_, mpl.figure.Figure) - if plot_chance_level: - assert isinstance(display.chance_level_, mpl.lines.Line2D) - assert tuple(display.chance_level_.get_xdata()) == (0, 1) - assert tuple(display.chance_level_.get_ydata()) == (0, 1) - else: - assert display.chance_level_ is None + _check_chance_level(plot_chance_level, chance_level_kw, display) - # Checking for chance level line styles - if plot_chance_level and chance_level_kw is None: - assert display.chance_level_.get_color() == "k" - assert display.chance_level_.get_linestyle() == "--" - assert display.chance_level_.get_label() == "Chance level (AUC = 0.5)" - elif plot_chance_level: - if "c" in chance_level_kw: - assert display.chance_level_.get_color() == chance_level_kw["c"] - else: - assert display.chance_level_.get_color() == chance_level_kw["color"] - if "lw" in chance_level_kw: - assert display.chance_level_.get_linewidth() == chance_level_kw["lw"] - else: - assert display.chance_level_.get_linewidth() == chance_level_kw["linewidth"] - if "ls" in chance_level_kw: - assert display.chance_level_.get_linestyle() == chance_level_kw["ls"] - else: - assert display.chance_level_.get_linestyle() == chance_level_kw["linestyle"] - # Checking for legend behaviour + # Checking for legend behaviour + if plot_chance_level and chance_level_kw is not None: if label is not None or chance_level_kw.get("label") is not None: legend = display.ax_.get_legend() assert legend is not None # Legend should be present if any label is set @@ -229,6 +428,62 @@ def test_roc_curve_chance_level_line( assert display.ax_.get_legend() is None +@pytest.mark.parametrize("plot_chance_level", [True, False]) +@pytest.mark.parametrize( + "chance_level_kw", + [ + None, + {"linewidth": 1, "color": "red", "linestyle": "-", "label": "DummyEstimator"}, + {"lw": 1, "c": "red", "ls": "-", "label": "DummyEstimator"}, + {"lw": 1, "color": "blue", "ls": "-", "label": None}, + ], +) +# To ensure both curve line kwargs and change line kwargs passed correctly +@pytest.mark.parametrize("fold_line_kwargs", [None, {"alpha": 0.8}]) +def test_roc_curve_chance_level_line_from_cv_results( + pyplot, + data_binary, + plot_chance_level, + chance_level_kw, + fold_line_kwargs, +): + """Check chance level plotting behavior with `from_cv_results`.""" + X, y = data_binary + n_cv = 3 + cv_results = cross_validate( + LogisticRegression(), X, y, cv=n_cv, return_estimator=True, return_indices=True + ) + + display = RocCurveDisplay.from_cv_results( + cv_results, + X, + y, + plot_chance_level=plot_chance_level, + chance_level_kw=chance_level_kw, + fold_line_kwargs=fold_line_kwargs, + ) + + import matplotlib as mpl # noqa + + assert all(isinstance(line, mpl.lines.Line2D) for line in display.line_) + if fold_line_kwargs: + assert all(line.get_alpha() == 0.8 for line in display.line_) + assert isinstance(display.ax_, mpl.axes.Axes) + assert isinstance(display.figure_, mpl.figure.Figure) + + _check_chance_level(plot_chance_level, chance_level_kw, display) + + legend = display.ax_.get_legend() + # There is always a legend, to indicate each 'Fold' curve + assert legend is not None + legend_labels = [text.get_text() for text in legend.get_texts()] + if plot_chance_level and chance_level_kw is not None: + if chance_level_kw.get("label") is not None: + assert chance_level_kw["label"] in legend_labels + else: + assert len(legend_labels) == n_cv + + @pytest.mark.parametrize( "clf", [ @@ -285,8 +540,22 @@ def test_roc_curve_display_default_labels(pyplot, roc_auc, name, expected_labels assert disp.line_[idx].get_label() == expected_label +def _check_auc(display, constructor_name): + roc_auc_limit = 0.95679 + roc_auc_limit_multi = [0.97007, 0.985915, 0.980952] + + if constructor_name == "from_cv_results": + for idx, roc_auc in enumerate(display.roc_auc_): + assert roc_auc == pytest.approx(roc_auc_limit_multi[idx]) + else: + assert display.roc_auc == pytest.approx(roc_auc_limit) + assert trapezoid(display.tpr, display.fpr) == pytest.approx(roc_auc_limit) + + @pytest.mark.parametrize("response_method", ["predict_proba", "decision_function"]) -@pytest.mark.parametrize("constructor_name", ["from_estimator", "from_predictions"]) +@pytest.mark.parametrize( + "constructor_name", ["from_estimator", "from_predictions", "from_cv_results"] +) def test_plot_roc_curve_pos_label(pyplot, response_method, constructor_name): # check that we can provide the positive label and display the proper # statistics @@ -309,9 +578,13 @@ def test_plot_roc_curve_pos_label(pyplot, response_method, constructor_name): classifier = LogisticRegression() classifier.fit(X_train, y_train) + cv_results = cross_validate( + LogisticRegression(), X, y, cv=3, return_estimator=True, return_indices=True + ) - # sanity check to be sure the positive class is `classes_[0]` and that we - # are betrayed by the class imbalance + # Sanity check to be sure the positive class is `classes_[0]` + # Class imbalance ensures a large difference in prediction values between classes, + # allowing us to catch errors when we switch `pos_label` assert classifier.classes_.tolist() == ["cancer", "not cancer"] y_pred = getattr(classifier, response_method)(X_test) @@ -320,63 +593,86 @@ def test_plot_roc_curve_pos_label(pyplot, response_method, constructor_name): y_pred_cancer = -1 * y_pred if y_pred.ndim == 1 else y_pred[:, 0] y_pred_not_cancer = y_pred if y_pred.ndim == 1 else y_pred[:, 1] + pos_label = "cancer" + y_pred = y_pred_cancer if constructor_name == "from_estimator": display = RocCurveDisplay.from_estimator( classifier, X_test, y_test, - pos_label="cancer", + pos_label=pos_label, response_method=response_method, ) - else: + elif constructor_name == "from_predictions": display = RocCurveDisplay.from_predictions( y_test, - y_pred_cancer, - pos_label="cancer", + y_pred, + pos_label=pos_label, + ) + else: + display = RocCurveDisplay.from_cv_results( + cv_results, + X, + y, + response_method=response_method, + pos_label=pos_label, ) - roc_auc_limit = 0.95679 - - assert display.roc_auc == pytest.approx(roc_auc_limit) - assert trapezoid(display.tpr, display.fpr) == pytest.approx(roc_auc_limit) + _check_auc(display, constructor_name) + pos_label = "not cancer" + y_pred = y_pred_not_cancer if constructor_name == "from_estimator": display = RocCurveDisplay.from_estimator( classifier, X_test, y_test, response_method=response_method, - pos_label="not cancer", + pos_label=pos_label, ) - else: + elif constructor_name == "from_predictions": display = RocCurveDisplay.from_predictions( y_test, - y_pred_not_cancer, - pos_label="not cancer", + y_pred, + pos_label=pos_label, + ) + else: + display = RocCurveDisplay.from_cv_results( + cv_results, + X, + y, + response_method=response_method, + pos_label=pos_label, ) - assert display.roc_auc == pytest.approx(roc_auc_limit) - assert trapezoid(display.tpr, display.fpr) == pytest.approx(roc_auc_limit) + _check_auc(display, constructor_name) @pytest.mark.parametrize("despine", [True, False]) -@pytest.mark.parametrize("constructor_name", ["from_estimator", "from_predictions"]) +@pytest.mark.parametrize( + "constructor_name", ["from_estimator", "from_predictions", "from_cv_results"] +) def test_plot_roc_curve_despine(pyplot, data_binary, despine, constructor_name): # Check that the despine keyword is working correctly X, y = data_binary lr = LogisticRegression().fit(X, y) lr.fit(X, y) + cv_results = cross_validate( + LogisticRegression(), X, y, cv=3, return_estimator=True, return_indices=True + ) y_pred = lr.decision_function(X) - # safe guard for the binary if/else construction - assert constructor_name in ("from_estimator", "from_predictions") + # safe guard for the if/else construction + assert constructor_name in ("from_estimator", "from_predictions", "from_cv_results") if constructor_name == "from_estimator": display = RocCurveDisplay.from_estimator(lr, X, y, despine=despine) - else: + elif constructor_name == "from_predictions": display = RocCurveDisplay.from_predictions(y, y_pred, despine=despine) + else: + display = RocCurveDisplay.from_cv_results(cv_results, X, y, despine=despine) for s in ["top", "right"]: assert display.ax_.spines[s].get_visible() is not despine diff --git a/sklearn/utils/_plotting.py b/sklearn/utils/_plotting.py index b41b46780f0bc..beaff75ebaae0 100644 --- a/sklearn/utils/_plotting.py +++ b/sklearn/utils/_plotting.py @@ -1,5 +1,6 @@ # Authors: The scikit-learn developers # SPDX-License-Identifier: BSD-3-Clause +import warnings from collections.abc import Mapping import numpy as np @@ -7,8 +8,15 @@ from . import check_consistent_length from ._optional_dependencies import check_matplotlib_support from ._response import _get_response_values_binary +from .fixes import parse_version from .multiclass import type_of_target -from .validation import _check_pos_label_consistency +from .validation import _check_pos_label_consistency, _num_samples + +MULTI_PARAM_ERROR_MSG = ( + "When '{param}' is provided, it must have the same length as " + "the number of curves to be plotted. Got: {len_param}; " + "expected: {n_curves}." +) class _BinaryClassifierCurveDisplayMixin: @@ -33,6 +41,94 @@ def _validate_plot_params(self, *, ax=None, name=None): name = getattr(self, "name", None) return ax, ax.figure, name + @classmethod + def _validate_and_get_response_values( + cls, estimator, X, y, *, response_method="auto", pos_label=None, name=None + ): + check_matplotlib_support(f"{cls.__name__}.from_estimator") + + name = estimator.__class__.__name__ if name is None else name + + y_pred, pos_label = _get_response_values_binary( + estimator, + X, + response_method=response_method, + pos_label=pos_label, + ) + + return y_pred, pos_label, name + + @classmethod + def _validate_from_predictions_params( + cls, y_true, y_pred, *, sample_weight=None, pos_label=None, name=None + ): + check_matplotlib_support(f"{cls.__name__}.from_predictions") + + if type_of_target(y_true) != "binary": + raise ValueError( + f"The target y is not binary. Got {type_of_target(y_true)} type of" + " target." + ) + + check_consistent_length(y_true, y_pred, sample_weight) + pos_label = _check_pos_label_consistency(pos_label, y_true) + + name = name if name is not None else "Classifier" + + return pos_label, name + + @classmethod + def _validate_from_cv_results_params( + cls, cv_results, X, y, *, sample_weight=None, pos_label=None, fold_names=None + ): + check_matplotlib_support(f"{cls.__name__}.from_predictions") + + required_keys = {"estimator", "indices"} + if not all(key in cv_results for key in required_keys): + raise ValueError( + "'cv_results' does not contain one of the following required keys: " + f"{required_keys}. Set explicitly the parameters return_estimator=True " + "and return_indices=True to the function cross_validate." + ) + + train_size, test_size = ( + len(cv_results["indices"]["train"][0]), + len(cv_results["indices"]["test"][0]), + ) + + if _num_samples(X) != train_size + test_size: + raise ValueError( + "'X' does not contain the correct number of samples. " + f"Expected {train_size + test_size}, got {_num_samples(X)}." + ) + + if type_of_target(y) != "binary": + raise ValueError( + f"The target y is not binary. Got {type_of_target(y)} type of" + " target." + ) + check_consistent_length(X, y, sample_weight) + + try: + pos_label = _check_pos_label_consistency(pos_label, y) + except ValueError as e: + # Alter error message + raise ValueError(str(e).replace("y_true", "y")) + + n_curves = len(cv_results["estimator"]) + if fold_names is None: + fold_names = [f"Fold {idx}" for idx in range(n_curves)] + elif len(fold_names) != n_curves: + raise ValueError( + MULTI_PARAM_ERROR_MSG.format( + param="fold_names", len_param=len(fold_names), n_curves=n_curves + ) + ) + else: + fold_names = fold_names + + return pos_label, fold_names + @classmethod def _get_line_kwargs( cls, @@ -40,31 +136,57 @@ def _get_line_kwargs( names, summary_values, summary_value_name, - fold_line_kws, - default_line_kwargs={}, + fold_line_kwargs, + default_line_kwargs=None, **kwargs, ): - """Get validated line kwargs for each curve.""" + """Get validated line kwargs for each curve. + + Parameters + ---------- + n_curves : int + Number of curves. + + names : list[str] + Names of each curve. + + summary_values : list[float] + List of summary values for each curve (e.g., ROC AUC, average precision). + + summary_value_name : str + Name of the summary value provided in `summary_values`. + + fold_line_kwargs : dict or list of dict + Dictionary with keywords passed to the matplotlib's `plot` function + to draw the individual ROC curves. If a list is provided, the + parameters are applied to the ROC curves sequentially. If a single + dictionary is provided, the same parameters are applied to all ROC + curves. Ignored for single curve plots - pass as `**kwargs` for + single curve plots. + + default_line_kwargs : dict, default=None + Default line kwargs to be used in all curves, unless overridden by + `fold_line_kwargs`. + + **kwargs : dict + For a single curve plots only, keyword arguments to be passed to + matplotlib's `plot`. Ignored for multi-curve plots - use `fold_line_kwargs` + for multi-curve plots. + """ # Ensure parameters are of the correct length names_ = [None] * n_curves if names is None else names summary_values_ = ( [None] * n_curves if summary_values is None else summary_values ) - # `fold_line_kws` ignored for single curve plots + # `fold_line_kwargs` ignored for single curve plots # `kwargs` ignored for multi-curve plots if n_curves == 1: - fold_line_kws = [kwargs] + fold_line_kwargs = [kwargs] else: - if fold_line_kws is None: - fold_line_kws = [{"alpha": 0.5, "linestyle": "--"}] * n_curves - elif isinstance(fold_line_kws, Mapping): - fold_line_kws = [fold_line_kws] * n_curves - elif len(fold_line_kws) != n_curves: - raise ValueError( - "When `fold_line_kws` is a list, it must have the same length as " - "the number of curves to be plotted." - ) + fold_line_kwargs = _validate_line_kwargs(n_curves, fold_line_kwargs) + if default_line_kwargs is None: + default_line_kwargs = {} line_kwargs = [] for fold_idx, (curve_summary_value, curve_name) in enumerate( zip(summary_values_, names_) @@ -81,46 +203,10 @@ def _get_line_kwargs( default_line_kwargs["label"] = curve_name line_kwargs.append( - _validate_style_kwargs(default_line_kwargs, fold_line_kws[fold_idx]) + _validate_style_kwargs(default_line_kwargs, fold_line_kwargs[fold_idx]) ) return line_kwargs - @classmethod - def _validate_and_get_response_values( - cls, estimator, X, y, *, response_method="auto", pos_label=None, name=None - ): - check_matplotlib_support(f"{cls.__name__}.from_estimator") - - name = estimator.__class__.__name__ if name is None else name - - y_pred, pos_label = _get_response_values_binary( - estimator, - X, - response_method=response_method, - pos_label=pos_label, - ) - - return y_pred, pos_label, name - - @classmethod - def _validate_from_predictions_params( - cls, y_true, y_pred, *, sample_weight=None, pos_label=None, name=None - ): - check_matplotlib_support(f"{cls.__name__}.from_predictions") - - if type_of_target(y_true) != "binary": - raise ValueError( - f"The target y is not binary. Got {type_of_target(y_true)} type of" - " target." - ) - - check_consistent_length(y_true, y_pred, sample_weight) - pos_label = _check_pos_label_consistency(pos_label, y_true) - - name = name if name is not None else "Classifier" - - return pos_label, name - def _validate_score_name(score_name, scoring, negate_score): """Validate the `score_name` parameter. @@ -237,8 +323,29 @@ def _despine(ax): ax.spines[s].set_bounds(0, 1) -# TODO(1.9): remove -# Should this be a parent class method? +def _deprecate_estimator_name(old, new, version): + """Deprecate `estimator_name` in favour of `name`.""" + version = parse_version(version) + # Not sure if I should hard code this because this wouldn't work if we release + # a new major version ? + version_remove = f"{version.major}.{version.minor + 2}" + if old != "deprecated": + if new: + raise ValueError( + f"Both 'estimator_name' and 'name' provided, please only use 'name' " + f"as 'estimator_name' is deprecated in {version} and will be removed " + f"in {version_remove}." + ) + warnings.warn( + f"'estimator_name' was passed to 'name' as 'estimator_name' is deprecated " + f"in {version} and will be removed in {version_remove}. Please use " + f"'name' in future.", + FutureWarning, + ) + return old + return new + + def _convert_to_list_leaving_none(param): """Convert parameters to a list, leaving `None` as is.""" if param is None: @@ -248,7 +355,6 @@ def _convert_to_list_leaving_none(param): return [param] -# Should this be a parent class/mixin method? def _check_param_lengths(required, optional, class_name): """Check required and optional parameters are of the same length.""" optional_provided = {} @@ -270,35 +376,25 @@ def _check_param_lengths(required, optional, class_name): ) -def _process_fold_names_line_kwargs(n_curves, fold_names, fold_line_kwargs): +# Potentially useful for non binary displays `LearningCurveDisplay` and +# `ValidationCurveDisplay`, so not placed under `_BinaryClassifierCurveDisplayMixin` +def _validate_line_kwargs(n_curves, fold_line_kwargs=None, default_line_kwargs=None): """Ensure that `fold_names` and `fold_line_kwargs` are of correct length.""" - msg = ( - "When `{param}` is provided, it must have the same length as " - "the number of curves to be plotted. Got {len_param} " - "instead of {n_curves}." - ) - - if fold_names is None: - # " fold ?" - fold_names_ = [f"Fold: {idx}" for idx in range(n_curves)] - elif len(fold_names) != n_curves: - raise ValueError( - msg.format(param="fold_names", len_param=len(fold_names), n_curves=n_curves) - ) - else: - fold_names_ = fold_names - - if isinstance(fold_line_kwargs, Mapping): - fold_line_kws_ = [fold_line_kwargs] * n_curves - elif fold_names is not None and len(fold_line_kwargs) != n_curves: + if fold_line_kwargs is None and default_line_kwargs is not None: + fold_line_kwargs = default_line_kwargs + elif fold_line_kwargs is None: + fold_line_kwargs = [{}] * n_curves + elif isinstance(fold_line_kwargs, Mapping): + fold_line_kwargs = [fold_line_kwargs] * n_curves + elif len(fold_line_kwargs) != n_curves: raise ValueError( - msg.format( + MULTI_PARAM_ERROR_MSG.format( param="fold_line_kwargs", len_param=len(fold_line_kwargs), n_curves=n_curves, ) ) else: - fold_line_kws_ = fold_line_kwargs + fold_line_kwargs = fold_line_kwargs - return fold_names_, fold_line_kws_ + return fold_line_kwargs From d3990716fc89e28421f12c2148603d9c53dc6b8c Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Fri, 7 Mar 2025 16:39:24 +1100 Subject: [PATCH 33/63] review and fix tests --- .../sklearn.metrics/30399.enhancement.rst | 3 ++- sklearn/metrics/_plot/roc_curve.py | 9 +++++---- sklearn/metrics/_plot/tests/test_roc_curve_display.py | 6 +++--- sklearn/utils/_plotting.py | 8 +++----- 4 files changed, 13 insertions(+), 13 deletions(-) diff --git a/doc/whats_new/upcoming_changes/sklearn.metrics/30399.enhancement.rst b/doc/whats_new/upcoming_changes/sklearn.metrics/30399.enhancement.rst index e9b75a20a5000..c3b6d77c5aefb 100644 --- a/doc/whats_new/upcoming_changes/sklearn.metrics/30399.enhancement.rst +++ b/doc/whats_new/upcoming_changes/sklearn.metrics/30399.enhancement.rst @@ -1,3 +1,4 @@ - Add class method `from_cv_results` to :class:`metrics.RocCurveDisplay`, which allows - easy plotting of multiple ROC curves using cross-validation results. + easy plotting of multiple ROC curves from :func:`model_selection.cross_validate` + results. By :user:`Lucy Liu ` diff --git a/sklearn/metrics/_plot/roc_curve.py b/sklearn/metrics/_plot/roc_curve.py index 426d28821c848..06ef94bea4ec3 100644 --- a/sklearn/metrics/_plot/roc_curve.py +++ b/sklearn/metrics/_plot/roc_curve.py @@ -47,6 +47,7 @@ class RocCurveDisplay(_BinaryClassifierCurveDisplayMixin): is also `None` no legend is added. name : str or list of str, default=None + (Do we prefer curve_name) ? Name of each ROC curve, used for labeling curves in the legend. If plotting multiple curves, should be a list of the same length as `fpr` and `tpr`. If `None`, no name is not shown in the legend. If `roc_auc` @@ -261,9 +262,8 @@ def plot( if despine: _despine(self.ax_) - if ( - line_kwargs[0].get("label") is not None - or chance_level_kw.get("label") is not None + if line_kwargs[0].get("label") is not None or ( + plot_chance_level and chance_level_kw.get("label") is not None ): self.ax_.legend(loc="lower right") @@ -556,7 +556,8 @@ def from_cv_results( ---------- cv_results : dict Dictionary as returned by :func:`~sklearn.model_selection.cross_validate` - using `return_estimator=True` and `return_indices=True`. + using `return_estimator=True` and `return_indices=True` (i.e., dictionary + should contain the keys "estimator" and "indices"). X : {array-like, sparse matrix} of shape (n_samples, n_features) Input values. diff --git a/sklearn/metrics/_plot/tests/test_roc_curve_display.py b/sklearn/metrics/_plot/tests/test_roc_curve_display.py index 9f6510c041581..786db2ff5b427 100644 --- a/sklearn/metrics/_plot/tests/test_roc_curve_display.py +++ b/sklearn/metrics/_plot/tests/test_roc_curve_display.py @@ -202,7 +202,7 @@ def test_roc_curve_from_cv_results_param_validation(pyplot, data_binary, data): @pytest.mark.parametrize("response_method", ["predict_proba", "decision_function"]) @pytest.mark.parametrize("with_sample_weight", [True, False]) @pytest.mark.parametrize("with_strings", [True, False]) -def test_test_roc_curve_display_plotting_from_cv_results( +def test_roc_curve_display_plotting_from_cv_results( pyplot, data_binary, with_strings, @@ -274,7 +274,7 @@ def test_test_roc_curve_display_plotting_from_cv_results( # Default alpha for `from_cv_results` line.get_alpha() == 0.5 expected_label = f"{fold_names[idx]} (AUC = {display.roc_auc_[idx]:.2f})" - assert display.line_.get_label() == expected_label + assert line.get_label() == expected_label @pytest.mark.parametrize("fold_names", [None, ["one", "two", "three"]]) @@ -459,7 +459,7 @@ def test_roc_curve_chance_level_line_from_cv_results( X, y, plot_chance_level=plot_chance_level, - chance_level_kw=chance_level_kw, + chance_level_kwargs=chance_level_kw, fold_line_kwargs=fold_line_kwargs, ) diff --git a/sklearn/utils/_plotting.py b/sklearn/utils/_plotting.py index beaff75ebaae0..893e8d500e3bb 100644 --- a/sklearn/utils/_plotting.py +++ b/sklearn/utils/_plotting.py @@ -33,12 +33,10 @@ def _validate_plot_params(self, *, ax=None, name=None): if ax is None: _, ax = plt.subplots() - # We are changing from `estimator_name` to `name`, Display objects will - # have one or the other. Try old attr name: `estimator_name` first. + # Display classes are in process of changing from `estimator_name` to `name`. + # Try old attr name: `estimator_name` first. if name is None: - name = getattr(self, "estimator_name", None) - if name is None: - name = getattr(self, "name", None) + name = getattr(self, "estimator_name", getattr(self, "name", None)) return ax, ax.figure, name @classmethod From 126dfb9b5566a0c9cd1a453a355112b8a20291da Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Fri, 7 Mar 2025 21:38:18 +1100 Subject: [PATCH 34/63] comment --- sklearn/metrics/_plot/tests/test_common_curve_display.py | 3 ++- sklearn/metrics/_plot/tests/test_roc_curve_display.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/sklearn/metrics/_plot/tests/test_common_curve_display.py b/sklearn/metrics/_plot/tests/test_common_curve_display.py index 0c1c27927e127..7c8516d5619e0 100644 --- a/sklearn/metrics/_plot/tests/test_common_curve_display.py +++ b/sklearn/metrics/_plot/tests/test_common_curve_display.py @@ -164,6 +164,8 @@ def test_display_curve_estimator_name_multiple_calls( assert clf_name in disp.line_.get_label() +# TODO: remove this test once classes moved to using `name_` instead of +# `estimator_name` @pytest.mark.parametrize( "clf", [ @@ -190,7 +192,6 @@ def test_display_curve_not_fitted_errors_old_name(pyplot, data_binary, clf, Disp assert disp.estimator_name == model.__class__.__name__ -# This uses the new `name` para @pytest.mark.parametrize( "clf", [ diff --git a/sklearn/metrics/_plot/tests/test_roc_curve_display.py b/sklearn/metrics/_plot/tests/test_roc_curve_display.py index 786db2ff5b427..7dfe477e949e5 100644 --- a/sklearn/metrics/_plot/tests/test_roc_curve_display.py +++ b/sklearn/metrics/_plot/tests/test_roc_curve_display.py @@ -122,7 +122,8 @@ def test_roc_curve_display_plotting( pos_label=pos_label, ) - # Both processed and unprocessed attributes should be the same for single curve + # Both processed (e.g., `roc_auc_`) and unprocessed (e.g., `roc_auc`) attributes + # should be the same for single curve assert_allclose(display.roc_auc_[0], auc(fpr, tpr)) assert_allclose(display.roc_auc, auc(fpr, tpr)) assert_allclose(display.fpr_[0], fpr) From 412c31311923a627aa3d2958de28eb40c228d5de Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Wed, 19 Mar 2025 15:02:08 +1100 Subject: [PATCH 35/63] review and add tests --- sklearn/metrics/_plot/roc_curve.py | 10 +- .../_plot/tests/test_roc_curve_display.py | 42 ++++++++ sklearn/utils/_plotting.py | 99 +++++++++++++------ 3 files changed, 114 insertions(+), 37 deletions(-) diff --git a/sklearn/metrics/_plot/roc_curve.py b/sklearn/metrics/_plot/roc_curve.py index 06ef94bea4ec3..291ad284501a4 100644 --- a/sklearn/metrics/_plot/roc_curve.py +++ b/sklearn/metrics/_plot/roc_curve.py @@ -8,7 +8,6 @@ _convert_to_list_leaving_none, _deprecate_estimator_name, _despine, - _validate_line_kwargs, _validate_style_kwargs, ) from ...utils._response import _get_response_values_binary @@ -522,8 +521,6 @@ def from_predictions( return viz.plot( ax=ax, - # Should we provide `name` to both `cls` and `plot` or just `cls`? - name=name, plot_chance_level=plot_chance_level, chance_level_kw=chance_level_kw, despine=despine, @@ -595,8 +592,8 @@ def from_cv_results( the CV fold. fold_line_kwargs : dict or list of dict, default=None - Dictionary with keywords passed to the matplotlib's `plot` function - to draw the individual ROC curves. If a list is provided, the + Keywords arguments to be passed to matplotlib's `plot` function + to draw individual ROC curves. If a list is provided, the parameters are applied to the ROC curves of each CV fold sequentially. If a single dictionary is provided, the same parameters are applied to all ROC curves. @@ -648,7 +645,7 @@ def from_cv_results( pos_label=pos_label, fold_names=fold_names, ) - fold_line_kwargs_ = _validate_line_kwargs( + fold_line_kwargs_ = cls._validate_line_kwargs( len(cv_results["estimator"]), fold_line_kwargs, default_line_kwargs={"alpha": 0.5, "linestyle": "--"}, @@ -688,7 +685,6 @@ def from_cv_results( viz = cls( fpr=fpr_all, tpr=tpr_all, - # Should we provide `name` to both `cls` and `plot` or just `cls`? name=fold_names_, roc_auc=auc_all, pos_label=pos_label, diff --git a/sklearn/metrics/_plot/tests/test_roc_curve_display.py b/sklearn/metrics/_plot/tests/test_roc_curve_display.py index b9e631ada4d97..d7e18564eef59 100644 --- a/sklearn/metrics/_plot/tests/test_roc_curve_display.py +++ b/sklearn/metrics/_plot/tests/test_roc_curve_display.py @@ -199,6 +199,48 @@ def test_roc_curve_from_cv_results_param_validation(pyplot, data_binary, data): ) +@pytest.mark.parametrize( + "fold_line_kwargs", + [None, {"alpha": 0.2}, [{"alpha": 0.2}, {"alpha": 0.3}, {"alpha": 0.4}]], +) +def test_roc_curve_display_from_cv_results_validate_line_kwargs( + pyplot, data_binary, fold_line_kwargs +): + """Check `_validate_line_kwargs` correctly validates line kwargs.""" + X, y = data_binary + n_cv = 3 + cv_results = cross_validate( + LogisticRegression(), X, y, cv=n_cv, return_estimator=True, return_indices=True + ) + display = RocCurveDisplay.from_cv_results( + cv_results, + X, + y, + fold_line_kwargs=fold_line_kwargs, + ) + if fold_line_kwargs is None: + # Default `alpha` used + assert all(line.get_alpha() == 0.5 for line in display.line_) + elif isinstance(fold_line_kwargs, Mapping): + # `alpha` from dict used for all curves + assert all(line.get_alpha() == 0.2 for line in display.line_) + else: + # Different `alpha` used for each curve + assert all( + line.get_alpha() == fold_line_kwargs[i]["alpha"] + for i, line in enumerate(display.line_) + ) + + +# TODO : Remove in 1.9 +def test_roc_curve_display_estimator_name_deprecation(pyplot): + """Check deprecation of `estimator_name`.""" + fpr = np.array([0, 0.5, 1]) + tpr = np.array([0, 0.5, 1]) + with pytest.warns(FutureWarning, match="'estimator_name' is deprecated in"): + RocCurveDisplay(fpr=fpr, tpr=tpr, estimator_name="test") + + @pytest.mark.parametrize("drop_intermediate", [True, False]) @pytest.mark.parametrize("response_method", ["predict_proba", "decision_function"]) @pytest.mark.parametrize("with_sample_weight", [True, False]) diff --git a/sklearn/utils/_plotting.py b/sklearn/utils/_plotting.py index 893e8d500e3bb..1854db4f68fac 100644 --- a/sklearn/utils/_plotting.py +++ b/sklearn/utils/_plotting.py @@ -181,7 +181,7 @@ def _get_line_kwargs( if n_curves == 1: fold_line_kwargs = [kwargs] else: - fold_line_kwargs = _validate_line_kwargs(n_curves, fold_line_kwargs) + fold_line_kwargs = cls._validate_line_kwargs(n_curves, fold_line_kwargs) if default_line_kwargs is None: default_line_kwargs = {} @@ -205,6 +205,71 @@ def _get_line_kwargs( ) return line_kwargs + # TODO: if useful for non binary displays (e.g.,`LearningCurveDisplay`, + # `ValidationCurveDisplay`) amend to function + @classmethod + def _validate_line_kwargs( + cls, n_curves, fold_line_kwargs=None, default_line_kwargs=None + ): + """Ensure `fold_line_kwargs` length and incorporate default kwargs. + + * If `fold_line_kwargs` is None: + * If `default_line_kwargs` is None, list of `n_curves` empty dictionaries + is returned. + * If `default_line_kwargs` is not None, list of `n_curves` dictionaries + of `default_line_kwargs` returned. + * If `fold_line_kwargs` is a single dictionary, it is incorporated with + `default_line_kwargs` using `_validate_style_kwargs`, and the resulting + dictionary is repeated `n_curves` times and returned. + * If `fold_line_kwargs` is a list of length `n_curves`, each dict is + incorporated with `default_line_kwargs` using `_validate_style_kwargs` and + returned as list of `n_curves` dictionaries. + + If `fold_line_kwargs` is a list not of length `n_curves`, an error is raised. + + Parameters + ---------- + n_curves : int + Number of curves. + + fold_line_kwargs : dict or list of dict, default=None + Keywords arguments to be passed to matplotlib's `plot` function + to draw ROC curves. + + default_line_kwargs : dict, default=None + Default line kwargs to be used in all curves, unless overridden by + `fold_line_kwargs`. + + Returns + ------- + fold_line_kwargs : list of dict + List of `n_curves` dictionaries. + """ + if fold_line_kwargs is None and default_line_kwargs is None: + fold_line_kwargs = [{}] * n_curves + elif fold_line_kwargs is None and default_line_kwargs is not None: + fold_line_kwargs = default_line_kwargs + + if isinstance(fold_line_kwargs, Mapping): + fold_line_kwargs = [fold_line_kwargs] * n_curves + elif len(fold_line_kwargs) != n_curves: + raise ValueError( + MULTI_PARAM_ERROR_MSG.format( + param="fold_line_kwargs", + len_param=len(fold_line_kwargs), + n_curves=n_curves, + ) + ) + else: + fold_line_kwargs = fold_line_kwargs + + if default_line_kwargs is not None: + fold_line_kwargs = [ + _validate_style_kwargs(default_line_kwargs, single_kwargs) + for single_kwargs in fold_line_kwargs + ] + return fold_line_kwargs + def _validate_score_name(score_name, scoring, negate_score): """Validate the `score_name` parameter. @@ -324,8 +389,6 @@ def _despine(ax): def _deprecate_estimator_name(old, new, version): """Deprecate `estimator_name` in favour of `name`.""" version = parse_version(version) - # Not sure if I should hard code this because this wouldn't work if we release - # a new major version ? version_remove = f"{version.major}.{version.minor + 2}" if old != "deprecated": if new: @@ -335,9 +398,9 @@ def _deprecate_estimator_name(old, new, version): f"in {version_remove}." ) warnings.warn( - f"'estimator_name' was passed to 'name' as 'estimator_name' is deprecated " - f"in {version} and will be removed in {version_remove}. Please use " - f"'name' in future.", + f"'estimator_name' is deprecated in {version} and will be removed in " + f"{version_remove}. The value of 'estimator_name' was passed to 'name'" + "but please use 'name' in future.", FutureWarning, ) return old @@ -372,27 +435,3 @@ def _check_param_lengths(required, optional, class_name): f"from `{class_name}` initialization (or `plot`) should all be lists of " f"the same length. Got: {lengths_formatted}" ) - - -# Potentially useful for non binary displays `LearningCurveDisplay` and -# `ValidationCurveDisplay`, so not placed under `_BinaryClassifierCurveDisplayMixin` -def _validate_line_kwargs(n_curves, fold_line_kwargs=None, default_line_kwargs=None): - """Ensure that `fold_names` and `fold_line_kwargs` are of correct length.""" - if fold_line_kwargs is None and default_line_kwargs is not None: - fold_line_kwargs = default_line_kwargs - elif fold_line_kwargs is None: - fold_line_kwargs = [{}] * n_curves - elif isinstance(fold_line_kwargs, Mapping): - fold_line_kwargs = [fold_line_kwargs] * n_curves - elif len(fold_line_kwargs) != n_curves: - raise ValueError( - MULTI_PARAM_ERROR_MSG.format( - param="fold_line_kwargs", - len_param=len(fold_line_kwargs), - n_curves=n_curves, - ) - ) - else: - fold_line_kwargs = fold_line_kwargs - - return fold_line_kwargs From d3af77ce854ab1435cef2e6fe434543a89345369 Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Tue, 25 Mar 2025 14:51:53 +1100 Subject: [PATCH 36/63] add aggregate legend --- sklearn/metrics/_plot/roc_curve.py | 107 ++++++++++--- .../_plot/tests/test_roc_curve_display.py | 148 ++++++++++-------- sklearn/utils/_plotting.py | 128 +++++++++------ 3 files changed, 253 insertions(+), 130 deletions(-) diff --git a/sklearn/metrics/_plot/roc_curve.py b/sklearn/metrics/_plot/roc_curve.py index 291ad284501a4..2b5174d59bf33 100644 --- a/sklearn/metrics/_plot/roc_curve.py +++ b/sklearn/metrics/_plot/roc_curve.py @@ -1,6 +1,8 @@ # Authors: The scikit-learn developers # SPDX-License-Identifier: BSD-3-Clause +import numpy as np + from ...utils import _safe_indexing from ...utils._plotting import ( _BinaryClassifierCurveDisplayMixin, @@ -40,17 +42,32 @@ class RocCurveDisplay(_BinaryClassifierCurveDisplayMixin): and `fpr`. roc_auc : float or list of floats, default=None - Area under ROC curve, used for labeling curves in the legend. + Area under ROC curve, used for labeling each curve in the legend. If plotting multiple curves, should be a list of the same length as `fpr` - and `tpr`. If `None`, no area under ROC curve score is shown. If `name` - is also `None` no legend is added. + and `tpr`. If `None`, individual ROC AUC scores are not shown. See + `roc_auc_aggregate` for alternative. + If `name` and `roc_auc_aggregate` are also `None` no legend is added. + + roc_auc_aggregate : tuple(float), default=None + ROC AUC mean and standard deviation. An alternative to `roc_auc` when + plotting multiple curves and a single legend entry showing ROC AUC mean and + standard deviation for all curves is desired. + If `True`, `name` cannot be a list of length >1. name : str or list of str, default=None (Do we prefer curve_name) ? - Name of each ROC curve, used for labeling curves in the legend. - If plotting multiple curves, should be a list of the same length as `fpr` - and `tpr`. If `None`, no name is not shown in the legend. If `roc_auc` - is also `None` no legend is added. + Name for labeling legend entries. For single ROC curve, should be a str or + list of length one. For multiple ROC curves: + + * if list of names provided, should be the same length as `fpr` + and `tpr`. Each individual curve will be labeled in the legend. Cannot + be used in conjunction with `roc_auc_aggregate`. + * if a single name provided (as str or list of length one), a single legend + entry will be used to label all curves. Cannot be used in conjunction with + `roc_auc`. + + If `None`, no name is not shown in the legend. If `roc_auc` + and `roc_auc_aggregate` are also `None` no legend is added. pos_label : int, float, bool or str, default=None The class considered as the positive class when computing the roc auc @@ -115,6 +132,7 @@ def __init__( fpr, tpr, roc_auc=None, + roc_auc_aggregate=None, name=None, pos_label=None, estimator_name="deprecated", @@ -122,26 +140,47 @@ def __init__( self.fpr = fpr self.tpr = tpr self.roc_auc = roc_auc + self.roc_auc_aggregate = roc_auc_aggregate self.name = _deprecate_estimator_name(estimator_name, name, "1.7") self.pos_label = pos_label def _validate_plot_params(self, *, ax=None, name=None): self.ax_, self.figure_, name_ = super()._validate_plot_params(ax=ax, name=name) + if self.roc_auc_aggregate: + if self.roc_auc is not None: + raise ValueError( + "'self.roc_auc' and 'self.roc_auc_aggregate' cannot both be " + "provided." + ) + if isinstance(name_, list) and len(name_) != 1: + raise ValueError( + "When 'roc_auc_aggregate' is True, 'name' (or self.name) " + "must be a string or a list of length one." + ) + self.fpr_ = _convert_to_list_leaving_none(self.fpr) self.tpr_ = _convert_to_list_leaving_none(self.tpr) self.roc_auc_ = _convert_to_list_leaving_none(self.roc_auc) self.name_ = _convert_to_list_leaving_none(name_) + optional = {"self.roc_auc": self.roc_auc_} + if self.name_ is not None and len(self.name_) != 1: + optional.update({"'name' (or self.name)": self.name_}) + _check_param_lengths( required={"self.fpr": self.fpr_, "self.tpr": self.tpr_}, - optional={ - "self.roc_auc": self.roc_auc_, - "`name` from `plot` (or self.name)": self.name_, - }, + optional=optional, class_name="RocCurveDisplay", ) + if self.roc_auc: + if isinstance(name_, list) and len(name_) == 1: + raise ValueError( + "When 'roc_auc' is provided, 'name' (or self.name) " + f"must be None or a list of length {len(self.fpr_)}." + ) + def plot( self, ax=None, @@ -208,13 +247,20 @@ def plot( Object that stores computed values. """ self._validate_plot_params(ax=ax, name=name) + summary_value, summary_value_name = self.roc_auc_, "AUC" + if self.roc_auc_aggregate: + summary_value, summary_value_name = self.roc_auc_aggregate, "AUC" + elif self.roc_auc: + summary_value, summary_value_name = self.roc_auc_, "AUC" + else: + summary_value, summary_value_name = None, None n_curves = len(self.fpr_) line_kwargs = self._get_line_kwargs( n_curves, self.name_, - self.roc_auc_, - "AUC", + summary_value, + summary_value_name, fold_line_kwargs=fold_line_kwargs, **kwargs, ) @@ -539,8 +585,9 @@ def from_cv_results( response_method="auto", pos_label=None, ax=None, - fold_names=None, + name=None, fold_line_kwargs=None, + show_aggregate_score=True, plot_chance_level=False, chance_level_kwargs=None, despine=False, @@ -586,10 +633,12 @@ def from_cv_results( Axes object to plot on. If `None`, a new figure and axes is created. - fold_names : list of str, default=None - Names of each ROC curve, used for labeling curves in the legend. - If `None`, the name will be set to "Fold " where N is the index of - the CV fold. + name : list of str or str, default=None + Name for labeling legend entries. To label each individual curve, + provide a list of names the same length as the number of cross-validation + folds. In this case `show_aggregate_score` cannot be `True`. + To label all curves using a single legend entry, provide a str + or list of length one. If `None`, no name is shown in the legend. fold_line_kwargs : dict or list of dict, default=None Keywords arguments to be passed to matplotlib's `plot` function @@ -598,6 +647,13 @@ def from_cv_results( sequentially. If a single dictionary is provided, the same parameters are applied to all ROC curves. + show_aggregate_score : bool, default=True + Whether to show the ROC AUC mean and standard deviation of curves from + all folds as a single legend entry. If `True`, `name` should be a single + string and `fold_line_kwargs` should be a single dictionary, to prevent + confusion in the legend. If `False`, `name` should be None or a list the + same length as the number of cross-validation folds. + plot_chance_level : bool, default=False Whether to plot the chance level. @@ -637,14 +693,17 @@ def from_cv_results( <...> >>> plt.show() """ - pos_label, fold_names_ = cls._validate_from_cv_results_params( + pos_label, name_ = cls._validate_from_cv_results_params( cv_results, X, y, sample_weight=sample_weight, pos_label=pos_label, - fold_names=fold_names, + name=name, + fold_line_kwargs=fold_line_kwargs, + show_aggregate_score=show_aggregate_score, ) + fold_line_kwargs_ = cls._validate_line_kwargs( len(cv_results["estimator"]), fold_line_kwargs, @@ -682,11 +741,17 @@ def from_cv_results( tpr_all.append(tpr) auc_all.append(roc_auc) + roc_auc_aggregate = None + if show_aggregate_score: + roc_auc_aggregate = (np.mean(auc_all), np.std(auc_all)) + auc_all = None + viz = cls( fpr=fpr_all, tpr=tpr_all, - name=fold_names_, + name=name_, roc_auc=auc_all, + roc_auc_aggregate=roc_auc_aggregate, pos_label=pos_label, ) return viz.plot( diff --git a/sklearn/metrics/_plot/tests/test_roc_curve_display.py b/sklearn/metrics/_plot/tests/test_roc_curve_display.py index d7e18564eef59..d42b8368b6517 100644 --- a/sklearn/metrics/_plot/tests/test_roc_curve_display.py +++ b/sklearn/metrics/_plot/tests/test_roc_curve_display.py @@ -188,48 +188,51 @@ def test_roc_curve_from_cv_results_param_validation(pyplot, data_binary, data): RocCurveDisplay.from_cv_results(cv_results, X_bad_pos_label, y_bad_pos_label) # `fold_names` incorrect length - with pytest.raises(ValueError, match="When 'fold_names' is provided, it must"): - RocCurveDisplay.from_cv_results(cv_results, X, y, fold_names=["fold"]) + with pytest.raises(ValueError, match="'name' must be None or list of length"): + RocCurveDisplay.from_cv_results( + cv_results, X, y, name=["fold"], show_aggregate_score=False + ) # `fold_line_kwargs` incorrect length with pytest.raises( - ValueError, match="When 'fold_line_kwargs' is provided, it must" + ValueError, match="'fold_line_kwargs' must be a single dictionary to" ): RocCurveDisplay.from_cv_results( cv_results, X, y, fold_line_kwargs=[{"alpha": 1}] ) -@pytest.mark.parametrize( - "fold_line_kwargs", - [None, {"alpha": 0.2}, [{"alpha": 0.2}, {"alpha": 0.3}, {"alpha": 0.4}]], -) -def test_roc_curve_display_from_cv_results_validate_line_kwargs( - pyplot, data_binary, fold_line_kwargs -): - """Check `_validate_line_kwargs` correctly validates line kwargs.""" - X, y = data_binary - n_cv = 3 - cv_results = cross_validate( - LogisticRegression(), X, y, cv=n_cv, return_estimator=True, return_indices=True - ) - display = RocCurveDisplay.from_cv_results( - cv_results, - X, - y, - fold_line_kwargs=fold_line_kwargs, - ) - if fold_line_kwargs is None: - # Default `alpha` used - assert all(line.get_alpha() == 0.5 for line in display.line_) - elif isinstance(fold_line_kwargs, Mapping): - # `alpha` from dict used for all curves - assert all(line.get_alpha() == 0.2 for line in display.line_) - else: - # Different `alpha` used for each curve - assert all( - line.get_alpha() == fold_line_kwargs[i]["alpha"] - for i, line in enumerate(display.line_) - ) +# @pytest.mark.parametrize( +# "fold_line_kwargs", +# [None, {"alpha": 0.2}, [{"alpha": 0.2}, {"alpha": 0.3}, {"alpha": 0.4}]], +# ) +# def test_roc_curve_display_from_cv_results_validate_line_kwargs( +# pyplot, data_binary, fold_line_kwargs +# ): +# """Check `_validate_line_kwargs` correctly validates line kwargs.""" +# X, y = data_binary +# n_cv = 3 +# cv_results = cross_validate( +# LogisticRegression(), X, y, cv=n_cv, return_estimator=True, +# return_indices=True +# ) +# display = RocCurveDisplay.from_cv_results( +# cv_results, +# X, +# y, +# fold_line_kwargs=fold_line_kwargs, +# ) +# if fold_line_kwargs is None: +# # Default `alpha` used +# assert all(line.get_alpha() == 0.5 for line in display.line_) +# elif isinstance(fold_line_kwargs, Mapping): +# # `alpha` from dict used for all curves +# assert all(line.get_alpha() == 0.2 for line in display.line_) +# else: +# # Different `alpha` used for each curve +# assert all( +# line.get_alpha() == fold_line_kwargs[i]["alpha"] +# for i, line in enumerate(display.line_) +# ) # TODO : Remove in 1.9 @@ -241,6 +244,7 @@ def test_roc_curve_display_estimator_name_deprecation(pyplot): RocCurveDisplay(fpr=fpr, tpr=tpr, estimator_name="test") +@pytest.mark.parametrize("show_aggregate_score", [True, False]) @pytest.mark.parametrize("drop_intermediate", [True, False]) @pytest.mark.parametrize("response_method", ["predict_proba", "decision_function"]) @pytest.mark.parametrize("with_sample_weight", [True, False]) @@ -252,6 +256,7 @@ def test_roc_curve_display_plotting_from_cv_results( with_sample_weight, response_method, drop_intermediate, + show_aggregate_score, ): """Check overall plotting of `from_cv_results`.""" X, y = data_binary @@ -278,8 +283,10 @@ def test_roc_curve_display_plotting_from_cv_results( drop_intermediate=drop_intermediate, response_method=response_method, pos_label=pos_label, + show_aggregate_score=show_aggregate_score, ) + auc_all = [] for idx, (estimator, test_indices) in enumerate( zip(cv_results["estimator"], cv_results["indices"]["test"]) ): @@ -302,40 +309,52 @@ def test_roc_curve_display_plotting_from_cv_results( drop_intermediate=drop_intermediate, pos_label=pos_label, ) - assert_allclose(display.roc_auc_[idx], auc(fpr, tpr)) - assert_allclose(display.fpr_[idx], fpr) - assert_allclose(display.tpr_[idx], tpr) - - fold_names = ["Fold 0", "Fold 1", "Fold 2"] - assert display.name_ == fold_names + if show_aggregate_score: + auc_all.append(auc(fpr, tpr)) + else: + assert_allclose(display.roc_auc_[idx], auc(fpr, tpr)) + assert_allclose(display.fpr_[idx], fpr) + assert_allclose(display.tpr_[idx], tpr) + + if show_aggregate_score: + mean, std = np.mean(auc_all), np.std(auc_all) + assert (mean, std) == pytest.approx(display.roc_auc_aggregate) + assert display.name_ is None + else: + fold_names = ["Fold 0", "Fold 1", "Fold 2"] + assert display.name_ == fold_names import matplotlib as mpl _check_figure_axes_and_labels(display, pos_label) + aggregate_expected_labels = ["AUC = 1.00 +/- 0.00", "_child1", "_child2"] for idx, line in enumerate(display.line_): assert isinstance(line, mpl.lines.Line2D) # Default alpha for `from_cv_results` line.get_alpha() == 0.5 - expected_label = f"{fold_names[idx]} (AUC = {display.roc_auc_[idx]:.2f})" - assert line.get_label() == expected_label - - -@pytest.mark.parametrize("fold_names", [None, ["one", "two", "three"]]) -def test_roc_curve_from_cv_results_fold_names(pyplot, data_binary, fold_names): - """Check fold names behaviour correct in `from_cv_results`.""" - X, y = data_binary - cv_results = cross_validate( - LogisticRegression(), X, y, cv=3, return_estimator=True, return_indices=True - ) - display = RocCurveDisplay.from_cv_results(cv_results, X, y, fold_names=fold_names) - legend = display.ax_.get_legend() - legend_labels = [text.get_text() for text in legend.get_texts()] - expected_names = ( - ["Fold 0", "Fold 1", "Fold 2"] if fold_names is None else fold_names - ) - assert display.name_ == expected_names - expected_labels = [name + " (AUC = 1.00)" for name in expected_names] - assert legend_labels == expected_labels + if show_aggregate_score: + assert line.get_label() == aggregate_expected_labels[idx] + else: + expected_label = f"{fold_names[idx]} (AUC = {display.roc_auc_[idx]:.2f})" + assert line.get_label() == expected_label + + +# @pytest.mark.parametrize("fold_names", [None, ["one", "two", "three"]]) +# def test_roc_curve_from_cv_results_fold_names(pyplot, data_binary, fold_names): +# """Check fold names behaviour correct in `from_cv_results`.""" +# X, y = data_binary +# cv_results = cross_validate( +# LogisticRegression(), X, y, cv=3, return_estimator=True, return_indices=True +# ) +# display = RocCurveDisplay.from_cv_results(cv_results, X, y, fold_names=fold_names) +# legend = display.ax_.get_legend() +# legend_labels = [text.get_text() for text in legend.get_texts()] +# expected_names = ( +# ["Fold 0", "Fold 1", "Fold 2"] if fold_names is None else fold_names +# ) +# assert display.name_ == expected_names +# expected_labels = [name + " (AUC = 1.00)" for name in expected_names] +# assert legend_labels == expected_labels @pytest.mark.parametrize( @@ -351,7 +370,7 @@ def test_roc_curve_from_cv_results_line_kwargs(pyplot, data_binary, fold_line_kw LogisticRegression(), X, y, cv=3, return_estimator=True, return_indices=True ) display = RocCurveDisplay.from_cv_results( - cv_results, X, y, fold_line_kwargs=fold_line_kwargs + cv_results, X, y, fold_line_kwargs=fold_line_kwargs, show_aggregate_score=False ) mpl_default_colors = mpl.rcParams["axes.prop_cycle"].by_key()["color"] @@ -524,7 +543,7 @@ def test_roc_curve_chance_level_line_from_cv_results( if chance_level_kw.get("label") is not None: assert chance_level_kw["label"] in legend_labels else: - assert len(legend_labels) == n_cv + assert len(legend_labels) == 1 @pytest.mark.parametrize( @@ -585,11 +604,10 @@ def test_roc_curve_display_default_labels(pyplot, roc_auc, name, expected_labels def _check_auc(display, constructor_name): roc_auc_limit = 0.95679 - roc_auc_limit_multi = [0.97007, 0.985915, 0.980952] + roc_auc_mean_std = (0.978979, 0.006617449) if constructor_name == "from_cv_results": - for idx, roc_auc in enumerate(display.roc_auc_): - assert roc_auc == pytest.approx(roc_auc_limit_multi[idx]) + assert display.roc_auc_aggregate == pytest.approx(roc_auc_mean_std) else: assert display.roc_auc == pytest.approx(roc_auc_limit) assert trapezoid(display.tpr, display.fpr) == pytest.approx(roc_auc_limit) diff --git a/sklearn/utils/_plotting.py b/sklearn/utils/_plotting.py index 1854db4f68fac..df066da2f5fa1 100644 --- a/sklearn/utils/_plotting.py +++ b/sklearn/utils/_plotting.py @@ -12,10 +12,9 @@ from .multiclass import type_of_target from .validation import _check_pos_label_consistency, _num_samples -MULTI_PARAM_ERROR_MSG = ( - "When '{param}' is provided, it must have the same length as " - "the number of curves to be plotted. Got: {len_param}; " - "expected: {n_curves}." +AGGREGATE_ERROR_MESSAGE = ( + "'fold_line_kwargs' must be a single dictionary to be applied to all curves " + "when {param_value}, as only one legend entry will be added." ) @@ -77,7 +76,16 @@ def _validate_from_predictions_params( @classmethod def _validate_from_cv_results_params( - cls, cv_results, X, y, *, sample_weight=None, pos_label=None, fold_names=None + cls, + cv_results, + X, + y, + *, + sample_weight, + pos_label, + name, + fold_line_kwargs, + show_aggregate_score, ): check_matplotlib_support(f"{cls.__name__}.from_predictions") @@ -114,25 +122,48 @@ def _validate_from_cv_results_params( raise ValueError(str(e).replace("y_true", "y")) n_curves = len(cv_results["estimator"]) - if fold_names is None: - fold_names = [f"Fold {idx}" for idx in range(n_curves)] - elif len(fold_names) != n_curves: - raise ValueError( - MULTI_PARAM_ERROR_MSG.format( - param="fold_names", len_param=len(fold_names), n_curves=n_curves + if show_aggregate_score: + if isinstance(name, list) and len(name) != 1: + raise ValueError( + "'name' must be a string or list of length one when " + "'show_aggregate_score' is True, as only one legend entry " + "will be added." + ) + # Should we allow a list of length 1 (in addition to single dict) ?? + if isinstance(fold_line_kwargs, list): + raise ValueError( + AGGREGATE_ERROR_MESSAGE.format( + param_value="'show_aggregate_score' is True" + ) ) - ) else: - fold_names = fold_names + # Individual ROC AUC scores shown + if name is not None and (isinstance(name, list) and len(name) != n_curves): + raise ValueError( + f"'name' must be None or list of length {n_curves} when " + f"'show_aggregate_score' is False." + ) + if name is None: + name = [f"Fold {idx}" for idx in range(n_curves)] + return pos_label, name - return pos_label, fold_names + @classmethod + def _get_legend_label(cls, curve_summary_value, curve_name, summary_value_name): + """Helper to get legend label using `name_` and `summary_value_`""" + if curve_summary_value is not None and curve_name is not None: + label = f"{curve_name} ({summary_value_name} = {curve_summary_value:0.2f})" + elif curve_summary_value is not None: + label = f"{summary_value_name} = {curve_summary_value:0.2f}" + elif curve_name is not None: + label = curve_name + return label @classmethod def _get_line_kwargs( cls, n_curves, - names, - summary_values, + name, + summary_value, summary_value_name, fold_line_kwargs, default_line_kwargs=None, @@ -145,13 +176,14 @@ def _get_line_kwargs( n_curves : int Number of curves. - names : list[str] - Names of each curve. + name : list[str] or None + Name for labeling legend entries. - summary_values : list[float] - List of summary values for each curve (e.g., ROC AUC, average precision). + summary_value : list[float] or tuple(float, float) or None + Either list of `n_curves` summary values for each curve (e.g., ROC AUC, + average precision) or a single float summary value for all curves. - summary_value_name : str + summary_value_name : str or None Name of the summary value provided in `summary_values`. fold_line_kwargs : dict or list of dict @@ -172,34 +204,40 @@ def _get_line_kwargs( for multi-curve plots. """ # Ensure parameters are of the correct length - names_ = [None] * n_curves if names is None else names - summary_values_ = ( - [None] * n_curves if summary_values is None else summary_values - ) + name_ = [None] * n_curves if name is None else name + summary_value_ = [None] * n_curves if summary_value is None else summary_value # `fold_line_kwargs` ignored for single curve plots # `kwargs` ignored for multi-curve plots if n_curves == 1: fold_line_kwargs = [kwargs] else: + # Should we add an extra check ensuring `fold_line_kwargs` is the + # same when `summary_value` is (mean, std) ? or allow this flexibility + # when using `plot` - note this is checked and prevented in + # `from_cv_results` fold_line_kwargs = cls._validate_line_kwargs(n_curves, fold_line_kwargs) if default_line_kwargs is None: default_line_kwargs = {} - line_kwargs = [] - for fold_idx, (curve_summary_value, curve_name) in enumerate( - zip(summary_values_, names_) - ): - if curve_summary_value is not None and curve_name is not None: - default_line_kwargs["label"] = ( - f"{curve_name} ({summary_value_name} = {curve_summary_value:0.2f})" - ) - elif curve_summary_value is not None: - default_line_kwargs["label"] = ( - f"{summary_value_name} = {curve_summary_value:0.2f}" + + labels = [] + if isinstance(summary_value_, tuple): + label_aggregate = cls._get_legend_label( + summary_value_[0], name_[0], summary_value_name + ) + label_aggregate = label_aggregate + f" +/- {summary_value_[1]:0.2f}" + labels.extend([label_aggregate] + [None] * (n_curves - 1)) + else: + for curve_summary_value, curve_name in zip(summary_value_, name_): + labels.append( + cls._get_legend_label( + curve_summary_value, curve_name, summary_value_name + ) ) - elif curve_name is not None: - default_line_kwargs["label"] = curve_name + line_kwargs = [] + for fold_idx, label in enumerate(labels): + default_line_kwargs["label"] = label line_kwargs.append( _validate_style_kwargs(default_line_kwargs, fold_line_kwargs[fold_idx]) ) @@ -209,7 +247,11 @@ def _get_line_kwargs( # `ValidationCurveDisplay`) amend to function @classmethod def _validate_line_kwargs( - cls, n_curves, fold_line_kwargs=None, default_line_kwargs=None + cls, + n_curves, + fold_line_kwargs=None, + default_line_kwargs=None, + aggregate_error=False, ): """Ensure `fold_line_kwargs` length and incorporate default kwargs. @@ -254,11 +296,8 @@ def _validate_line_kwargs( fold_line_kwargs = [fold_line_kwargs] * n_curves elif len(fold_line_kwargs) != n_curves: raise ValueError( - MULTI_PARAM_ERROR_MSG.format( - param="fold_line_kwargs", - len_param=len(fold_line_kwargs), - n_curves=n_curves, - ) + f"'fold_line_kwargs' must be a list of length {n_curves} or a " + f"single dictionary. Got list of length: {len(fold_line_kwargs)}." ) else: fold_line_kwargs = fold_line_kwargs @@ -268,6 +307,7 @@ def _validate_line_kwargs( _validate_style_kwargs(default_line_kwargs, single_kwargs) for single_kwargs in fold_line_kwargs ] + return fold_line_kwargs From d22158068417869344d1261c2fa83584327f38e3 Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Wed, 26 Mar 2025 20:18:49 +1100 Subject: [PATCH 37/63] amend condition --- sklearn/utils/_plotting.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/utils/_plotting.py b/sklearn/utils/_plotting.py index df066da2f5fa1..ef5594de7d14f 100644 --- a/sklearn/utils/_plotting.py +++ b/sklearn/utils/_plotting.py @@ -138,10 +138,10 @@ def _validate_from_cv_results_params( ) else: # Individual ROC AUC scores shown - if name is not None and (isinstance(name, list) and len(name) != n_curves): + if isinstance(name, list) and len(name) != n_curves: raise ValueError( f"'name' must be None or list of length {n_curves} when " - f"'show_aggregate_score' is False." + f"'show_aggregate_score' is False. Got list of length {len(name)}" ) if name is None: name = [f"Fold {idx}" for idx in range(n_curves)] From be649de540de591b88688137518070a596271c5f Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Wed, 2 Apr 2025 14:50:02 +1100 Subject: [PATCH 38/63] review --- sklearn/metrics/_plot/roc_curve.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/sklearn/metrics/_plot/roc_curve.py b/sklearn/metrics/_plot/roc_curve.py index 2b5174d59bf33..c190b4ed08c51 100644 --- a/sklearn/metrics/_plot/roc_curve.py +++ b/sklearn/metrics/_plot/roc_curve.py @@ -33,13 +33,11 @@ class RocCurveDisplay(_BinaryClassifierCurveDisplayMixin): ---------- fpr : ndarray or list of ndarrays False positive rates. Each ndarray should contain values for a single curve. - If plotting multiple curves, list should be of same length as - and `tpr`. + If plotting multiple curves, list should be of same length as `tpr`. - tpr : ndarray list of ndarrays + tpr : ndarray or list of ndarrays True positive rates. Each ndarray should contain values for a single curve. - If plotting multiple curves, list should be of same length as - and `fpr`. + If plotting multiple curves, list should be of same length as `fpr`. roc_auc : float or list of floats, default=None Area under ROC curve, used for labeling each curve in the legend. From e1f455b62527b231b2faf01c6f3c5eddd7b09474 Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Thu, 3 Apr 2025 15:40:19 +1100 Subject: [PATCH 39/63] wip remove agg params --- sklearn/metrics/_plot/roc_curve.py | 155 ++++++++---------- .../_plot/tests/test_roc_curve_display.py | 44 +++-- sklearn/utils/_plotting.py | 65 ++++---- 3 files changed, 118 insertions(+), 146 deletions(-) diff --git a/sklearn/metrics/_plot/roc_curve.py b/sklearn/metrics/_plot/roc_curve.py index c190b4ed08c51..ed846b1df11f6 100644 --- a/sklearn/metrics/_plot/roc_curve.py +++ b/sklearn/metrics/_plot/roc_curve.py @@ -1,10 +1,13 @@ # Authors: The scikit-learn developers # SPDX-License-Identifier: BSD-3-Clause +from collections.abc import Mapping + import numpy as np from ...utils import _safe_indexing from ...utils._plotting import ( + MULTICURVE_LABELLING_ERROR, _BinaryClassifierCurveDisplayMixin, _check_param_lengths, _convert_to_list_leaving_none, @@ -35,37 +38,36 @@ class RocCurveDisplay(_BinaryClassifierCurveDisplayMixin): False positive rates. Each ndarray should contain values for a single curve. If plotting multiple curves, list should be of same length as `tpr`. + .. versionchanged:: 1.7 + Now accepts a list for plotting multiple curves. + tpr : ndarray or list of ndarrays True positive rates. Each ndarray should contain values for a single curve. If plotting multiple curves, list should be of same length as `fpr`. + .. versionchanged:: 1.7 + Now accepts a list for plotting multiple curves. + roc_auc : float or list of floats, default=None Area under ROC curve, used for labeling each curve in the legend. If plotting multiple curves, should be a list of the same length as `fpr` - and `tpr`. If `None`, individual ROC AUC scores are not shown. See - `roc_auc_aggregate` for alternative. - If `name` and `roc_auc_aggregate` are also `None` no legend is added. + and `tpr`. If `None`, ROC AUC scores are not shown in the legend. - roc_auc_aggregate : tuple(float), default=None - ROC AUC mean and standard deviation. An alternative to `roc_auc` when - plotting multiple curves and a single legend entry showing ROC AUC mean and - standard deviation for all curves is desired. - If `True`, `name` cannot be a list of length >1. + .. versionchanged:: 1.7 + Now accepts a list for plotting multiple curves. name : str or list of str, default=None - (Do we prefer curve_name) ? - Name for labeling legend entries. For single ROC curve, should be a str or - list of length one. For multiple ROC curves: - - * if list of names provided, should be the same length as `fpr` - and `tpr`. Each individual curve will be labeled in the legend. Cannot - be used in conjunction with `roc_auc_aggregate`. - * if a single name provided (as str or list of length one), a single legend - entry will be used to label all curves. Cannot be used in conjunction with - `roc_auc`. + Name for labeling legend entries. For single ROC curve, should be a string. + For multiple ROC curves, the number of legend entries depend on + the `fold_line_kwargs` passed to `plot`. + To label each curve, provide a list of strings. To avoid labeling + individual curves that have the same appearance, this cannot be used in + conjunction with `fold_line_kwargs` being a list. If a string is + provided, either label the single legend entry or if there are + multiple legend entries, label each individual curve with the + same name. If `None`, no name is shown in the legend. - If `None`, no name is not shown in the legend. If `roc_auc` - and `roc_auc_aggregate` are also `None` no legend is added. + .. versionadded:: 1.7 pos_label : int, float, bool or str, default=None The class considered as the positive class when computing the roc auc @@ -142,42 +144,29 @@ def __init__( self.name = _deprecate_estimator_name(estimator_name, name, "1.7") self.pos_label = pos_label - def _validate_plot_params(self, *, ax=None, name=None): + def _validate_plot_params(self, *, ax=None, name=None, fold_line_kwargs=None): self.ax_, self.figure_, name_ = super()._validate_plot_params(ax=ax, name=name) - if self.roc_auc_aggregate: - if self.roc_auc is not None: - raise ValueError( - "'self.roc_auc' and 'self.roc_auc_aggregate' cannot both be " - "provided." - ) - if isinstance(name_, list) and len(name_) != 1: - raise ValueError( - "When 'roc_auc_aggregate' is True, 'name' (or self.name) " - "must be a string or a list of length one." - ) - self.fpr_ = _convert_to_list_leaving_none(self.fpr) self.tpr_ = _convert_to_list_leaving_none(self.tpr) self.roc_auc_ = _convert_to_list_leaving_none(self.roc_auc) self.name_ = _convert_to_list_leaving_none(name_) - optional = {"self.roc_auc": self.roc_auc_} - if self.name_ is not None and len(self.name_) != 1: - optional.update({"'name' (or self.name)": self.name_}) - _check_param_lengths( required={"self.fpr": self.fpr_, "self.tpr": self.tpr_}, - optional=optional, + optional={ + "self.roc_auc": self.roc_auc_, + "'name' (or self.name)": self.name_, + }, class_name="RocCurveDisplay", ) - if self.roc_auc: - if isinstance(name_, list) and len(name_) == 1: - raise ValueError( - "When 'roc_auc' is provided, 'name' (or self.name) " - f"must be None or a list of length {len(self.fpr_)}." - ) + if ( + isinstance(self.name_, list) + and len(self.name_) != 1 + and (isinstance(fold_line_kwargs, Mapping) or fold_line_kwargs is None) + ): + raise ValueError(MULTICURVE_LABELLING_ERROR.format(n_curves=len(self.fpr_))) def plot( self, @@ -202,9 +191,14 @@ def plot( created. name : str or list of str, default=None - Name of each ROC curve, used for labeling curves in the legend. - If `None`, use `name` provided at `RocCurveDisplay` initialization. If - also not provided at initialization, no name is shown in the legend. + Name for labeling legend entries. + To label each curve, provide a list of strings. To avoid labeling + individual curves that have the same appearance, this cannot be used in + conjunction with `fold_line_kwargs` being a list. If a string is + provided, either label the single legend entry or if there are + multiple legend entries, label each individual curve with the + same name. If `None`, set to `name` provided at `RocCurveDisplay` + initialization. If still `None`, no name is shown in the legend. .. versionadded:: 1.7 @@ -225,12 +219,13 @@ def plot( .. versionadded:: 1.6 fold_line_kwargs : dict or list of dict, default=None - Dictionary with keywords passed to the matplotlib's `plot` function - to draw the individual ROC curves. If a list is provided, the - parameters are applied to the ROC curves sequentially. If a single - dictionary is provided, the same parameters are applied to all ROC - curves. Ignored for single curve plots - pass as `**kwargs` for - single curve plots. + Keywords arguments to be passed to matplotlib's `plot` function + to draw individual ROC curves. If a list is provided the + parameters are applied to the ROC curves of each CV fold + sequentially and a legend entry is added for each curve. + If a single dictionary is provided, the same + parameters are applied to all ROC curves and a single legend + entry for all curves is added, labeled with the mean ROC AUC score. .. versionadded:: 1.7 @@ -246,12 +241,12 @@ def plot( """ self._validate_plot_params(ax=ax, name=name) summary_value, summary_value_name = self.roc_auc_, "AUC" - if self.roc_auc_aggregate: - summary_value, summary_value_name = self.roc_auc_aggregate, "AUC" - elif self.roc_auc: - summary_value, summary_value_name = self.roc_auc_, "AUC" - else: - summary_value, summary_value_name = None, None + if ( + self.roc_auc_ + and isinstance(fold_line_kwargs, list) + and len(fold_line_kwargs) != 1 + ): + summary_value = (np.mean(self.roc_auc_), np.std(self.roc_auc_)) n_curves = len(self.fpr_) line_kwargs = self._get_line_kwargs( @@ -585,7 +580,6 @@ def from_cv_results( ax=None, name=None, fold_line_kwargs=None, - show_aggregate_score=True, plot_chance_level=False, chance_level_kwargs=None, despine=False, @@ -631,26 +625,24 @@ def from_cv_results( Axes object to plot on. If `None`, a new figure and axes is created. - name : list of str or str, default=None - Name for labeling legend entries. To label each individual curve, - provide a list of names the same length as the number of cross-validation - folds. In this case `show_aggregate_score` cannot be `True`. - To label all curves using a single legend entry, provide a str - or list of length one. If `None`, no name is shown in the legend. + name : str or list of str, default=None + Name for labeling legend entries. The number of legend entries + depends on `fold_line_kwargs`. + To label each curve, provide a list of strings. To avoid labeling + individual curves that have the same appearance, this cannot be used in + conjunction with `fold_line_kwargs` being a list. If a string is + provided, either label the single legend entry or if there are + multiple legend entries, label each individual curve with the + same name. If `None`, no name is shown in the legend. fold_line_kwargs : dict or list of dict, default=None Keywords arguments to be passed to matplotlib's `plot` function - to draw individual ROC curves. If a list is provided, the + to draw individual ROC curves. If a list is provided the parameters are applied to the ROC curves of each CV fold - sequentially. If a single dictionary is provided, the same - parameters are applied to all ROC curves. - - show_aggregate_score : bool, default=True - Whether to show the ROC AUC mean and standard deviation of curves from - all folds as a single legend entry. If `True`, `name` should be a single - string and `fold_line_kwargs` should be a single dictionary, to prevent - confusion in the legend. If `False`, `name` should be None or a list the - same length as the number of cross-validation folds. + sequentially and a legend entry is added for each curve. + If a single dictionary is provided, the same + parameters are applied to all ROC curves and a single legend + entry for all curves is added, labeled with the mean ROC AUC score. plot_chance_level : bool, default=False Whether to plot the chance level. @@ -691,7 +683,7 @@ def from_cv_results( <...> >>> plt.show() """ - pos_label, name_ = cls._validate_from_cv_results_params( + pos_label = cls._validate_from_cv_results_params( cv_results, X, y, @@ -699,7 +691,6 @@ def from_cv_results( pos_label=pos_label, name=name, fold_line_kwargs=fold_line_kwargs, - show_aggregate_score=show_aggregate_score, ) fold_line_kwargs_ = cls._validate_line_kwargs( @@ -739,17 +730,11 @@ def from_cv_results( tpr_all.append(tpr) auc_all.append(roc_auc) - roc_auc_aggregate = None - if show_aggregate_score: - roc_auc_aggregate = (np.mean(auc_all), np.std(auc_all)) - auc_all = None - viz = cls( fpr=fpr_all, tpr=tpr_all, - name=name_, + name=name, roc_auc=auc_all, - roc_auc_aggregate=roc_auc_aggregate, pos_label=pos_label, ) return viz.plot( diff --git a/sklearn/metrics/_plot/tests/test_roc_curve_display.py b/sklearn/metrics/_plot/tests/test_roc_curve_display.py index d42b8368b6517..3062f616af0b4 100644 --- a/sklearn/metrics/_plot/tests/test_roc_curve_display.py +++ b/sklearn/metrics/_plot/tests/test_roc_curve_display.py @@ -244,7 +244,10 @@ def test_roc_curve_display_estimator_name_deprecation(pyplot): RocCurveDisplay(fpr=fpr, tpr=tpr, estimator_name="test") -@pytest.mark.parametrize("show_aggregate_score", [True, False]) +@pytest.mark.parametrize( + "fold_line_kwargs", + [None, [{"color": "blue"}, {"color": "green"}, {"color": "red"}]], +) @pytest.mark.parametrize("drop_intermediate", [True, False]) @pytest.mark.parametrize("response_method", ["predict_proba", "decision_function"]) @pytest.mark.parametrize("with_sample_weight", [True, False]) @@ -256,7 +259,7 @@ def test_roc_curve_display_plotting_from_cv_results( with_sample_weight, response_method, drop_intermediate, - show_aggregate_score, + fold_line_kwargs, ): """Check overall plotting of `from_cv_results`.""" X, y = data_binary @@ -283,10 +286,9 @@ def test_roc_curve_display_plotting_from_cv_results( drop_intermediate=drop_intermediate, response_method=response_method, pos_label=pos_label, - show_aggregate_score=show_aggregate_score, + fold_line_kwargs=fold_line_kwargs, ) - auc_all = [] for idx, (estimator, test_indices) in enumerate( zip(cv_results["estimator"], cv_results["indices"]["test"]) ): @@ -309,20 +311,11 @@ def test_roc_curve_display_plotting_from_cv_results( drop_intermediate=drop_intermediate, pos_label=pos_label, ) - if show_aggregate_score: - auc_all.append(auc(fpr, tpr)) - else: - assert_allclose(display.roc_auc_[idx], auc(fpr, tpr)) - assert_allclose(display.fpr_[idx], fpr) - assert_allclose(display.tpr_[idx], tpr) - - if show_aggregate_score: - mean, std = np.mean(auc_all), np.std(auc_all) - assert (mean, std) == pytest.approx(display.roc_auc_aggregate) - assert display.name_ is None - else: - fold_names = ["Fold 0", "Fold 1", "Fold 2"] - assert display.name_ == fold_names + assert_allclose(display.roc_auc_[idx], auc(fpr, tpr)) + assert_allclose(display.fpr_[idx], fpr) + assert_allclose(display.tpr_[idx], tpr) + + assert display.name_ is None import matplotlib as mpl @@ -332,11 +325,13 @@ def test_roc_curve_display_plotting_from_cv_results( assert isinstance(line, mpl.lines.Line2D) # Default alpha for `from_cv_results` line.get_alpha() == 0.5 - if show_aggregate_score: - assert line.get_label() == aggregate_expected_labels[idx] + if fold_line_kwargs is None: + print(line.get_label()) + # assert line.get_label() == aggregate_expected_labels[idx] else: - expected_label = f"{fold_names[idx]} (AUC = {display.roc_auc_[idx]:.2f})" - assert line.get_label() == expected_label + # expected_label = f"AUC = {display.roc_auc_[idx]:.2f}" + # assert line.get_label() == expected_label + print(line.get_label()) # @pytest.mark.parametrize("fold_names", [None, ["one", "two", "three"]]) @@ -604,10 +599,11 @@ def test_roc_curve_display_default_labels(pyplot, roc_auc, name, expected_labels def _check_auc(display, constructor_name): roc_auc_limit = 0.95679 - roc_auc_mean_std = (0.978979, 0.006617449) + roc_auc_limit_multi = [0.97007, 0.985915, 0.980952] if constructor_name == "from_cv_results": - assert display.roc_auc_aggregate == pytest.approx(roc_auc_mean_std) + for idx, roc_auc in enumerate(display.roc_auc_): + assert roc_auc == pytest.approx(roc_auc_limit_multi[idx]) else: assert display.roc_auc == pytest.approx(roc_auc_limit) assert trapezoid(display.tpr, display.fpr) == pytest.approx(roc_auc_limit) diff --git a/sklearn/utils/_plotting.py b/sklearn/utils/_plotting.py index ef5594de7d14f..99f3352ef005f 100644 --- a/sklearn/utils/_plotting.py +++ b/sklearn/utils/_plotting.py @@ -12,9 +12,11 @@ from .multiclass import type_of_target from .validation import _check_pos_label_consistency, _num_samples -AGGREGATE_ERROR_MESSAGE = ( - "'fold_line_kwargs' must be a single dictionary to be applied to all curves " - "when {param_value}, as only one legend entry will be added." +MULTICURVE_LABELLING_ERROR = ( + "To avoid labeling individual curves that have the same appearance, " + "`fold_line_kwargs` should be a list of {n_curves} dictionaries. Alternatively, " + "set `name` to `None` or a single string to add a single legend entry with mean " + "ROC AUC score of all curves." ) @@ -85,7 +87,6 @@ def _validate_from_cv_results_params( pos_label, name, fold_line_kwargs, - show_aggregate_score, ): check_matplotlib_support(f"{cls.__name__}.from_predictions") @@ -122,30 +123,19 @@ def _validate_from_cv_results_params( raise ValueError(str(e).replace("y_true", "y")) n_curves = len(cv_results["estimator"]) - if show_aggregate_score: - if isinstance(name, list) and len(name) != 1: - raise ValueError( - "'name' must be a string or list of length one when " - "'show_aggregate_score' is True, as only one legend entry " - "will be added." - ) - # Should we allow a list of length 1 (in addition to single dict) ?? - if isinstance(fold_line_kwargs, list): - raise ValueError( - AGGREGATE_ERROR_MESSAGE.format( - param_value="'show_aggregate_score' is True" - ) - ) + if ( + isinstance(name, list) + and len(name) != 1 + and (isinstance(fold_line_kwargs, Mapping) or fold_line_kwargs is None) + ): + raise ValueError(MULTICURVE_LABELLING_ERROR.format(n_curves=n_curves)) else: - # Individual ROC AUC scores shown - if isinstance(name, list) and len(name) != n_curves: + if isinstance(name, list) and len(name) not in (1, n_curves): raise ValueError( - f"'name' must be None or list of length {n_curves} when " - f"'show_aggregate_score' is False. Got list of length {len(name)}" + f"`name` must be a list of length {n_curves} or a string. " + f"Got list of length: {len(name)}." ) - if name is None: - name = [f"Fold {idx}" for idx in range(n_curves)] - return pos_label, name + return pos_label @classmethod def _get_legend_label(cls, curve_summary_value, curve_name, summary_value_name): @@ -211,10 +201,6 @@ def _get_line_kwargs( if n_curves == 1: fold_line_kwargs = [kwargs] else: - # Should we add an extra check ensuring `fold_line_kwargs` is the - # same when `summary_value` is (mean, std) ? or allow this flexibility - # when using `plot` - note this is checked and prevented in - # `from_cv_results` fold_line_kwargs = cls._validate_line_kwargs(n_curves, fold_line_kwargs) if default_line_kwargs is None: @@ -226,6 +212,7 @@ def _get_line_kwargs( summary_value_[0], name_[0], summary_value_name ) label_aggregate = label_aggregate + f" +/- {summary_value_[1]:0.2f}" + # Add `label` for first curve only, set to `None` for remaining curves labels.extend([label_aggregate] + [None] * (n_curves - 1)) else: for curve_summary_value, curve_name in zip(summary_value_, name_): @@ -244,14 +231,13 @@ def _get_line_kwargs( return line_kwargs # TODO: if useful for non binary displays (e.g.,`LearningCurveDisplay`, - # `ValidationCurveDisplay`) amend to function + # `ValidationCurveDisplay`) change to function @classmethod def _validate_line_kwargs( cls, n_curves, fold_line_kwargs=None, default_line_kwargs=None, - aggregate_error=False, ): """Ensure `fold_line_kwargs` length and incorporate default kwargs. @@ -294,10 +280,11 @@ def _validate_line_kwargs( if isinstance(fold_line_kwargs, Mapping): fold_line_kwargs = [fold_line_kwargs] * n_curves + # Should we allow list with single dict? elif len(fold_line_kwargs) != n_curves: raise ValueError( f"'fold_line_kwargs' must be a list of length {n_curves} or a " - f"single dictionary. Got list of length: {len(fold_line_kwargs)}." + f"dictionary. Got list of length: {len(fold_line_kwargs)}." ) else: fold_line_kwargs = fold_line_kwargs @@ -465,13 +452,17 @@ def _check_param_lengths(required, optional, class_name): all_params = {**required, **optional_provided} if len({len(param) for param in all_params.values()}) > 1: - required_formatted = ", ".join(f"'{key}'" for key in required.keys()) - optional_formatted = ", ".join(f"'{key}'" for key in optional_provided.keys()) + param_keys = [key for key in all_params.keys()] + # Note: below code requires `len(param_keys) >= 2`, which is the case for all + # display classes + params_formatted = " and ".join([", ".join(param_keys[:-1]), param_keys[-1]]) + or_plot = "" + if "'name' (or self.name)" in param_keys: + or_plot = " (or `plot`)" lengths_formatted = ", ".join( f"{key}: {len(value)}" for key, value in all_params.items() ) raise ValueError( - f"{required_formatted}, and optional parameters {optional_formatted} " - f"from `{class_name}` initialization (or `plot`) should all be lists of " - f"the same length. Got: {lengths_formatted}" + f"{params_formatted} from `{class_name}` initialization{or_plot}, " + f"should all be lists of the same length. Got: {lengths_formatted}" ) From 81e00d93d740582ad96dc9ab2c7ced5ce0729cbb Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Fri, 4 Apr 2025 11:29:24 +1100 Subject: [PATCH 40/63] iter remove aggre param --- sklearn/metrics/_plot/roc_curve.py | 26 ++++++++++++++++++++------ sklearn/utils/_plotting.py | 12 +++--------- 2 files changed, 23 insertions(+), 15 deletions(-) diff --git a/sklearn/metrics/_plot/roc_curve.py b/sklearn/metrics/_plot/roc_curve.py index ed846b1df11f6..f9de9ec8a348e 100644 --- a/sklearn/metrics/_plot/roc_curve.py +++ b/sklearn/metrics/_plot/roc_curve.py @@ -693,11 +693,25 @@ def from_cv_results( fold_line_kwargs=fold_line_kwargs, ) - fold_line_kwargs_ = cls._validate_line_kwargs( - len(cv_results["estimator"]), - fold_line_kwargs, - default_line_kwargs={"alpha": 0.5, "linestyle": "--"}, - ) + n_curves = len(cv_results["estimator"]) + default_curve_kwargs = {"alpha": 0.5, "linestyle": "--"} + if fold_line_kwargs is None: + fold_line_kwargs = default_curve_kwargs + elif isinstance(fold_line_kwargs, Mapping): + fold_line_kwargs = _validate_style_kwargs( + default_curve_kwargs, fold_line_kwargs + ) + elif isinstance(fold_line_kwargs, list): + if len(fold_line_kwargs) != n_curves: + raise ValueError( + f"'fold_line_kwargs' must be a list of length {n_curves} or a " + f"dictionary. Got list of length: {len(fold_line_kwargs)}." + ) + else: + fold_line_kwargs = [ + _validate_style_kwargs(default_curve_kwargs, single_kwargs) + for single_kwargs in fold_line_kwargs + ] fpr_all = [] tpr_all = [] @@ -742,5 +756,5 @@ def from_cv_results( plot_chance_level=plot_chance_level, chance_level_kw=chance_level_kwargs, despine=despine, - fold_line_kwargs=fold_line_kwargs_, + fold_line_kwargs=fold_line_kwargs, ) diff --git a/sklearn/utils/_plotting.py b/sklearn/utils/_plotting.py index 99f3352ef005f..071659b181ae7 100644 --- a/sklearn/utils/_plotting.py +++ b/sklearn/utils/_plotting.py @@ -184,10 +184,6 @@ def _get_line_kwargs( curves. Ignored for single curve plots - pass as `**kwargs` for single curve plots. - default_line_kwargs : dict, default=None - Default line kwargs to be used in all curves, unless overridden by - `fold_line_kwargs`. - **kwargs : dict For a single curve plots only, keyword arguments to be passed to matplotlib's `plot`. Ignored for multi-curve plots - use `fold_line_kwargs` @@ -203,9 +199,6 @@ def _get_line_kwargs( else: fold_line_kwargs = cls._validate_line_kwargs(n_curves, fold_line_kwargs) - if default_line_kwargs is None: - default_line_kwargs = {} - labels = [] if isinstance(summary_value_, tuple): label_aggregate = cls._get_legend_label( @@ -223,10 +216,11 @@ def _get_line_kwargs( ) line_kwargs = [] + label_kwarg = {} for fold_idx, label in enumerate(labels): - default_line_kwargs["label"] = label + label_kwarg["label"] = label line_kwargs.append( - _validate_style_kwargs(default_line_kwargs, fold_line_kwargs[fold_idx]) + _validate_style_kwargs(label_kwarg, fold_line_kwargs[fold_idx]) ) return line_kwargs From 9ee37e2e9b5edf0c3662484c81750c700b49620a Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Fri, 4 Apr 2025 16:58:49 +1100 Subject: [PATCH 41/63] remove agg, tests pass --- sklearn/metrics/_plot/roc_curve.py | 66 ++++----- .../_plot/tests/test_roc_curve_display.py | 66 +++++---- sklearn/utils/_plotting.py | 129 ++++++++---------- 3 files changed, 120 insertions(+), 141 deletions(-) diff --git a/sklearn/metrics/_plot/roc_curve.py b/sklearn/metrics/_plot/roc_curve.py index f9de9ec8a348e..30c01033b2986 100644 --- a/sklearn/metrics/_plot/roc_curve.py +++ b/sklearn/metrics/_plot/roc_curve.py @@ -63,8 +63,8 @@ class RocCurveDisplay(_BinaryClassifierCurveDisplayMixin): To label each curve, provide a list of strings. To avoid labeling individual curves that have the same appearance, this cannot be used in conjunction with `fold_line_kwargs` being a list. If a string is - provided, either label the single legend entry or if there are - multiple legend entries, label each individual curve with the + provided, it will be used to either label the single legend entry or if + there are multiple legend entries, label each individual curve with the same name. If `None`, no name is shown in the legend. .. versionadded:: 1.7 @@ -144,7 +144,7 @@ def __init__( self.name = _deprecate_estimator_name(estimator_name, name, "1.7") self.pos_label = pos_label - def _validate_plot_params(self, *, ax=None, name=None, fold_line_kwargs=None): + def _validate_plot_params(self, *, ax, name, fold_line_kwargs): self.ax_, self.figure_, name_ = super()._validate_plot_params(ax=ax, name=name) self.fpr_ = _convert_to_list_leaving_none(self.fpr) @@ -152,12 +152,12 @@ def _validate_plot_params(self, *, ax=None, name=None, fold_line_kwargs=None): self.roc_auc_ = _convert_to_list_leaving_none(self.roc_auc) self.name_ = _convert_to_list_leaving_none(name_) + optional = {"self.roc_auc": self.roc_auc_} + if isinstance(self.name_, list) and len(self.name_) != 1: + optional.update({"'name' (or self.name)": self.name_}) _check_param_lengths( required={"self.fpr": self.fpr_, "self.tpr": self.tpr_}, - optional={ - "self.roc_auc": self.roc_auc_, - "'name' (or self.name)": self.name_, - }, + optional=optional, class_name="RocCurveDisplay", ) @@ -195,8 +195,8 @@ def plot( To label each curve, provide a list of strings. To avoid labeling individual curves that have the same appearance, this cannot be used in conjunction with `fold_line_kwargs` being a list. If a string is - provided, either label the single legend entry or if there are - multiple legend entries, label each individual curve with the + provided, it will be used to either label the single legend entry or + if there are multiple legend entries, label each individual curve with the same name. If `None`, set to `name` provided at `RocCurveDisplay` initialization. If still `None`, no name is shown in the legend. @@ -239,14 +239,14 @@ def plot( display : :class:`~sklearn.metrics.RocCurveDisplay` Object that stores computed values. """ - self._validate_plot_params(ax=ax, name=name) + self._validate_plot_params(ax=ax, name=name, fold_line_kwargs=fold_line_kwargs) + n_curves = len(self.fpr_) summary_value, summary_value_name = self.roc_auc_, "AUC" - if ( - self.roc_auc_ - and isinstance(fold_line_kwargs, list) - and len(fold_line_kwargs) != 1 - ): - summary_value = (np.mean(self.roc_auc_), np.std(self.roc_auc_)) + if not isinstance(fold_line_kwargs, list) and n_curves > 1: + if self.roc_auc_: + summary_value = (np.mean(self.roc_auc_), np.std(self.roc_auc_)) + else: + summary_value = (None, None) n_curves = len(self.fpr_) line_kwargs = self._get_line_kwargs( @@ -631,8 +631,8 @@ def from_cv_results( To label each curve, provide a list of strings. To avoid labeling individual curves that have the same appearance, this cannot be used in conjunction with `fold_line_kwargs` being a list. If a string is - provided, either label the single legend entry or if there are - multiple legend entries, label each individual curve with the + provided, it will be used to either label the single legend entry or + if there are multiple legend entries, label each individual curve with the same name. If `None`, no name is shown in the legend. fold_line_kwargs : dict or list of dict, default=None @@ -683,7 +683,7 @@ def from_cv_results( <...> >>> plt.show() """ - pos_label = cls._validate_from_cv_results_params( + pos_label_, fold_line_kwargs_ = cls._validate_from_cv_results_params( cv_results, X, y, @@ -693,26 +693,6 @@ def from_cv_results( fold_line_kwargs=fold_line_kwargs, ) - n_curves = len(cv_results["estimator"]) - default_curve_kwargs = {"alpha": 0.5, "linestyle": "--"} - if fold_line_kwargs is None: - fold_line_kwargs = default_curve_kwargs - elif isinstance(fold_line_kwargs, Mapping): - fold_line_kwargs = _validate_style_kwargs( - default_curve_kwargs, fold_line_kwargs - ) - elif isinstance(fold_line_kwargs, list): - if len(fold_line_kwargs) != n_curves: - raise ValueError( - f"'fold_line_kwargs' must be a list of length {n_curves} or a " - f"dictionary. Got list of length: {len(fold_line_kwargs)}." - ) - else: - fold_line_kwargs = [ - _validate_style_kwargs(default_curve_kwargs, single_kwargs) - for single_kwargs in fold_line_kwargs - ] - fpr_all = [] tpr_all = [] auc_all = [] @@ -724,7 +704,7 @@ def from_cv_results( estimator, _safe_indexing(X, test_indices), response_method=response_method, - pos_label=pos_label, + pos_label=pos_label_, )[0] sample_weight_fold = ( None @@ -734,7 +714,7 @@ def from_cv_results( fpr, tpr, _ = roc_curve( y_true, y_pred, - pos_label=pos_label, + pos_label=pos_label_, sample_weight=sample_weight_fold, drop_intermediate=drop_intermediate, ) @@ -749,12 +729,12 @@ def from_cv_results( tpr=tpr_all, name=name, roc_auc=auc_all, - pos_label=pos_label, + pos_label=pos_label_, ) return viz.plot( ax=ax, plot_chance_level=plot_chance_level, chance_level_kw=chance_level_kwargs, despine=despine, - fold_line_kwargs=fold_line_kwargs, + fold_line_kwargs=fold_line_kwargs_, ) diff --git a/sklearn/metrics/_plot/tests/test_roc_curve_display.py b/sklearn/metrics/_plot/tests/test_roc_curve_display.py index 3062f616af0b4..6a13421f67e4e 100644 --- a/sklearn/metrics/_plot/tests/test_roc_curve_display.py +++ b/sklearn/metrics/_plot/tests/test_roc_curve_display.py @@ -187,19 +187,29 @@ def test_roc_curve_from_cv_results_param_validation(pyplot, data_binary, data): with pytest.raises(ValueError, match=r"y takes value in \{1, 2\}"): RocCurveDisplay.from_cv_results(cv_results, X_bad_pos_label, y_bad_pos_label) - # `fold_names` incorrect length - with pytest.raises(ValueError, match="'name' must be None or list of length"): - RocCurveDisplay.from_cv_results( - cv_results, X, y, name=["fold"], show_aggregate_score=False - ) + # `name` is list while `fold_line_kwargs` is None or dict + for fold_line_kwargs in (None, {"alpha": 0.2}): + with pytest.raises(ValueError, match="To avoid labeling individual curves"): + RocCurveDisplay.from_cv_results( + cv_results, + X, + y, + name=["one", "two", "three"], + fold_line_kwargs=fold_line_kwargs, + ) + # `fold_line_kwargs` incorrect length - with pytest.raises( - ValueError, match="'fold_line_kwargs' must be a single dictionary to" - ): + with pytest.raises(ValueError, match="`fold_line_kwargs` must be None, a list"): RocCurveDisplay.from_cv_results( cv_results, X, y, fold_line_kwargs=[{"alpha": 1}] ) + # `fold_line_kwargs` both alias provided + with pytest.raises(TypeError, match="Got both c and"): + RocCurveDisplay.from_cv_results( + cv_results, X, y, fold_line_kwargs={"c": "blue", "color": "red"} + ) + # @pytest.mark.parametrize( # "fold_line_kwargs", @@ -246,7 +256,11 @@ def test_roc_curve_display_estimator_name_deprecation(pyplot): @pytest.mark.parametrize( "fold_line_kwargs", - [None, [{"color": "blue"}, {"color": "green"}, {"color": "red"}]], + [ + None, + {"color": "blue"}, + [{"color": "blue"}, {"color": "green"}, {"color": "red"}], + ], ) @pytest.mark.parametrize("drop_intermediate", [True, False]) @pytest.mark.parametrize("response_method", ["predict_proba", "decision_function"]) @@ -325,13 +339,12 @@ def test_roc_curve_display_plotting_from_cv_results( assert isinstance(line, mpl.lines.Line2D) # Default alpha for `from_cv_results` line.get_alpha() == 0.5 - if fold_line_kwargs is None: - print(line.get_label()) - # assert line.get_label() == aggregate_expected_labels[idx] + if isinstance(fold_line_kwargs, list): + # Each individual curve labelled + assert line.get_label() == f"AUC = {display.roc_auc_[idx]:.2f}" else: - # expected_label = f"AUC = {display.roc_auc_[idx]:.2f}" - # assert line.get_label() == expected_label - print(line.get_label()) + # Single aggregate label + assert line.get_label() == aggregate_expected_labels[idx] # @pytest.mark.parametrize("fold_names", [None, ["one", "two", "three"]]) @@ -365,7 +378,7 @@ def test_roc_curve_from_cv_results_line_kwargs(pyplot, data_binary, fold_line_kw LogisticRegression(), X, y, cv=3, return_estimator=True, return_indices=True ) display = RocCurveDisplay.from_cv_results( - cv_results, X, y, fold_line_kwargs=fold_line_kwargs, show_aggregate_score=False + cv_results, X, y, fold_line_kwargs=fold_line_kwargs ) mpl_default_colors = mpl.rcParams["axes.prop_cycle"].by_key()["color"] @@ -576,23 +589,28 @@ def test_roc_curve_display_complex_pipeline(pyplot, data_binary, clf, constructo @pytest.mark.parametrize( - "roc_auc, name, expected_labels", + "roc_auc, name, fold_line_kwargs, expected_labels", [ - ([0.9, 0.8], None, ["AUC = 0.90", "AUC = 0.80"]), - ([0.8, 0.7], [None, None], ["AUC = 0.80", "AUC = 0.70"]), - (None, ["fold1", "fold2"], ["fold1", "fold2"]), + ([0.9, 0.8], None, None, ["AUC = 0.85 +/- 0.05", "_child1"]), + ([0.9, 0.8], "Est name", None, ["Est name (AUC = 0.85 +/- 0.05)", "_child1"]), ( [0.8, 0.7], - ["my_est2", "my_est2"], - ["my_est2 (AUC = 0.80)", "my_est2 (AUC = 0.70)"], + ["fold1", "fold2"], + [{"c": "blue"}, {"c": "red"}], + ["fold1 (AUC = 0.80)", "fold2 (AUC = 0.70)"], ), + (None, ["fold1", "fold2"], [{"c": "blue"}, {"c": "red"}], ["fold1", "fold2"]), ], ) -def test_roc_curve_display_default_labels(pyplot, roc_auc, name, expected_labels): +def test_roc_curve_display_default_labels( + pyplot, roc_auc, name, fold_line_kwargs, expected_labels +): """Check the default labels used in the display.""" fpr = [np.array([0, 0.5, 1]), np.array([0, 0.3, 1])] tpr = [np.array([0, 0.5, 1]), np.array([0, 0.3, 1])] - disp = RocCurveDisplay(fpr=fpr, tpr=tpr, roc_auc=roc_auc, name=name).plot() + disp = RocCurveDisplay(fpr=fpr, tpr=tpr, roc_auc=roc_auc, name=name).plot( + fold_line_kwargs=fold_line_kwargs + ) for idx, expected_label in enumerate(expected_labels): assert disp.line_[idx].get_label() == expected_label diff --git a/sklearn/utils/_plotting.py b/sklearn/utils/_plotting.py index 071659b181ae7..5ecba178b465f 100644 --- a/sklearn/utils/_plotting.py +++ b/sklearn/utils/_plotting.py @@ -19,6 +19,10 @@ "ROC AUC score of all curves." ) +CURVE_KWARGS_ERROR = ( + "`fold_line_kwargs` must be None, a list of length {n_curves} or a dictionary." +) + class _BinaryClassifierCurveDisplayMixin: """Mixin class to be used in Displays requiring a binary classifier. @@ -119,7 +123,7 @@ def _validate_from_cv_results_params( try: pos_label = _check_pos_label_consistency(pos_label, y) except ValueError as e: - # Alter error message + # Adapt error message raise ValueError(str(e).replace("y_true", "y")) n_curves = len(cv_results["estimator"]) @@ -135,7 +139,33 @@ def _validate_from_cv_results_params( f"`name` must be a list of length {n_curves} or a string. " f"Got list of length: {len(name)}." ) - return pos_label + + # Add `default_curve_kwargs` to `fold_line_kwargs` + default_curve_kwargs = {"alpha": 0.5, "linestyle": "--"} + if fold_line_kwargs is None: + fold_line_kwargs_ = default_curve_kwargs + elif isinstance(fold_line_kwargs, Mapping): + fold_line_kwargs_ = _validate_style_kwargs( + default_curve_kwargs, fold_line_kwargs + ) + elif isinstance(fold_line_kwargs, list): + if len(fold_line_kwargs) != n_curves: + raise ValueError( + CURVE_KWARGS_ERROR.format(n_curves=n_curves) + + f" Got list of length: {len(fold_line_kwargs)}." + ) + else: + fold_line_kwargs_ = [ + _validate_style_kwargs(default_curve_kwargs, single_kwargs) + for single_kwargs in fold_line_kwargs + ] + else: + raise ValueError( + CURVE_KWARGS_ERROR.format(n_curves=n_curves) + + f" Got: {fold_line_kwargs}." + ) + + return pos_label, fold_line_kwargs_ @classmethod def _get_legend_label(cls, curve_summary_value, curve_name, summary_value_name): @@ -146,6 +176,8 @@ def _get_legend_label(cls, curve_summary_value, curve_name, summary_value_name): label = f"{summary_value_name} = {curve_summary_value:0.2f}" elif curve_name is not None: label = curve_name + else: + label = None return label @classmethod @@ -156,7 +188,6 @@ def _get_line_kwargs( summary_value, summary_value_name, fold_line_kwargs, - default_line_kwargs=None, **kwargs, ): """Get validated line kwargs for each curve. @@ -190,6 +221,8 @@ def _get_line_kwargs( for multi-curve plots. """ # Ensure parameters are of the correct length + if isinstance(name, list) and len(name) == 1: + name_ = name * n_curves name_ = [None] * n_curves if name is None else name summary_value_ = [None] * n_curves if summary_value is None else summary_value # `fold_line_kwargs` ignored for single curve plots @@ -197,14 +230,30 @@ def _get_line_kwargs( if n_curves == 1: fold_line_kwargs = [kwargs] else: - fold_line_kwargs = cls._validate_line_kwargs(n_curves, fold_line_kwargs) + # Ensure `fold_line_kwargs` is of correct length + if fold_line_kwargs is None: + fold_line_kwargs = [{}] * n_curves + elif isinstance(fold_line_kwargs, Mapping): + fold_line_kwargs = [fold_line_kwargs] * n_curves + elif len(fold_line_kwargs) != n_curves: + raise ValueError( + CURVE_KWARGS_ERROR.format(n_curves=n_curves) + + f" Got list of length: {len(fold_line_kwargs)}." + ) labels = [] if isinstance(summary_value_, tuple): label_aggregate = cls._get_legend_label( summary_value_[0], name_[0], summary_value_name ) - label_aggregate = label_aggregate + f" +/- {summary_value_[1]:0.2f}" + # Add the "+/- std" to the end (in brackets if name provided) + if summary_value_[1] is not None: + if name_[0] is not None: + label_aggregate = ( + label_aggregate[:-1] + f" +/- {summary_value_[1]:0.2f})" + ) + else: + label_aggregate = label_aggregate + f" +/- {summary_value_[1]:0.2f}" # Add `label` for first curve only, set to `None` for remaining curves labels.extend([label_aggregate] + [None] * (n_curves - 1)) else: @@ -216,81 +265,13 @@ def _get_line_kwargs( ) line_kwargs = [] - label_kwarg = {} for fold_idx, label in enumerate(labels): - label_kwarg["label"] = label + label_kwarg = {"label": label} line_kwargs.append( _validate_style_kwargs(label_kwarg, fold_line_kwargs[fold_idx]) ) return line_kwargs - # TODO: if useful for non binary displays (e.g.,`LearningCurveDisplay`, - # `ValidationCurveDisplay`) change to function - @classmethod - def _validate_line_kwargs( - cls, - n_curves, - fold_line_kwargs=None, - default_line_kwargs=None, - ): - """Ensure `fold_line_kwargs` length and incorporate default kwargs. - - * If `fold_line_kwargs` is None: - * If `default_line_kwargs` is None, list of `n_curves` empty dictionaries - is returned. - * If `default_line_kwargs` is not None, list of `n_curves` dictionaries - of `default_line_kwargs` returned. - * If `fold_line_kwargs` is a single dictionary, it is incorporated with - `default_line_kwargs` using `_validate_style_kwargs`, and the resulting - dictionary is repeated `n_curves` times and returned. - * If `fold_line_kwargs` is a list of length `n_curves`, each dict is - incorporated with `default_line_kwargs` using `_validate_style_kwargs` and - returned as list of `n_curves` dictionaries. - - If `fold_line_kwargs` is a list not of length `n_curves`, an error is raised. - - Parameters - ---------- - n_curves : int - Number of curves. - - fold_line_kwargs : dict or list of dict, default=None - Keywords arguments to be passed to matplotlib's `plot` function - to draw ROC curves. - - default_line_kwargs : dict, default=None - Default line kwargs to be used in all curves, unless overridden by - `fold_line_kwargs`. - - Returns - ------- - fold_line_kwargs : list of dict - List of `n_curves` dictionaries. - """ - if fold_line_kwargs is None and default_line_kwargs is None: - fold_line_kwargs = [{}] * n_curves - elif fold_line_kwargs is None and default_line_kwargs is not None: - fold_line_kwargs = default_line_kwargs - - if isinstance(fold_line_kwargs, Mapping): - fold_line_kwargs = [fold_line_kwargs] * n_curves - # Should we allow list with single dict? - elif len(fold_line_kwargs) != n_curves: - raise ValueError( - f"'fold_line_kwargs' must be a list of length {n_curves} or a " - f"dictionary. Got list of length: {len(fold_line_kwargs)}." - ) - else: - fold_line_kwargs = fold_line_kwargs - - if default_line_kwargs is not None: - fold_line_kwargs = [ - _validate_style_kwargs(default_line_kwargs, single_kwargs) - for single_kwargs in fold_line_kwargs - ] - - return fold_line_kwargs - def _validate_score_name(score_name, scoring, negate_score): """Validate the `score_name` parameter. From 5577b8b0ddfef8291874b163f2d14353f58d08d9 Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Wed, 9 Apr 2025 10:52:00 +1000 Subject: [PATCH 42/63] rm old param --- sklearn/metrics/_plot/roc_curve.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/sklearn/metrics/_plot/roc_curve.py b/sklearn/metrics/_plot/roc_curve.py index 30c01033b2986..3bff54472347c 100644 --- a/sklearn/metrics/_plot/roc_curve.py +++ b/sklearn/metrics/_plot/roc_curve.py @@ -132,7 +132,6 @@ def __init__( fpr, tpr, roc_auc=None, - roc_auc_aggregate=None, name=None, pos_label=None, estimator_name="deprecated", @@ -140,7 +139,6 @@ def __init__( self.fpr = fpr self.tpr = tpr self.roc_auc = roc_auc - self.roc_auc_aggregate = roc_auc_aggregate self.name = _deprecate_estimator_name(estimator_name, name, "1.7") self.pos_label = pos_label From 001e00e5080d5cd462174cdb97c830c20b092e32 Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Fri, 11 Apr 2025 10:39:02 +1000 Subject: [PATCH 43/63] guillaume review --- sklearn/metrics/_plot/roc_curve.py | 2 +- sklearn/utils/_plotting.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sklearn/metrics/_plot/roc_curve.py b/sklearn/metrics/_plot/roc_curve.py index 3bff54472347c..5b74214c63fd6 100644 --- a/sklearn/metrics/_plot/roc_curve.py +++ b/sklearn/metrics/_plot/roc_curve.py @@ -247,7 +247,7 @@ def plot( summary_value = (None, None) n_curves = len(self.fpr_) - line_kwargs = self._get_line_kwargs( + line_kwargs = self._validate_line_kwargs( n_curves, self.name_, summary_value, diff --git a/sklearn/utils/_plotting.py b/sklearn/utils/_plotting.py index 5ecba178b465f..5a973d20d13e6 100644 --- a/sklearn/utils/_plotting.py +++ b/sklearn/utils/_plotting.py @@ -181,7 +181,7 @@ def _get_legend_label(cls, curve_summary_value, curve_name, summary_value_name): return label @classmethod - def _get_line_kwargs( + def _validate_line_kwargs( cls, n_curves, name, @@ -197,7 +197,7 @@ def _get_line_kwargs( n_curves : int Number of curves. - name : list[str] or None + name : list of str or None Name for labeling legend entries. summary_value : list[float] or tuple(float, float) or None From 8ec3659a7caba748bd4bc6f90a28b0240c757cee Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Fri, 11 Apr 2025 16:10:36 +1000 Subject: [PATCH 44/63] validate only in plot --- sklearn/metrics/_plot/roc_curve.py | 98 ++++++------ .../_plot/tests/test_roc_curve_display.py | 8 +- sklearn/utils/_plotting.py | 140 +++++++++--------- 3 files changed, 124 insertions(+), 122 deletions(-) diff --git a/sklearn/metrics/_plot/roc_curve.py b/sklearn/metrics/_plot/roc_curve.py index 5b74214c63fd6..05cbf31f2c036 100644 --- a/sklearn/metrics/_plot/roc_curve.py +++ b/sklearn/metrics/_plot/roc_curve.py @@ -1,13 +1,11 @@ # Authors: The scikit-learn developers # SPDX-License-Identifier: BSD-3-Clause -from collections.abc import Mapping import numpy as np from ...utils import _safe_indexing from ...utils._plotting import ( - MULTICURVE_LABELLING_ERROR, _BinaryClassifierCurveDisplayMixin, _check_param_lengths, _convert_to_list_leaving_none, @@ -57,15 +55,15 @@ class RocCurveDisplay(_BinaryClassifierCurveDisplayMixin): Now accepts a list for plotting multiple curves. name : str or list of str, default=None - Name for labeling legend entries. For single ROC curve, should be a string. - For multiple ROC curves, the number of legend entries depend on - the `fold_line_kwargs` passed to `plot`. + Name for labeling legend entries. The number of legend entries + is determined by the `fold_line_kwargs` passed to `plot`. To label each curve, provide a list of strings. To avoid labeling individual curves that have the same appearance, this cannot be used in - conjunction with `fold_line_kwargs` being a list. If a string is - provided, it will be used to either label the single legend entry or if - there are multiple legend entries, label each individual curve with the - same name. If `None`, no name is shown in the legend. + conjunction with `fold_line_kwargs` being a dictionary or None. If a + string is provided, it will be used to either label the single legend entry + or if there are multiple legend entries, label each individual curve with + the same name. If `None`, set to `name` provided at `RocCurveDisplay` + initialization. If still `None`, no name is shown in the legend. .. versionadded:: 1.7 @@ -80,6 +78,7 @@ class RocCurveDisplay(_BinaryClassifierCurveDisplayMixin): Name of estimator. If None, the estimator name is not shown. .. deprecated:: 1.7 + `estimator_name` is deprecated and will be removed in 1.9. Use `name` instead. Attributes ---------- @@ -142,7 +141,7 @@ def __init__( self.name = _deprecate_estimator_name(estimator_name, name, "1.7") self.pos_label = pos_label - def _validate_plot_params(self, *, ax, name, fold_line_kwargs): + def _validate_plot_params(self, *, ax, name): self.ax_, self.figure_, name_ = super()._validate_plot_params(ax=ax, name=name) self.fpr_ = _convert_to_list_leaving_none(self.fpr) @@ -159,13 +158,6 @@ def _validate_plot_params(self, *, ax, name, fold_line_kwargs): class_name="RocCurveDisplay", ) - if ( - isinstance(self.name_, list) - and len(self.name_) != 1 - and (isinstance(fold_line_kwargs, Mapping) or fold_line_kwargs is None) - ): - raise ValueError(MULTICURVE_LABELLING_ERROR.format(n_curves=len(self.fpr_))) - def plot( self, ax=None, @@ -179,9 +171,6 @@ def plot( ): """Plot visualization. - For single curve plots, extra keyword arguments will be passed to - matplotlib's ``plot``. - Parameters ---------- ax : matplotlib axes, default=None @@ -189,13 +178,14 @@ def plot( created. name : str or list of str, default=None - Name for labeling legend entries. + Name for labeling legend entries. The number of legend entries + is determined by `fold_line_kwargs`. To label each curve, provide a list of strings. To avoid labeling individual curves that have the same appearance, this cannot be used in - conjunction with `fold_line_kwargs` being a list. If a string is - provided, it will be used to either label the single legend entry or - if there are multiple legend entries, label each individual curve with the - same name. If `None`, set to `name` provided at `RocCurveDisplay` + conjunction with `fold_line_kwargs` being a dictionary or None. If a + string is provided, it will be used to either label the single legend entry + or if there are multiple legend entries, label each individual curve with + the same name. If `None`, set to `name` provided at `RocCurveDisplay` initialization. If still `None`, no name is shown in the legend. .. versionadded:: 1.7 @@ -218,26 +208,29 @@ def plot( fold_line_kwargs : dict or list of dict, default=None Keywords arguments to be passed to matplotlib's `plot` function - to draw individual ROC curves. If a list is provided the + to draw individual ROC curves. For single curve plotting, should be + a dictionary. For multi-curve plotting, if a list is provided the parameters are applied to the ROC curves of each CV fold sequentially and a legend entry is added for each curve. - If a single dictionary is provided, the same - parameters are applied to all ROC curves and a single legend - entry for all curves is added, labeled with the mean ROC AUC score. + If a single dictionary is provided, the same parameters are applied + to all ROC curves and a single legend entry for all curves is added, + labeled with the mean ROC AUC score. .. versionadded:: 1.7 **kwargs : dict - For a single curve plots only, keyword arguments to be passed to - matplotlib's `plot`. Ignored for multi-curve plots - use `fold_line_kwargs` - for multi-curve plots. + Keyword arguments to be passed to matplotlib's `plot`. + + .. deprecated:: 1.7 + **kwargs is deprecated and will be removed in 1.9. Pass matplotlib + arguments to `fold_line_kwargs` as a dictionary instead. Returns ------- display : :class:`~sklearn.metrics.RocCurveDisplay` Object that stores computed values. """ - self._validate_plot_params(ax=ax, name=name, fold_line_kwargs=fold_line_kwargs) + self._validate_plot_params(ax=ax, name=name) n_curves = len(self.fpr_) summary_value, summary_value_name = self.roc_auc_, "AUC" if not isinstance(fold_line_kwargs, list) and n_curves > 1: @@ -246,7 +239,6 @@ def plot( else: summary_value = (None, None) - n_curves = len(self.fpr_) line_kwargs = self._validate_line_kwargs( n_curves, self.name_, @@ -380,9 +372,18 @@ def from_estimator( .. versionadded:: 1.6 + fold_line_kwargs : dict, default=None + Keywords arguments to be passed to matplotlib's `plot` function. + + .. versionadded:: 1.7 + **kwargs : dict Keyword arguments to be passed to matplotlib's `plot`. + .. deprecated:: 1.7 + **kwargs is deprecated and will be removed in 1.9. Pass matplotlib + arguments to `fold_line_kwargs` as a dictionary instead. + Returns ------- display : :class:`~sklearn.metrics.RocCurveDisplay` @@ -503,9 +504,18 @@ def from_predictions( .. versionadded:: 1.6 + fold_line_kwargs : dict, default=None + Keywords arguments to be passed to matplotlib's `plot` function. + + .. versionadded:: 1.7 + **kwargs : dict Additional keywords arguments passed to matplotlib `plot` function. + .. deprecated:: 1.7 + **kwargs is deprecated and will be removed in 1.9. Pass matplotlib + arguments to `fold_line_kwargs` as a dictionary instead. + Returns ------- display : :class:`~sklearn.metrics.RocCurveDisplay` @@ -625,22 +635,22 @@ def from_cv_results( name : str or list of str, default=None Name for labeling legend entries. The number of legend entries - depends on `fold_line_kwargs`. + is determined by `fold_line_kwargs`. To label each curve, provide a list of strings. To avoid labeling individual curves that have the same appearance, this cannot be used in - conjunction with `fold_line_kwargs` being a list. If a string is - provided, it will be used to either label the single legend entry or - if there are multiple legend entries, label each individual curve with the - same name. If `None`, no name is shown in the legend. + conjunction with `fold_line_kwargs` being a dictionary or None. If a + string is provided, it will be used to either label the single legend entry + or if there are multiple legend entries, label each individual curve with + the same name. If `None`, no name is shown in the legend. fold_line_kwargs : dict or list of dict, default=None Keywords arguments to be passed to matplotlib's `plot` function to draw individual ROC curves. If a list is provided the parameters are applied to the ROC curves of each CV fold sequentially and a legend entry is added for each curve. - If a single dictionary is provided, the same - parameters are applied to all ROC curves and a single legend - entry for all curves is added, labeled with the mean ROC AUC score. + If a single dictionary is provided, the same parameters are applied + to all ROC curves and a single legend entry for all curves is added, + labeled with the mean ROC AUC score. plot_chance_level : bool, default=False Whether to plot the chance level. @@ -681,7 +691,7 @@ def from_cv_results( <...> >>> plt.show() """ - pos_label_, fold_line_kwargs_ = cls._validate_from_cv_results_params( + pos_label_ = cls._validate_from_cv_results_params( cv_results, X, y, @@ -734,5 +744,5 @@ def from_cv_results( plot_chance_level=plot_chance_level, chance_level_kw=chance_level_kwargs, despine=despine, - fold_line_kwargs=fold_line_kwargs_, + fold_line_kwargs=fold_line_kwargs, ) diff --git a/sklearn/metrics/_plot/tests/test_roc_curve_display.py b/sklearn/metrics/_plot/tests/test_roc_curve_display.py index 6a13421f67e4e..cb89a812abd51 100644 --- a/sklearn/metrics/_plot/tests/test_roc_curve_display.py +++ b/sklearn/metrics/_plot/tests/test_roc_curve_display.py @@ -159,7 +159,7 @@ def test_roc_curve_from_cv_results_param_validation(pyplot, data_binary, data): for cv_results in (cv_results_no_est, cv_results_no_indices): with pytest.raises( ValueError, - match="'cv_results' does not contain one of the following required", + match="`cv_results` does not contain one of the following required", ): RocCurveDisplay.from_cv_results(cv_results, X, y) @@ -168,12 +168,12 @@ def test_roc_curve_from_cv_results_param_validation(pyplot, data_binary, data): ) # `X` wrong length - with pytest.raises(ValueError, match="'X' does not contain the correct"): + with pytest.raises(ValueError, match="`X` does not contain the correct"): RocCurveDisplay.from_cv_results(cv_results, X[:10, :], y) # `y` not binary X_mutli, y_multi = data - with pytest.raises(ValueError, match="The target y is not binary."): + with pytest.raises(ValueError, match="The target `y` is not binary."): RocCurveDisplay.from_cv_results(cv_results, X, y_multi) # input inconsistent length @@ -250,7 +250,7 @@ def test_roc_curve_display_estimator_name_deprecation(pyplot): """Check deprecation of `estimator_name`.""" fpr = np.array([0, 0.5, 1]) tpr = np.array([0, 0.5, 1]) - with pytest.warns(FutureWarning, match="'estimator_name' is deprecated in"): + with pytest.warns(FutureWarning, match="`estimator_name` is deprecated in"): RocCurveDisplay(fpr=fpr, tpr=tpr, estimator_name="test") diff --git a/sklearn/utils/_plotting.py b/sklearn/utils/_plotting.py index 5a973d20d13e6..7babe6f7a018b 100644 --- a/sklearn/utils/_plotting.py +++ b/sklearn/utils/_plotting.py @@ -12,17 +12,6 @@ from .multiclass import type_of_target from .validation import _check_pos_label_consistency, _num_samples -MULTICURVE_LABELLING_ERROR = ( - "To avoid labeling individual curves that have the same appearance, " - "`fold_line_kwargs` should be a list of {n_curves} dictionaries. Alternatively, " - "set `name` to `None` or a single string to add a single legend entry with mean " - "ROC AUC score of all curves." -) - -CURVE_KWARGS_ERROR = ( - "`fold_line_kwargs` must be None, a list of length {n_curves} or a dictionary." -) - class _BinaryClassifierCurveDisplayMixin: """Mixin class to be used in Displays requiring a binary classifier. @@ -97,9 +86,10 @@ def _validate_from_cv_results_params( required_keys = {"estimator", "indices"} if not all(key in cv_results for key in required_keys): raise ValueError( - "'cv_results' does not contain one of the following required keys: " - f"{required_keys}. Set explicitly the parameters return_estimator=True " - "and return_indices=True to the function cross_validate." + "`cv_results` does not contain one of the following required keys: " + f"{required_keys}. Set explicitly the parameters " + "`return_estimator=True` and `return_indices=True` to the function" + "`cross_validate`." ) train_size, test_size = ( @@ -109,13 +99,13 @@ def _validate_from_cv_results_params( if _num_samples(X) != train_size + test_size: raise ValueError( - "'X' does not contain the correct number of samples. " + "`X` does not contain the correct number of samples. " f"Expected {train_size + test_size}, got {_num_samples(X)}." ) if type_of_target(y) != "binary": raise ValueError( - f"The target y is not binary. Got {type_of_target(y)} type of" + f"The target `y` is not binary. Got {type_of_target(y)} type of" " target." ) check_consistent_length(X, y, sample_weight) @@ -127,46 +117,37 @@ def _validate_from_cv_results_params( raise ValueError(str(e).replace("y_true", "y")) n_curves = len(cv_results["estimator"]) + # NB: Both these also checked in `plot`, but thought it best to fail earlier. + cls._validate_multi_fold_line_kwargs(cls, fold_line_kwargs, name, n_curves) + if isinstance(name, list) and len(name) not in (1, n_curves): + raise ValueError( + f"`name` must be None, a list of length {n_curves} or a single " + f"string. Got list of length: {len(name)}." + ) + + return pos_label + + def _validate_multi_fold_line_kwargs(cls, fold_line_kwargs, name, n_curves): + """Check `fold_line_kwargs`, including combination with `name`, is valid.""" + if isinstance(fold_line_kwargs, list) and len(fold_line_kwargs) != n_curves: + raise ValueError( + f"`fold_line_kwargs` must be None, a list of length {n_curves} or a " + f"dictionary. Got: {fold_line_kwargs}." + ) + + # Ensure valid `name` and `fold_line_kwargs` combination. if ( isinstance(name, list) and len(name) != 1 and (isinstance(fold_line_kwargs, Mapping) or fold_line_kwargs is None) ): - raise ValueError(MULTICURVE_LABELLING_ERROR.format(n_curves=n_curves)) - else: - if isinstance(name, list) and len(name) not in (1, n_curves): - raise ValueError( - f"`name` must be a list of length {n_curves} or a string. " - f"Got list of length: {len(name)}." - ) - - # Add `default_curve_kwargs` to `fold_line_kwargs` - default_curve_kwargs = {"alpha": 0.5, "linestyle": "--"} - if fold_line_kwargs is None: - fold_line_kwargs_ = default_curve_kwargs - elif isinstance(fold_line_kwargs, Mapping): - fold_line_kwargs_ = _validate_style_kwargs( - default_curve_kwargs, fold_line_kwargs - ) - elif isinstance(fold_line_kwargs, list): - if len(fold_line_kwargs) != n_curves: - raise ValueError( - CURVE_KWARGS_ERROR.format(n_curves=n_curves) - + f" Got list of length: {len(fold_line_kwargs)}." - ) - else: - fold_line_kwargs_ = [ - _validate_style_kwargs(default_curve_kwargs, single_kwargs) - for single_kwargs in fold_line_kwargs - ] - else: raise ValueError( - CURVE_KWARGS_ERROR.format(n_curves=n_curves) - + f" Got: {fold_line_kwargs}." + "To avoid labeling individual curves that have the same appearance, " + f"`fold_line_kwargs` should be a list of {n_curves} dictionaries. " + "Alternatively, set `name` to `None` or a single string to label " + "a single legend entry with mean ROC AUC score of all curves." ) - return pos_label, fold_line_kwargs_ - @classmethod def _get_legend_label(cls, curve_summary_value, curve_name, summary_value_name): """Helper to get legend label using `name_` and `summary_value_`""" @@ -200,9 +181,10 @@ def _validate_line_kwargs( name : list of str or None Name for labeling legend entries. - summary_value : list[float] or tuple(float, float) or None - Either list of `n_curves` summary values for each curve (e.g., ROC AUC, - average precision) or a single float summary value for all curves. + summary_value : list of float or tuple of float or None + A list of `n_curves` summary values for each curve (e.g., ROC AUC, + average precision) or a tuple of mean and standard deviation values for + all curves or None. summary_value_name : str or None Name of the summary value provided in `summary_values`. @@ -216,30 +198,41 @@ def _validate_line_kwargs( single curve plots. **kwargs : dict - For a single curve plots only, keyword arguments to be passed to - matplotlib's `plot`. Ignored for multi-curve plots - use `fold_line_kwargs` - for multi-curve plots. + Deprecated. Keyword arguments to be passed to matplotlib's `plot`. """ + # Deprecate **kwargs + if fold_line_kwargs and kwargs: + raise ValueError( + "Cannot provide both `fold_line_kwargs` and `kwargs`. `**kwargs` is " + "deprecated in 1.7 and will be removed in 1.9. Pass all matplotlib " + "arguments to `fold_line_kwargs` as a dictionary." + ) + if kwargs: + warnings.warn( + "`**kwargs` is deprecated and will be removed in 1.9. Pass all " + "matplotlib arguments to `fold_line_kwargs` as a dictionary instead.", + FutureWarning, + ) + fold_line_kwargs = kwargs + + cls._validate_multi_fold_line_kwargs(cls, fold_line_kwargs, name, n_curves) + # Ensure parameters are of the correct length if isinstance(name, list) and len(name) == 1: name_ = name * n_curves name_ = [None] * n_curves if name is None else name summary_value_ = [None] * n_curves if summary_value is None else summary_value - # `fold_line_kwargs` ignored for single curve plots - # `kwargs` ignored for multi-curve plots - if n_curves == 1: - fold_line_kwargs = [kwargs] - else: - # Ensure `fold_line_kwargs` is of correct length + + # Ensure `fold_line_kwargs` is of correct length + if fold_line_kwargs is None: + fold_line_kwargs = [{}] * n_curves + if isinstance(fold_line_kwargs, Mapping): + fold_line_kwargs = [fold_line_kwargs] * n_curves + + default_multi_curve_kwargs = {"alpha": 0.5, "linestyle": "--", "color": "blue"} + if n_curves > 1: if fold_line_kwargs is None: - fold_line_kwargs = [{}] * n_curves - elif isinstance(fold_line_kwargs, Mapping): - fold_line_kwargs = [fold_line_kwargs] * n_curves - elif len(fold_line_kwargs) != n_curves: - raise ValueError( - CURVE_KWARGS_ERROR.format(n_curves=n_curves) - + f" Got list of length: {len(fold_line_kwargs)}." - ) + fold_line_kwargs = [default_multi_curve_kwargs] * n_curves labels = [] if isinstance(summary_value_, tuple): @@ -395,14 +388,13 @@ def _deprecate_estimator_name(old, new, version): if old != "deprecated": if new: raise ValueError( - f"Both 'estimator_name' and 'name' provided, please only use 'name' " - f"as 'estimator_name' is deprecated in {version} and will be removed " - f"in {version_remove}." + "Cannot provide both `estimator_name` and `name`. `estimator_name` " + f"is deprecated in {version} and will be removed in {version_remove}. " + "Use `name` only." ) warnings.warn( - f"'estimator_name' is deprecated in {version} and will be removed in " - f"{version_remove}. The value of 'estimator_name' was passed to 'name'" - "but please use 'name' in future.", + f"`estimator_name` is deprecated in {version} and will be removed in " + f"{version_remove}. Use `name` instead.", FutureWarning, ) return old From cda029667599e5f471f64b555a37f47ce4d4f4b8 Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Fri, 11 Apr 2025 16:30:58 +1000 Subject: [PATCH 45/63] name change to curve_kwargs --- sklearn/metrics/_plot/roc_curve.py | 48 ++++++------- .../_plot/tests/test_roc_curve_display.py | 72 +++++++++---------- sklearn/utils/_plotting.py | 64 ++++++++--------- 3 files changed, 90 insertions(+), 94 deletions(-) diff --git a/sklearn/metrics/_plot/roc_curve.py b/sklearn/metrics/_plot/roc_curve.py index 05cbf31f2c036..72d5ab09441e0 100644 --- a/sklearn/metrics/_plot/roc_curve.py +++ b/sklearn/metrics/_plot/roc_curve.py @@ -56,10 +56,10 @@ class RocCurveDisplay(_BinaryClassifierCurveDisplayMixin): name : str or list of str, default=None Name for labeling legend entries. The number of legend entries - is determined by the `fold_line_kwargs` passed to `plot`. + is determined by the `curve_kwargs` passed to `plot`. To label each curve, provide a list of strings. To avoid labeling individual curves that have the same appearance, this cannot be used in - conjunction with `fold_line_kwargs` being a dictionary or None. If a + conjunction with `curve_kwargs` being a dictionary or None. If a string is provided, it will be used to either label the single legend entry or if there are multiple legend entries, label each individual curve with the same name. If `None`, set to `name` provided at `RocCurveDisplay` @@ -166,7 +166,7 @@ def plot( plot_chance_level=False, chance_level_kw=None, despine=False, - fold_line_kwargs=None, + curve_kwargs=None, **kwargs, ): """Plot visualization. @@ -179,10 +179,10 @@ def plot( name : str or list of str, default=None Name for labeling legend entries. The number of legend entries - is determined by `fold_line_kwargs`. + is determined by `curve_kwargs`. To label each curve, provide a list of strings. To avoid labeling individual curves that have the same appearance, this cannot be used in - conjunction with `fold_line_kwargs` being a dictionary or None. If a + conjunction with `curve_kwargs` being a dictionary or None. If a string is provided, it will be used to either label the single legend entry or if there are multiple legend entries, label each individual curve with the same name. If `None`, set to `name` provided at `RocCurveDisplay` @@ -206,7 +206,7 @@ def plot( .. versionadded:: 1.6 - fold_line_kwargs : dict or list of dict, default=None + curve_kwargs : dict or list of dict, default=None Keywords arguments to be passed to matplotlib's `plot` function to draw individual ROC curves. For single curve plotting, should be a dictionary. For multi-curve plotting, if a list is provided the @@ -223,7 +223,7 @@ def plot( .. deprecated:: 1.7 **kwargs is deprecated and will be removed in 1.9. Pass matplotlib - arguments to `fold_line_kwargs` as a dictionary instead. + arguments to `curve_kwargs` as a dictionary instead. Returns ------- @@ -233,18 +233,18 @@ def plot( self._validate_plot_params(ax=ax, name=name) n_curves = len(self.fpr_) summary_value, summary_value_name = self.roc_auc_, "AUC" - if not isinstance(fold_line_kwargs, list) and n_curves > 1: + if not isinstance(curve_kwargs, list) and n_curves > 1: if self.roc_auc_: summary_value = (np.mean(self.roc_auc_), np.std(self.roc_auc_)) else: summary_value = (None, None) - line_kwargs = self._validate_line_kwargs( + curve_kwargs = self._validate_curve_kwargs( n_curves, self.name_, summary_value, summary_value_name, - fold_line_kwargs=fold_line_kwargs, + curve_kwargs=curve_kwargs, **kwargs, ) @@ -262,7 +262,7 @@ def plot( ) self.line_ = [] - for fpr, tpr, line_kw in zip(self.fpr_, self.tpr_, line_kwargs): + for fpr, tpr, line_kw in zip(self.fpr_, self.tpr_, curve_kwargs): self.line_.extend(self.ax_.plot(fpr, tpr, **line_kw)) # Return single artist if only one curve is plotted if len(self.line_) == 1: @@ -290,7 +290,7 @@ def plot( if despine: _despine(self.ax_) - if line_kwargs[0].get("label") is not None or ( + if curve_kwargs[0].get("label") is not None or ( plot_chance_level and chance_level_kw.get("label") is not None ): self.ax_.legend(loc="lower right") @@ -372,7 +372,7 @@ def from_estimator( .. versionadded:: 1.6 - fold_line_kwargs : dict, default=None + curve_kwargs : dict, default=None Keywords arguments to be passed to matplotlib's `plot` function. .. versionadded:: 1.7 @@ -381,8 +381,8 @@ def from_estimator( Keyword arguments to be passed to matplotlib's `plot`. .. deprecated:: 1.7 - **kwargs is deprecated and will be removed in 1.9. Pass matplotlib - arguments to `fold_line_kwargs` as a dictionary instead. + **kwargs is deprecated and will be removed in 1.9. Pass matplotlib + arguments to `curve_kwargs` as a dictionary instead. Returns ------- @@ -504,7 +504,7 @@ def from_predictions( .. versionadded:: 1.6 - fold_line_kwargs : dict, default=None + curve_kwargs : dict, default=None Keywords arguments to be passed to matplotlib's `plot` function. .. versionadded:: 1.7 @@ -513,8 +513,8 @@ def from_predictions( Additional keywords arguments passed to matplotlib `plot` function. .. deprecated:: 1.7 - **kwargs is deprecated and will be removed in 1.9. Pass matplotlib - arguments to `fold_line_kwargs` as a dictionary instead. + **kwargs is deprecated and will be removed in 1.9. Pass matplotlib + arguments to `curve_kwargs` as a dictionary instead. Returns ------- @@ -587,7 +587,7 @@ def from_cv_results( pos_label=None, ax=None, name=None, - fold_line_kwargs=None, + curve_kwargs=None, plot_chance_level=False, chance_level_kwargs=None, despine=False, @@ -635,15 +635,15 @@ def from_cv_results( name : str or list of str, default=None Name for labeling legend entries. The number of legend entries - is determined by `fold_line_kwargs`. + is determined by `curve_kwargs`. To label each curve, provide a list of strings. To avoid labeling individual curves that have the same appearance, this cannot be used in - conjunction with `fold_line_kwargs` being a dictionary or None. If a + conjunction with `curve_kwargs` being a dictionary or None. If a string is provided, it will be used to either label the single legend entry or if there are multiple legend entries, label each individual curve with the same name. If `None`, no name is shown in the legend. - fold_line_kwargs : dict or list of dict, default=None + curve_kwargs : dict or list of dict, default=None Keywords arguments to be passed to matplotlib's `plot` function to draw individual ROC curves. If a list is provided the parameters are applied to the ROC curves of each CV fold @@ -698,7 +698,7 @@ def from_cv_results( sample_weight=sample_weight, pos_label=pos_label, name=name, - fold_line_kwargs=fold_line_kwargs, + curve_kwargs=curve_kwargs, ) fpr_all = [] @@ -744,5 +744,5 @@ def from_cv_results( plot_chance_level=plot_chance_level, chance_level_kw=chance_level_kwargs, despine=despine, - fold_line_kwargs=fold_line_kwargs, + curve_kwargs=curve_kwargs, ) diff --git a/sklearn/metrics/_plot/tests/test_roc_curve_display.py b/sklearn/metrics/_plot/tests/test_roc_curve_display.py index cb89a812abd51..f3b88b2125494 100644 --- a/sklearn/metrics/_plot/tests/test_roc_curve_display.py +++ b/sklearn/metrics/_plot/tests/test_roc_curve_display.py @@ -187,38 +187,36 @@ def test_roc_curve_from_cv_results_param_validation(pyplot, data_binary, data): with pytest.raises(ValueError, match=r"y takes value in \{1, 2\}"): RocCurveDisplay.from_cv_results(cv_results, X_bad_pos_label, y_bad_pos_label) - # `name` is list while `fold_line_kwargs` is None or dict - for fold_line_kwargs in (None, {"alpha": 0.2}): + # `name` is list while `curve_kwargs` is None or dict + for curve_kwargs in (None, {"alpha": 0.2}): with pytest.raises(ValueError, match="To avoid labeling individual curves"): RocCurveDisplay.from_cv_results( cv_results, X, y, name=["one", "two", "three"], - fold_line_kwargs=fold_line_kwargs, + curve_kwargs=curve_kwargs, ) - # `fold_line_kwargs` incorrect length - with pytest.raises(ValueError, match="`fold_line_kwargs` must be None, a list"): - RocCurveDisplay.from_cv_results( - cv_results, X, y, fold_line_kwargs=[{"alpha": 1}] - ) + # `curve_kwargs` incorrect length + with pytest.raises(ValueError, match="`curve_kwargs` must be None, a list"): + RocCurveDisplay.from_cv_results(cv_results, X, y, curve_kwargs=[{"alpha": 1}]) - # `fold_line_kwargs` both alias provided + # `curve_kwargs` both alias provided with pytest.raises(TypeError, match="Got both c and"): RocCurveDisplay.from_cv_results( - cv_results, X, y, fold_line_kwargs={"c": "blue", "color": "red"} + cv_results, X, y, curve_kwargs={"c": "blue", "color": "red"} ) # @pytest.mark.parametrize( -# "fold_line_kwargs", +# "curve_kwargs", # [None, {"alpha": 0.2}, [{"alpha": 0.2}, {"alpha": 0.3}, {"alpha": 0.4}]], # ) -# def test_roc_curve_display_from_cv_results_validate_line_kwargs( -# pyplot, data_binary, fold_line_kwargs +# def test_roc_curve_display_from_cv_results_validate_curve_kwargs( +# pyplot, data_binary, curve_kwargs # ): -# """Check `_validate_line_kwargs` correctly validates line kwargs.""" +# """Check `_validate_curve_kwargs` correctly validates line kwargs.""" # X, y = data_binary # n_cv = 3 # cv_results = cross_validate( @@ -229,18 +227,18 @@ def test_roc_curve_from_cv_results_param_validation(pyplot, data_binary, data): # cv_results, # X, # y, -# fold_line_kwargs=fold_line_kwargs, +# curve_kwargs=curve_kwargs, # ) -# if fold_line_kwargs is None: +# if curve_kwargs is None: # # Default `alpha` used # assert all(line.get_alpha() == 0.5 for line in display.line_) -# elif isinstance(fold_line_kwargs, Mapping): +# elif isinstance(curve_kwargs, Mapping): # # `alpha` from dict used for all curves # assert all(line.get_alpha() == 0.2 for line in display.line_) # else: # # Different `alpha` used for each curve # assert all( -# line.get_alpha() == fold_line_kwargs[i]["alpha"] +# line.get_alpha() == curve_kwargs[i]["alpha"] # for i, line in enumerate(display.line_) # ) @@ -255,7 +253,7 @@ def test_roc_curve_display_estimator_name_deprecation(pyplot): @pytest.mark.parametrize( - "fold_line_kwargs", + "curve_kwargs", [ None, {"color": "blue"}, @@ -273,7 +271,7 @@ def test_roc_curve_display_plotting_from_cv_results( with_sample_weight, response_method, drop_intermediate, - fold_line_kwargs, + curve_kwargs, ): """Check overall plotting of `from_cv_results`.""" X, y = data_binary @@ -300,7 +298,7 @@ def test_roc_curve_display_plotting_from_cv_results( drop_intermediate=drop_intermediate, response_method=response_method, pos_label=pos_label, - fold_line_kwargs=fold_line_kwargs, + curve_kwargs=curve_kwargs, ) for idx, (estimator, test_indices) in enumerate( @@ -339,7 +337,7 @@ def test_roc_curve_display_plotting_from_cv_results( assert isinstance(line, mpl.lines.Line2D) # Default alpha for `from_cv_results` line.get_alpha() == 0.5 - if isinstance(fold_line_kwargs, list): + if isinstance(curve_kwargs, list): # Each individual curve labelled assert line.get_label() == f"AUC = {display.roc_auc_[idx]:.2f}" else: @@ -366,30 +364,28 @@ def test_roc_curve_display_plotting_from_cv_results( @pytest.mark.parametrize( - "fold_line_kwargs", + "curve_kwargs", [None, {"color": "red"}, [{"c": "red"}, {"c": "green"}, {"c": "yellow"}]], ) -def test_roc_curve_from_cv_results_line_kwargs(pyplot, data_binary, fold_line_kwargs): +def test_roc_curve_from_cv_results_curve_kwargs(pyplot, data_binary, curve_kwargs): """Check line kwargs passed correctly in `from_cv_results`.""" - import matplotlib as mpl X, y = data_binary cv_results = cross_validate( LogisticRegression(), X, y, cv=3, return_estimator=True, return_indices=True ) display = RocCurveDisplay.from_cv_results( - cv_results, X, y, fold_line_kwargs=fold_line_kwargs + cv_results, X, y, curve_kwargs=curve_kwargs ) - mpl_default_colors = mpl.rcParams["axes.prop_cycle"].by_key()["color"] for idx, line in enumerate(display.line_): color = line.get_color() - if fold_line_kwargs is None: - assert color == mpl_default_colors[idx] - elif isinstance(fold_line_kwargs, Mapping): + if curve_kwargs is None: + assert color == "blue" + elif isinstance(curve_kwargs, Mapping): assert color == "red" else: - assert color == fold_line_kwargs[idx]["c"] + assert color == curve_kwargs[idx]["c"] def _check_chance_level(plot_chance_level, chance_level_kw, display): @@ -509,13 +505,13 @@ def test_roc_curve_chance_level_line( ], ) # To ensure both curve line kwargs and change line kwargs passed correctly -@pytest.mark.parametrize("fold_line_kwargs", [None, {"alpha": 0.8}]) +@pytest.mark.parametrize("curve_kwargs", [None, {"alpha": 0.8}]) def test_roc_curve_chance_level_line_from_cv_results( pyplot, data_binary, plot_chance_level, chance_level_kw, - fold_line_kwargs, + curve_kwargs, ): """Check chance level plotting behavior with `from_cv_results`.""" X, y = data_binary @@ -530,13 +526,13 @@ def test_roc_curve_chance_level_line_from_cv_results( y, plot_chance_level=plot_chance_level, chance_level_kwargs=chance_level_kw, - fold_line_kwargs=fold_line_kwargs, + curve_kwargs=curve_kwargs, ) import matplotlib as mpl assert all(isinstance(line, mpl.lines.Line2D) for line in display.line_) - if fold_line_kwargs: + if curve_kwargs: assert all(line.get_alpha() == 0.8 for line in display.line_) assert isinstance(display.ax_, mpl.axes.Axes) assert isinstance(display.figure_, mpl.figure.Figure) @@ -589,7 +585,7 @@ def test_roc_curve_display_complex_pipeline(pyplot, data_binary, clf, constructo @pytest.mark.parametrize( - "roc_auc, name, fold_line_kwargs, expected_labels", + "roc_auc, name, curve_kwargs, expected_labels", [ ([0.9, 0.8], None, None, ["AUC = 0.85 +/- 0.05", "_child1"]), ([0.9, 0.8], "Est name", None, ["Est name (AUC = 0.85 +/- 0.05)", "_child1"]), @@ -603,13 +599,13 @@ def test_roc_curve_display_complex_pipeline(pyplot, data_binary, clf, constructo ], ) def test_roc_curve_display_default_labels( - pyplot, roc_auc, name, fold_line_kwargs, expected_labels + pyplot, roc_auc, name, curve_kwargs, expected_labels ): """Check the default labels used in the display.""" fpr = [np.array([0, 0.5, 1]), np.array([0, 0.3, 1])] tpr = [np.array([0, 0.5, 1]), np.array([0, 0.3, 1])] disp = RocCurveDisplay(fpr=fpr, tpr=tpr, roc_auc=roc_auc, name=name).plot( - fold_line_kwargs=fold_line_kwargs + curve_kwargs=curve_kwargs ) for idx, expected_label in enumerate(expected_labels): assert disp.line_[idx].get_label() == expected_label diff --git a/sklearn/utils/_plotting.py b/sklearn/utils/_plotting.py index 7babe6f7a018b..c80ba61dc9c34 100644 --- a/sklearn/utils/_plotting.py +++ b/sklearn/utils/_plotting.py @@ -79,7 +79,7 @@ def _validate_from_cv_results_params( sample_weight, pos_label, name, - fold_line_kwargs, + curve_kwargs, ): check_matplotlib_support(f"{cls.__name__}.from_predictions") @@ -118,7 +118,7 @@ def _validate_from_cv_results_params( n_curves = len(cv_results["estimator"]) # NB: Both these also checked in `plot`, but thought it best to fail earlier. - cls._validate_multi_fold_line_kwargs(cls, fold_line_kwargs, name, n_curves) + cls._validate_multi_curve_kwargs(cls, curve_kwargs, name, n_curves) if isinstance(name, list) and len(name) not in (1, n_curves): raise ValueError( f"`name` must be None, a list of length {n_curves} or a single " @@ -127,23 +127,23 @@ def _validate_from_cv_results_params( return pos_label - def _validate_multi_fold_line_kwargs(cls, fold_line_kwargs, name, n_curves): - """Check `fold_line_kwargs`, including combination with `name`, is valid.""" - if isinstance(fold_line_kwargs, list) and len(fold_line_kwargs) != n_curves: + def _validate_multi_curve_kwargs(cls, curve_kwargs, name, n_curves): + """Check `curve_kwargs`, including combination with `name`, is valid.""" + if isinstance(curve_kwargs, list) and len(curve_kwargs) != n_curves: raise ValueError( - f"`fold_line_kwargs` must be None, a list of length {n_curves} or a " - f"dictionary. Got: {fold_line_kwargs}." + f"`curve_kwargs` must be None, a list of length {n_curves} or a " + f"dictionary. Got: {curve_kwargs}." ) - # Ensure valid `name` and `fold_line_kwargs` combination. + # Ensure valid `name` and `curve_kwargs` combination. if ( isinstance(name, list) and len(name) != 1 - and (isinstance(fold_line_kwargs, Mapping) or fold_line_kwargs is None) + and (isinstance(curve_kwargs, Mapping) or curve_kwargs is None) ): raise ValueError( "To avoid labeling individual curves that have the same appearance, " - f"`fold_line_kwargs` should be a list of {n_curves} dictionaries. " + f"`curve_kwargs` should be a list of {n_curves} dictionaries. " "Alternatively, set `name` to `None` or a single string to label " "a single legend entry with mean ROC AUC score of all curves." ) @@ -162,13 +162,13 @@ def _get_legend_label(cls, curve_summary_value, curve_name, summary_value_name): return label @classmethod - def _validate_line_kwargs( + def _validate_curve_kwargs( cls, n_curves, name, summary_value, summary_value_name, - fold_line_kwargs, + curve_kwargs, **kwargs, ): """Get validated line kwargs for each curve. @@ -189,7 +189,7 @@ def _validate_line_kwargs( summary_value_name : str or None Name of the summary value provided in `summary_values`. - fold_line_kwargs : dict or list of dict + curve_kwargs : dict or list of dict Dictionary with keywords passed to the matplotlib's `plot` function to draw the individual ROC curves. If a list is provided, the parameters are applied to the ROC curves sequentially. If a single @@ -201,38 +201,38 @@ def _validate_line_kwargs( Deprecated. Keyword arguments to be passed to matplotlib's `plot`. """ # Deprecate **kwargs - if fold_line_kwargs and kwargs: + if curve_kwargs and kwargs: raise ValueError( - "Cannot provide both `fold_line_kwargs` and `kwargs`. `**kwargs` is " + "Cannot provide both `curve_kwargs` and `kwargs`. `**kwargs` is " "deprecated in 1.7 and will be removed in 1.9. Pass all matplotlib " - "arguments to `fold_line_kwargs` as a dictionary." + "arguments to `curve_kwargs` as a dictionary." ) if kwargs: warnings.warn( "`**kwargs` is deprecated and will be removed in 1.9. Pass all " - "matplotlib arguments to `fold_line_kwargs` as a dictionary instead.", + "matplotlib arguments to `curve_kwargs` as a dictionary instead.", FutureWarning, ) - fold_line_kwargs = kwargs + curve_kwargs = kwargs - cls._validate_multi_fold_line_kwargs(cls, fold_line_kwargs, name, n_curves) + cls._validate_multi_curve_kwargs(cls, curve_kwargs, name, n_curves) - # Ensure parameters are of the correct length + # Ensure `name` is of the correct length if isinstance(name, list) and len(name) == 1: name_ = name * n_curves name_ = [None] * n_curves if name is None else name summary_value_ = [None] * n_curves if summary_value is None else summary_value - # Ensure `fold_line_kwargs` is of correct length - if fold_line_kwargs is None: - fold_line_kwargs = [{}] * n_curves - if isinstance(fold_line_kwargs, Mapping): - fold_line_kwargs = [fold_line_kwargs] * n_curves + # Ensure `curve_kwargs` is of correct length + if isinstance(curve_kwargs, Mapping): + curve_kwargs = [curve_kwargs] * n_curves default_multi_curve_kwargs = {"alpha": 0.5, "linestyle": "--", "color": "blue"} - if n_curves > 1: - if fold_line_kwargs is None: - fold_line_kwargs = [default_multi_curve_kwargs] * n_curves + if curve_kwargs is None: + if n_curves > 1: + curve_kwargs = [default_multi_curve_kwargs] * n_curves + else: + curve_kwargs = [{}] labels = [] if isinstance(summary_value_, tuple): @@ -257,13 +257,13 @@ def _validate_line_kwargs( ) ) - line_kwargs = [] + curve_kwargs_ = [] for fold_idx, label in enumerate(labels): label_kwarg = {"label": label} - line_kwargs.append( - _validate_style_kwargs(label_kwarg, fold_line_kwargs[fold_idx]) + curve_kwargs_.append( + _validate_style_kwargs(label_kwarg, curve_kwargs[fold_idx]) ) - return line_kwargs + return curve_kwargs_ def _validate_score_name(score_name, scoring, negate_score): From 3cc2cee6b4fa84e182bd84ac5f693b65924f9e49 Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Fri, 11 Apr 2025 17:34:01 +1000 Subject: [PATCH 46/63] pass curve_kwarg properly and order in docstring --- sklearn/metrics/_plot/roc_curve.py | 85 ++++++++++--------- .../_plot/tests/test_roc_curve_display.py | 60 ++++++------- sklearn/utils/_plotting.py | 2 +- 3 files changed, 77 insertions(+), 70 deletions(-) diff --git a/sklearn/metrics/_plot/roc_curve.py b/sklearn/metrics/_plot/roc_curve.py index 72d5ab09441e0..f4a7dbf4b5e23 100644 --- a/sklearn/metrics/_plot/roc_curve.py +++ b/sklearn/metrics/_plot/roc_curve.py @@ -142,31 +142,32 @@ def __init__( self.pos_label = pos_label def _validate_plot_params(self, *, ax, name): - self.ax_, self.figure_, name_ = super()._validate_plot_params(ax=ax, name=name) + self.ax_, self.figure_, name = super()._validate_plot_params(ax=ax, name=name) - self.fpr_ = _convert_to_list_leaving_none(self.fpr) - self.tpr_ = _convert_to_list_leaving_none(self.tpr) - self.roc_auc_ = _convert_to_list_leaving_none(self.roc_auc) - self.name_ = _convert_to_list_leaving_none(name_) + fpr = _convert_to_list_leaving_none(self.fpr) + tpr = _convert_to_list_leaving_none(self.tpr) + roc_auc = _convert_to_list_leaving_none(self.roc_auc) + name = _convert_to_list_leaving_none(name) - optional = {"self.roc_auc": self.roc_auc_} - if isinstance(self.name_, list) and len(self.name_) != 1: - optional.update({"'name' (or self.name)": self.name_}) + optional = {"self.roc_auc": roc_auc} + if isinstance(name, list) and len(name) != 1: + optional.update({"'name' (or self.name)": name}) _check_param_lengths( - required={"self.fpr": self.fpr_, "self.tpr": self.tpr_}, + required={"self.fpr": fpr, "self.tpr": tpr}, optional=optional, class_name="RocCurveDisplay", ) + return fpr, tpr, roc_auc, name def plot( self, ax=None, *, name=None, + curve_kwargs=None, plot_chance_level=False, chance_level_kw=None, despine=False, - curve_kwargs=None, **kwargs, ): """Plot visualization. @@ -190,6 +191,18 @@ def plot( .. versionadded:: 1.7 + curve_kwargs : dict or list of dict, default=None + Keywords arguments to be passed to matplotlib's `plot` function + to draw individual ROC curves. For single curve plotting, should be + a dictionary. For multi-curve plotting, if a list is provided the + parameters are applied to the ROC curves of each CV fold + sequentially and a legend entry is added for each curve. + If a single dictionary is provided, the same parameters are applied + to all ROC curves and a single legend entry for all curves is added, + labeled with the mean ROC AUC score. + + .. versionadded:: 1.7 + plot_chance_level : bool, default=False Whether to plot the chance level. @@ -206,18 +219,6 @@ def plot( .. versionadded:: 1.6 - curve_kwargs : dict or list of dict, default=None - Keywords arguments to be passed to matplotlib's `plot` function - to draw individual ROC curves. For single curve plotting, should be - a dictionary. For multi-curve plotting, if a list is provided the - parameters are applied to the ROC curves of each CV fold - sequentially and a legend entry is added for each curve. - If a single dictionary is provided, the same parameters are applied - to all ROC curves and a single legend entry for all curves is added, - labeled with the mean ROC AUC score. - - .. versionadded:: 1.7 - **kwargs : dict Keyword arguments to be passed to matplotlib's `plot`. @@ -230,18 +231,18 @@ def plot( display : :class:`~sklearn.metrics.RocCurveDisplay` Object that stores computed values. """ - self._validate_plot_params(ax=ax, name=name) - n_curves = len(self.fpr_) - summary_value, summary_value_name = self.roc_auc_, "AUC" + fpr, tpr, roc_auc, name = self._validate_plot_params(ax=ax, name=name) + n_curves = len(fpr) + summary_value, summary_value_name = roc_auc, "AUC" if not isinstance(curve_kwargs, list) and n_curves > 1: - if self.roc_auc_: - summary_value = (np.mean(self.roc_auc_), np.std(self.roc_auc_)) + if roc_auc: + summary_value = (np.mean(roc_auc), np.std(roc_auc)) else: summary_value = (None, None) curve_kwargs = self._validate_curve_kwargs( n_curves, - self.name_, + name, summary_value, summary_value_name, curve_kwargs=curve_kwargs, @@ -262,7 +263,7 @@ def plot( ) self.line_ = [] - for fpr, tpr, line_kw in zip(self.fpr_, self.tpr_, curve_kwargs): + for fpr, tpr, line_kw in zip(fpr, tpr, curve_kwargs): self.line_.extend(self.ax_.plot(fpr, tpr, **line_kw)) # Return single artist if only one curve is plotted if len(self.line_) == 1: @@ -310,6 +311,7 @@ def from_estimator( pos_label=None, name=None, ax=None, + curve_kwargs=None, plot_chance_level=False, chance_level_kw=None, despine=False, @@ -356,6 +358,11 @@ def from_estimator( ax : matplotlib axes, default=None Axes object to plot on. If `None`, a new figure and axes is created. + curve_kwargs : dict, default=None + Keywords arguments to be passed to matplotlib's `plot` function. + + .. versionadded:: 1.7 + plot_chance_level : bool, default=False Whether to plot the chance level. @@ -372,11 +379,6 @@ def from_estimator( .. versionadded:: 1.6 - curve_kwargs : dict, default=None - Keywords arguments to be passed to matplotlib's `plot` function. - - .. versionadded:: 1.7 - **kwargs : dict Keyword arguments to be passed to matplotlib's `plot`. @@ -429,6 +431,7 @@ def from_estimator( name=name, ax=ax, pos_label=pos_label, + curve_kwargs=curve_kwargs, plot_chance_level=plot_chance_level, chance_level_kw=chance_level_kw, despine=despine, @@ -446,6 +449,7 @@ def from_predictions( pos_label=None, name=None, ax=None, + curve_kwargs=None, plot_chance_level=False, chance_level_kw=None, despine=False, @@ -488,6 +492,11 @@ def from_predictions( Axes object to plot on. If `None`, a new figure and axes is created. + curve_kwargs : dict, default=None + Keywords arguments to be passed to matplotlib's `plot` function. + + .. versionadded:: 1.7 + plot_chance_level : bool, default=False Whether to plot the chance level. @@ -504,11 +513,6 @@ def from_predictions( .. versionadded:: 1.6 - curve_kwargs : dict, default=None - Keywords arguments to be passed to matplotlib's `plot` function. - - .. versionadded:: 1.7 - **kwargs : dict Additional keywords arguments passed to matplotlib `plot` function. @@ -568,6 +572,7 @@ def from_predictions( return viz.plot( ax=ax, + curve_kwargs=curve_kwargs, plot_chance_level=plot_chance_level, chance_level_kw=chance_level_kw, despine=despine, @@ -741,8 +746,8 @@ def from_cv_results( ) return viz.plot( ax=ax, + curve_kwargs=curve_kwargs, plot_chance_level=plot_chance_level, chance_level_kw=chance_level_kwargs, despine=despine, - curve_kwargs=curve_kwargs, ) diff --git a/sklearn/metrics/_plot/tests/test_roc_curve_display.py b/sklearn/metrics/_plot/tests/test_roc_curve_display.py index f3b88b2125494..edb8189e617ad 100644 --- a/sklearn/metrics/_plot/tests/test_roc_curve_display.py +++ b/sklearn/metrics/_plot/tests/test_roc_curve_display.py @@ -102,7 +102,7 @@ def test_roc_curve_display_plotting( sample_weight=sample_weight, drop_intermediate=drop_intermediate, pos_label=pos_label, - alpha=0.8, + curve_kwargs={"alpha": 0.8}, ) else: display = RocCurveDisplay.from_predictions( @@ -111,7 +111,7 @@ def test_roc_curve_display_plotting( sample_weight=sample_weight, drop_intermediate=drop_intermediate, pos_label=pos_label, - alpha=0.8, + curve_kwargs={"alpha": 0.8}, ) fpr, tpr, _ = roc_curve( @@ -122,16 +122,10 @@ def test_roc_curve_display_plotting( pos_label=pos_label, ) - # Both processed (e.g., `roc_auc_`) and unprocessed (e.g., `roc_auc`) attributes - # should be the same for single curve - assert_allclose(display.roc_auc_[0], auc(fpr, tpr)) assert_allclose(display.roc_auc, auc(fpr, tpr)) - assert_allclose(display.fpr_[0], fpr) assert_allclose(display.fpr, fpr) - assert_allclose(display.tpr_[0], tpr) assert_allclose(display.tpr, tpr) - assert display.name_[0] == default_name assert display.name == default_name import matplotlib as mpl @@ -140,7 +134,6 @@ def test_roc_curve_display_plotting( assert isinstance(display.line_, mpl.lines.Line2D) assert display.line_.get_alpha() == 0.8 - expected_label = f"{default_name} (AUC = {display.roc_auc_[0]:.2f})" expected_label = f"{default_name} (AUC = {display.roc_auc:.2f})" assert display.line_.get_label() == expected_label @@ -323,11 +316,11 @@ def test_roc_curve_display_plotting_from_cv_results( drop_intermediate=drop_intermediate, pos_label=pos_label, ) - assert_allclose(display.roc_auc_[idx], auc(fpr, tpr)) - assert_allclose(display.fpr_[idx], fpr) - assert_allclose(display.tpr_[idx], tpr) + assert_allclose(display.roc_auc[idx], auc(fpr, tpr)) + assert_allclose(display.fpr[idx], fpr) + assert_allclose(display.tpr[idx], tpr) - assert display.name_ is None + assert display.name is None import matplotlib as mpl @@ -339,28 +332,37 @@ def test_roc_curve_display_plotting_from_cv_results( line.get_alpha() == 0.5 if isinstance(curve_kwargs, list): # Each individual curve labelled - assert line.get_label() == f"AUC = {display.roc_auc_[idx]:.2f}" + assert line.get_label() == f"AUC = {display.roc_auc[idx]:.2f}" else: # Single aggregate label assert line.get_label() == aggregate_expected_labels[idx] -# @pytest.mark.parametrize("fold_names", [None, ["one", "two", "three"]]) -# def test_roc_curve_from_cv_results_fold_names(pyplot, data_binary, fold_names): -# """Check fold names behaviour correct in `from_cv_results`.""" +# @pytest.mark.parametrize("curve_kwargs", [None, {"color": "red"}, [{"c": "red"}, {"c": "green"}, {"c": "yellow"}]]) +# @pytest.mark.parametrize("name", [None, "single", ["one", "two", "three"]]) +# def test_roc_curve_from_cv_results_legend_label(pyplot, data_binary, name, curve_kwargs): +# """Check legend label correct with all `curve_kwargs`, `name` combinations.""" # X, y = data_binary # cv_results = cross_validate( # LogisticRegression(), X, y, cv=3, return_estimator=True, return_indices=True # ) -# display = RocCurveDisplay.from_cv_results(cv_results, X, y, fold_names=fold_names) +# if not isinstance(curve_kwargs, list) and +# display = RocCurveDisplay.from_cv_results( +# cv_results, X, y, name=name, curve_kwargs=curve_kwargs +# ) # legend = display.ax_.get_legend() # legend_labels = [text.get_text() for text in legend.get_texts()] -# expected_names = ( -# ["Fold 0", "Fold 1", "Fold 2"] if fold_names is None else fold_names -# ) -# assert display.name_ == expected_names -# expected_labels = [name + " (AUC = 1.00)" for name in expected_names] -# assert legend_labels == expected_labels +# print(legend_labels) +# print(display.name) +# if isinstance(curve_kwargs, list): +# print(display.name_) + +# expected_names = ( +# ["Fold 0", "Fold 1", "Fold 2"] if name is None else name +# ) +# assert display.name_ == expected_names +# expected_labels = [name + " (AUC = 1.00)" for name in expected_names] +# assert legend_labels == expected_labels @pytest.mark.parametrize( @@ -381,8 +383,10 @@ def test_roc_curve_from_cv_results_curve_kwargs(pyplot, data_binary, curve_kwarg for idx, line in enumerate(display.line_): color = line.get_color() if curve_kwargs is None: + # Default color assert color == "blue" elif isinstance(curve_kwargs, Mapping): + # All curves "red" assert color == "red" else: assert color == curve_kwargs[idx]["c"] @@ -456,8 +460,7 @@ def test_roc_curve_chance_level_line( lr, X, y, - label=label, - alpha=0.8, + curve_kwargs={"alpha": 0.8, "label": label}, plot_chance_level=plot_chance_level, chance_level_kw=chance_level_kw, ) @@ -465,8 +468,7 @@ def test_roc_curve_chance_level_line( display = RocCurveDisplay.from_predictions( y, y_pred, - label=label, - alpha=0.8, + curve_kwargs={"alpha": 0.8, "label": label}, plot_chance_level=plot_chance_level, chance_level_kw=chance_level_kw, ) @@ -616,7 +618,7 @@ def _check_auc(display, constructor_name): roc_auc_limit_multi = [0.97007, 0.985915, 0.980952] if constructor_name == "from_cv_results": - for idx, roc_auc in enumerate(display.roc_auc_): + for idx, roc_auc in enumerate(display.roc_auc): assert roc_auc == pytest.approx(roc_auc_limit_multi[idx]) else: assert display.roc_auc == pytest.approx(roc_auc_limit) diff --git a/sklearn/utils/_plotting.py b/sklearn/utils/_plotting.py index c80ba61dc9c34..a15a9ca0c6161 100644 --- a/sklearn/utils/_plotting.py +++ b/sklearn/utils/_plotting.py @@ -139,7 +139,7 @@ def _validate_multi_curve_kwargs(cls, curve_kwargs, name, n_curves): if ( isinstance(name, list) and len(name) != 1 - and (isinstance(curve_kwargs, Mapping) or curve_kwargs is None) + and not isinstance(curve_kwargs, list) ): raise ValueError( "To avoid labeling individual curves that have the same appearance, " From f71e6c774dc2b1ee70a76112f04486531e3c037c Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Fri, 11 Apr 2025 17:54:03 +1000 Subject: [PATCH 47/63] fix naming --- .../_plot/tests/test_roc_curve_display.py | 69 ++++++++++++------- sklearn/utils/_plotting.py | 22 +++--- 2 files changed, 55 insertions(+), 36 deletions(-) diff --git a/sklearn/metrics/_plot/tests/test_roc_curve_display.py b/sklearn/metrics/_plot/tests/test_roc_curve_display.py index edb8189e617ad..acf00451b0f84 100644 --- a/sklearn/metrics/_plot/tests/test_roc_curve_display.py +++ b/sklearn/metrics/_plot/tests/test_roc_curve_display.py @@ -338,31 +338,50 @@ def test_roc_curve_display_plotting_from_cv_results( assert line.get_label() == aggregate_expected_labels[idx] -# @pytest.mark.parametrize("curve_kwargs", [None, {"color": "red"}, [{"c": "red"}, {"c": "green"}, {"c": "yellow"}]]) -# @pytest.mark.parametrize("name", [None, "single", ["one", "two", "three"]]) -# def test_roc_curve_from_cv_results_legend_label(pyplot, data_binary, name, curve_kwargs): -# """Check legend label correct with all `curve_kwargs`, `name` combinations.""" -# X, y = data_binary -# cv_results = cross_validate( -# LogisticRegression(), X, y, cv=3, return_estimator=True, return_indices=True -# ) -# if not isinstance(curve_kwargs, list) and -# display = RocCurveDisplay.from_cv_results( -# cv_results, X, y, name=name, curve_kwargs=curve_kwargs -# ) -# legend = display.ax_.get_legend() -# legend_labels = [text.get_text() for text in legend.get_texts()] -# print(legend_labels) -# print(display.name) -# if isinstance(curve_kwargs, list): -# print(display.name_) - -# expected_names = ( -# ["Fold 0", "Fold 1", "Fold 2"] if name is None else name -# ) -# assert display.name_ == expected_names -# expected_labels = [name + " (AUC = 1.00)" for name in expected_names] -# assert legend_labels == expected_labels +@pytest.mark.parametrize( + "curve_kwargs", + [None, {"color": "red"}, [{"c": "red"}, {"c": "green"}, {"c": "yellow"}]], +) +@pytest.mark.parametrize("name", [None, "single", ["one", "two", "three"]]) +def test_roc_curve_from_cv_results_legend_label( + pyplot, data_binary, name, curve_kwargs +): + """Check legend label correct with all `curve_kwargs`, `name` combinations.""" + X, y = data_binary + n_cv = 3 + cv_results = cross_validate( + LogisticRegression(), X, y, cv=n_cv, return_estimator=True, return_indices=True + ) + if not isinstance(curve_kwargs, list) and isinstance(name, list): + with pytest.raises(ValueError, match="To avoid labeling individual curves"): + display = RocCurveDisplay.from_cv_results( + cv_results, X, y, name=name, curve_kwargs=curve_kwargs + ) + else: + display = RocCurveDisplay.from_cv_results( + cv_results, X, y, name=name, curve_kwargs=curve_kwargs + ) + legend = display.ax_.get_legend() + legend_labels = [text.get_text() for text in legend.get_texts()] + if isinstance(curve_kwargs, list): + # Multiple labels in legend + assert len(legend_labels) == 3 + for idx, label in enumerate(legend_labels): + if name is None: + assert label == "AUC = 1.00" + elif isinstance(name, str): + assert label == "single (AUC = 1.00)" + else: + # `name` is a list of different strings + assert label == f"{name[idx]} (AUC = 1.00)" + else: + # Single label in legend + assert len(legend_labels) == 1 + if name is None: + assert legend_labels[0] == "AUC = 1.00 +/- 0.00" + else: + # name is single string + assert legend_labels[0] == "single (AUC = 1.00 +/- 0.00)" @pytest.mark.parametrize( diff --git a/sklearn/utils/_plotting.py b/sklearn/utils/_plotting.py index a15a9ca0c6161..ee11dad43c4d1 100644 --- a/sklearn/utils/_plotting.py +++ b/sklearn/utils/_plotting.py @@ -150,7 +150,7 @@ def _validate_multi_curve_kwargs(cls, curve_kwargs, name, n_curves): @classmethod def _get_legend_label(cls, curve_summary_value, curve_name, summary_value_name): - """Helper to get legend label using `name_` and `summary_value_`""" + """Helper to get legend label using `name` and `summary_value`""" if curve_summary_value is not None and curve_name is not None: label = f"{curve_name} ({summary_value_name} = {curve_summary_value:0.2f})" elif curve_summary_value is not None: @@ -219,9 +219,9 @@ def _validate_curve_kwargs( # Ensure `name` is of the correct length if isinstance(name, list) and len(name) == 1: - name_ = name * n_curves - name_ = [None] * n_curves if name is None else name - summary_value_ = [None] * n_curves if summary_value is None else summary_value + name = name * n_curves + name = [None] * n_curves if name is None else name + summary_value = [None] * n_curves if summary_value is None else summary_value # Ensure `curve_kwargs` is of correct length if isinstance(curve_kwargs, Mapping): @@ -235,22 +235,22 @@ def _validate_curve_kwargs( curve_kwargs = [{}] labels = [] - if isinstance(summary_value_, tuple): + if isinstance(summary_value, tuple): label_aggregate = cls._get_legend_label( - summary_value_[0], name_[0], summary_value_name + summary_value[0], name[0], summary_value_name ) # Add the "+/- std" to the end (in brackets if name provided) - if summary_value_[1] is not None: - if name_[0] is not None: + if summary_value[1] is not None: + if name[0] is not None: label_aggregate = ( - label_aggregate[:-1] + f" +/- {summary_value_[1]:0.2f})" + label_aggregate[:-1] + f" +/- {summary_value[1]:0.2f})" ) else: - label_aggregate = label_aggregate + f" +/- {summary_value_[1]:0.2f}" + label_aggregate = label_aggregate + f" +/- {summary_value[1]:0.2f}" # Add `label` for first curve only, set to `None` for remaining curves labels.extend([label_aggregate] + [None] * (n_curves - 1)) else: - for curve_summary_value, curve_name in zip(summary_value_, name_): + for curve_summary_value, curve_name in zip(summary_value, name): labels.append( cls._get_legend_label( curve_summary_value, curve_name, summary_value_name From b171f1a0a03ea2d4ab639676f0c6f5c913a3fa41 Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Fri, 11 Apr 2025 18:41:39 +1000 Subject: [PATCH 48/63] add kwarg dep test --- .../_plot/tests/test_roc_curve_display.py | 93 ++++++++++++------- 1 file changed, 57 insertions(+), 36 deletions(-) diff --git a/sklearn/metrics/_plot/tests/test_roc_curve_display.py b/sklearn/metrics/_plot/tests/test_roc_curve_display.py index acf00451b0f84..48e2a8aec3640 100644 --- a/sklearn/metrics/_plot/tests/test_roc_curve_display.py +++ b/sklearn/metrics/_plot/tests/test_roc_curve_display.py @@ -202,38 +202,37 @@ def test_roc_curve_from_cv_results_param_validation(pyplot, data_binary, data): ) -# @pytest.mark.parametrize( -# "curve_kwargs", -# [None, {"alpha": 0.2}, [{"alpha": 0.2}, {"alpha": 0.3}, {"alpha": 0.4}]], -# ) -# def test_roc_curve_display_from_cv_results_validate_curve_kwargs( -# pyplot, data_binary, curve_kwargs -# ): -# """Check `_validate_curve_kwargs` correctly validates line kwargs.""" -# X, y = data_binary -# n_cv = 3 -# cv_results = cross_validate( -# LogisticRegression(), X, y, cv=n_cv, return_estimator=True, -# return_indices=True -# ) -# display = RocCurveDisplay.from_cv_results( -# cv_results, -# X, -# y, -# curve_kwargs=curve_kwargs, -# ) -# if curve_kwargs is None: -# # Default `alpha` used -# assert all(line.get_alpha() == 0.5 for line in display.line_) -# elif isinstance(curve_kwargs, Mapping): -# # `alpha` from dict used for all curves -# assert all(line.get_alpha() == 0.2 for line in display.line_) -# else: -# # Different `alpha` used for each curve -# assert all( -# line.get_alpha() == curve_kwargs[i]["alpha"] -# for i, line in enumerate(display.line_) -# ) +@pytest.mark.parametrize( + "curve_kwargs", + [None, {"alpha": 0.2}, [{"alpha": 0.2}, {"alpha": 0.3}, {"alpha": 0.4}]], +) +def test_roc_curve_display_from_cv_results_curve_kwargs( + pyplot, data_binary, curve_kwargs +): + """Check `curve_kwargs` correctly passed.""" + X, y = data_binary + n_cv = 3 + cv_results = cross_validate( + LogisticRegression(), X, y, cv=n_cv, return_estimator=True, return_indices=True + ) + display = RocCurveDisplay.from_cv_results( + cv_results, + X, + y, + curve_kwargs=curve_kwargs, + ) + if curve_kwargs is None: + # Default `alpha` used + assert all(line.get_alpha() == 0.5 for line in display.line_) + elif isinstance(curve_kwargs, Mapping): + # `alpha` from dict used for all curves + assert all(line.get_alpha() == 0.2 for line in display.line_) + else: + # Different `alpha` used for each curve + assert all( + line.get_alpha() == curve_kwargs[i]["alpha"] + for i, line in enumerate(display.line_) + ) # TODO : Remove in 1.9 @@ -245,6 +244,31 @@ def test_roc_curve_display_estimator_name_deprecation(pyplot): RocCurveDisplay(fpr=fpr, tpr=tpr, estimator_name="test") +# TODO : Remove in 1.9 +@pytest.mark.parametrize("constructor_name", ["from_estimator", "from_predictions"]) +def test_roc_curve_display_kwargs_deprecation(pyplot, data_binary, constructor_name): + """Check **kwargs deprecated correctly in favour of `curve_kwargs`.""" + X, y = data_binary + lr = LogisticRegression() + lr.fit(X, y) + # Error when both `curve_kwargs` and `**kwargs` provided + with pytest.raises(ValueError, match="Cannot provide both `curve_kwargs`"): + if constructor_name == "from_estimator": + RocCurveDisplay.from_estimator( + lr, X, y, curve_kwargs={"alpha": 1}, label="test" + ) + else: + RocCurveDisplay.from_predictions( + y, y, curve_kwargs={"alpha": 1}, label="test" + ) + # Warning when `**kwargs`` provided + with pytest.warns(FutureWarning, match=r"`\*\*kwargs` is deprecated and will be"): + if constructor_name == "from_estimator": + RocCurveDisplay.from_estimator(lr, X, y, label="test") + else: + RocCurveDisplay.from_predictions(y, y, label="test") + + @pytest.mark.parametrize( "curve_kwargs", [ @@ -453,10 +477,7 @@ def _check_chance_level(plot_chance_level, chance_level_kw, display): {"lw": 1, "color": "blue", "ls": "-", "label": None}, ], ) -@pytest.mark.parametrize( - "constructor_name", - ["from_estimator", "from_predictions"], -) +@pytest.mark.parametrize("constructor_name", ["from_estimator", "from_predictions"]) def test_roc_curve_chance_level_line( pyplot, data_binary, From 436985b4899df2f6c8213c47fa08d68ed3bc298a Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Fri, 11 Apr 2025 19:38:46 +1000 Subject: [PATCH 49/63] guillaume review --- sklearn/metrics/_plot/roc_curve.py | 46 +++--- .../_plot/tests/test_roc_curve_display.py | 45 ++++-- sklearn/utils/_plotting.py | 132 ++++++++++-------- 3 files changed, 133 insertions(+), 90 deletions(-) diff --git a/sklearn/metrics/_plot/roc_curve.py b/sklearn/metrics/_plot/roc_curve.py index f4a7dbf4b5e23..3f6480edc6196 100644 --- a/sklearn/metrics/_plot/roc_curve.py +++ b/sklearn/metrics/_plot/roc_curve.py @@ -233,18 +233,20 @@ def plot( """ fpr, tpr, roc_auc, name = self._validate_plot_params(ax=ax, name=name) n_curves = len(fpr) - summary_value, summary_value_name = roc_auc, "AUC" if not isinstance(curve_kwargs, list) and n_curves > 1: if roc_auc: - summary_value = (np.mean(roc_auc), np.std(roc_auc)) + legend_metric = {"mean": np.mean(roc_auc), "std": np.std(roc_auc)} else: - summary_value = (None, None) + legend_metric = {"mean": None, "std": None} + else: + roc_auc = roc_auc if roc_auc is not None else [None] * n_curves + legend_metric = {"metric": roc_auc} curve_kwargs = self._validate_curve_kwargs( n_curves, name, - summary_value, - summary_value_name, + legend_metric, + "AUC", curve_kwargs=curve_kwargs, **kwargs, ) @@ -347,8 +349,8 @@ def from_estimator( :term:`decision_function` is tried next. pos_label : int, float, bool or str, default=None - The class considered as the positive class when computing the roc auc - metrics. By default, `estimators.classes_[1]` is considered + The class considered as the positive class when computing the ROC AUC. + By default, `estimators.classes_[1]` is considered as the positive class. name : str, default=None @@ -480,9 +482,9 @@ def from_predictions( ROC curves. pos_label : int, float, bool or str, default=None - The label of the positive class. When `pos_label=None`, if `y_true` - is in {-1, 1} or {0, 1}, `pos_label` is set to 1, otherwise an - error will be raised. + The label of the positive class when computing the ROC AUC. + When `pos_label=None`, if `y_true` is in {-1, 1} or {0, 1}, `pos_label` + is set to 1, otherwise an error will be raised. name : str, default=None Name of ROC curve for legend labeling. If `None`, name will be set to @@ -629,8 +631,8 @@ def from_cv_results( :term:`predict_proba` is tried first and if it does not exist :term:`decision_function` is tried next. - pos_label : str or int, default=None - The class considered as the positive class when computing the roc auc + pos_label : int, float, bool or str, default=None + The class considered as the positive class when computing the ROC AUC metrics. By default, `estimators.classes_[1]` is considered as the positive class. @@ -706,19 +708,17 @@ def from_cv_results( curve_kwargs=curve_kwargs, ) - fpr_all = [] - tpr_all = [] - auc_all = [] + fpr_folds, tpr_folds, auc_folds = [], [], [] for estimator, test_indices in zip( cv_results["estimator"], cv_results["indices"]["test"] ): y_true = _safe_indexing(y, test_indices) - y_pred = _get_response_values_binary( + y_pred, _ = _get_response_values_binary( estimator, _safe_indexing(X, test_indices), response_method=response_method, pos_label=pos_label_, - )[0] + ) sample_weight_fold = ( None if sample_weight is None @@ -733,15 +733,15 @@ def from_cv_results( ) roc_auc = auc(fpr, tpr) - fpr_all.append(fpr) - tpr_all.append(tpr) - auc_all.append(roc_auc) + fpr_folds.append(fpr) + tpr_folds.append(tpr) + auc_folds.append(roc_auc) viz = cls( - fpr=fpr_all, - tpr=tpr_all, + fpr=fpr_folds, + tpr=tpr_folds, name=name, - roc_auc=auc_all, + roc_auc=auc_folds, pos_label=pos_label_, ) return viz.plot( diff --git a/sklearn/metrics/_plot/tests/test_roc_curve_display.py b/sklearn/metrics/_plot/tests/test_roc_curve_display.py index 48e2a8aec3640..32fa02254fd10 100644 --- a/sklearn/metrics/_plot/tests/test_roc_curve_display.py +++ b/sklearn/metrics/_plot/tests/test_roc_curve_display.py @@ -192,7 +192,7 @@ def test_roc_curve_from_cv_results_param_validation(pyplot, data_binary, data): ) # `curve_kwargs` incorrect length - with pytest.raises(ValueError, match="`curve_kwargs` must be None, a list"): + with pytest.raises(ValueError, match="`curve_kwargs` must be None, a dictionary"): RocCurveDisplay.from_cv_results(cv_results, X, y, curve_kwargs=[{"alpha": 1}]) # `curve_kwargs` both alias provided @@ -245,28 +245,40 @@ def test_roc_curve_display_estimator_name_deprecation(pyplot): # TODO : Remove in 1.9 -@pytest.mark.parametrize("constructor_name", ["from_estimator", "from_predictions"]) +@pytest.mark.parametrize( + "constructor_name", ["from_estimator", "from_predictions", "plot"] +) def test_roc_curve_display_kwargs_deprecation(pyplot, data_binary, constructor_name): """Check **kwargs deprecated correctly in favour of `curve_kwargs`.""" X, y = data_binary lr = LogisticRegression() lr.fit(X, y) + fpr = np.array([0, 0.5, 1]) + tpr = np.array([0, 0.5, 1]) + # Error when both `curve_kwargs` and `**kwargs` provided with pytest.raises(ValueError, match="Cannot provide both `curve_kwargs`"): if constructor_name == "from_estimator": RocCurveDisplay.from_estimator( lr, X, y, curve_kwargs={"alpha": 1}, label="test" ) - else: + elif constructor_name == "from_predictions": RocCurveDisplay.from_predictions( y, y, curve_kwargs={"alpha": 1}, label="test" ) + else: + RocCurveDisplay(fpr=fpr, tpr=tpr).plot( + curve_kwargs={"alpha": 1}, label="test" + ) + # Warning when `**kwargs`` provided with pytest.warns(FutureWarning, match=r"`\*\*kwargs` is deprecated and will be"): if constructor_name == "from_estimator": RocCurveDisplay.from_estimator(lr, X, y, label="test") - else: + elif constructor_name == "from_predictions": RocCurveDisplay.from_predictions(y, y, label="test") + else: + RocCurveDisplay(fpr=fpr, tpr=tpr).plot(label="test") @pytest.mark.parametrize( @@ -367,8 +379,9 @@ def test_roc_curve_display_plotting_from_cv_results( [None, {"color": "red"}, [{"c": "red"}, {"c": "green"}, {"c": "yellow"}]], ) @pytest.mark.parametrize("name", [None, "single", ["one", "two", "three"]]) +@pytest.mark.parametrize("constructor_name", ["from_cv_results", "plot"]) def test_roc_curve_from_cv_results_legend_label( - pyplot, data_binary, name, curve_kwargs + pyplot, data_binary, constructor_name, name, curve_kwargs ): """Check legend label correct with all `curve_kwargs`, `name` combinations.""" X, y = data_binary @@ -376,15 +389,29 @@ def test_roc_curve_from_cv_results_legend_label( cv_results = cross_validate( LogisticRegression(), X, y, cv=n_cv, return_estimator=True, return_indices=True ) + fpr = [np.array([0, 0.5, 1]), np.array([0, 0.5, 1]), np.array([0, 0.5, 1])] + tpr = [np.array([0, 0.5, 1]), np.array([0, 0.5, 1]), np.array([0, 0.5, 1])] + roc_auc = [1.0, 1.0, 1.0] if not isinstance(curve_kwargs, list) and isinstance(name, list): with pytest.raises(ValueError, match="To avoid labeling individual curves"): + if constructor_name == "from_cv_results": + RocCurveDisplay.from_cv_results( + cv_results, X, y, name=name, curve_kwargs=curve_kwargs + ) + else: + RocCurveDisplay(fpr=fpr, tpr=tpr, roc_auc=roc_auc).plot( + name=name, curve_kwargs=curve_kwargs + ) + + else: + if constructor_name == "from_cv_results": display = RocCurveDisplay.from_cv_results( cv_results, X, y, name=name, curve_kwargs=curve_kwargs ) - else: - display = RocCurveDisplay.from_cv_results( - cv_results, X, y, name=name, curve_kwargs=curve_kwargs - ) + else: + display = RocCurveDisplay(fpr=fpr, tpr=tpr, roc_auc=roc_auc).plot( + name=name, curve_kwargs=curve_kwargs + ) legend = display.ax_.get_legend() legend_labels = [text.get_text() for text in legend.get_texts()] if isinstance(curve_kwargs, list): diff --git a/sklearn/utils/_plotting.py b/sklearn/utils/_plotting.py index ee11dad43c4d1..9d0027791055c 100644 --- a/sklearn/utils/_plotting.py +++ b/sklearn/utils/_plotting.py @@ -81,7 +81,7 @@ def _validate_from_cv_results_params( name, curve_kwargs, ): - check_matplotlib_support(f"{cls.__name__}.from_predictions") + check_matplotlib_support(f"{cls.__name__}.from_cv_results") required_keys = {"estimator", "indices"} if not all(key in cv_results for key in required_keys): @@ -116,58 +116,57 @@ def _validate_from_cv_results_params( # Adapt error message raise ValueError(str(e).replace("y_true", "y")) - n_curves = len(cv_results["estimator"]) - # NB: Both these also checked in `plot`, but thought it best to fail earlier. - cls._validate_multi_curve_kwargs(cls, curve_kwargs, name, n_curves) - if isinstance(name, list) and len(name) not in (1, n_curves): - raise ValueError( - f"`name` must be None, a list of length {n_curves} or a single " - f"string. Got list of length: {len(name)}." - ) + # n_curves = len(cv_results["estimator"]) + # # NB: Both these also checked in `plot`, but thought it best to fail earlier. + # cls._validate_multi_curve_kwargs(cls, curve_kwargs, name, n_curves) + # if isinstance(name, list) and len(name) not in (1, n_curves): + # raise ValueError( + # f"`name` must be None, a list of length {n_curves} or a single " + # f"string. Got list of length: {len(name)}." + # ) return pos_label - def _validate_multi_curve_kwargs(cls, curve_kwargs, name, n_curves): - """Check `curve_kwargs`, including combination with `name`, is valid.""" - if isinstance(curve_kwargs, list) and len(curve_kwargs) != n_curves: - raise ValueError( - f"`curve_kwargs` must be None, a list of length {n_curves} or a " - f"dictionary. Got: {curve_kwargs}." - ) - - # Ensure valid `name` and `curve_kwargs` combination. - if ( - isinstance(name, list) - and len(name) != 1 - and not isinstance(curve_kwargs, list) - ): - raise ValueError( - "To avoid labeling individual curves that have the same appearance, " - f"`curve_kwargs` should be a list of {n_curves} dictionaries. " - "Alternatively, set `name` to `None` or a single string to label " - "a single legend entry with mean ROC AUC score of all curves." - ) - - @classmethod - def _get_legend_label(cls, curve_summary_value, curve_name, summary_value_name): - """Helper to get legend label using `name` and `summary_value`""" - if curve_summary_value is not None and curve_name is not None: - label = f"{curve_name} ({summary_value_name} = {curve_summary_value:0.2f})" - elif curve_summary_value is not None: - label = f"{summary_value_name} = {curve_summary_value:0.2f}" + # def _validate_multi_curve_kwargs(cls, curve_kwargs, name, n_curves): + # """Check `curve_kwargs`, including combination with `name`, is valid.""" + # if isinstance(curve_kwargs, list) and len(curve_kwargs) != n_curves: + # raise ValueError( + # f"`curve_kwargs` must be None, a dictionary or a list of length " + # f"{n_curves}. Got: {curve_kwargs}." + # ) + + # # Ensure valid `name` and `curve_kwargs` combination. + # if ( + # isinstance(name, list) + # and len(name) != 1 + # and not isinstance(curve_kwargs, list) + # ): + # raise ValueError( + # "To avoid labeling individual curves that have the same appearance, " + # f"`curve_kwargs` should be a list of {n_curves} dictionaries. " + # "Alternatively, set `name` to `None` or a single string to label " + # "a single legend entry with mean ROC AUC score of all curves." + # ) + + @staticmethod + def _get_legend_label(curve_legend_metric, curve_name, legend_metric_name): + """Helper to get legend label using `name` and `legend_metric`""" + if curve_legend_metric is not None and curve_name is not None: + label = f"{curve_name} ({legend_metric_name} = {curve_legend_metric:0.2f})" + elif curve_legend_metric is not None: + label = f"{legend_metric_name} = {curve_legend_metric:0.2f}" elif curve_name is not None: label = curve_name else: label = None return label - @classmethod + @staticmethod def _validate_curve_kwargs( - cls, n_curves, name, - summary_value, - summary_value_name, + legend_metric, + legend_metric_name, curve_kwargs, **kwargs, ): @@ -181,13 +180,12 @@ def _validate_curve_kwargs( name : list of str or None Name for labeling legend entries. - summary_value : list of float or tuple of float or None - A list of `n_curves` summary values for each curve (e.g., ROC AUC, - average precision) or a tuple of mean and standard deviation values for - all curves or None. + legend_metric : dict or None + Dictionary with "mean" and "std" keys, or "metric" key of metric + values for each curve. If None, "label" will not contain metric values. - summary_value_name : str or None - Name of the summary value provided in `summary_values`. + legend_metric_name : str or None + Name of the summary value provided in `legend_metrics`. curve_kwargs : dict or list of dict Dictionary with keywords passed to the matplotlib's `plot` function @@ -215,13 +213,29 @@ def _validate_curve_kwargs( ) curve_kwargs = kwargs - cls._validate_multi_curve_kwargs(cls, curve_kwargs, name, n_curves) + if isinstance(curve_kwargs, list) and len(curve_kwargs) != n_curves: + raise ValueError( + f"`curve_kwargs` must be None, a dictionary or a list of length " + f"{n_curves}. Got: {curve_kwargs}." + ) + + # Ensure valid `name` and `curve_kwargs` combination. + if ( + isinstance(name, list) + and len(name) != 1 + and not isinstance(curve_kwargs, list) + ): + raise ValueError( + "To avoid labeling individual curves that have the same appearance, " + f"`curve_kwargs` should be a list of {n_curves} dictionaries. " + "Alternatively, set `name` to `None` or a single string to label " + "a single legend entry with mean ROC AUC score of all curves." + ) # Ensure `name` is of the correct length if isinstance(name, list) and len(name) == 1: name = name * n_curves name = [None] * n_curves if name is None else name - summary_value = [None] * n_curves if summary_value is None else summary_value # Ensure `curve_kwargs` is of correct length if isinstance(curve_kwargs, Mapping): @@ -235,25 +249,27 @@ def _validate_curve_kwargs( curve_kwargs = [{}] labels = [] - if isinstance(summary_value, tuple): - label_aggregate = cls._get_legend_label( - summary_value[0], name[0], summary_value_name + if "mean" in legend_metric: + label_aggregate = _BinaryClassifierCurveDisplayMixin._get_legend_label( + legend_metric["mean"], name[0], legend_metric_name ) # Add the "+/- std" to the end (in brackets if name provided) - if summary_value[1] is not None: + if legend_metric["mean"] is not None: if name[0] is not None: label_aggregate = ( - label_aggregate[:-1] + f" +/- {summary_value[1]:0.2f})" + label_aggregate[:-1] + f" +/- {legend_metric['std']:0.2f})" ) else: - label_aggregate = label_aggregate + f" +/- {summary_value[1]:0.2f}" + label_aggregate = ( + label_aggregate + f" +/- {legend_metric['std']:0.2f}" + ) # Add `label` for first curve only, set to `None` for remaining curves labels.extend([label_aggregate] + [None] * (n_curves - 1)) else: - for curve_summary_value, curve_name in zip(summary_value, name): + for curve_legend_metric, curve_name in zip(legend_metric["metric"], name): labels.append( - cls._get_legend_label( - curve_summary_value, curve_name, summary_value_name + _BinaryClassifierCurveDisplayMixin._get_legend_label( + curve_legend_metric, curve_name, legend_metric_name ) ) From 0a25f173b11b9977163994b09badbb464c9e6de9 Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Fri, 11 Apr 2025 20:25:08 +1000 Subject: [PATCH 50/63] fix test --- sklearn/metrics/_plot/tests/test_common_curve_display.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/metrics/_plot/tests/test_common_curve_display.py b/sklearn/metrics/_plot/tests/test_common_curve_display.py index 7c8516d5619e0..2dde6cc76be97 100644 --- a/sklearn/metrics/_plot/tests/test_common_curve_display.py +++ b/sklearn/metrics/_plot/tests/test_common_curve_display.py @@ -214,7 +214,7 @@ def test_display_curve_not_fitted_errors(pyplot, data_binary, clf, Display): model.fit(X, y) disp = Display.from_estimator(model, X, y) assert model.__class__.__name__ in disp.line_.get_label() - assert disp.name_[0] == model.__class__.__name__ + assert disp.name == model.__class__.__name__ @pytest.mark.parametrize( From 345ed9c16a3a957fdea8a81d28eb1790091a3bfb Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Fri, 11 Apr 2025 21:23:21 +1000 Subject: [PATCH 51/63] format --- sklearn/metrics/_plot/roc_curve.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sklearn/metrics/_plot/roc_curve.py b/sklearn/metrics/_plot/roc_curve.py index 3f6480edc6196..90670766834c0 100644 --- a/sklearn/metrics/_plot/roc_curve.py +++ b/sklearn/metrics/_plot/roc_curve.py @@ -78,7 +78,8 @@ class RocCurveDisplay(_BinaryClassifierCurveDisplayMixin): Name of estimator. If None, the estimator name is not shown. .. deprecated:: 1.7 - `estimator_name` is deprecated and will be removed in 1.9. Use `name` instead. + `estimator_name` is deprecated and will be removed in 1.9. Use `name` + instead. Attributes ---------- From f230ba219e76d9949e1a341110561855c7bacec2 Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Fri, 11 Apr 2025 22:04:48 +1000 Subject: [PATCH 52/63] format --- sklearn/metrics/_plot/roc_curve.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/metrics/_plot/roc_curve.py b/sklearn/metrics/_plot/roc_curve.py index 90670766834c0..b73d431762b5f 100644 --- a/sklearn/metrics/_plot/roc_curve.py +++ b/sklearn/metrics/_plot/roc_curve.py @@ -224,7 +224,7 @@ def plot( Keyword arguments to be passed to matplotlib's `plot`. .. deprecated:: 1.7 - **kwargs is deprecated and will be removed in 1.9. Pass matplotlib + kwargs is deprecated and will be removed in 1.9. Pass matplotlib arguments to `curve_kwargs` as a dictionary instead. Returns From a378188b20c500bc4070bbfd82fc39800831a99f Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Fri, 11 Apr 2025 22:59:26 +1000 Subject: [PATCH 53/63] format --- sklearn/metrics/_plot/roc_curve.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/metrics/_plot/roc_curve.py b/sklearn/metrics/_plot/roc_curve.py index b73d431762b5f..7af72dad41e9a 100644 --- a/sklearn/metrics/_plot/roc_curve.py +++ b/sklearn/metrics/_plot/roc_curve.py @@ -386,7 +386,7 @@ def from_estimator( Keyword arguments to be passed to matplotlib's `plot`. .. deprecated:: 1.7 - **kwargs is deprecated and will be removed in 1.9. Pass matplotlib + kwargs is deprecated and will be removed in 1.9. Pass matplotlib arguments to `curve_kwargs` as a dictionary instead. Returns From 1769887ba75d01a21c4d4532c3bbeac3b2d6a136 Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Sat, 12 Apr 2025 10:26:48 +1000 Subject: [PATCH 54/63] format --- sklearn/metrics/_plot/roc_curve.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sklearn/metrics/_plot/roc_curve.py b/sklearn/metrics/_plot/roc_curve.py index 7af72dad41e9a..054fe6bc29168 100644 --- a/sklearn/metrics/_plot/roc_curve.py +++ b/sklearn/metrics/_plot/roc_curve.py @@ -220,7 +220,7 @@ def plot( .. versionadded:: 1.6 - **kwargs : dict + kwargs : dict Keyword arguments to be passed to matplotlib's `plot`. .. deprecated:: 1.7 @@ -382,7 +382,7 @@ def from_estimator( .. versionadded:: 1.6 - **kwargs : dict + kwargs : dict Keyword arguments to be passed to matplotlib's `plot`. .. deprecated:: 1.7 @@ -516,11 +516,11 @@ def from_predictions( .. versionadded:: 1.6 - **kwargs : dict + kwargs : dict Additional keywords arguments passed to matplotlib `plot` function. .. deprecated:: 1.7 - **kwargs is deprecated and will be removed in 1.9. Pass matplotlib + kwargs is deprecated and will be removed in 1.9. Pass matplotlib arguments to `curve_kwargs` as a dictionary instead. Returns From be576908067b8403b84112f2ed5d5526ecf4b964 Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Sat, 12 Apr 2025 21:04:19 +1000 Subject: [PATCH 55/63] revert **kwargs --- sklearn/metrics/_plot/roc_curve.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sklearn/metrics/_plot/roc_curve.py b/sklearn/metrics/_plot/roc_curve.py index 054fe6bc29168..97f0aea6d0ed4 100644 --- a/sklearn/metrics/_plot/roc_curve.py +++ b/sklearn/metrics/_plot/roc_curve.py @@ -220,7 +220,7 @@ def plot( .. versionadded:: 1.6 - kwargs : dict + **kwargs : dict Keyword arguments to be passed to matplotlib's `plot`. .. deprecated:: 1.7 @@ -382,7 +382,7 @@ def from_estimator( .. versionadded:: 1.6 - kwargs : dict + **kwargs : dict Keyword arguments to be passed to matplotlib's `plot`. .. deprecated:: 1.7 @@ -516,7 +516,7 @@ def from_predictions( .. versionadded:: 1.6 - kwargs : dict + **kwargs : dict Additional keywords arguments passed to matplotlib `plot` function. .. deprecated:: 1.7 From c57debbdb49b8783e8e48418235e6cc9e6c6a260 Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Mon, 14 Apr 2025 11:57:29 +1000 Subject: [PATCH 56/63] add tests codecov --- .../_plot/tests/test_roc_curve_display.py | 175 +++++++++++++++--- 1 file changed, 154 insertions(+), 21 deletions(-) diff --git a/sklearn/metrics/_plot/tests/test_roc_curve_display.py b/sklearn/metrics/_plot/tests/test_roc_curve_display.py index 8de105e6cde99..56f11c0b44dda 100644 --- a/sklearn/metrics/_plot/tests/test_roc_curve_display.py +++ b/sklearn/metrics/_plot/tests/test_roc_curve_display.py @@ -138,6 +138,95 @@ def test_roc_curve_display_plotting( assert display.line_.get_label() == expected_label +@pytest.mark.parametrize( + "params, err_msg", + [ + ( + { + "fpr": [np.array([0, 0.5, 1]), np.array([0, 0.5, 1])], + "tpr": [np.array([0, 0.5, 1])], + "roc_auc": None, + "name": None, + }, + "self.fpr and self.tpr from `RocCurveDisplay` initialization,", + ), + ( + { + "fpr": [np.array([0, 0.5, 1])], + "tpr": [np.array([0, 0.5, 1]), np.array([0, 0.5, 1])], + "roc_auc": [0.8, 0.9], + "name": None, + }, + "self.fpr, self.tpr and self.roc_auc from `RocCurveDisplay`", + ), + ( + { + "fpr": [np.array([0, 0.5, 1]), np.array([0, 0.5, 1])], + "tpr": [np.array([0, 0.5, 1]), np.array([0, 0.5, 1])], + "roc_auc": [0.8], + "name": None, + }, + "Got: self.fpr: 2, self.tpr: 2, self.roc_auc: 1", + ), + ( + { + "fpr": [np.array([0, 0.5, 1]), np.array([0, 0.5, 1])], + "tpr": [np.array([0, 0.5, 1]), np.array([0, 0.5, 1])], + "roc_auc": [0.8, 0.9], + "name": ["curve1", "curve2", "curve3"], + }, + r"self.fpr, self.tpr, self.roc_auc and 'name' \(or self.name\)", + ), + ( + { + "fpr": [np.array([0, 0.5, 1]), np.array([0, 0.5, 1])], + "tpr": [np.array([0, 0.5, 1]), np.array([0, 0.5, 1])], + "roc_auc": [0.8, 0.9], + # List of length 1 is always allowed + "name": ["curve1"], + }, + None, + ), + ], +) +def test_roc_curve_plot_parameter_length_validation(pyplot, params, err_msg): + """Check `plot` parameter length validation performed correctly.""" + display = RocCurveDisplay(**params) + if err_msg: + with pytest.raises(ValueError, match=err_msg): + display.plot() + else: + # No error should be raised + display.plot() + + +def test_validate_plot_params(pyplot): + """Check `_validate_plot_params` returns the correct variables.""" + fpr = np.array([0, 0.5, 1]) + tpr = [np.array([0, 0.5, 1])] + roc_auc = None + name = "test_curve" + + # Initialize display with test inputs + display = RocCurveDisplay( + fpr=fpr, + tpr=tpr, + roc_auc=roc_auc, + name=name, + pos_label=None, + ) + fpr_out, tpr_out, roc_auc_out, name_out = display._validate_plot_params( + ax=None, name=None + ) + + assert isinstance(fpr_out, list) + assert isinstance(tpr_out, list) + assert len(fpr_out) == 1 + assert len(tpr_out) == 1 + assert roc_auc_out is None + assert name_out == ["test_curve"] + + def test_roc_curve_from_cv_results_param_validation(pyplot, data_binary, data): """Check parameter validation is correct.""" X, y = data_binary @@ -374,14 +463,70 @@ def test_roc_curve_display_plotting_from_cv_results( assert line.get_label() == aggregate_expected_labels[idx] +@pytest.mark.parametrize("roc_auc", [[1.0, 1.0, 1.0], None]) +@pytest.mark.parametrize( + "curve_kwargs", + [None, {"color": "red"}, [{"c": "red"}, {"c": "green"}, {"c": "yellow"}]], +) +@pytest.mark.parametrize("name", [None, "single", ["one", "two", "three"]]) +def test_roc_curve_plot_legend_label(pyplot, data_binary, name, curve_kwargs, roc_auc): + """Check legend label correct with all `curve_kwargs`, `name` combinations.""" + fpr = [np.array([0, 0.5, 1]), np.array([0, 0.5, 1]), np.array([0, 0.5, 1])] + tpr = [np.array([0, 0.5, 1]), np.array([0, 0.5, 1]), np.array([0, 0.5, 1])] + if not isinstance(curve_kwargs, list) and isinstance(name, list): + with pytest.raises(ValueError, match="To avoid labeling individual curves"): + RocCurveDisplay(fpr=fpr, tpr=tpr, roc_auc=roc_auc).plot( + name=name, curve_kwargs=curve_kwargs + ) + + else: + display = RocCurveDisplay(fpr=fpr, tpr=tpr, roc_auc=roc_auc).plot( + name=name, curve_kwargs=curve_kwargs + ) + legend = display.ax_.get_legend() + if legend is None: + # No legend is created, exit test early + assert name is None + assert roc_auc is None + return + else: + legend_labels = [text.get_text() for text in legend.get_texts()] + + if isinstance(curve_kwargs, list): + # Multiple labels in legend + assert len(legend_labels) == 3 + for idx, label in enumerate(legend_labels): + if name is None: + expected_label = "AUC = 1.00" if roc_auc else None + assert label == expected_label + elif isinstance(name, str): + expected_label = "single (AUC = 1.00)" if roc_auc else "single" + assert label == expected_label + else: + # `name` is a list of different strings + expected_label = ( + f"{name[idx]} (AUC = 1.00)" if roc_auc else f"{name[idx]}" + ) + assert label == expected_label + else: + # Single label in legend + assert len(legend_labels) == 1 + if name is None: + expected_label = "AUC = 1.00 +/- 0.00" if roc_auc else None + assert legend_labels[0] == expected_label + else: + # name is single string + expected_label = "single (AUC = 1.00 +/- 0.00)" if roc_auc else "single" + assert legend_labels[0] == expected_label + + @pytest.mark.parametrize( "curve_kwargs", [None, {"color": "red"}, [{"c": "red"}, {"c": "green"}, {"c": "yellow"}]], ) @pytest.mark.parametrize("name", [None, "single", ["one", "two", "three"]]) -@pytest.mark.parametrize("constructor_name", ["from_cv_results", "plot"]) def test_roc_curve_from_cv_results_legend_label( - pyplot, data_binary, constructor_name, name, curve_kwargs + pyplot, data_binary, name, curve_kwargs ): """Check legend label correct with all `curve_kwargs`, `name` combinations.""" X, y = data_binary @@ -389,29 +534,17 @@ def test_roc_curve_from_cv_results_legend_label( cv_results = cross_validate( LogisticRegression(), X, y, cv=n_cv, return_estimator=True, return_indices=True ) - fpr = [np.array([0, 0.5, 1]), np.array([0, 0.5, 1]), np.array([0, 0.5, 1])] - tpr = [np.array([0, 0.5, 1]), np.array([0, 0.5, 1]), np.array([0, 0.5, 1])] - roc_auc = [1.0, 1.0, 1.0] + if not isinstance(curve_kwargs, list) and isinstance(name, list): with pytest.raises(ValueError, match="To avoid labeling individual curves"): - if constructor_name == "from_cv_results": - RocCurveDisplay.from_cv_results( - cv_results, X, y, name=name, curve_kwargs=curve_kwargs - ) - else: - RocCurveDisplay(fpr=fpr, tpr=tpr, roc_auc=roc_auc).plot( - name=name, curve_kwargs=curve_kwargs - ) - - else: - if constructor_name == "from_cv_results": - display = RocCurveDisplay.from_cv_results( + RocCurveDisplay.from_cv_results( cv_results, X, y, name=name, curve_kwargs=curve_kwargs ) - else: - display = RocCurveDisplay(fpr=fpr, tpr=tpr, roc_auc=roc_auc).plot( - name=name, curve_kwargs=curve_kwargs - ) + else: + display = RocCurveDisplay.from_cv_results( + cv_results, X, y, name=name, curve_kwargs=curve_kwargs + ) + legend = display.ax_.get_legend() legend_labels = [text.get_text() for text in legend.get_texts()] if isinstance(curve_kwargs, list): From d8f11a567edde1589c79d7f7a083a71d06cec8d5 Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Mon, 14 Apr 2025 12:28:40 +1000 Subject: [PATCH 57/63] add _BinaryClassifierCurveDisplayMixin tests --- sklearn/utils/_plotting.py | 4 +- sklearn/utils/tests/test_plotting.py | 98 ++++++++++++++++++++++++++++ 2 files changed, 100 insertions(+), 2 deletions(-) diff --git a/sklearn/utils/_plotting.py b/sklearn/utils/_plotting.py index 9d0027791055c..ee54684635202 100644 --- a/sklearn/utils/_plotting.py +++ b/sklearn/utils/_plotting.py @@ -78,8 +78,8 @@ def _validate_from_cv_results_params( *, sample_weight, pos_label, - name, - curve_kwargs, + # name, + # curve_kwargs, ): check_matplotlib_support(f"{cls.__name__}.from_cv_results") diff --git a/sklearn/utils/tests/test_plotting.py b/sklearn/utils/tests/test_plotting.py index 1f0c675577bca..caaea133a1ee5 100644 --- a/sklearn/utils/tests/test_plotting.py +++ b/sklearn/utils/tests/test_plotting.py @@ -2,6 +2,7 @@ import pytest from sklearn.utils._plotting import ( + _BinaryClassifierCurveDisplayMixin, _despine, _interval_max_min_ratio, _validate_score_name, @@ -9,6 +10,103 @@ ) +@pytest.mark.parametrize( + "params, err_msg", + [ + ( + { + # Missing "indices" key + "cv_results": {"estimator": "dummy"}, + "X": np.array([[1, 2], [3, 4]]), + "y": np.array([0, 1]), + "sample_weight": None, + "pos_label": None, + }, + "`cv_results` does not contain one of the following", + ), + ( + { + "cv_results": { + "estimator": "dummy", + "indices": {"test": [[1, 2], [1, 2]], "train": [[3, 4], [3, 4]]}, + }, + # `X` wrong length + "X": np.array([[1, 2]]), + "y": np.array([0, 1]), + "sample_weight": None, + "pos_label": None, + }, + "`X` does not contain the correct number of", + ), + ( + { + "cv_results": { + "estimator": "dummy", + "indices": {"test": [[1, 2], [1, 2]], "train": [[3, 4], [3, 4]]}, + }, + "X": np.array([1, 2, 3, 4]), + # `y` not binary + "y": np.array([0, 2, 1, 3]), + "sample_weight": None, + "pos_label": None, + }, + "The target `y` is not binary", + ), + ( + { + "cv_results": { + "estimator": "dummy", + "indices": {"test": [[1, 2], [1, 2]], "train": [[3, 4], [3, 4]]}, + }, + "X": np.array([1, 2, 3, 4]), + "y": np.array([0, 1, 0, 1]), + # `sample_weight` wrong length + "sample_weight": np.array([0.5]), + "pos_label": None, + }, + "Found input variables with inconsistent", + ), + ( + { + "cv_results": { + "estimator": "dummy", + "indices": {"test": [[1, 2], [1, 2]], "train": [[3, 4], [3, 4]]}, + }, + "X": np.array([1, 2, 3, 4]), + "y": np.array([2, 3, 2, 3]), + "sample_weight": None, + # Not specified when `y` not in {0, 1} or {-1, 1} + "pos_label": None, + }, + "y takes value in {2, 3} and pos_label is not specified", + ), + ], +) +def test_validate_from_cv_results_params(pyplot, params, err_msg): + """Check parameter validation is performed correctly.""" + with pytest.raises(ValueError, match=err_msg): + _BinaryClassifierCurveDisplayMixin()._validate_from_cv_results_params(**params) + + +@pytest.mark.parametrize( + "curve_legend_metric, curve_name, expected_label", + [ + (0.85, None, "AUC = 0.85"), + (None, "Model A", "Model A"), + (0.95, "Random Forest", "Random Forest (AUC = 0.95)"), + (None, None, None), + ], +) +def test_get_legend_label(curve_legend_metric, curve_name, expected_label): + """Check `_get_legend_label` returns the correct label.""" + legend_metric_name = "AUC" + label = _BinaryClassifierCurveDisplayMixin._get_legend_label( + curve_legend_metric, curve_name, legend_metric_name + ) + print(label) + assert label == expected_label + + def metric(): pass # pragma: no cover From 77511e24a53ebf0a67c93c50e63a77eac9c2bf7e Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Mon, 14 Apr 2025 12:42:51 +1000 Subject: [PATCH 58/63] typo --- sklearn/metrics/_plot/roc_curve.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/sklearn/metrics/_plot/roc_curve.py b/sklearn/metrics/_plot/roc_curve.py index e8f167ebda145..586366dfbf2f4 100644 --- a/sklearn/metrics/_plot/roc_curve.py +++ b/sklearn/metrics/_plot/roc_curve.py @@ -737,8 +737,6 @@ def from_cv_results( y, sample_weight=sample_weight, pos_label=pos_label, - name=name, - curve_kwargs=curve_kwargs, ) fpr_folds, tpr_folds, auc_folds = [], [], [] From 1c9c67cc79366d1d7bcb51f4d9d1b9ba4955b698 Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Mon, 14 Apr 2025 13:37:26 +1000 Subject: [PATCH 59/63] change to feature --- .../sklearn.metrics/{30399.enhancement.rst => 30399.feature.rst} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename doc/whats_new/upcoming_changes/sklearn.metrics/{30399.enhancement.rst => 30399.feature.rst} (100%) diff --git a/doc/whats_new/upcoming_changes/sklearn.metrics/30399.enhancement.rst b/doc/whats_new/upcoming_changes/sklearn.metrics/30399.feature.rst similarity index 100% rename from doc/whats_new/upcoming_changes/sklearn.metrics/30399.enhancement.rst rename to doc/whats_new/upcoming_changes/sklearn.metrics/30399.feature.rst From 4b82ba3c02621ef07e6132ea25308887da152b6d Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Thu, 17 Apr 2025 11:18:55 +1000 Subject: [PATCH 60/63] remove early param validation --- sklearn/utils/_plotting.py | 32 -------------------------------- 1 file changed, 32 deletions(-) diff --git a/sklearn/utils/_plotting.py b/sklearn/utils/_plotting.py index ee54684635202..b30ad40affacd 100644 --- a/sklearn/utils/_plotting.py +++ b/sklearn/utils/_plotting.py @@ -78,8 +78,6 @@ def _validate_from_cv_results_params( *, sample_weight, pos_label, - # name, - # curve_kwargs, ): check_matplotlib_support(f"{cls.__name__}.from_cv_results") @@ -116,38 +114,8 @@ def _validate_from_cv_results_params( # Adapt error message raise ValueError(str(e).replace("y_true", "y")) - # n_curves = len(cv_results["estimator"]) - # # NB: Both these also checked in `plot`, but thought it best to fail earlier. - # cls._validate_multi_curve_kwargs(cls, curve_kwargs, name, n_curves) - # if isinstance(name, list) and len(name) not in (1, n_curves): - # raise ValueError( - # f"`name` must be None, a list of length {n_curves} or a single " - # f"string. Got list of length: {len(name)}." - # ) - return pos_label - # def _validate_multi_curve_kwargs(cls, curve_kwargs, name, n_curves): - # """Check `curve_kwargs`, including combination with `name`, is valid.""" - # if isinstance(curve_kwargs, list) and len(curve_kwargs) != n_curves: - # raise ValueError( - # f"`curve_kwargs` must be None, a dictionary or a list of length " - # f"{n_curves}. Got: {curve_kwargs}." - # ) - - # # Ensure valid `name` and `curve_kwargs` combination. - # if ( - # isinstance(name, list) - # and len(name) != 1 - # and not isinstance(curve_kwargs, list) - # ): - # raise ValueError( - # "To avoid labeling individual curves that have the same appearance, " - # f"`curve_kwargs` should be a list of {n_curves} dictionaries. " - # "Alternatively, set `name` to `None` or a single string to label " - # "a single legend entry with mean ROC AUC score of all curves." - # ) - @staticmethod def _get_legend_label(curve_legend_metric, curve_name, legend_metric_name): """Helper to get legend label using `name` and `legend_metric`""" From d6c1d9f78ec97e2d17e40ba51f37f84452eed9c1 Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Thu, 17 Apr 2025 11:30:25 +1000 Subject: [PATCH 61/63] lint --- sklearn/utils/_plotting.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sklearn/utils/_plotting.py b/sklearn/utils/_plotting.py index b30ad40affacd..497faa3627555 100644 --- a/sklearn/utils/_plotting.py +++ b/sklearn/utils/_plotting.py @@ -103,8 +103,7 @@ def _validate_from_cv_results_params( if type_of_target(y) != "binary": raise ValueError( - f"The target `y` is not binary. Got {type_of_target(y)} type of" - " target." + f"The target `y` is not binary. Got {type_of_target(y)} type of target." ) check_consistent_length(X, y, sample_weight) From df177ab0468add807cd5adadb4dd275e6416ac42 Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Thu, 24 Apr 2025 13:40:23 +1000 Subject: [PATCH 62/63] review --- .../_plot/tests/test_roc_curve_display.py | 8 +- sklearn/utils/_plotting.py | 45 ++--- sklearn/utils/tests/test_plotting.py | 167 +++++++++++++++++- 3 files changed, 193 insertions(+), 27 deletions(-) diff --git a/sklearn/metrics/_plot/tests/test_roc_curve_display.py b/sklearn/metrics/_plot/tests/test_roc_curve_display.py index 56f11c0b44dda..3f788009a21a9 100644 --- a/sklearn/metrics/_plot/tests/test_roc_curve_display.py +++ b/sklearn/metrics/_plot/tests/test_roc_curve_display.py @@ -27,8 +27,6 @@ def data(): return X, y -# This data always (with and without `drop_intermediate`) -# results in an AUC of 1.0, should we consider changing the data used?? @pytest.fixture(scope="module") def data_binary(data): X, y = data @@ -324,7 +322,7 @@ def test_roc_curve_display_from_cv_results_curve_kwargs( ) -# TODO : Remove in 1.9 +# TODO(1.9): Remove in 1.9 def test_roc_curve_display_estimator_name_deprecation(pyplot): """Check deprecation of `estimator_name`.""" fpr = np.array([0, 0.5, 1]) @@ -333,7 +331,7 @@ def test_roc_curve_display_estimator_name_deprecation(pyplot): RocCurveDisplay(fpr=fpr, tpr=tpr, estimator_name="test") -# TODO : Remove in 1.9 +# TODO(1.9): Remove in 1.9 @pytest.mark.parametrize( "constructor_name", ["from_estimator", "from_predictions", "plot"] ) @@ -706,7 +704,6 @@ def test_roc_curve_chance_level_line( {"lw": 1, "color": "blue", "ls": "-", "label": None}, ], ) -# To ensure both curve line kwargs and change line kwargs passed correctly @pytest.mark.parametrize("curve_kwargs", [None, {"alpha": 0.8}]) def test_roc_curve_chance_level_line_from_cv_results( pyplot, @@ -734,6 +731,7 @@ def test_roc_curve_chance_level_line_from_cv_results( import matplotlib as mpl assert all(isinstance(line, mpl.lines.Line2D) for line in display.line_) + # Ensure both curve line kwargs passed correctly as well if curve_kwargs: assert all(line.get_alpha() == 0.8 for line in display.line_) assert isinstance(display.ax_, mpl.axes.Axes) diff --git a/sklearn/utils/_plotting.py b/sklearn/utils/_plotting.py index 497faa3627555..e1f13674c42e2 100644 --- a/sklearn/utils/_plotting.py +++ b/sklearn/utils/_plotting.py @@ -147,24 +147,24 @@ def _validate_curve_kwargs( name : list of str or None Name for labeling legend entries. - legend_metric : dict or None + legend_metric : dict Dictionary with "mean" and "std" keys, or "metric" key of metric values for each curve. If None, "label" will not contain metric values. - legend_metric_name : str or None + legend_metric_name : str Name of the summary value provided in `legend_metrics`. - curve_kwargs : dict or list of dict + curve_kwargs : dict or list of dict or None Dictionary with keywords passed to the matplotlib's `plot` function - to draw the individual ROC curves. If a list is provided, the - parameters are applied to the ROC curves sequentially. If a single - dictionary is provided, the same parameters are applied to all ROC - curves. Ignored for single curve plots - pass as `**kwargs` for - single curve plots. + to draw the individual curves. If a list is provided, the + parameters are applied to the curves sequentially. If a single + dictionary is provided, the same parameters are applied to all + curves. **kwargs : dict Deprecated. Keyword arguments to be passed to matplotlib's `plot`. """ + # TODO(1.9): Remove # Deprecate **kwargs if curve_kwargs and kwargs: raise ValueError( @@ -186,6 +186,9 @@ def _validate_curve_kwargs( f"{n_curves}. Got: {curve_kwargs}." ) + if isinstance(name, str): + name = [name] + # Ensure valid `name` and `curve_kwargs` combination. if ( isinstance(name, list) @@ -220,8 +223,10 @@ def _validate_curve_kwargs( label_aggregate = _BinaryClassifierCurveDisplayMixin._get_legend_label( legend_metric["mean"], name[0], legend_metric_name ) - # Add the "+/- std" to the end (in brackets if name provided) - if legend_metric["mean"] is not None: + # Note: "std" always `None` when "mean" is `None` - no metric value added + # to label in this case + if legend_metric["std"] is not None: + # Add the "+/- std" to the end (in brackets if name provided) if name[0] is not None: label_aggregate = ( label_aggregate[:-1] + f" +/- {legend_metric['std']:0.2f})" @@ -240,12 +245,10 @@ def _validate_curve_kwargs( ) ) - curve_kwargs_ = [] - for fold_idx, label in enumerate(labels): - label_kwarg = {"label": label} - curve_kwargs_.append( - _validate_style_kwargs(label_kwarg, curve_kwargs[fold_idx]) - ) + curve_kwargs_ = [ + _validate_style_kwargs({"label": label}, curve_kwargs[fold_idx]) + for fold_idx, label in enumerate(labels) + ] return curve_kwargs_ @@ -364,12 +367,12 @@ def _despine(ax): ax.spines[s].set_bounds(0, 1) -def _deprecate_estimator_name(old, new, version): +def _deprecate_estimator_name(estimator_name, name, version): """Deprecate `estimator_name` in favour of `name`.""" version = parse_version(version) version_remove = f"{version.major}.{version.minor + 2}" - if old != "deprecated": - if new: + if estimator_name != "deprecated": + if name: raise ValueError( "Cannot provide both `estimator_name` and `name`. `estimator_name` " f"is deprecated in {version} and will be removed in {version_remove}. " @@ -380,8 +383,8 @@ def _deprecate_estimator_name(old, new, version): f"{version_remove}. Use `name` instead.", FutureWarning, ) - return old - return new + return estimator_name + return name def _convert_to_list_leaving_none(param): diff --git a/sklearn/utils/tests/test_plotting.py b/sklearn/utils/tests/test_plotting.py index caaea133a1ee5..1cf36c648b64f 100644 --- a/sklearn/utils/tests/test_plotting.py +++ b/sklearn/utils/tests/test_plotting.py @@ -3,6 +3,7 @@ from sklearn.utils._plotting import ( _BinaryClassifierCurveDisplayMixin, + _deprecate_estimator_name, _despine, _interval_max_min_ratio, _validate_score_name, @@ -103,10 +104,146 @@ def test_get_legend_label(curve_legend_metric, curve_name, expected_label): label = _BinaryClassifierCurveDisplayMixin._get_legend_label( curve_legend_metric, curve_name, legend_metric_name ) - print(label) assert label == expected_label +# TODO(1.9) : Remove +@pytest.mark.parametrize("curve_kwargs", [{"alpha": 1.0}, None]) +@pytest.mark.parametrize("kwargs", [{}, {"alpha": 1.0}]) +def test_validate_curve_kwargs_deprecate_kwargs(curve_kwargs, kwargs): + """Check `_validate_curve_kwargs` deprecates kwargs correctly.""" + n_curves = 1 + name = None + legend_metric = {"mean": 0.8, "std": 0.1} + legend_metric_name = "AUC" + + if curve_kwargs and kwargs: + with pytest.raises(ValueError, match="Cannot provide both `curve_kwargs`"): + _BinaryClassifierCurveDisplayMixin._validate_curve_kwargs( + n_curves, + name, + legend_metric, + legend_metric_name, + curve_kwargs, + **kwargs, + ) + elif kwargs: + with pytest.warns(FutureWarning, match=r"`\*\*kwargs` is deprecated and"): + _BinaryClassifierCurveDisplayMixin._validate_curve_kwargs( + n_curves, + name, + legend_metric, + legend_metric_name, + curve_kwargs, + **kwargs, + ) + else: + # No warning or error should be raised + _BinaryClassifierCurveDisplayMixin._validate_curve_kwargs( + n_curves, name, legend_metric, legend_metric_name, curve_kwargs, **kwargs + ) + + +def test_validate_curve_kwargs_error(): + """Check `_validate_curve_kwargs` performs parameter validation correctly.""" + n_curves = 3 + legend_metric = {"mean": 0.8, "std": 0.1} + legend_metric_name = "AUC" + with pytest.raises(ValueError, match="`curve_kwargs` must be None"): + _BinaryClassifierCurveDisplayMixin._validate_curve_kwargs( + n_curves=n_curves, + name=None, + legend_metric=legend_metric, + legend_metric_name=legend_metric_name, + curve_kwargs=[{"alpha": 1.0}], + ) + with pytest.raises(ValueError, match="To avoid labeling individual curves"): + name = ["one", "two", "three"] + _BinaryClassifierCurveDisplayMixin._validate_curve_kwargs( + n_curves=n_curves, + name=name, + legend_metric=legend_metric, + legend_metric_name=legend_metric_name, + curve_kwargs=None, + ) + _BinaryClassifierCurveDisplayMixin._validate_curve_kwargs( + n_curves=n_curves, + name=name, + legend_metric=legend_metric, + legend_metric_name=legend_metric_name, + curve_kwargs={"alpha": 1.0}, + ) + + +@pytest.mark.parametrize("name", [None, "curve_name", ["curve_name"]]) +@pytest.mark.parametrize( + "legend_metric", + [ + {"mean": 0.8, "std": 0.2}, + {"mean": None, "std": None}, + ], +) +@pytest.mark.parametrize("legend_metric_name", ["AUC", "AP"]) +@pytest.mark.parametrize( + "curve_kwargs", + [ + None, + {"color": "red"}, + ], +) +def test_validate_curve_kwargs_single_legend( + name, legend_metric, legend_metric_name, curve_kwargs +): + """Check `_validate_curve_kwargs` returns correct kwargs for single legend entry.""" + n_curves = 3 + + curve_kwargs_out = _BinaryClassifierCurveDisplayMixin._validate_curve_kwargs( + n_curves=n_curves, + name=name, + legend_metric=legend_metric, + legend_metric_name=legend_metric_name, + curve_kwargs=curve_kwargs, + ) + + assert isinstance(curve_kwargs_out, list) + assert len(curve_kwargs_out) == n_curves + + expected_label = None + if isinstance(name, list): + name = name[0] + if name is not None: + expected_label = name + if legend_metric["mean"] is not None: + expected_label = expected_label + f" ({legend_metric_name} = 0.80 +/- 0.20)" + elif legend_metric["mean"] is not None: + expected_label = f"{legend_metric_name} = 0.80 +/- 0.20" + + assert curve_kwargs_out[0]["label"] == expected_label + # All remaining curves should have None as "label" + assert curve_kwargs_out[1]["label"] is None + assert curve_kwargs_out[2]["label"] is None + + default_multi_curve_kwargs = {"alpha": 0.5, "linestyle": "--", "color": "blue"} + # Default multi-curve kwargs + if curve_kwargs is None: + assert all(len(kwargs) == 4 for kwargs in curve_kwargs_out) + assert all(kwargs["alpha"] == 0.5 for kwargs in curve_kwargs_out) + assert all(kwargs["linestyle"] == "--" for kwargs in curve_kwargs_out) + assert all(kwargs["color"] == "blue" for kwargs in curve_kwargs_out) + else: + assert all(len(kwargs) == 2 for kwargs in curve_kwargs_out) + assert all(kwargs["color"] == "red" for kwargs in curve_kwargs_out) + + +@pytest.mark.parametrize( + "legend_metric", [{"metric": [1.0, 1.0, 1.0]}, {"metric": None}] +) +def test_validate_curve_kwargs_multi_legend( + name, legend_metric, legend_metric_name, curve_kwargs +): + """Check `_validate_curve_kwargs` returns correct kwargs for multi legend entry.""" + + def metric(): pass # pragma: no cover @@ -236,3 +373,31 @@ def test_despine(pyplot): assert ax.spines["right"].get_visible() is False assert ax.spines["bottom"].get_bounds() == (0, 1) assert ax.spines["left"].get_bounds() == (0, 1) + + +@pytest.mark.parametrize("estimator_name", ["my_est_name", "deprecated"]) +@pytest.mark.parametrize("name", [None, "my_name"]) +def test_deprecate_estimator_name(estimator_name, name): + """Check `_deprecate_estimator_name` behaves correctly""" + version = "1.7" + version_remove = "1.9" + + if estimator_name == "deprecated": + name_out = _deprecate_estimator_name(estimator_name, name, version) + assert name_out == name + # `estimator_name` is provided and `name` is: + elif name is None: + warning_message = ( + f"`estimator_name` is deprecated in {version} and will be removed in " + f"{version_remove}. Use `name` instead." + ) + with pytest.warns(FutureWarning, match=warning_message): + result = _deprecate_estimator_name(estimator_name, name, version) + assert result == estimator_name + elif name is not None: + error_message = ( + f"Cannot provide both `estimator_name` and `name`. `estimator_name` " + f"is deprecated in {version} and will be removed in {version_remove}. " + ) + with pytest.raises(ValueError, match=error_message): + _deprecate_estimator_name(estimator_name, name, version) From 2de0f78a9720a82e3e280e6ed5d0c9305f16a49f Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Thu, 24 Apr 2025 14:37:39 +1000 Subject: [PATCH 63/63] add multi test --- sklearn/utils/_plotting.py | 5 ++- sklearn/utils/tests/test_plotting.py | 46 ++++++++++++++++++++++++---- 2 files changed, 42 insertions(+), 9 deletions(-) diff --git a/sklearn/utils/_plotting.py b/sklearn/utils/_plotting.py index e1f13674c42e2..ac893282ea6cf 100644 --- a/sklearn/utils/_plotting.py +++ b/sklearn/utils/_plotting.py @@ -186,9 +186,6 @@ def _validate_curve_kwargs( f"{n_curves}. Got: {curve_kwargs}." ) - if isinstance(name, str): - name = [name] - # Ensure valid `name` and `curve_kwargs` combination. if ( isinstance(name, list) @@ -203,6 +200,8 @@ def _validate_curve_kwargs( ) # Ensure `name` is of the correct length + if isinstance(name, str): + name = [name] if isinstance(name, list) and len(name) == 1: name = name * n_curves name = [None] * n_curves if name is None else name diff --git a/sklearn/utils/tests/test_plotting.py b/sklearn/utils/tests/test_plotting.py index 1cf36c648b64f..c0cff3265c621 100644 --- a/sklearn/utils/tests/test_plotting.py +++ b/sklearn/utils/tests/test_plotting.py @@ -196,7 +196,6 @@ def test_validate_curve_kwargs_single_legend( ): """Check `_validate_curve_kwargs` returns correct kwargs for single legend entry.""" n_curves = 3 - curve_kwargs_out = _BinaryClassifierCurveDisplayMixin._validate_curve_kwargs( n_curves=n_curves, name=name, @@ -215,6 +214,7 @@ def test_validate_curve_kwargs_single_legend( expected_label = name if legend_metric["mean"] is not None: expected_label = expected_label + f" ({legend_metric_name} = 0.80 +/- 0.20)" + # `name` is None elif legend_metric["mean"] is not None: expected_label = f"{legend_metric_name} = 0.80 +/- 0.20" @@ -223,7 +223,6 @@ def test_validate_curve_kwargs_single_legend( assert curve_kwargs_out[1]["label"] is None assert curve_kwargs_out[2]["label"] is None - default_multi_curve_kwargs = {"alpha": 0.5, "linestyle": "--", "color": "blue"} # Default multi-curve kwargs if curve_kwargs is None: assert all(len(kwargs) == 4 for kwargs in curve_kwargs_out) @@ -235,13 +234,48 @@ def test_validate_curve_kwargs_single_legend( assert all(kwargs["color"] == "red" for kwargs in curve_kwargs_out) +@pytest.mark.parametrize("name", [None, "curve_name", ["one", "two", "three"]]) @pytest.mark.parametrize( - "legend_metric", [{"metric": [1.0, 1.0, 1.0]}, {"metric": None}] + "legend_metric", [{"metric": [1.0, 1.0, 1.0]}, {"metric": [None, None, None]}] ) -def test_validate_curve_kwargs_multi_legend( - name, legend_metric, legend_metric_name, curve_kwargs -): +@pytest.mark.parametrize("legend_metric_name", ["AUC", "AP"]) +def test_validate_curve_kwargs_multi_legend(name, legend_metric, legend_metric_name): """Check `_validate_curve_kwargs` returns correct kwargs for multi legend entry.""" + n_curves = 3 + curve_kwargs = [{"color": "red"}, {"color": "yellow"}, {"color": "blue"}] + curve_kwargs_out = _BinaryClassifierCurveDisplayMixin._validate_curve_kwargs( + n_curves=n_curves, + name=name, + legend_metric=legend_metric, + legend_metric_name=legend_metric_name, + curve_kwargs=curve_kwargs, + ) + + assert isinstance(curve_kwargs_out, list) + assert len(curve_kwargs_out) == n_curves + + expected_labels = [None, None, None] + if isinstance(name, str): + expected_labels = "curve_name" + if legend_metric["metric"][0] is not None: + expected_labels = expected_labels + f" ({legend_metric_name} = 1.00)" + expected_labels = [expected_labels] * n_curves + elif isinstance(name, list) and legend_metric["metric"][0] is None: + expected_labels = name + elif isinstance(name, list) and legend_metric["metric"][0] is not None: + expected_labels = [ + f"{name_single} ({legend_metric_name} = 1.00)" for name_single in name + ] + # `name` is None + elif legend_metric["metric"][0] is not None: + expected_labels = [f"{legend_metric_name} = 1.00"] * n_curves + + for idx, expected_label in enumerate(expected_labels): + assert curve_kwargs_out[idx]["label"] == expected_label + + assert all(len(kwargs) == 2 for kwargs in curve_kwargs_out) + for curve_kwarg, curve_kwarg_out in zip(curve_kwargs, curve_kwargs_out): + assert curve_kwarg_out["color"] == curve_kwarg["color"] def metric():