@@ -127,12 +127,14 @@ def test_roc_curve_display_plotting(
127
127
128
128
129
129
@pytest .mark .parametrize ("plot_chance_level" , [True , False ])
130
+ @pytest .mark .parametrize ("label" , [None , "Test Label" ])
130
131
@pytest .mark .parametrize (
131
132
"chance_level_kw" ,
132
133
[
133
134
None ,
134
135
{"linewidth" : 1 , "color" : "red" , "linestyle" : "-" , "label" : "DummyEstimator" },
135
136
{"lw" : 1 , "c" : "red" , "ls" : "-" , "label" : "DummyEstimator" },
137
+ {"lw" : 1 , "color" : "blue" , "ls" : "-" , "label" : None },
136
138
],
137
139
)
138
140
@pytest .mark .parametrize (
@@ -144,6 +146,7 @@ def test_roc_curve_chance_level_line(
144
146
data_binary ,
145
147
plot_chance_level ,
146
148
chance_level_kw ,
149
+ label ,
147
150
constructor_name ,
148
151
):
149
152
"""Check the chance level line plotting behaviour."""
@@ -160,6 +163,7 @@ def test_roc_curve_chance_level_line(
160
163
lr ,
161
164
X ,
162
165
y ,
166
+ label = label ,
163
167
alpha = 0.8 ,
164
168
plot_chance_level = plot_chance_level ,
165
169
chance_level_kw = chance_level_kw ,
@@ -168,6 +172,7 @@ def test_roc_curve_chance_level_line(
168
172
display = RocCurveDisplay .from_predictions (
169
173
y ,
170
174
y_pred ,
175
+ label = label ,
171
176
alpha = 0.8 ,
172
177
plot_chance_level = plot_chance_level ,
173
178
chance_level_kw = chance_level_kw ,
@@ -193,7 +198,6 @@ def test_roc_curve_chance_level_line(
193
198
assert display .chance_level_ .get_linestyle () == "--"
194
199
assert display .chance_level_ .get_label () == "Chance level (AUC = 0.5)"
195
200
elif plot_chance_level :
196
- assert display .chance_level_ .get_label () == chance_level_kw ["label" ]
197
201
if "c" in chance_level_kw :
198
202
assert display .chance_level_ .get_color () == chance_level_kw ["c" ]
199
203
else :
@@ -206,6 +210,17 @@ def test_roc_curve_chance_level_line(
206
210
assert display .chance_level_ .get_linestyle () == chance_level_kw ["ls" ]
207
211
else :
208
212
assert display .chance_level_ .get_linestyle () == chance_level_kw ["linestyle" ]
213
+ # Checking for legend behaviour
214
+ if label is not None or chance_level_kw .get ("label" ) is not None :
215
+ legend = display .ax_ .get_legend ()
216
+ assert legend is not None # Legend should be present if any label is set
217
+ legend_labels = [text .get_text () for text in legend .get_texts ()]
218
+ if label is not None :
219
+ assert label in legend_labels
220
+ if chance_level_kw .get ("label" ) is not None :
221
+ assert chance_level_kw ["label" ] in legend_labels
222
+ else :
223
+ assert display .ax_ .get_legend () is None
209
224
210
225
211
226
@pytest .mark .parametrize (
0 commit comments