From 0da36ae3a105b0ffa3792c882fad124cfb009fb4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=A9mie=20du=20Boisberranger?= Date: Tue, 29 Jul 2025 15:38:04 +0200 Subject: [PATCH 01/11] wip --- sklearn/metrics/_plot/roc_curve.py | 44 ++++++++++++++++++++++++++++-- 1 file changed, 41 insertions(+), 3 deletions(-) diff --git a/sklearn/metrics/_plot/roc_curve.py b/sklearn/metrics/_plot/roc_curve.py index 59c01f2db91a0..97aee3811da1c 100644 --- a/sklearn/metrics/_plot/roc_curve.py +++ b/sklearn/metrics/_plot/roc_curve.py @@ -132,6 +132,7 @@ def __init__( *, fpr, tpr, + thresholds=None, roc_auc=None, name=None, pos_label=None, @@ -139,6 +140,7 @@ def __init__( ): self.fpr = fpr self.tpr = tpr + self.thresholds = thresholds self.roc_auc = roc_auc self.name = _deprecate_estimator_name(estimator_name, name, "1.7") self.pos_label = pos_label @@ -148,6 +150,7 @@ def _validate_plot_params(self, *, ax, name): fpr = _convert_to_list_leaving_none(self.fpr) tpr = _convert_to_list_leaving_none(self.tpr) + thresholds = _convert_to_list_leaving_none(self.thresholds) roc_auc = _convert_to_list_leaving_none(self.roc_auc) name = _convert_to_list_leaving_none(name) @@ -159,7 +162,7 @@ def _validate_plot_params(self, *, ax, name): optional=optional, class_name="RocCurveDisplay", ) - return fpr, tpr, roc_auc, name + return fpr, tpr, thresholds,roc_auc, name def plot( self, @@ -233,7 +236,7 @@ def plot( display : :class:`~sklearn.metrics.RocCurveDisplay` Object that stores computed values. """ - fpr, tpr, roc_auc, name = self._validate_plot_params(ax=ax, name=name) + fpr, tpr, thresholds, roc_auc, name = self._validate_plot_params(ax=ax, name=name) n_curves = len(fpr) if not isinstance(curve_kwargs, list) and n_curves > 1: if roc_auc: @@ -273,6 +276,40 @@ def plot( if len(self.line_) == 1: self.line_ = self.line_[0] + annot = self.ax_.annotate( + "", xy=(0,0), xytext=(20,20),textcoords="offset points", bbox=dict(boxstyle="round", fc="w"), arrowprops=dict(arrowstyle="->") + ) + annot.set_visible(False) + + def update_annot(ind): + x,y = self.line_.get_data() + annot.xy = (x[ind["ind"][0]], y[ind["ind"][0]]) + + # Find the index of the closest point in fpr to x[ind['ind'][0]] + # use it to get the corresponding threshold + idx = np.argmin(np.abs(self.fpr - x[ind["ind"][0]])) + + text = f"FPR: {x[ind['ind'][0]]:.2f}, TPR: {y[ind['ind'][0]]:.2f}, Threshold: {self.thresholds[idx]:.2f}" + # text = f"Threshold: {self.thresholds[idx]:.2f}" + + annot.set_text(text) + annot.get_bbox_patch().set_alpha(0.4) + + def hover(event): + vis = annot.get_visible() + if event.inaxes == self.ax_: + cont, ind = self.line_.contains(event) + if cont: + update_annot(ind) + annot.set_visible(True) + self.figure_.canvas.draw_idle() + else: + if vis: + annot.set_visible(False) + self.figure_.canvas.draw_idle() + + self.figure_.canvas.mpl_connect("motion_notify_event", hover) + info_pos_label = ( f" (Positive label: {self.pos_label})" if self.pos_label is not None else "" ) @@ -580,7 +617,7 @@ def from_predictions( y_true, y_score, sample_weight=sample_weight, pos_label=pos_label, name=name ) - fpr, tpr, _ = roc_curve( + fpr, tpr, thresholds = roc_curve( y_true, y_score, pos_label=pos_label, @@ -592,6 +629,7 @@ def from_predictions( viz = cls( fpr=fpr, tpr=tpr, + thresholds=thresholds, roc_auc=roc_auc, name=name, pos_label=pos_label_validated, From 30ad5c4ec25a8a694867b65cc5be8e5f6354ada9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=A9mie=20du=20Boisberranger?= Date: Thu, 31 Jul 2025 11:45:39 +0200 Subject: [PATCH 02/11] wip --- sklearn/metrics/_plot/roc_curve.py | 56 +++++--------- sklearn/utils/_plotting.py | 118 ++++++++++++++++++++++++++++- 2 files changed, 136 insertions(+), 38 deletions(-) diff --git a/sklearn/metrics/_plot/roc_curve.py b/sklearn/metrics/_plot/roc_curve.py index 97aee3811da1c..21c4457250a15 100644 --- a/sklearn/metrics/_plot/roc_curve.py +++ b/sklearn/metrics/_plot/roc_curve.py @@ -7,6 +7,7 @@ from ...utils import _safe_indexing from ...utils._plotting import ( _BinaryClassifierCurveDisplayMixin, + _LineTooltipMixin, _check_param_lengths, _convert_to_list_leaving_none, _deprecate_estimator_name, @@ -18,7 +19,7 @@ from .._ranking import auc, roc_curve -class RocCurveDisplay(_BinaryClassifierCurveDisplayMixin): +class RocCurveDisplay(_BinaryClassifierCurveDisplayMixin, _LineTooltipMixin): """ROC Curve visualization. It is recommended to use @@ -49,6 +50,13 @@ class RocCurveDisplay(_BinaryClassifierCurveDisplayMixin): .. versionchanged:: 1.7 Now accepts a list for plotting multiple curves. + threshold : ndarray or list of ndarrays, default=None + The thresholds at which the fpr and tpr have been computed. Each ndarray should + contain values for a single curve. If plotting multiple curves, list should be + of same length as `fpr` and `tpr`. + Only used to display the threshold values along the curve as a tooltip. If None, + only the fpr and tpr values are displayed. + roc_auc : float or list of floats, default=None Area under ROC curve, used for labeling each curve in the legend. If plotting multiple curves, should be a list of the same length as `fpr` @@ -162,7 +170,7 @@ def _validate_plot_params(self, *, ax, name): optional=optional, class_name="RocCurveDisplay", ) - return fpr, tpr, thresholds,roc_auc, name + return fpr, tpr, thresholds, roc_auc, name def plot( self, @@ -236,7 +244,9 @@ def plot( display : :class:`~sklearn.metrics.RocCurveDisplay` Object that stores computed values. """ - fpr, tpr, thresholds, roc_auc, name = self._validate_plot_params(ax=ax, name=name) + fpr, tpr, thresholds, roc_auc, name = self._validate_plot_params( + ax=ax, name=name + ) n_curves = len(fpr) if not isinstance(curve_kwargs, list) and n_curves > 1: if roc_auc: @@ -276,39 +286,9 @@ def plot( if len(self.line_) == 1: self.line_ = self.line_[0] - annot = self.ax_.annotate( - "", xy=(0,0), xytext=(20,20),textcoords="offset points", bbox=dict(boxstyle="round", fc="w"), arrowprops=dict(arrowstyle="->") + self._add_line_tooltip( + x_label="FPR", y_label="TPR", t_label="threshold", t_vals=thresholds ) - annot.set_visible(False) - - def update_annot(ind): - x,y = self.line_.get_data() - annot.xy = (x[ind["ind"][0]], y[ind["ind"][0]]) - - # Find the index of the closest point in fpr to x[ind['ind'][0]] - # use it to get the corresponding threshold - idx = np.argmin(np.abs(self.fpr - x[ind["ind"][0]])) - - text = f"FPR: {x[ind['ind'][0]]:.2f}, TPR: {y[ind['ind'][0]]:.2f}, Threshold: {self.thresholds[idx]:.2f}" - # text = f"Threshold: {self.thresholds[idx]:.2f}" - - annot.set_text(text) - annot.get_bbox_patch().set_alpha(0.4) - - def hover(event): - vis = annot.get_visible() - if event.inaxes == self.ax_: - cont, ind = self.line_.contains(event) - if cont: - update_annot(ind) - annot.set_visible(True) - self.figure_.canvas.draw_idle() - else: - if vis: - annot.set_visible(False) - self.figure_.canvas.draw_idle() - - self.figure_.canvas.mpl_connect("motion_notify_event", hover) info_pos_label = ( f" (Positive label: {self.pos_label})" if self.pos_label is not None else "" @@ -769,7 +749,7 @@ def from_cv_results( pos_label=pos_label, ) - fpr_folds, tpr_folds, auc_folds = [], [], [] + fpr_folds, tpr_folds, thresholds_folds, auc_folds = [], [], [], [] for estimator, test_indices in zip( cv_results["estimator"], cv_results["indices"]["test"] ): @@ -785,7 +765,7 @@ def from_cv_results( if sample_weight is None else _safe_indexing(sample_weight, test_indices) ) - fpr, tpr, _ = roc_curve( + fpr, tpr, thresholds = roc_curve( y_true, y_pred, pos_label=pos_label_, @@ -796,11 +776,13 @@ def from_cv_results( fpr_folds.append(fpr) tpr_folds.append(tpr) + thresholds_folds.append(thresholds) auc_folds.append(roc_auc) viz = cls( fpr=fpr_folds, tpr=tpr_folds, + thresholds=thresholds_folds, roc_auc=auc_folds, name=name, pos_label=pos_label_, diff --git a/sklearn/utils/_plotting.py b/sklearn/utils/_plotting.py index e4447978df78f..c864189ad1c7e 100644 --- a/sklearn/utils/_plotting.py +++ b/sklearn/utils/_plotting.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: BSD-3-Clause import warnings from collections.abc import Mapping - +from functools import partial import numpy as np from . import check_consistent_length @@ -250,6 +250,122 @@ def _validate_curve_kwargs( return curve_kwargs_ +class _LineTooltipMixin: + """Mixin class to add a tooltip to a line in a plot. + + The tooltip displays 2 to 3 informations: the x value, the y value and an optional + t value for parametric curves (x(t), y(t)). + """ + + def _add_line_tooltip(self, *, x_label, y_label, t_label=None, t_vals=None): + """Create the line tooltip and connect it to a mouse hover event. + + Parameters + ---------- + x_label : str + The label for the x value. + + y_label : str + The label for the y value. + + t_label : str, default=None + The label for the parameter value. + + t_vals : list of ndarrays, default=None + The parameter values along the line. + """ + self._x_label_short = x_label + self._y_label_short = y_label + self._t_label = t_label + + self.line_tooltip_ = self.ax_.annotate( + text="", + xy=(0, 0), + xytext=(20, 20), + textcoords="offset points", + fontsize="small", + bbox=dict(boxstyle="round", fc=(0.8, 0.8, 0.8, 0.8)), + arrowprops=dict(arrowstyle="-"), + zorder=10, # bring to front + ) + self.line_tooltip_.set_visible(False) + + # Set an attribute on the axes annotation to be able to keep only one visible + # at a time when there are multiple display instances that share an axes. + setattr(self.line_tooltip_, "_skl_line_tooltip", True) + + self.ax_.figure.canvas.mpl_connect( + "motion_notify_event", partial(self._hover, t_vals=t_vals) + ) + + def _hover(self, event, t_vals): + """Callback for the mouse over event. + + Parameters + ---------- + event : matplotlib event + the event triggering the callback. + + t_vals : list of ndarrays or None + The parameter values along the line. + """ + if event.inaxes != self.ax_: + return + + lines = _convert_to_list_leaving_none(self.line_) + for i, line in enumerate(lines): + contains, indexes = line.contains(event) + # stop at the first line on which the event occured + if contains: + idx = indexes["ind"][0] + x_vals, y_vals = line.get_data() + t = t_vals[i][idx] if t_vals is not None else t_vals + self._update_line_tooltip(x_vals[idx], y_vals[idx], t) + self.line_tooltip_.set_visible(True) + break + else: # hide the tooltip if the event didn't occur on any line + if self.line_tooltip_.get_visible(): + self.line_tooltip_.set_visible(False) + + found_visible = False + for child in self.ax_.get_children(): + if hasattr(child, "_skl_line_tooltip") and child.get_visible(): + if not found_visible: + found_visible = True + else: + child.set_visible(False) + + self.figure_.canvas.draw_idle() + + def _update_line_tooltip(self, x, y, t): + """Update the text in the line tooltip. + + Parameters + ---------- + x : float + the x value to display. + + y : float + the y value to display. + + t : float or None + the parameter value to display. + """ + text = f"{self._x_label_short}: {x:.2f}, {self._y_label_short}: {y:.2f}" + if t is not None: + text += f", {self._t_label}: {t:.2f}" + + # Compute an offset for the text depending on the quadrant where the cursor is + # to keep the tooltip somewhat inside the axes. + xlim, ylim = self.ax_.get_xlim(), self.ax_.get_ylim() + x_offset = 20 if x < (xlim[0] + xlim[1]) / 2 else -160 + y_offset = 20 if y < (ylim[0] + ylim[1]) / 2 else -30 + + self.line_tooltip_.xyann = (x_offset, y_offset) + self.line_tooltip_.xy = (x, y) + self.line_tooltip_.set_text(text) + + def _validate_score_name(score_name, scoring, negate_score): """Validate the `score_name` parameter. From 4454932d0718cf971e3de81332e4bee68a264c31 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=A9mie=20du=20Boisberranger?= Date: Thu, 31 Jul 2025 17:07:18 +0200 Subject: [PATCH 03/11] add tests and doc --- sklearn/metrics/_plot/roc_curve.py | 2 +- .../_plot/tests/test_roc_curve_display.py | 27 +++- sklearn/utils/_plotting.py | 1 + sklearn/utils/tests/test_plotting.py | 130 ++++++++++++++++++ 4 files changed, 157 insertions(+), 3 deletions(-) diff --git a/sklearn/metrics/_plot/roc_curve.py b/sklearn/metrics/_plot/roc_curve.py index 21c4457250a15..1c8c0dd1f938b 100644 --- a/sklearn/metrics/_plot/roc_curve.py +++ b/sklearn/metrics/_plot/roc_curve.py @@ -7,12 +7,12 @@ from ...utils import _safe_indexing from ...utils._plotting import ( _BinaryClassifierCurveDisplayMixin, - _LineTooltipMixin, _check_param_lengths, _convert_to_list_leaving_none, _deprecate_estimator_name, _deprecate_y_pred_parameter, _despine, + _LineTooltipMixin, _validate_style_kwargs, ) from ...utils._response import _get_response_values_binary diff --git a/sklearn/metrics/_plot/tests/test_roc_curve_display.py b/sklearn/metrics/_plot/tests/test_roc_curve_display.py index 33461456d8e84..31b3de3aa2a8c 100644 --- a/sklearn/metrics/_plot/tests/test_roc_curve_display.py +++ b/sklearn/metrics/_plot/tests/test_roc_curve_display.py @@ -16,6 +16,7 @@ from sklearn.preprocessing import StandardScaler from sklearn.utils import _safe_indexing, shuffle from sklearn.utils._response import _get_response_values_binary +from sklearn.utils.tests.test_plotting import _simulate_mouse_event @pytest.fixture(scope="module") @@ -201,6 +202,7 @@ def test_validate_plot_params(pyplot): """Check `_validate_plot_params` returns the correct variables.""" fpr = np.array([0, 0.5, 1]) tpr = [np.array([0, 0.5, 1])] + thresholds = np.array([1, 0.5, 0]) roc_auc = None name = "test_curve" @@ -208,18 +210,21 @@ def test_validate_plot_params(pyplot): display = RocCurveDisplay( fpr=fpr, tpr=tpr, + thresholds=thresholds, roc_auc=roc_auc, name=name, pos_label=None, ) - fpr_out, tpr_out, roc_auc_out, name_out = display._validate_plot_params( - ax=None, name=None + fpr_out, tpr_out, thresholds_out, roc_auc_out, name_out = ( + display._validate_plot_params(ax=None, name=None) ) assert isinstance(fpr_out, list) assert isinstance(tpr_out, list) + assert isinstance(thresholds_out, list) assert len(fpr_out) == 1 assert len(tpr_out) == 1 + assert len(thresholds_out) == 1 assert roc_auc_out is None assert name_out == ["test_curve"] @@ -980,3 +985,21 @@ def test_plot_roc_curve_despine(pyplot, data_binary, despine, constructor_name): if despine: for s in ["bottom", "left"]: assert display.ax_.spines[s].get_bounds() == (0, 1) + + +def test_roc_curve_display_line_tooltip(pyplot, data_binary): + """Test the line tooltip on a Roc curve.""" + X, y = data_binary + + lr = LogisticRegression().fit(X, y) + display = RocCurveDisplay.from_estimator(lr, X, y) + + assert hasattr(display, "line_tooltip_") + + # simulate a mouse event on the line + x, y = display.fpr[10], display.tpr[10] + _simulate_mouse_event(display, x, y) + + assert display.line_tooltip_.get_visible() + text = display.line_tooltip_.get_text() + assert all(name in text for name in ("FPR", "TPR", "threshold")) diff --git a/sklearn/utils/_plotting.py b/sklearn/utils/_plotting.py index c864189ad1c7e..2c73899004c48 100644 --- a/sklearn/utils/_plotting.py +++ b/sklearn/utils/_plotting.py @@ -3,6 +3,7 @@ import warnings from collections.abc import Mapping from functools import partial + import numpy as np from . import check_consistent_length diff --git a/sklearn/utils/tests/test_plotting.py b/sklearn/utils/tests/test_plotting.py index db2f797ac2547..e315b171c4160 100644 --- a/sklearn/utils/tests/test_plotting.py +++ b/sklearn/utils/tests/test_plotting.py @@ -7,6 +7,7 @@ _deprecate_estimator_name, _despine, _interval_max_min_ratio, + _LineTooltipMixin, _validate_score_name, _validate_style_kwargs, ) @@ -542,3 +543,132 @@ def test_deprecate_estimator_name(estimator_name, name): ) with pytest.raises(ValueError, match=error_message): _deprecate_estimator_name(estimator_name, name, version) + + +class _TestCurveDisplay(_LineTooltipMixin): + def __init__(self, n_curves=1, x_max=1.0, y_max=1.0, parametric=False): + self.n_curves = n_curves + self.x_max = x_max + self.y_max = y_max + self.parametric = parametric + + def plot(self, ax=None): + import matplotlib.pyplot as plt + + if ax is None: + _, ax = plt.subplots() + self.ax_, self.figure_ = ax, ax.figure + + self.line_ = [] + self.t_vals, self.x_vals, self.y_vals = [], [], [] + x_ends = np.linspace( + self.x_max if self.n_curves == 1 else self.x_max / 2, + self.x_max, + self.n_curves, + ) + for i, x_end in enumerate(x_ends): + self.t_vals.append(t := np.linspace(0, 1, 100)) + self.x_vals.append(x := x_end * t) + self.y_vals.append(y := self.y_max * t) + self.line_.extend(self.ax_.plot(x, y, label=f"curve {i + 1}")) + + params = {"t_label": "t", "t_vals": self.t_vals} if self.parametric else {} + self._add_line_tooltip(x_label="x", y_label="y", **params) + + return self + + +def _simulate_mouse_event(display, x, y): + """Emit a mouse location event at data coordinates (x, y).""" + import matplotlib as mpl + + # needed to update xlim and ylim before using transData since we're not calling + # plt.show(). See https://github.com/matplotlib/matplotlib/issues/28075. + display.ax_.autoscale_view() + + x_display, y_display = display.ax_.transData.transform((x, y)) + event = mpl.backend_bases.MouseEvent( + name="motion_notify_event", + canvas=display.figure_.canvas, + x=x_display, + y=y_display, + ) + display.ax_.figure.canvas.callbacks.process("motion_notify_event", event) + + +@pytest.mark.parametrize("n_curves", [1, 5]) +@pytest.mark.parametrize("x_max", [1, 10]) +@pytest.mark.parametrize("y_max", [1, 10]) +@pytest.mark.parametrize("parametric", [True, False]) +def test_line_tooltip(n_curves, x_max, y_max, parametric, pyplot): + """Test the behavior for the line tooltip.""" + import matplotlib as mpl + + display = _TestCurveDisplay( + n_curves=n_curves, + x_max=x_max, + y_max=y_max, + parametric=parametric, + ).plot() + + assert hasattr(display, "line_tooltip_") + assert isinstance(display.line_tooltip_, mpl.text.Annotation) + assert display.line_tooltip_.get_text() == "" + assert not display.line_tooltip_.get_visible() + + # simulate a mouse event occuring on a line. Take the point in the middle of the + # last curve for instance + x = display.x_vals[-1][50] + y = display.y_vals[-1][50] + _simulate_mouse_event(display, x, y) + + assert display.line_tooltip_.get_visible() + text = display.line_tooltip_.get_text() + assert all(name in text for name in ("x", "y")) + if parametric: + assert "t" in text + + # simulate a second event, not occuring on any line this time. None of these curves + # ever touch the lower right corner + _simulate_mouse_event(display, x_max, 0) + + assert not display.line_tooltip_.get_visible() + + # simulate a third event occuring on several lines at once. (0, 0) belongs to all + # curves. + _simulate_mouse_event(display, 0, 0) + + assert display.line_tooltip_.get_visible() + text = display.line_tooltip_.get_text() + assert all(name in text for name in ("x", "y")) + if parametric: + assert "t" in text + + +@pytest.mark.parametrize("y_max", [1, 10]) +def test_line_tooltip_multiple_displays(pyplot, y_max): + """Test the line tooltip when different displays share the same axes.""" + display = _TestCurveDisplay(x_max=0.5, y_max=y_max).plot() + display2 = _TestCurveDisplay(x_max=1, y_max=y_max).plot(ax=display.ax_) + + assert hasattr(display, "line_tooltip_") + assert hasattr(display2, "line_tooltip_") + assert not display.line_tooltip_.get_visible() + assert not display2.line_tooltip_.get_visible() + + # simulate a mouse event occuring on the first line + _simulate_mouse_event(display, 0.5, y_max) + + assert display.line_tooltip_.get_visible() + assert not display2.line_tooltip_.get_visible() + + # simulate a mouse event occuring on the second line + _simulate_mouse_event(display2, 1, y_max) + + assert not display.line_tooltip_.get_visible() + assert display2.line_tooltip_.get_visible() + + # simulate a mouse event occuring on both lines. Only 1 line tooltip is visible + _simulate_mouse_event(display, 0, 0) + + assert display.line_tooltip_.get_visible() != display2.line_tooltip_.get_visible() From 1ecf20b5de86f4314fd31ebe2ee74529458cade8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=A9mie=20du=20Boisberranger?= Date: Thu, 31 Jul 2025 17:22:37 +0200 Subject: [PATCH 04/11] fix docstring --- sklearn/metrics/_plot/roc_curve.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/metrics/_plot/roc_curve.py b/sklearn/metrics/_plot/roc_curve.py index 1c8c0dd1f938b..edfb5018091fa 100644 --- a/sklearn/metrics/_plot/roc_curve.py +++ b/sklearn/metrics/_plot/roc_curve.py @@ -50,7 +50,7 @@ class RocCurveDisplay(_BinaryClassifierCurveDisplayMixin, _LineTooltipMixin): .. versionchanged:: 1.7 Now accepts a list for plotting multiple curves. - threshold : ndarray or list of ndarrays, default=None + thresholds : ndarray or list of ndarrays, default=None The thresholds at which the fpr and tpr have been computed. Each ndarray should contain values for a single curve. If plotting multiple curves, list should be of same length as `fpr` and `tpr`. From 4006c57858e157bf9b454181ae60898ed0af4d8c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=A9mie=20du=20Boisberranger?= Date: Thu, 31 Jul 2025 17:54:35 +0200 Subject: [PATCH 05/11] comment --- sklearn/utils/_plotting.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sklearn/utils/_plotting.py b/sklearn/utils/_plotting.py index 2c73899004c48..259d4a3f1a3aa 100644 --- a/sklearn/utils/_plotting.py +++ b/sklearn/utils/_plotting.py @@ -328,6 +328,9 @@ def _hover(self, event, t_vals): if self.line_tooltip_.get_visible(): self.line_tooltip_.set_visible(False) + # Loop through all line tooltips contained the axes and hide all but one. + # This is necessary in addition to the loop over the display lines above to + # account for the case when multiple display instances share the same axes. found_visible = False for child in self.ax_.get_children(): if hasattr(child, "_skl_line_tooltip") and child.get_visible(): From 26b2b1210141d80b255abf1a635c309a883db38c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=A9mie=20du=20Boisberranger?= Date: Thu, 31 Jul 2025 18:15:47 +0200 Subject: [PATCH 06/11] typos --- sklearn/utils/_plotting.py | 6 +++--- sklearn/utils/tests/test_plotting.py | 12 ++++++------ 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/sklearn/utils/_plotting.py b/sklearn/utils/_plotting.py index 259d4a3f1a3aa..b6e2592e2dd1d 100644 --- a/sklearn/utils/_plotting.py +++ b/sklearn/utils/_plotting.py @@ -254,8 +254,8 @@ def _validate_curve_kwargs( class _LineTooltipMixin: """Mixin class to add a tooltip to a line in a plot. - The tooltip displays 2 to 3 informations: the x value, the y value and an optional - t value for parametric curves (x(t), y(t)). + The tooltip displays 2 to 3 info: the x value, the y value and an optional t value + for parametric curves (x(t), y(t)). """ def _add_line_tooltip(self, *, x_label, y_label, t_label=None, t_vals=None): @@ -316,7 +316,7 @@ def _hover(self, event, t_vals): lines = _convert_to_list_leaving_none(self.line_) for i, line in enumerate(lines): contains, indexes = line.contains(event) - # stop at the first line on which the event occured + # stop at the first line on which the event occurred if contains: idx = indexes["ind"][0] x_vals, y_vals = line.get_data() diff --git a/sklearn/utils/tests/test_plotting.py b/sklearn/utils/tests/test_plotting.py index e315b171c4160..c4bbf8d52b2b6 100644 --- a/sklearn/utils/tests/test_plotting.py +++ b/sklearn/utils/tests/test_plotting.py @@ -616,7 +616,7 @@ def test_line_tooltip(n_curves, x_max, y_max, parametric, pyplot): assert display.line_tooltip_.get_text() == "" assert not display.line_tooltip_.get_visible() - # simulate a mouse event occuring on a line. Take the point in the middle of the + # simulate a mouse event occurring on a line. Take the point in the middle of the # last curve for instance x = display.x_vals[-1][50] y = display.y_vals[-1][50] @@ -628,13 +628,13 @@ def test_line_tooltip(n_curves, x_max, y_max, parametric, pyplot): if parametric: assert "t" in text - # simulate a second event, not occuring on any line this time. None of these curves + # simulate a second event, not occurring on any line this time. None of these curves # ever touch the lower right corner _simulate_mouse_event(display, x_max, 0) assert not display.line_tooltip_.get_visible() - # simulate a third event occuring on several lines at once. (0, 0) belongs to all + # simulate a third event occurring on several lines at once. (0, 0) belongs to all # curves. _simulate_mouse_event(display, 0, 0) @@ -656,19 +656,19 @@ def test_line_tooltip_multiple_displays(pyplot, y_max): assert not display.line_tooltip_.get_visible() assert not display2.line_tooltip_.get_visible() - # simulate a mouse event occuring on the first line + # simulate a mouse event occurring on the first line _simulate_mouse_event(display, 0.5, y_max) assert display.line_tooltip_.get_visible() assert not display2.line_tooltip_.get_visible() - # simulate a mouse event occuring on the second line + # simulate a mouse event occurring on the second line _simulate_mouse_event(display2, 1, y_max) assert not display.line_tooltip_.get_visible() assert display2.line_tooltip_.get_visible() - # simulate a mouse event occuring on both lines. Only 1 line tooltip is visible + # simulate a mouse event occurring on both lines. Only 1 line tooltip is visible _simulate_mouse_event(display, 0, 0) assert display.line_tooltip_.get_visible() != display2.line_tooltip_.get_visible() From d6d699871a3774b184d1765ef0b22b5cc42693c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=A9mie=20du=20Boisberranger?= Date: Mon, 4 Aug 2025 18:20:09 +0200 Subject: [PATCH 07/11] try %matplotlib widget in an example --- build_tools/circle/doc_environment.yml | 1 + build_tools/circle/doc_linux-64_conda.lock | 28 +++++++++++++++---- .../update_environments_and_lock_files.py | 1 + .../plot_cost_sensitive_learning.py | 3 ++ 4 files changed, 28 insertions(+), 5 deletions(-) diff --git a/build_tools/circle/doc_environment.yml b/build_tools/circle/doc_environment.yml index dcf3f0b0db699..c1c9095c6b475 100644 --- a/build_tools/circle/doc_environment.yml +++ b/build_tools/circle/doc_environment.yml @@ -39,6 +39,7 @@ dependencies: - towncrier - jupyterlite-sphinx - jupyterlite-pyodide-kernel + - ipympl - pip - pip: - sphinxcontrib-sass diff --git a/build_tools/circle/doc_linux-64_conda.lock b/build_tools/circle/doc_linux-64_conda.lock index d179ba70af52c..d258abd0bb584 100644 --- a/build_tools/circle/doc_linux-64_conda.lock +++ b/build_tools/circle/doc_linux-64_conda.lock @@ -1,6 +1,6 @@ # Generated by conda-lock. # platform: linux-64 -# input_hash: 9bc9ca426bc05685148b1ae7e671907e9d514e40b6bb1c8d7c916d4fdc8b70f2 +# input_hash: afdbcc323f4c61ff76aa07208fd051f771b8199561066eb3d776125c9fee92ec @EXPLICIT https://conda.anaconda.org/conda-forge/linux-64/_libgcc_mutex-0.1-conda_forge.tar.bz2#d7c89558ba9fa0495403155b64376d81 https://conda.anaconda.org/conda-forge/noarch/font-ttf-dejavu-sans-mono-2.37-hab24e00_0.tar.bz2#0c96522c6bdaed4b1566d11387caaf45 @@ -107,6 +107,7 @@ https://conda.anaconda.org/conda-forge/linux-64/xcb-util-wm-0.4.2-hb711507_0.con https://conda.anaconda.org/conda-forge/linux-64/xorg-libsm-1.2.6-he73a12e_0.conda#1c74ff8c35dcadf952a16f752ca5aa49 https://conda.anaconda.org/conda-forge/linux-64/xorg-libx11-1.8.12-h4f16b4b_0.conda#db038ce880f100acc74dba10302b5630 https://conda.anaconda.org/conda-forge/noarch/alabaster-1.0.0-pyhd8ed1ab_1.conda#1fd9696649f65fd6611fcdb4ffec738a +https://conda.anaconda.org/conda-forge/noarch/asttokens-3.0.0-pyhd8ed1ab_1.conda#8f587de4bcf981e26228f268df374a9b https://conda.anaconda.org/conda-forge/noarch/attrs-25.3.0-pyh71513ae_0.conda#a10d11958cadc13fdb43df75f8b1903f https://conda.anaconda.org/conda-forge/linux-64/brotli-1.1.0-hb9d3cd8_3.conda#5d08a0ac29e6a5a984817584775d4131 https://conda.anaconda.org/conda-forge/linux-64/brotli-python-1.1.0-py310hf71b8c6_3.conda#63d24a5dd21c738d706f91569dbd1892 @@ -116,14 +117,17 @@ https://conda.anaconda.org/conda-forge/noarch/charset-normalizer-3.4.2-pyhd8ed1a https://conda.anaconda.org/conda-forge/noarch/click-8.2.2-pyh707e725_0.conda#2cc16494e4ce28efc52fb29ec3c348a1 https://conda.anaconda.org/conda-forge/noarch/cloudpickle-3.1.1-pyhd8ed1ab_0.conda#364ba6c9fb03886ac979b482f39ebb92 https://conda.anaconda.org/conda-forge/noarch/colorama-0.4.6-pyhd8ed1ab_1.conda#962b9857ee8e7018c22f2776ffa0b2d7 +https://conda.anaconda.org/conda-forge/noarch/comm-0.2.3-pyhe01879c_0.conda#2da13f2b299d8e1995bafbbe9689a2f7 https://conda.anaconda.org/conda-forge/linux-64/conda-gcc-specs-14.3.0-hb991d5c_4.conda#b6025bc20bf223d68402821f181707fb https://conda.anaconda.org/conda-forge/noarch/cpython-3.10.18-py310hd8ed1ab_0.conda#7004cb3fa62ad44d1cb70f3b080dfc8f https://conda.anaconda.org/conda-forge/noarch/cycler-0.12.1-pyhd8ed1ab_1.conda#44600c4667a319d67dbe0681fc0bc833 https://conda.anaconda.org/conda-forge/linux-64/cyrus-sasl-2.1.28-hd9c7081_0.conda#cae723309a49399d2949362f4ab5c9e4 https://conda.anaconda.org/conda-forge/linux-64/cython-3.1.2-py310had8cdd9_2.conda#be416b1d5ffef48c394cbbb04bc864ae +https://conda.anaconda.org/conda-forge/noarch/decorator-5.2.1-pyhd8ed1ab_0.conda#9ce473d1d1be1cc3810856a48b3fab32 https://conda.anaconda.org/conda-forge/noarch/defusedxml-0.7.1-pyhd8ed1ab_0.tar.bz2#961b3a227b437d82ad7054484cfa71b2 https://conda.anaconda.org/conda-forge/noarch/docutils-0.21.2-pyhd8ed1ab_1.conda#24c1ca34138ee57de72a943237cde4cc https://conda.anaconda.org/conda-forge/noarch/execnet-2.1.1-pyhd8ed1ab_1.conda#a71efeae2c160f6789900ba2631a2c90 +https://conda.anaconda.org/conda-forge/noarch/executing-2.2.0-pyhd8ed1ab_0.conda#81d30c08f9a3e556e8ca9e124b044d14 https://conda.anaconda.org/conda-forge/linux-64/gcc_linux-64-14.3.0-h1382650_11.conda#2e650506e6371ac4289c9bf7fc207f3b https://conda.anaconda.org/conda-forge/linux-64/gfortran_impl_linux-64-14.3.0-h7db7018_4.conda#4cb71ecc31f139f8bf96963c53b5b8a1 https://conda.anaconda.org/conda-forge/linux-64/gxx_impl_linux-64-14.3.0-he663afc_4.conda#1f7b059bae1fc5e72ae23883e04abc48 @@ -134,6 +138,7 @@ https://conda.anaconda.org/conda-forge/noarch/imagesize-1.4.1-pyhd8ed1ab_0.tar.b https://conda.anaconda.org/conda-forge/noarch/iniconfig-2.0.0-pyhd8ed1ab_1.conda#6837f3eff7dcea42ecd714ce1ac2b108 https://conda.anaconda.org/conda-forge/noarch/json5-0.12.0-pyhd8ed1ab_0.conda#56275442557b3b45752c10980abfe2db https://conda.anaconda.org/conda-forge/linux-64/jsonpointer-3.0.0-py310hff52083_1.conda#ce614a01b0aee1b29cee13d606bcb5d5 +https://conda.anaconda.org/conda-forge/noarch/jupyterlab_widgets-3.0.15-pyhd8ed1ab_0.conda#ad100d215fad890ab0ee10418f36876f https://conda.anaconda.org/conda-forge/linux-64/kiwisolver-1.4.8-py310h3788b33_1.conda#b70dd76da5231e6073fd44c42a1d78c5 https://conda.anaconda.org/conda-forge/noarch/lark-1.2.2-pyhd8ed1ab_1.conda#3a8063b25e603999188ed4bbf3485404 https://conda.anaconda.org/conda-forge/linux-64/lcms2-2.17-h717163a_0.conda#000e85703f0fd9594c81710dd5066471 @@ -141,9 +146,9 @@ https://conda.anaconda.org/conda-forge/linux-64/libavif16-1.3.0-h766b0b6_0.conda https://conda.anaconda.org/conda-forge/linux-64/libblas-3.9.0-32_h59b9bed_openblas.conda#2af9f3d5c2e39f417ce040f5a35c40c6 https://conda.anaconda.org/conda-forge/linux-64/libcups-2.3.3-hb8b1518_5.conda#d4a250da4737ee127fb1fa6452a9002e https://conda.anaconda.org/conda-forge/linux-64/libfreetype-2.13.3-ha770c72_1.conda#51f5be229d83ecd401fb369ab96ae669 -https://conda.anaconda.org/conda-forge/linux-64/libglib-2.84.2-h3618099_0.conda#072ab14a02164b7c0c089055368ff776 +https://conda.anaconda.org/conda-forge/linux-64/libglib-2.84.3-hf39c6af_0.conda#467f23819b1ea2b89c3fc94d65082301 https://conda.anaconda.org/conda-forge/linux-64/libglx-1.7.0-ha4b6fd6_2.conda#c8013e438185f33b13814c5c488acd5c -https://conda.anaconda.org/conda-forge/linux-64/libxml2-2.13.8-h4bc477f_0.conda#14dbe05b929e329dbaa6f2d0aa19466d +https://conda.anaconda.org/conda-forge/linux-64/libxml2-2.13.8-h04c0eec_1.conda#10bcbd05e1c1c9d652fccb42b776a9fa https://conda.anaconda.org/conda-forge/linux-64/markupsafe-3.0.2-py310h89163eb_1.conda#8ce3f0332fd6de0d737e2911d329523f https://conda.anaconda.org/conda-forge/noarch/mdurl-0.1.2-pyhd8ed1ab_1.conda#592132998493b3ff25fd7479396e8351 https://conda.anaconda.org/conda-forge/noarch/meson-1.8.3-pyhe01879c_0.conda#ed40b34242ec6d216605db54d19c6df5 @@ -151,15 +156,18 @@ https://conda.anaconda.org/conda-forge/noarch/munkres-1.1.4-pyhd8ed1ab_1.conda#3 https://conda.anaconda.org/conda-forge/noarch/narwhals-2.0.1-pyhe01879c_0.conda#5f0dea40791cecf0f82882b9eea7f7c1 https://conda.anaconda.org/conda-forge/noarch/networkx-3.4.2-pyh267e887_2.conda#fd40bf7f7f4bc4b647dc8512053d9873 https://conda.anaconda.org/conda-forge/linux-64/openblas-0.3.30-pthreads_h6ec200e_1.conda#611fcf119d77a78439794c43f7667664 -https://conda.anaconda.org/conda-forge/linux-64/openjpeg-2.5.3-h5fbd93e_0.conda#9e5816bc95d285c115a3ebc2f8563564 +https://conda.anaconda.org/conda-forge/linux-64/openjpeg-2.5.3-h55fea9a_1.conda#01243c4aaf71bde0297966125aea4706 https://conda.anaconda.org/conda-forge/noarch/packaging-25.0-pyh29332c3_1.conda#58335b26c38bf4a20f399384c33cbcf9 https://conda.anaconda.org/conda-forge/noarch/pandocfilters-1.5.0-pyhd8ed1ab_0.tar.bz2#457c2c8c08e54905d6954e79cb5b5db9 +https://conda.anaconda.org/conda-forge/noarch/parso-0.8.4-pyhd8ed1ab_1.conda#5c092057b6badd30f75b06244ecd01c9 +https://conda.anaconda.org/conda-forge/noarch/pickleshare-0.7.5-pyhd8ed1ab_1004.conda#11a9d1d09a3615fc07c3faf79bc0b943 https://conda.anaconda.org/conda-forge/noarch/pkginfo-1.12.1.2-pyhd8ed1ab_0.conda#dc702b2fae7ebe770aff3c83adb16b63 https://conda.anaconda.org/conda-forge/noarch/platformdirs-4.3.8-pyhe01879c_0.conda#424844562f5d337077b445ec6b1398a7 https://conda.anaconda.org/conda-forge/noarch/pluggy-1.6.0-pyhd8ed1ab_0.conda#7da7ccd349dbf6487a7778579d2bb971 https://conda.anaconda.org/conda-forge/noarch/prometheus_client-0.22.1-pyhd8ed1ab_0.conda#c64b77ccab10b822722904d889fa83b5 https://conda.anaconda.org/conda-forge/linux-64/psutil-7.0.0-py310ha75aee5_0.conda#da7d592394ff9084a23f62a1186451a2 https://conda.anaconda.org/conda-forge/noarch/ptyprocess-0.7.0-pyhd8ed1ab_1.conda#7d9daffbb8d8e0af0f769dbbcd173a54 +https://conda.anaconda.org/conda-forge/noarch/pure_eval-0.2.3-pyhd8ed1ab_1.conda#3bfdfb8dbcdc4af1ae3f9a8eb3948f04 https://conda.anaconda.org/conda-forge/noarch/pycparser-2.22-pyh29332c3_1.conda#12c566707c80111f9799308d9e265aef https://conda.anaconda.org/conda-forge/noarch/pygments-2.19.2-pyhd8ed1ab_0.conda#6b6ece66ebcae2d5f326c77ef2c5a066 https://conda.anaconda.org/conda-forge/noarch/pyparsing-3.2.3-pyhe01879c_2.conda#aa0028616c0750c773698fdc254b2b8d @@ -188,10 +196,12 @@ https://conda.anaconda.org/conda-forge/noarch/typing_extensions-4.14.1-pyhe01879 https://conda.anaconda.org/conda-forge/noarch/typing_utils-0.1.0-pyhd8ed1ab_1.conda#f6d7aa696c67756a650e91e15e88223c https://conda.anaconda.org/conda-forge/linux-64/unicodedata2-16.0.0-py310ha75aee5_0.conda#1d7a4b9202cdd10d56ecdd7f6c347190 https://conda.anaconda.org/conda-forge/noarch/uri-template-1.3.0-pyhd8ed1ab_1.conda#e7cb0f5745e4c5035a460248334af7eb +https://conda.anaconda.org/conda-forge/noarch/wcwidth-0.2.13-pyhd8ed1ab_1.conda#b68980f2495d096e71c7fd9d7ccf63e6 https://conda.anaconda.org/conda-forge/noarch/webcolors-24.11.1-pyhd8ed1ab_0.conda#b49f7b291e15494aafb0a7d74806f337 https://conda.anaconda.org/conda-forge/noarch/webencodings-0.5.1-pyhd8ed1ab_3.conda#2841eb5bfc75ce15e9a0054b98dcd64d https://conda.anaconda.org/conda-forge/noarch/websocket-client-1.8.0-pyhd8ed1ab_1.conda#84f8f77f0a9c6ef401ee96611745da8f https://conda.anaconda.org/conda-forge/noarch/wheel-0.45.1-pyhd8ed1ab_1.conda#75cb7132eb58d97896e173ef12ac9986 +https://conda.anaconda.org/conda-forge/noarch/widgetsnbextension-4.0.14-pyhd8ed1ab_0.conda#2f1f99b13b9d2a03570705030a0b3e7c https://conda.anaconda.org/conda-forge/linux-64/xcb-util-image-0.4.0-hb711507_2.conda#a0901183f08b6c7107aab109733a3c91 https://conda.anaconda.org/conda-forge/linux-64/xkeyboard-config-2.45-hb9d3cd8_0.conda#397a013c2dc5145a70737871aaa87e98 https://conda.anaconda.org/conda-forge/linux-64/xorg-libxext-1.3.6-hb9d3cd8_0.conda#febbab7d15033c913d53c7a2c102309d @@ -215,6 +225,7 @@ https://conda.anaconda.org/conda-forge/linux-64/gxx_linux-64-14.3.0-ha7acb78_11. https://conda.anaconda.org/conda-forge/noarch/h2-4.2.0-pyhd8ed1ab_0.conda#b4754fb1bdcb70c8fd54f918301582c6 https://conda.anaconda.org/conda-forge/noarch/importlib-metadata-8.7.0-pyhe01879c_1.conda#63ccfdc3a3ce25b027b8767eb722fca8 https://conda.anaconda.org/conda-forge/noarch/importlib_resources-6.5.2-pyhd8ed1ab_0.conda#c85c76dc67d75619a92f51dfbce06992 +https://conda.anaconda.org/conda-forge/noarch/jedi-0.19.2-pyhd8ed1ab_1.conda#a4f4c5dc9b80bc50e0d3dc4e6e8f1bd9 https://conda.anaconda.org/conda-forge/noarch/jinja2-3.1.6-pyhd8ed1ab_0.conda#446bd6c8cb26050d528881df495ce646 https://conda.anaconda.org/conda-forge/noarch/joblib-1.5.1-pyhd8ed1ab_0.conda#fb1c14694de51a476ce8636d92b6f42c https://conda.anaconda.org/conda-forge/noarch/jupyter_core-5.8.1-pyh31011fe_0.conda#b7d89d860ebcda28a5303526cdee68ab @@ -226,13 +237,16 @@ https://conda.anaconda.org/conda-forge/linux-64/libllvm20-20.1.8-hecd9e04_0.cond https://conda.anaconda.org/conda-forge/linux-64/libxkbcommon-1.10.0-h65c71a3_0.conda#fedf6bfe5d21d21d2b1785ec00a8889a https://conda.anaconda.org/conda-forge/linux-64/libxslt-1.1.43-h7a3aeb2_0.conda#31059dc620fa57d787e3899ed0421e6d https://conda.anaconda.org/conda-forge/noarch/markdown-it-py-3.0.0-pyhd8ed1ab_1.conda#fee3164ac23dfca50cfcc8b85ddefb81 +https://conda.anaconda.org/conda-forge/noarch/matplotlib-inline-0.1.7-pyhd8ed1ab_1.conda#af6ab708897df59bd6e7283ceab1b56b https://conda.anaconda.org/conda-forge/noarch/memory_profiler-0.61.0-pyhd8ed1ab_1.conda#71abbefb6f3b95e1668cd5e0af3affb9 https://conda.anaconda.org/conda-forge/noarch/mistune-3.1.3-pyh29332c3_0.conda#7ec6576e328bc128f4982cd646eeba85 https://conda.anaconda.org/conda-forge/linux-64/openldap-2.6.10-he970967_0.conda#2e5bf4f1da39c0b32778561c3c4e5878 https://conda.anaconda.org/conda-forge/noarch/overrides-7.7.0-pyhd8ed1ab_1.conda#e51f1e4089cad105b6cac64bd8166587 +https://conda.anaconda.org/conda-forge/noarch/pexpect-4.9.0-pyhd8ed1ab_1.conda#d0d408b1f18883a944376da5cf8101ea https://conda.anaconda.org/conda-forge/linux-64/pillow-11.3.0-py310h7e6dc6c_0.conda#e609995f031bc848be8ea159865e8afc https://conda.anaconda.org/conda-forge/noarch/pip-25.2-pyh8b19718_0.conda#dfce4b2af4bfe90cdcaf56ca0b28ddf5 https://conda.anaconda.org/conda-forge/noarch/plotly-6.2.0-pyhd8ed1ab_0.conda#8a9590843af49b36f37ac3dbcf5fc3d9 +https://conda.anaconda.org/conda-forge/noarch/prompt-toolkit-3.0.51-pyha770c72_0.conda#d17ae9db4dc594267181bd199bf9a551 https://conda.anaconda.org/conda-forge/noarch/pyproject-metadata-0.9.1-pyhd8ed1ab_0.conda#22ae7c6ea81e0c8661ef32168dda929b https://conda.anaconda.org/conda-forge/noarch/python-dateutil-2.9.0.post0-pyhe01879c_2.conda#5b8d21249ff20967101ffa321cab24e8 https://conda.anaconda.org/conda-forge/noarch/python-gil-3.10.18-hd8ed1ab_0.conda#a40e3a920f2c46f94e027bd599b88b17 @@ -240,6 +254,7 @@ https://conda.anaconda.org/conda-forge/linux-64/pyzmq-27.0.1-py310h9a5fd63_0.con https://conda.anaconda.org/conda-forge/noarch/referencing-0.36.2-pyh29332c3_0.conda#9140f1c09dd5489549c6a33931b943c7 https://conda.anaconda.org/conda-forge/noarch/rfc3339-validator-0.1.4-pyhd8ed1ab_1.conda#36de09a8d3e5d5e6f4ee63af49e59706 https://conda.anaconda.org/conda-forge/noarch/rfc3987-syntax-1.1.0-pyhe01879c_1.conda#7234f99325263a5af6d4cd195035e8f2 +https://conda.anaconda.org/conda-forge/noarch/stack_data-0.6.3-pyhd8ed1ab_1.conda#b1b505328da7a6b246787df4b5a49fbc https://conda.anaconda.org/conda-forge/noarch/terminado-0.18.1-pyh0d859eb_0.conda#efba281bbdae5f6b0a1d53c6d4a97c93 https://conda.anaconda.org/conda-forge/noarch/tinycss2-1.4.0-pyhd8ed1ab_0.conda#f1acf5fdefa8300de697982bcb1761c9 https://conda.anaconda.org/conda-forge/noarch/typing-extensions-4.14.1-h4440ef1_0.conda#75be1a943e0a7f99fcf118309092c635 @@ -251,7 +266,7 @@ https://conda.anaconda.org/conda-forge/linux-64/xorg-libxi-1.8.2-hb9d3cd8_0.cond https://conda.anaconda.org/conda-forge/linux-64/xorg-libxrandr-1.5.4-hb9d3cd8_0.conda#2de7f99d6581a4a7adbff607b5c278ca https://conda.anaconda.org/conda-forge/linux-64/xorg-libxxf86vm-1.1.6-hb9d3cd8_0.conda#5efa5fa6243a622445fdfd72aee15efa https://conda.anaconda.org/conda-forge/noarch/_python_abi3_support-1.0-hd8ed1ab_2.conda#aaa2a381ccc56eac91d63b6c1240312f -https://conda.anaconda.org/conda-forge/noarch/anyio-4.9.0-pyh29332c3_0.conda#9749a2c77a7c40d432ea0927662d7e52 +https://conda.anaconda.org/conda-forge/noarch/anyio-4.10.0-pyhe01879c_0.conda#cc2613bfa71dec0eb2113ee21ac9ccbf https://conda.anaconda.org/conda-forge/linux-64/argon2-cffi-bindings-25.1.0-py310h7c4b9e2_0.conda#3fd41ccdb9263ad51cf89b05cade6fb7 https://conda.anaconda.org/conda-forge/noarch/arrow-1.3.0-pyhd8ed1ab_1.conda#46b53236fdd990271b03c3978d4218a9 https://conda.anaconda.org/conda-forge/noarch/beautifulsoup4-4.13.4-pyha770c72_0.conda#9f07c4fc992adb2d6c30da7fab3959a7 @@ -263,6 +278,7 @@ https://conda.anaconda.org/conda-forge/noarch/fqdn-1.5.1-pyhd8ed1ab_1.conda#d354 https://conda.anaconda.org/conda-forge/linux-64/gfortran-14.3.0-he448592_4.conda#6f88c38cdf941173e9aec76f967d4d28 https://conda.anaconda.org/conda-forge/linux-64/gxx-14.3.0-he448592_4.conda#26ccfde67e88b646e57a7e56ce4ef56d https://conda.anaconda.org/conda-forge/noarch/importlib-resources-6.5.2-pyhd8ed1ab_0.conda#e376ea42e9ae40f3278b0f79c9bf9826 +https://conda.anaconda.org/conda-forge/noarch/ipython-8.37.0-pyh8f84b5b_0.conda#177cfa19fe3d74c87a8889286dc64090 https://conda.anaconda.org/conda-forge/noarch/jsonschema-specifications-2025.4.1-pyh29332c3_0.conda#41ff526b1083fde51fbdc93f29282e0e https://conda.anaconda.org/conda-forge/noarch/jupyter_client-8.6.3-pyhd8ed1ab_1.conda#4ebae00eae9705b0c3d6d1018a81d047 https://conda.anaconda.org/conda-forge/noarch/jupyter_server_terminals-0.5.3-pyhd8ed1ab_1.conda#2d983ff1b82a1ccb6f2e9d8784bdd6bd @@ -285,6 +301,7 @@ https://conda.anaconda.org/conda-forge/linux-64/cxx-compiler-1.11.0-hfcd1e18_0.c https://conda.anaconda.org/conda-forge/linux-64/fortran-compiler-1.11.0-h9bea470_0.conda#d5596f445a1273ddc5ea68864c01b69f https://conda.anaconda.org/conda-forge/linux-64/imagecodecs-2025.3.30-py310h4eb8eaf_2.conda#a9c921699d37e862f9bf8dcf9d343838 https://conda.anaconda.org/conda-forge/noarch/imageio-2.37.0-pyhfb79c49_0.conda#b5577bc2212219566578fd5af9993af6 +https://conda.anaconda.org/conda-forge/noarch/ipywidgets-8.1.7-pyhd8ed1ab_0.conda#7c9449eac5975ef2d7753da262a72707 https://conda.anaconda.org/conda-forge/noarch/isoduration-20.11.0-pyhd8ed1ab_1.conda#0b0154421989637d424ccf0f104be51a https://conda.anaconda.org/conda-forge/noarch/jsonschema-4.25.0-pyhe01879c_0.conda#c6e3fd94e058dba67d917f38a11b50ab https://conda.anaconda.org/conda-forge/noarch/jupyterlite-core-0.6.3-pyhe01879c_0.conda#36ebdbf67840763b491045b5a36a2b78 @@ -308,6 +325,7 @@ https://conda.anaconda.org/conda-forge/linux-64/pyamg-5.2.1-py310ha2bacc8_1.cond https://conda.anaconda.org/conda-forge/noarch/requests-2.32.4-pyhd8ed1ab_0.conda#f6082eae112814f1447b56a5e1f6ed05 https://conda.anaconda.org/conda-forge/linux-64/statsmodels-0.14.5-py310haaf2d95_0.conda#92b4b51b83f2cfded298f1b8c7a99e32 https://conda.anaconda.org/conda-forge/noarch/tifffile-2025.5.10-pyhd8ed1ab_0.conda#1fdb801f28bf4987294c49aaa314bf5e +https://conda.anaconda.org/conda-forge/noarch/ipympl-0.9.7-pyhd8ed1ab_1.conda#f5b7e17a56dd953c80c77cbdab378467 https://conda.anaconda.org/conda-forge/noarch/jupyter_events-0.12.0-pyh29332c3_0.conda#f56000b36f09ab7533877e695e4e8cb0 https://conda.anaconda.org/conda-forge/noarch/jupytext-1.17.2-pyh80e38bb_0.conda#6d0652a97ef103de0c77b9c610d0c20d https://conda.anaconda.org/conda-forge/noarch/nbclient-0.10.2-pyhd8ed1ab_0.conda#6bb0d77277061742744176ab555b723c diff --git a/build_tools/update_environments_and_lock_files.py b/build_tools/update_environments_and_lock_files.py index b99e0e8f8d416..758be24970a96 100644 --- a/build_tools/update_environments_and_lock_files.py +++ b/build_tools/update_environments_and_lock_files.py @@ -378,6 +378,7 @@ def remove_from(alist, to_remove): "towncrier", "jupyterlite-sphinx", "jupyterlite-pyodide-kernel", + "ipympl", ], "pip_dependencies": [ "sphinxcontrib-sass", diff --git a/examples/model_selection/plot_cost_sensitive_learning.py b/examples/model_selection/plot_cost_sensitive_learning.py index 6b5b651463b05..893e7cb4449fd 100644 --- a/examples/model_selection/plot_cost_sensitive_learning.py +++ b/examples/model_selection/plot_cost_sensitive_learning.py @@ -38,6 +38,9 @@ # Authors: The scikit-learn developers # SPDX-License-Identifier: BSD-3-Clause +# %% +%matplotlib widget + # %% # Cost-sensitive learning with constant gains and costs # ----------------------------------------------------- From d0735d5f3ed1b48133814f72c3638071917d4d48 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=A9mie=20du=20Boisberranger?= Date: Fri, 8 Aug 2025 12:54:08 +0200 Subject: [PATCH 08/11] try sg config to add magic command --- doc/conf.py | 1 + examples/model_selection/plot_cost_sensitive_learning.py | 3 --- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/doc/conf.py b/doc/conf.py index 71c9ec5bb60c3..999c2e197c3fa 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -745,6 +745,7 @@ def reset_sklearn_config(gallery_conf, fname): "plot_gallery": "True", "recommender": {"enable": True, "n_examples": 4, "min_df": 12}, "reset_modules": ("matplotlib", "seaborn", reset_sklearn_config), + "first_notebook_cell": "%matplotlib widget", } if with_jupyterlite: sphinx_gallery_conf["jupyterlite"] = { diff --git a/examples/model_selection/plot_cost_sensitive_learning.py b/examples/model_selection/plot_cost_sensitive_learning.py index 893e7cb4449fd..6b5b651463b05 100644 --- a/examples/model_selection/plot_cost_sensitive_learning.py +++ b/examples/model_selection/plot_cost_sensitive_learning.py @@ -38,9 +38,6 @@ # Authors: The scikit-learn developers # SPDX-License-Identifier: BSD-3-Clause -# %% -%matplotlib widget - # %% # Cost-sensitive learning with constant gains and costs # ----------------------------------------------------- From e916a77b2dc49a29f6603f43f035526717d4675a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=A9mie=20du=20Boisberranger?= Date: Fri, 8 Aug 2025 16:10:13 +0200 Subject: [PATCH 09/11] add to pr curve --- .../metrics/_plot/precision_recall_curve.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/sklearn/metrics/_plot/precision_recall_curve.py b/sklearn/metrics/_plot/precision_recall_curve.py index 3e64fd776ae16..4dc3e72de9dca 100644 --- a/sklearn/metrics/_plot/precision_recall_curve.py +++ b/sklearn/metrics/_plot/precision_recall_curve.py @@ -6,13 +6,14 @@ from sklearn.metrics._ranking import average_precision_score, precision_recall_curve from sklearn.utils._plotting import ( _BinaryClassifierCurveDisplayMixin, + _LineTooltipMixin, _deprecate_y_pred_parameter, _despine, _validate_style_kwargs, ) -class PrecisionRecallDisplay(_BinaryClassifierCurveDisplayMixin): +class PrecisionRecallDisplay(_BinaryClassifierCurveDisplayMixin, _LineTooltipMixin): """Precision Recall visualization. It is recommended to use @@ -34,6 +35,13 @@ class PrecisionRecallDisplay(_BinaryClassifierCurveDisplayMixin): recall : ndarray Recall values. + thresholds : ndarray or list of ndarrays, default=None + The thresholds at which the fpr and tpr have been computed. Each ndarray should + contain values for a single curve. If plotting multiple curves, list should be + of same length as `fpr` and `tpr`. + Only used to display the threshold values along the curve as a tooltip. If None, + only the fpr and tpr values are displayed. + average_precision : float, default=None Average precision. If None, the average precision is not shown. @@ -117,6 +125,7 @@ def __init__( precision, recall, *, + thresholds=None, average_precision=None, estimator_name=None, pos_label=None, @@ -125,6 +134,7 @@ def __init__( self.estimator_name = estimator_name self.precision = precision self.recall = recall + self.thresholds = thresholds self.average_precision = average_precision self.pos_label = pos_label self.prevalence_pos_label = prevalence_pos_label @@ -206,6 +216,10 @@ def plot( (self.line_,) = self.ax_.plot(self.recall, self.precision, **line_kwargs) + self._add_line_tooltip( + x_label="R", y_label="P", t_label="threshold", t_vals=[self.thresholds] + ) + info_pos_label = ( f" (Positive label: {self.pos_label})" if self.pos_label is not None else "" ) @@ -537,7 +551,7 @@ def from_predictions( y_true, y_score, sample_weight=sample_weight, pos_label=pos_label, name=name ) - precision, recall, _ = precision_recall_curve( + precision, recall, thresholds = precision_recall_curve( y_true, y_score, pos_label=pos_label, @@ -554,6 +568,7 @@ def from_predictions( viz = cls( precision=precision, recall=recall, + thresholds=thresholds, average_precision=average_precision, estimator_name=name, pos_label=pos_label, From aea3c925a860a147bd4cc6fa0be0c33579d3158b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=A9mie=20du=20Boisberranger?= Date: Fri, 8 Aug 2025 16:10:27 +0200 Subject: [PATCH 10/11] subclass Annotation instead --- sklearn/utils/_plotting.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/sklearn/utils/_plotting.py b/sklearn/utils/_plotting.py index 4101f98c64f52..fa16570b5c6a7 100644 --- a/sklearn/utils/_plotting.py +++ b/sklearn/utils/_plotting.py @@ -251,6 +251,16 @@ def _validate_curve_kwargs( return curve_kwargs_ +try: + from matplotlib.text import Annotation + + class LineTooltip(Annotation): + """Custom annotation class to be able identify it among the axes children.""" + +except: + pass + + class _LineTooltipMixin: """Mixin class to add a tooltip to a line in a plot. @@ -289,12 +299,9 @@ def _add_line_tooltip(self, *, x_label, y_label, t_label=None, t_vals=None): arrowprops=dict(arrowstyle="-"), zorder=10, # bring to front ) + self.line_tooltip_.__class__ = LineTooltip self.line_tooltip_.set_visible(False) - # Set an attribute on the axes annotation to be able to keep only one visible - # at a time when there are multiple display instances that share an axes. - setattr(self.line_tooltip_, "_skl_line_tooltip", True) - self.ax_.figure.canvas.mpl_connect( "motion_notify_event", partial(self._hover, t_vals=t_vals) ) @@ -333,7 +340,7 @@ def _hover(self, event, t_vals): # account for the case when multiple display instances share the same axes. found_visible = False for child in self.ax_.get_children(): - if hasattr(child, "_skl_line_tooltip") and child.get_visible(): + if isinstance(child, LineTooltip) and child.get_visible(): if not found_visible: found_visible = True else: From 7f7e30ebb8b7254179c360897611382f7b6c76b4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=A9mie=20du=20Boisberranger?= Date: Fri, 8 Aug 2025 16:12:56 +0200 Subject: [PATCH 11/11] lint --- sklearn/metrics/_plot/precision_recall_curve.py | 2 +- sklearn/utils/_plotting.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/metrics/_plot/precision_recall_curve.py b/sklearn/metrics/_plot/precision_recall_curve.py index 4dc3e72de9dca..8d0785115aa80 100644 --- a/sklearn/metrics/_plot/precision_recall_curve.py +++ b/sklearn/metrics/_plot/precision_recall_curve.py @@ -6,9 +6,9 @@ from sklearn.metrics._ranking import average_precision_score, precision_recall_curve from sklearn.utils._plotting import ( _BinaryClassifierCurveDisplayMixin, - _LineTooltipMixin, _deprecate_y_pred_parameter, _despine, + _LineTooltipMixin, _validate_style_kwargs, ) diff --git a/sklearn/utils/_plotting.py b/sklearn/utils/_plotting.py index fa16570b5c6a7..784a47009d417 100644 --- a/sklearn/utils/_plotting.py +++ b/sklearn/utils/_plotting.py @@ -257,7 +257,7 @@ def _validate_curve_kwargs( class LineTooltip(Annotation): """Custom annotation class to be able identify it among the axes children.""" -except: +except ImportError: pass