Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions build_tools/circle/doc_environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ dependencies:
- towncrier
- jupyterlite-sphinx
- jupyterlite-pyodide-kernel
- ipympl
- pip
- pip:
- sphinxcontrib-sass
28 changes: 23 additions & 5 deletions build_tools/circle/doc_linux-64_conda.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions build_tools/update_environments_and_lock_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,7 @@ def remove_from(alist, to_remove):
"towncrier",
"jupyterlite-sphinx",
"jupyterlite-pyodide-kernel",
"ipympl",
],
"pip_dependencies": [
"sphinxcontrib-sass",
Expand Down
1 change: 1 addition & 0 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,6 +745,7 @@ def reset_sklearn_config(gallery_conf, fname):
"plot_gallery": "True",
"recommender": {"enable": True, "n_examples": 4, "min_df": 12},
"reset_modules": ("matplotlib", "seaborn", reset_sklearn_config),
"first_notebook_cell": "%matplotlib widget",
}
if with_jupyterlite:
sphinx_gallery_conf["jupyterlite"] = {
Expand Down
19 changes: 17 additions & 2 deletions sklearn/metrics/_plot/precision_recall_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@
_BinaryClassifierCurveDisplayMixin,
_deprecate_y_pred_parameter,
_despine,
_LineTooltipMixin,
_validate_style_kwargs,
)


class PrecisionRecallDisplay(_BinaryClassifierCurveDisplayMixin):
class PrecisionRecallDisplay(_BinaryClassifierCurveDisplayMixin, _LineTooltipMixin):
"""Precision Recall visualization.

It is recommended to use
Expand All @@ -34,6 +35,13 @@ class PrecisionRecallDisplay(_BinaryClassifierCurveDisplayMixin):
recall : ndarray
Recall values.

thresholds : ndarray or list of ndarrays, default=None
The thresholds at which the fpr and tpr have been computed. Each ndarray should
contain values for a single curve. If plotting multiple curves, list should be
of same length as `fpr` and `tpr`.
Only used to display the threshold values along the curve as a tooltip. If None,
only the fpr and tpr values are displayed.

average_precision : float, default=None
Average precision. If None, the average precision is not shown.

Expand Down Expand Up @@ -117,6 +125,7 @@ def __init__(
precision,
recall,
*,
thresholds=None,
average_precision=None,
estimator_name=None,
pos_label=None,
Expand All @@ -125,6 +134,7 @@ def __init__(
self.estimator_name = estimator_name
self.precision = precision
self.recall = recall
self.thresholds = thresholds
self.average_precision = average_precision
self.pos_label = pos_label
self.prevalence_pos_label = prevalence_pos_label
Expand Down Expand Up @@ -206,6 +216,10 @@ def plot(

(self.line_,) = self.ax_.plot(self.recall, self.precision, **line_kwargs)

self._add_line_tooltip(
x_label="R", y_label="P", t_label="threshold", t_vals=[self.thresholds]
)

info_pos_label = (
f" (Positive label: {self.pos_label})" if self.pos_label is not None else ""
)
Expand Down Expand Up @@ -537,7 +551,7 @@ def from_predictions(
y_true, y_score, sample_weight=sample_weight, pos_label=pos_label, name=name
)

precision, recall, _ = precision_recall_curve(
precision, recall, thresholds = precision_recall_curve(
y_true,
y_score,
pos_label=pos_label,
Expand All @@ -554,6 +568,7 @@ def from_predictions(
viz = cls(
precision=precision,
recall=recall,
thresholds=thresholds,
average_precision=average_precision,
estimator_name=name,
pos_label=pos_label,
Expand Down
32 changes: 26 additions & 6 deletions sklearn/metrics/_plot/roc_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@
_deprecate_estimator_name,
_deprecate_y_pred_parameter,
_despine,
_LineTooltipMixin,
_validate_style_kwargs,
)
from sklearn.utils._response import _get_response_values_binary


class RocCurveDisplay(_BinaryClassifierCurveDisplayMixin):
class RocCurveDisplay(_BinaryClassifierCurveDisplayMixin, _LineTooltipMixin):
"""ROC Curve visualization.

It is recommended to use
Expand Down Expand Up @@ -49,6 +50,13 @@ class RocCurveDisplay(_BinaryClassifierCurveDisplayMixin):
.. versionchanged:: 1.7
Now accepts a list for plotting multiple curves.

thresholds : ndarray or list of ndarrays, default=None
The thresholds at which the fpr and tpr have been computed. Each ndarray should
contain values for a single curve. If plotting multiple curves, list should be
of same length as `fpr` and `tpr`.
Only used to display the threshold values along the curve as a tooltip. If None,
only the fpr and tpr values are displayed.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this docstring should mention that it's necessary to install the ipympl package and use the %matplotlib widget magic to use this feature in jupyter notebooks.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The user guide could also benefit from a new paragraph to explain how to enable such interactive tooltips.


roc_auc : float or list of floats, default=None
Area under ROC curve, used for labeling each curve in the legend.
If plotting multiple curves, should be a list of the same length as `fpr`
Expand Down Expand Up @@ -132,13 +140,15 @@ def __init__(
*,
fpr,
tpr,
thresholds=None,
roc_auc=None,
name=None,
pos_label=None,
estimator_name="deprecated",
):
self.fpr = fpr
self.tpr = tpr
self.thresholds = thresholds
self.roc_auc = roc_auc
self.name = _deprecate_estimator_name(estimator_name, name, "1.7")
self.pos_label = pos_label
Expand All @@ -148,6 +158,7 @@ def _validate_plot_params(self, *, ax, name):

fpr = _convert_to_list_leaving_none(self.fpr)
tpr = _convert_to_list_leaving_none(self.tpr)
thresholds = _convert_to_list_leaving_none(self.thresholds)
roc_auc = _convert_to_list_leaving_none(self.roc_auc)
name = _convert_to_list_leaving_none(name)

Expand All @@ -159,7 +170,7 @@ def _validate_plot_params(self, *, ax, name):
optional=optional,
class_name="RocCurveDisplay",
)
return fpr, tpr, roc_auc, name
return fpr, tpr, thresholds, roc_auc, name

def plot(
self,
Expand Down Expand Up @@ -233,7 +244,9 @@ def plot(
display : :class:`~sklearn.metrics.RocCurveDisplay`
Object that stores computed values.
"""
fpr, tpr, roc_auc, name = self._validate_plot_params(ax=ax, name=name)
fpr, tpr, thresholds, roc_auc, name = self._validate_plot_params(
ax=ax, name=name
)
n_curves = len(fpr)
if not isinstance(curve_kwargs, list) and n_curves > 1:
if roc_auc:
Expand Down Expand Up @@ -273,6 +286,10 @@ def plot(
if len(self.line_) == 1:
self.line_ = self.line_[0]

self._add_line_tooltip(
x_label="FPR", y_label="TPR", t_label="threshold", t_vals=thresholds
)

info_pos_label = (
f" (Positive label: {self.pos_label})" if self.pos_label is not None else ""
)
Expand Down Expand Up @@ -580,7 +597,7 @@ def from_predictions(
y_true, y_score, sample_weight=sample_weight, pos_label=pos_label, name=name
)

fpr, tpr, _ = roc_curve(
fpr, tpr, thresholds = roc_curve(
y_true,
y_score,
pos_label=pos_label,
Expand All @@ -592,6 +609,7 @@ def from_predictions(
viz = cls(
fpr=fpr,
tpr=tpr,
thresholds=thresholds,
roc_auc=roc_auc,
name=name,
pos_label=pos_label_validated,
Expand Down Expand Up @@ -731,7 +749,7 @@ def from_cv_results(
pos_label=pos_label,
)

fpr_folds, tpr_folds, auc_folds = [], [], []
fpr_folds, tpr_folds, thresholds_folds, auc_folds = [], [], [], []
for estimator, test_indices in zip(
cv_results["estimator"], cv_results["indices"]["test"]
):
Expand All @@ -747,7 +765,7 @@ def from_cv_results(
if sample_weight is None
else _safe_indexing(sample_weight, test_indices)
)
fpr, tpr, _ = roc_curve(
fpr, tpr, thresholds = roc_curve(
y_true,
y_pred,
pos_label=pos_label_,
Expand All @@ -758,11 +776,13 @@ def from_cv_results(

fpr_folds.append(fpr)
tpr_folds.append(tpr)
thresholds_folds.append(thresholds)
auc_folds.append(roc_auc)

viz = cls(
fpr=fpr_folds,
tpr=tpr_folds,
thresholds=thresholds_folds,
roc_auc=auc_folds,
name=name,
pos_label=pos_label_,
Expand Down
27 changes: 25 additions & 2 deletions sklearn/metrics/_plot/tests/test_roc_curve_display.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from sklearn.preprocessing import StandardScaler
from sklearn.utils import _safe_indexing, shuffle
from sklearn.utils._response import _get_response_values_binary
from sklearn.utils.tests.test_plotting import _simulate_mouse_event


@pytest.fixture(scope="module")
Expand Down Expand Up @@ -201,25 +202,29 @@ def test_validate_plot_params(pyplot):
"""Check `_validate_plot_params` returns the correct variables."""
fpr = np.array([0, 0.5, 1])
tpr = [np.array([0, 0.5, 1])]
thresholds = np.array([1, 0.5, 0])
roc_auc = None
name = "test_curve"

# Initialize display with test inputs
display = RocCurveDisplay(
fpr=fpr,
tpr=tpr,
thresholds=thresholds,
roc_auc=roc_auc,
name=name,
pos_label=None,
)
fpr_out, tpr_out, roc_auc_out, name_out = display._validate_plot_params(
ax=None, name=None
fpr_out, tpr_out, thresholds_out, roc_auc_out, name_out = (
display._validate_plot_params(ax=None, name=None)
)

assert isinstance(fpr_out, list)
assert isinstance(tpr_out, list)
assert isinstance(thresholds_out, list)
assert len(fpr_out) == 1
assert len(tpr_out) == 1
assert len(thresholds_out) == 1
assert roc_auc_out is None
assert name_out == ["test_curve"]

Expand Down Expand Up @@ -980,3 +985,21 @@ def test_plot_roc_curve_despine(pyplot, data_binary, despine, constructor_name):
if despine:
for s in ["bottom", "left"]:
assert display.ax_.spines[s].get_bounds() == (0, 1)


def test_roc_curve_display_line_tooltip(pyplot, data_binary):
"""Test the line tooltip on a Roc curve."""
X, y = data_binary

lr = LogisticRegression().fit(X, y)
display = RocCurveDisplay.from_estimator(lr, X, y)

assert hasattr(display, "line_tooltip_")

# simulate a mouse event on the line
x, y = display.fpr[10], display.tpr[10]
_simulate_mouse_event(display, x, y)

assert display.line_tooltip_.get_visible()
text = display.line_tooltip_.get_text()
assert all(name in text for name in ("FPR", "TPR", "threshold"))
Loading
Loading