Skip to content

Commit 08eebc6

Browse files
authored
FIX Avoid setting legend when labels are None in RocCurveDisplay kwargs (#29727)
1 parent 3b1a8ca commit 08eebc6

File tree

3 files changed

+23
-2
lines changed

3 files changed

+23
-2
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
- :class:`metrics.RocCurveDisplay` will no longer set a legend when
2+
`label` is `None` in both the `line_kwargs` and the `chance_level_kw`.
3+
By :user:`Arturo Amor <ArturoAmorQ>`

sklearn/metrics/_plot/roc_curve.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,10 @@ def plot(
185185
if despine:
186186
_despine(self.ax_)
187187

188-
if "label" in line_kwargs or "label" in chance_level_kw:
188+
if (
189+
line_kwargs.get("label") is not None
190+
or chance_level_kw.get("label") is not None
191+
):
189192
self.ax_.legend(loc="lower right")
190193

191194
return self

sklearn/metrics/_plot/tests/test_roc_curve_display.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -127,12 +127,14 @@ def test_roc_curve_display_plotting(
127127

128128

129129
@pytest.mark.parametrize("plot_chance_level", [True, False])
130+
@pytest.mark.parametrize("label", [None, "Test Label"])
130131
@pytest.mark.parametrize(
131132
"chance_level_kw",
132133
[
133134
None,
134135
{"linewidth": 1, "color": "red", "linestyle": "-", "label": "DummyEstimator"},
135136
{"lw": 1, "c": "red", "ls": "-", "label": "DummyEstimator"},
137+
{"lw": 1, "color": "blue", "ls": "-", "label": None},
136138
],
137139
)
138140
@pytest.mark.parametrize(
@@ -144,6 +146,7 @@ def test_roc_curve_chance_level_line(
144146
data_binary,
145147
plot_chance_level,
146148
chance_level_kw,
149+
label,
147150
constructor_name,
148151
):
149152
"""Check the chance level line plotting behaviour."""
@@ -160,6 +163,7 @@ def test_roc_curve_chance_level_line(
160163
lr,
161164
X,
162165
y,
166+
label=label,
163167
alpha=0.8,
164168
plot_chance_level=plot_chance_level,
165169
chance_level_kw=chance_level_kw,
@@ -168,6 +172,7 @@ def test_roc_curve_chance_level_line(
168172
display = RocCurveDisplay.from_predictions(
169173
y,
170174
y_pred,
175+
label=label,
171176
alpha=0.8,
172177
plot_chance_level=plot_chance_level,
173178
chance_level_kw=chance_level_kw,
@@ -193,7 +198,6 @@ def test_roc_curve_chance_level_line(
193198
assert display.chance_level_.get_linestyle() == "--"
194199
assert display.chance_level_.get_label() == "Chance level (AUC = 0.5)"
195200
elif plot_chance_level:
196-
assert display.chance_level_.get_label() == chance_level_kw["label"]
197201
if "c" in chance_level_kw:
198202
assert display.chance_level_.get_color() == chance_level_kw["c"]
199203
else:
@@ -206,6 +210,17 @@ def test_roc_curve_chance_level_line(
206210
assert display.chance_level_.get_linestyle() == chance_level_kw["ls"]
207211
else:
208212
assert display.chance_level_.get_linestyle() == chance_level_kw["linestyle"]
213+
# Checking for legend behaviour
214+
if label is not None or chance_level_kw.get("label") is not None:
215+
legend = display.ax_.get_legend()
216+
assert legend is not None # Legend should be present if any label is set
217+
legend_labels = [text.get_text() for text in legend.get_texts()]
218+
if label is not None:
219+
assert label in legend_labels
220+
if chance_level_kw.get("label") is not None:
221+
assert chance_level_kw["label"] in legend_labels
222+
else:
223+
assert display.ax_.get_legend() is None
209224

210225

211226
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)