Skip to content

Commit fe6960b

Browse files
jshn9515jeremiedbb
andauthored
FIX: Regression in DecisionBoundaryDisplay.from_estimator with colors and plot_method='contour' (#31553)
Co-authored-by: Jérémie du Boisberranger <jeremie@probabl.ai>
1 parent bab34a0 commit fe6960b

File tree

3 files changed

+93
-55
lines changed

3 files changed

+93
-55
lines changed
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
- Fix multiple issues in the multiclass setting of :class:`inspection.DecisionBoundaryDisplay`:
2+
3+
- `contour` plotting now correctly shows the decision boundary.
4+
- `cmap` and `colors` are now properly ignored in favor of `multiclass_colors`.
5+
- Linear segmented colormaps are now fully supported.
6+
7+
By :user:`Yunjie Lin <jshn9515>`

sklearn/inspection/_plot/decision_boundary.py

Lines changed: 44 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -221,17 +221,22 @@ def plot(self, plot_method="contourf", ax=None, xlabel=None, ylabel=None, **kwar
221221
self.surface_ = plot_func(self.xx0, self.xx1, self.response, **kwargs)
222222
else: # self.response.ndim == 3
223223
n_responses = self.response.shape[-1]
224-
if (
225-
isinstance(self.multiclass_colors, str)
226-
or self.multiclass_colors is None
224+
for kwarg in ("cmap", "colors"):
225+
if kwarg in kwargs:
226+
warnings.warn(
227+
f"'{kwarg}' is ignored in favor of 'multiclass_colors' "
228+
"in the multiclass case when the response method is "
229+
"'decision_function' or 'predict_proba'."
230+
)
231+
del kwargs[kwarg]
232+
233+
if self.multiclass_colors is None or isinstance(
234+
self.multiclass_colors, str
227235
):
228-
if isinstance(self.multiclass_colors, str):
229-
cmap = self.multiclass_colors
236+
if self.multiclass_colors is None:
237+
cmap = "tab10" if n_responses <= 10 else "gist_rainbow"
230238
else:
231-
if n_responses <= 10:
232-
cmap = "tab10"
233-
else:
234-
cmap = "gist_rainbow"
239+
cmap = self.multiclass_colors
235240

236241
# Special case for the tab10 and tab20 colormaps that encode a
237242
# discrete set of colors that are easily distinguishable
@@ -241,40 +246,41 @@ def plot(self, plot_method="contourf", ax=None, xlabel=None, ylabel=None, **kwar
241246
elif cmap == "tab20" and n_responses <= 20:
242247
colors = plt.get_cmap("tab20", 20).colors[:n_responses]
243248
else:
244-
colors = plt.get_cmap(cmap, n_responses).colors
245-
elif isinstance(self.multiclass_colors, str):
246-
colors = colors = plt.get_cmap(
247-
self.multiclass_colors, n_responses
248-
).colors
249-
else:
249+
cmap = plt.get_cmap(cmap, n_responses)
250+
if not hasattr(cmap, "colors"):
251+
# For LinearSegmentedColormap
252+
colors = cmap(np.linspace(0, 1, n_responses))
253+
else:
254+
colors = cmap.colors
255+
elif isinstance(self.multiclass_colors, list):
250256
colors = [mpl.colors.to_rgba(color) for color in self.multiclass_colors]
257+
else:
258+
raise ValueError("'multiclass_colors' must be a list or a str.")
251259

252260
self.multiclass_colors_ = colors
253-
multiclass_cmaps = [
254-
mpl.colors.LinearSegmentedColormap.from_list(
255-
f"colormap_{class_idx}", [(1.0, 1.0, 1.0, 1.0), (r, g, b, 1.0)]
256-
)
257-
for class_idx, (r, g, b, _) in enumerate(colors)
258-
]
259-
260-
self.surface_ = []
261-
for class_idx, cmap in enumerate(multiclass_cmaps):
262-
response = np.ma.array(
263-
self.response[:, :, class_idx],
264-
mask=~(self.response.argmax(axis=2) == class_idx),
261+
if plot_method == "contour":
262+
# Plot only argmax map for contour
263+
class_map = self.response.argmax(axis=2)
264+
self.surface_ = plot_func(
265+
self.xx0, self.xx1, class_map, colors=colors, **kwargs
265266
)
266-
# `cmap` should not be in kwargs
267-
safe_kwargs = kwargs.copy()
268-
if "cmap" in safe_kwargs:
269-
del safe_kwargs["cmap"]
270-
warnings.warn(
271-
"Plotting max class of multiclass 'decision_function' or "
272-
"'predict_proba', thus 'multiclass_colors' used and "
273-
"'cmap' kwarg ignored."
267+
else:
268+
multiclass_cmaps = [
269+
mpl.colors.LinearSegmentedColormap.from_list(
270+
f"colormap_{class_idx}", [(1.0, 1.0, 1.0, 1.0), (r, g, b, 1.0)]
271+
)
272+
for class_idx, (r, g, b, _) in enumerate(colors)
273+
]
274+
275+
self.surface_ = []
276+
for class_idx, cmap in enumerate(multiclass_cmaps):
277+
response = np.ma.array(
278+
self.response[:, :, class_idx],
279+
mask=~(self.response.argmax(axis=2) == class_idx),
280+
)
281+
self.surface_.append(
282+
plot_func(self.xx0, self.xx1, response, cmap=cmap, **kwargs)
274283
)
275-
self.surface_.append(
276-
plot_func(self.xx0, self.xx1, response, cmap=cmap, **safe_kwargs)
277-
)
278284

279285
if xlabel is not None or not ax.get_xlabel():
280286
xlabel = self.xlabel if xlabel is None else xlabel

sklearn/inspection/_plot/tests/test_boundary_decision_display.py

Lines changed: 42 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,10 @@ def test_input_validation_errors(pyplot, kwargs, error_msg, fitted_clf):
169169
@pytest.mark.parametrize(
170170
"kwargs, error_msg",
171171
[
172+
(
173+
{"multiclass_colors": {"dict": "not_list"}},
174+
"'multiclass_colors' must be a list or a str.",
175+
),
172176
({"multiclass_colors": "not_cmap"}, "it must be a valid Matplotlib colormap"),
173177
({"multiclass_colors": ["red", "green"]}, "it must be of the same length"),
174178
(
@@ -617,6 +621,7 @@ def test_multiclass_plot_max_class(pyplot, response_method):
617621
"multiclass_colors",
618622
[
619623
"plasma",
624+
"Blues",
620625
["red", "green", "blue"],
621626
],
622627
)
@@ -642,31 +647,51 @@ def test_multiclass_colors_cmap(pyplot, plot_method, multiclass_colors):
642647

643648
if multiclass_colors == "plasma":
644649
colors = mpl.pyplot.get_cmap(multiclass_colors, len(clf.classes_)).colors
650+
elif multiclass_colors == "Blues":
651+
cmap = mpl.pyplot.get_cmap(multiclass_colors, len(clf.classes_))
652+
colors = cmap(np.linspace(0, 1, len(clf.classes_)))
645653
else:
646654
colors = [mpl.colors.to_rgba(color) for color in multiclass_colors]
647655

648-
cmaps = [
649-
mpl.colors.LinearSegmentedColormap.from_list(
650-
f"colormap_{class_idx}", [(1.0, 1.0, 1.0, 1.0), (r, g, b, 1.0)]
651-
)
652-
for class_idx, (r, g, b, _) in enumerate(colors)
653-
]
654-
655-
for idx, quad in enumerate(disp.surface_):
656-
assert quad.cmap == cmaps[idx]
656+
if plot_method != "contour":
657+
cmaps = [
658+
mpl.colors.LinearSegmentedColormap.from_list(
659+
f"colormap_{class_idx}", [(1.0, 1.0, 1.0, 1.0), (r, g, b, 1.0)]
660+
)
661+
for class_idx, (r, g, b, _) in enumerate(colors)
662+
]
663+
for idx, quad in enumerate(disp.surface_):
664+
assert quad.cmap == cmaps[idx]
665+
else:
666+
assert_allclose(disp.surface_.colors, colors)
657667

658668

659-
def test_multiclass_plot_max_class_cmap_kwarg(pyplot):
660-
"""Check `cmap` kwarg ignored when using plotting max multiclass class."""
669+
def test_cmap_and_colors_logic(pyplot):
670+
"""Check the handling logic for `cmap` and `colors`."""
661671
X, y = load_iris_2d_scaled()
662672
clf = LogisticRegression().fit(X, y)
663673

664-
msg = (
665-
"Plotting max class of multiclass 'decision_function' or 'predict_proba', "
666-
"thus 'multiclass_colors' used and 'cmap' kwarg ignored."
667-
)
668-
with pytest.warns(UserWarning, match=msg):
669-
DecisionBoundaryDisplay.from_estimator(clf, X, cmap="viridis")
674+
with pytest.warns(
675+
UserWarning,
676+
match="'cmap' is ignored in favor of 'multiclass_colors'",
677+
):
678+
DecisionBoundaryDisplay.from_estimator(
679+
clf,
680+
X,
681+
multiclass_colors="plasma",
682+
cmap="Blues",
683+
)
684+
685+
with pytest.warns(
686+
UserWarning,
687+
match="'colors' is ignored in favor of 'multiclass_colors'",
688+
):
689+
DecisionBoundaryDisplay.from_estimator(
690+
clf,
691+
X,
692+
multiclass_colors="plasma",
693+
colors="blue",
694+
)
670695

671696

672697
def test_subclass_named_constructors_return_type_is_subclass(pyplot):

0 commit comments

Comments
 (0)