Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
- Fix multiple issues in the multiclass setting of :class:`inspection.DecisionBoundaryDisplay`:

- `contour` plotting now correctly shows the decision boundary.
- `cmap` and `colors` are now properly ignored in favor of `multiclass_colors`.
- Linear segmented colormaps are now fully supported.

By :user:`Yunjie Lin <jshn9515>`
82 changes: 44 additions & 38 deletions sklearn/inspection/_plot/decision_boundary.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,17 +221,22 @@ def plot(self, plot_method="contourf", ax=None, xlabel=None, ylabel=None, **kwar
self.surface_ = plot_func(self.xx0, self.xx1, self.response, **kwargs)
else: # self.response.ndim == 3
n_responses = self.response.shape[-1]
if (
isinstance(self.multiclass_colors, str)
or self.multiclass_colors is None
for kwarg in ("cmap", "colors"):
if kwarg in kwargs:
warnings.warn(
f"'{kwarg}' is ignored in favor of 'multiclass_colors' "
"in the multiclass case when the response method is "
"'decision_function' or 'predict_proba'."
)
del kwargs[kwarg]

if self.multiclass_colors is None or isinstance(
self.multiclass_colors, str
):
if isinstance(self.multiclass_colors, str):
cmap = self.multiclass_colors
if self.multiclass_colors is None:
cmap = "tab10" if n_responses <= 10 else "gist_rainbow"
else:
if n_responses <= 10:
cmap = "tab10"
else:
cmap = "gist_rainbow"
cmap = self.multiclass_colors

# Special case for the tab10 and tab20 colormaps that encode a
# discrete set of colors that are easily distinguishable
Expand All @@ -241,40 +246,41 @@ def plot(self, plot_method="contourf", ax=None, xlabel=None, ylabel=None, **kwar
elif cmap == "tab20" and n_responses <= 20:
colors = plt.get_cmap("tab20", 20).colors[:n_responses]
else:
colors = plt.get_cmap(cmap, n_responses).colors
elif isinstance(self.multiclass_colors, str):
colors = colors = plt.get_cmap(
self.multiclass_colors, n_responses
).colors
else:
cmap = plt.get_cmap(cmap, n_responses)
if not hasattr(cmap, "colors"):
# For LinearSegmentedColormap
colors = cmap(np.linspace(0, 1, n_responses))
else:
colors = cmap.colors
elif isinstance(self.multiclass_colors, list):
colors = [mpl.colors.to_rgba(color) for color in self.multiclass_colors]
else:
raise ValueError("'multiclass_colors' must be a list or a str.")

self.multiclass_colors_ = colors
multiclass_cmaps = [
mpl.colors.LinearSegmentedColormap.from_list(
f"colormap_{class_idx}", [(1.0, 1.0, 1.0, 1.0), (r, g, b, 1.0)]
)
for class_idx, (r, g, b, _) in enumerate(colors)
]

self.surface_ = []
for class_idx, cmap in enumerate(multiclass_cmaps):
response = np.ma.array(
self.response[:, :, class_idx],
mask=~(self.response.argmax(axis=2) == class_idx),
if plot_method == "contour":
# Plot only argmax map for contour
class_map = self.response.argmax(axis=2)
self.surface_ = plot_func(
self.xx0, self.xx1, class_map, colors=colors, **kwargs
)
# `cmap` should not be in kwargs
safe_kwargs = kwargs.copy()
if "cmap" in safe_kwargs:
del safe_kwargs["cmap"]
warnings.warn(
"Plotting max class of multiclass 'decision_function' or "
"'predict_proba', thus 'multiclass_colors' used and "
"'cmap' kwarg ignored."
else:
multiclass_cmaps = [
mpl.colors.LinearSegmentedColormap.from_list(
f"colormap_{class_idx}", [(1.0, 1.0, 1.0, 1.0), (r, g, b, 1.0)]
)
for class_idx, (r, g, b, _) in enumerate(colors)
]

self.surface_ = []
for class_idx, cmap in enumerate(multiclass_cmaps):
response = np.ma.array(
self.response[:, :, class_idx],
mask=~(self.response.argmax(axis=2) == class_idx),
)
self.surface_.append(
plot_func(self.xx0, self.xx1, response, cmap=cmap, **kwargs)
)
self.surface_.append(
plot_func(self.xx0, self.xx1, response, cmap=cmap, **safe_kwargs)
)

if xlabel is not None or not ax.get_xlabel():
xlabel = self.xlabel if xlabel is None else xlabel
Expand Down
59 changes: 42 additions & 17 deletions sklearn/inspection/_plot/tests/test_boundary_decision_display.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,10 @@ def test_input_validation_errors(pyplot, kwargs, error_msg, fitted_clf):
@pytest.mark.parametrize(
"kwargs, error_msg",
[
(
{"multiclass_colors": {"dict": "not_list"}},
"'multiclass_colors' must be a list or a str.",
),
({"multiclass_colors": "not_cmap"}, "it must be a valid Matplotlib colormap"),
({"multiclass_colors": ["red", "green"]}, "it must be of the same length"),
(
Expand Down Expand Up @@ -617,6 +621,7 @@ def test_multiclass_plot_max_class(pyplot, response_method):
"multiclass_colors",
[
"plasma",
"Blues",
["red", "green", "blue"],
],
)
Expand All @@ -642,31 +647,51 @@ def test_multiclass_colors_cmap(pyplot, plot_method, multiclass_colors):

if multiclass_colors == "plasma":
colors = mpl.pyplot.get_cmap(multiclass_colors, len(clf.classes_)).colors
elif multiclass_colors == "Blues":
cmap = mpl.pyplot.get_cmap(multiclass_colors, len(clf.classes_))
colors = cmap(np.linspace(0, 1, len(clf.classes_)))
else:
colors = [mpl.colors.to_rgba(color) for color in multiclass_colors]

cmaps = [
mpl.colors.LinearSegmentedColormap.from_list(
f"colormap_{class_idx}", [(1.0, 1.0, 1.0, 1.0), (r, g, b, 1.0)]
)
for class_idx, (r, g, b, _) in enumerate(colors)
]

for idx, quad in enumerate(disp.surface_):
assert quad.cmap == cmaps[idx]
if plot_method != "contour":
cmaps = [
mpl.colors.LinearSegmentedColormap.from_list(
f"colormap_{class_idx}", [(1.0, 1.0, 1.0, 1.0), (r, g, b, 1.0)]
)
for class_idx, (r, g, b, _) in enumerate(colors)
]
for idx, quad in enumerate(disp.surface_):
assert quad.cmap == cmaps[idx]
else:
assert_allclose(disp.surface_.colors, colors)


def test_multiclass_plot_max_class_cmap_kwarg(pyplot):
"""Check `cmap` kwarg ignored when using plotting max multiclass class."""
def test_cmap_and_colors_logic(pyplot):
"""Check the handling logic for `cmap` and `colors`."""
X, y = load_iris_2d_scaled()
clf = LogisticRegression().fit(X, y)

msg = (
"Plotting max class of multiclass 'decision_function' or 'predict_proba', "
"thus 'multiclass_colors' used and 'cmap' kwarg ignored."
)
with pytest.warns(UserWarning, match=msg):
DecisionBoundaryDisplay.from_estimator(clf, X, cmap="viridis")
with pytest.warns(
UserWarning,
match="'cmap' is ignored in favor of 'multiclass_colors'",
):
DecisionBoundaryDisplay.from_estimator(
clf,
X,
multiclass_colors="plasma",
cmap="Blues",
)

with pytest.warns(
UserWarning,
match="'colors' is ignored in favor of 'multiclass_colors'",
):
DecisionBoundaryDisplay.from_estimator(
clf,
X,
multiclass_colors="plasma",
colors="blue",
)


def test_subclass_named_constructors_return_type_is_subclass(pyplot):
Expand Down