From 02c6fbad9b7f29cc63bc6d02b950a483da14037b Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Fri, 29 Nov 2024 16:25:35 +1100 Subject: [PATCH 1/3] first commit --- sklearn/metrics/_plot/roc_curve.py | 194 ++++++++++++++++++++++++++++- 1 file changed, 193 insertions(+), 1 deletion(-) diff --git a/sklearn/metrics/_plot/roc_curve.py b/sklearn/metrics/_plot/roc_curve.py index 058b3612baa61..d1605eaaa1a83 100644 --- a/sklearn/metrics/_plot/roc_curve.py +++ b/sklearn/metrics/_plot/roc_curve.py @@ -1,12 +1,17 @@ # Authors: The scikit-learn developers # SPDX-License-Identifier: BSD-3-Clause +from collections.abc import Mapping +from .._ranking import auc, roc_curve +from ...utils import _safe_indexing +from ...utils._optional_dependencies import check_matplotlib_support from ...utils._plotting import ( _BinaryClassifierCurveDisplayMixin, _despine, _validate_style_kwargs, ) -from .._ranking import auc, roc_curve +from ...utils._response import _get_response_values_binary +from ...utils.validation import _num_samples class RocCurveDisplay(_BinaryClassifierCurveDisplayMixin): @@ -449,3 +454,190 @@ def from_predictions( despine=despine, **kwargs, ) + + @classmethod + def from_cv_results( + cls, + cv_results, + X, + y, + *, + response_method="auto", + sample_weight=None, + drop_intermediate=True, + pos_label=None, + ax=None, + fold_names=None, + fold_line_kw=None, + plot_chance_level=False, + chance_level_kw=None, + despine=False, + ): + """Create a multi-fold ROC curve display given cross-validation results. + + .. versionadded:: 1.7 + + Parameters + ---------- + cv_results : dict + Dictionary as returned by :func:`~sklearn.model_selection.cross_validate` + using `return_estimator=True` and `return_indices=True`. + + X : {array-like, sparse matrix} of shape (n_samples, n_features) + Input values. + + y : array-like of shape (n_samples,) + Target values. + + response_method : {'predict_proba', 'decision_function', 'auto'} \ + default='auto' + Specifies whether to use :term:`predict_proba` or + :term:`decision_function` as the target response. If set to 'auto', + :term:`predict_proba` is tried first and if it does not exist + :term:`decision_function` is tried next. + + sample_weight : array-like of shape (n_samples,), default=None + Sample weights. + + drop_intermediate : bool, default=True + Whether to drop some suboptimal thresholds which would not appear + on a plotted ROC curve. This is useful in order to create lighter + ROC curves. + + pos_label : str or int, default=None + The class considered as the positive class when computing the roc auc + metrics. By default, `estimators.classes_[1]` is considered + as the positive class. + + ax : matplotlib axes, default=None + Axes object to plot on. If `None`, a new figure and axes is + created. + + fold_names : list of str, default=None + Name used in the legend for each individual ROC curve. If `None`, + the name will be set to "ROC fold #N" where N is the index of the + CV fold. + + fold_line_kw : dict or list of dict, default=None + Dictionary with keywords passed to the matplotlib's `plot` function + to draw the individual ROC curves. If a list is provided, the + parameters are applied to the ROC curves of each CV fold + sequentially. If a single dictionary is provided, the same + parameters are applied to all ROC curves. + + plot_chance_level : bool, default=False + Whether to plot the chance level. + + chance_level_kw : dict, default=None + Keyword arguments to be passed to matplotlib's `plot` for rendering + the chance level line. + + despine : bool, default=False + Whether to remove the top and right spines from the plot. + + Returns + ------- + display : list of :class:`~sklearn.metrics.RocCurveDisplay` + A :class:`~sklearn.metrics.RocCurveDisplay` for each fold. + + See Also + -------- + roc_curve : Compute Receiver operating characteristic (ROC) curve. + RocCurveDisplay.from_estimator : ROC Curve visualization given an + estimator and some data. + RocCurveDisplay.from_predictions : ROC Curve visualization given the + probabilities of scores of a classifier. + roc_auc_score : Compute the area under the ROC curve. + + Examples + -------- + >>> import matplotlib.pyplot as plt + >>> from sklearn.datasets import make_classification + >>> from sklearn.metrics import RocCurveDisplay + >>> from sklearn.model_selection import cross_validate + >>> from sklearn.svm import SVC + >>> X, y = make_classification(random_state=0) + >>> clf = SVC(random_state=0) + >>> cv_results = cross_validate( + ... clf, X, y, cv=3, return_estimator=True, return_indices=True) + >>> RocCurveDisplay.from_cv_results(cv_results, X, y, kind="both") + <...> + >>> plt.show() + """ + # Ideally we would use `_validate_plot_params` but that is a instance + # method + check_matplotlib_support(f"{cls.__class__.__name__}.plot") + import matplotlib.pyplot as plt + + required_keys = {"estimator", "indices"} + if not all(key in cv_results for key in required_keys): + raise ValueError( + "cv_results does not contain one of the following required keys: " + f"{required_keys}. Set explicitly the parameters return_estimator=True " + "and return_indices=True to the function cross_validate." + ) + + train_size, test_size = ( + len(cv_results["indices"]["train"][0]), + len(cv_results["indices"]["test"][0]), + ) + + if _num_samples(X) != train_size + test_size: + raise ValueError( + "X does not contain the correct number of samples. " + f"Expected {train_size + test_size}, got {_num_samples(X)}." + ) + + n_curves = len(cv_results["estimator"]) + if fold_names is None: + # create an iterable of the same length as the number of ROC curves + fold_names_ = [None] * n_curves + elif fold_names is not None and len(fold_names) != n_curves: + raise ValueError( + "When `fold_names` is provided, it must have the same length as " + f"the number of ROC curves to be plotted. Got {len(fold_names)} names " + f"instead of {n_curves}." + ) + else: + fold_names_ = fold_names + + if fold_line_kw is None: + fold_line_kw = [ + {"alpha": 0.5, "color": "tab:blue", "linestyle": "--"} + ] * n_curves + elif isinstance(fold_line_kw, Mapping): + fold_line_kw = [fold_line_kw] * n_curves + elif len(fold_line_kw) != n_curves: + raise ValueError( + "When `fold_line_kw` is a list, it must have the same length as " + "the number of ROC curves to be plotted." + ) + + if ax is None: + _, ax = plt.subplots() + + displays = [] + for fold_id, (estimator, test_indices, name) in enumerate( + zip(cv_results["estimator"], cv_results["indices"]["test"], fold_names_) + ): + y_true = _safe_indexing(y, test_indices) + y_pred = _get_response_values_binary( + estimator, + _safe_indexing(X, test_indices), + response_method=response_method, + pos_label=pos_label, + )[0] + displays.append(cls.from_predictions( + y_true, + y_pred, + sample_weight=None if sample_weight is None else sample_weight[fold_id], + drop_intermediate=drop_intermediate, + pos_label=pos_label, + name=f"ROC fold {fold_id}" if name is None else name, + ax=ax, + plot_chance_level=plot_chance_level, + chance_level_kw=chance_level_kw, + despine=despine, + **fold_line_kw[fold_id], + )) + return displays From 963966b0c5f10a86d1ba8a683c04419d7dc8407e Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Fri, 29 Nov 2024 16:32:47 +1100 Subject: [PATCH 2/3] lint --- sklearn/metrics/_plot/roc_curve.py | 30 +++++++++++++++++------------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/sklearn/metrics/_plot/roc_curve.py b/sklearn/metrics/_plot/roc_curve.py index d1605eaaa1a83..43eb5328fb235 100644 --- a/sklearn/metrics/_plot/roc_curve.py +++ b/sklearn/metrics/_plot/roc_curve.py @@ -627,17 +627,21 @@ def from_cv_results( response_method=response_method, pos_label=pos_label, )[0] - displays.append(cls.from_predictions( - y_true, - y_pred, - sample_weight=None if sample_weight is None else sample_weight[fold_id], - drop_intermediate=drop_intermediate, - pos_label=pos_label, - name=f"ROC fold {fold_id}" if name is None else name, - ax=ax, - plot_chance_level=plot_chance_level, - chance_level_kw=chance_level_kw, - despine=despine, - **fold_line_kw[fold_id], - )) + displays.append( + cls.from_predictions( + y_true, + y_pred, + sample_weight=( + None if sample_weight is None else sample_weight[fold_id] + ), + drop_intermediate=drop_intermediate, + pos_label=pos_label, + name=f"ROC fold {fold_id}" if name is None else name, + ax=ax, + plot_chance_level=plot_chance_level, + chance_level_kw=chance_level_kw, + despine=despine, + **fold_line_kw[fold_id], + ) + ) return displays From c7966bf7ed376c548dd277e89b6c700d6ab4e142 Mon Sep 17 00:00:00 2001 From: Lucy Liu Date: Fri, 29 Nov 2024 16:37:27 +1100 Subject: [PATCH 3/3] lint --- sklearn/metrics/_plot/roc_curve.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/metrics/_plot/roc_curve.py b/sklearn/metrics/_plot/roc_curve.py index 43eb5328fb235..88308c5f8c1f8 100644 --- a/sklearn/metrics/_plot/roc_curve.py +++ b/sklearn/metrics/_plot/roc_curve.py @@ -2,7 +2,6 @@ # SPDX-License-Identifier: BSD-3-Clause from collections.abc import Mapping -from .._ranking import auc, roc_curve from ...utils import _safe_indexing from ...utils._optional_dependencies import check_matplotlib_support from ...utils._plotting import ( @@ -12,6 +11,7 @@ ) from ...utils._response import _get_response_values_binary from ...utils.validation import _num_samples +from .._ranking import auc, roc_curve class RocCurveDisplay(_BinaryClassifierCurveDisplayMixin):