Skip to content

ENH add from_cv_results in RocCurveDisplay (list of displays) #30370

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 196 additions & 0 deletions sklearn/metrics/_plot/roc_curve.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause
from collections.abc import Mapping

from ...utils import _safe_indexing
from ...utils._optional_dependencies import check_matplotlib_support
from ...utils._plotting import (
_BinaryClassifierCurveDisplayMixin,
_despine,
_validate_style_kwargs,
)
from ...utils._response import _get_response_values_binary
from ...utils.validation import _num_samples
from .._ranking import auc, roc_curve


Expand Down Expand Up @@ -449,3 +454,194 @@ def from_predictions(
despine=despine,
**kwargs,
)

@classmethod
def from_cv_results(
cls,
cv_results,
X,
y,
*,
response_method="auto",
sample_weight=None,
drop_intermediate=True,
pos_label=None,
ax=None,
fold_names=None,
fold_line_kw=None,
plot_chance_level=False,
chance_level_kw=None,
despine=False,
):
"""Create a multi-fold ROC curve display given cross-validation results.

.. versionadded:: 1.7

Parameters
----------
cv_results : dict
Dictionary as returned by :func:`~sklearn.model_selection.cross_validate`
using `return_estimator=True` and `return_indices=True`.

X : {array-like, sparse matrix} of shape (n_samples, n_features)
Input values.

y : array-like of shape (n_samples,)
Target values.

response_method : {'predict_proba', 'decision_function', 'auto'} \
default='auto'
Specifies whether to use :term:`predict_proba` or
:term:`decision_function` as the target response. If set to 'auto',
:term:`predict_proba` is tried first and if it does not exist
:term:`decision_function` is tried next.

sample_weight : array-like of shape (n_samples,), default=None
Sample weights.

drop_intermediate : bool, default=True
Whether to drop some suboptimal thresholds which would not appear
on a plotted ROC curve. This is useful in order to create lighter
ROC curves.

pos_label : str or int, default=None
The class considered as the positive class when computing the roc auc
metrics. By default, `estimators.classes_[1]` is considered
as the positive class.

ax : matplotlib axes, default=None
Axes object to plot on. If `None`, a new figure and axes is
created.

fold_names : list of str, default=None
Name used in the legend for each individual ROC curve. If `None`,
the name will be set to "ROC fold #N" where N is the index of the
CV fold.

fold_line_kw : dict or list of dict, default=None
Dictionary with keywords passed to the matplotlib's `plot` function
to draw the individual ROC curves. If a list is provided, the
parameters are applied to the ROC curves of each CV fold
sequentially. If a single dictionary is provided, the same
parameters are applied to all ROC curves.

plot_chance_level : bool, default=False
Whether to plot the chance level.

chance_level_kw : dict, default=None
Keyword arguments to be passed to matplotlib's `plot` for rendering
the chance level line.

despine : bool, default=False
Whether to remove the top and right spines from the plot.

Returns
-------
display : list of :class:`~sklearn.metrics.RocCurveDisplay`
A :class:`~sklearn.metrics.RocCurveDisplay` for each fold.

See Also
--------
roc_curve : Compute Receiver operating characteristic (ROC) curve.
RocCurveDisplay.from_estimator : ROC Curve visualization given an
estimator and some data.
RocCurveDisplay.from_predictions : ROC Curve visualization given the
probabilities of scores of a classifier.
roc_auc_score : Compute the area under the ROC curve.

Examples
--------
>>> import matplotlib.pyplot as plt
>>> from sklearn.datasets import make_classification
>>> from sklearn.metrics import RocCurveDisplay
>>> from sklearn.model_selection import cross_validate
>>> from sklearn.svm import SVC
>>> X, y = make_classification(random_state=0)
>>> clf = SVC(random_state=0)
>>> cv_results = cross_validate(
... clf, X, y, cv=3, return_estimator=True, return_indices=True)
>>> RocCurveDisplay.from_cv_results(cv_results, X, y, kind="both")
<...>
>>> plt.show()
"""
# Ideally we would use `_validate_plot_params` but that is a instance
# method
check_matplotlib_support(f"{cls.__class__.__name__}.plot")
import matplotlib.pyplot as plt
Comment on lines +567 to +570
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The _BinaryClassifierCurveDisplayMixin method _validate_plot_params, is an instance method which uses self, which means we can't use it here.

Not sure how much of a problem this is.


required_keys = {"estimator", "indices"}
if not all(key in cv_results for key in required_keys):
raise ValueError(
"cv_results does not contain one of the following required keys: "
f"{required_keys}. Set explicitly the parameters return_estimator=True "
"and return_indices=True to the function cross_validate."
)

train_size, test_size = (
len(cv_results["indices"]["train"][0]),
len(cv_results["indices"]["test"][0]),
)

if _num_samples(X) != train_size + test_size:
raise ValueError(
"X does not contain the correct number of samples. "
f"Expected {train_size + test_size}, got {_num_samples(X)}."
)

n_curves = len(cv_results["estimator"])
if fold_names is None:
# create an iterable of the same length as the number of ROC curves
fold_names_ = [None] * n_curves
elif fold_names is not None and len(fold_names) != n_curves:
raise ValueError(
"When `fold_names` is provided, it must have the same length as "
f"the number of ROC curves to be plotted. Got {len(fold_names)} names "
f"instead of {n_curves}."
)
else:
fold_names_ = fold_names

if fold_line_kw is None:
fold_line_kw = [
{"alpha": 0.5, "color": "tab:blue", "linestyle": "--"}
] * n_curves
elif isinstance(fold_line_kw, Mapping):
fold_line_kw = [fold_line_kw] * n_curves
elif len(fold_line_kw) != n_curves:
raise ValueError(
"When `fold_line_kw` is a list, it must have the same length as "
"the number of ROC curves to be plotted."
)

if ax is None:
_, ax = plt.subplots()

displays = []
for fold_id, (estimator, test_indices, name) in enumerate(
zip(cv_results["estimator"], cv_results["indices"]["test"], fold_names_)
):
y_true = _safe_indexing(y, test_indices)
y_pred = _get_response_values_binary(
estimator,
_safe_indexing(X, test_indices),
response_method=response_method,
pos_label=pos_label,
)[0]
displays.append(
cls.from_predictions(
y_true,
y_pred,
sample_weight=(
None if sample_weight is None else sample_weight[fold_id]
),
drop_intermediate=drop_intermediate,
pos_label=pos_label,
name=f"ROC fold {fold_id}" if name is None else name,
ax=ax,
plot_chance_level=plot_chance_level,
chance_level_kw=chance_level_kw,
despine=despine,
**fold_line_kw[fold_id],
)
)
return displays
Loading