From 1afc804b71ae5f68e6d81f4ca6ac4ed51f01a08e Mon Sep 17 00:00:00 2001 From: jshn9515 Date: Mon, 16 Jun 2025 13:20:40 +0800 Subject: [PATCH 01/17] Fix: DecisionBoundaryDisplay for `contour` and `cmap` (issue scikit-learn#31546 --- sklearn/inspection/_plot/decision_boundary.py | 104 ++++++++++-------- 1 file changed, 58 insertions(+), 46 deletions(-) diff --git a/sklearn/inspection/_plot/decision_boundary.py b/sklearn/inspection/_plot/decision_boundary.py index bc28708d7c488..7922907854a11 100644 --- a/sklearn/inspection/_plot/decision_boundary.py +++ b/sklearn/inspection/_plot/decision_boundary.py @@ -221,60 +221,72 @@ def plot(self, plot_method="contourf", ax=None, xlabel=None, ylabel=None, **kwar self.surface_ = plot_func(self.xx0, self.xx1, self.response, **kwargs) else: # self.response.ndim == 3 n_responses = self.response.shape[-1] - if ( - isinstance(self.multiclass_colors, str) - or self.multiclass_colors is None - ): - if isinstance(self.multiclass_colors, str): - cmap = self.multiclass_colors - else: - if n_responses <= 10: - cmap = "tab10" + if self.multiclass_colors is None: + if "cmap" in kwargs and "colors" in kwargs: + raise ValueError( + "Cannot specify both 'cmap' and 'colors' in kwargs. " + "Please use only one of them." + ) + if "cmap" in kwargs: + cmap = plt.get_cmap(kwargs.pop("cmap"), n_responses) + if not hasattr(cmap, "colors"): + # For LinearSegmentedColormap + colors = cmap(np.linspace(0, 1, n_responses)) else: - cmap = "gist_rainbow" - - # Special case for the tab10 and tab20 colormaps that encode a - # discrete set of colors that are easily distinguishable - # contrary to other colormaps that are continuous. - if cmap == "tab10" and n_responses <= 10: - colors = plt.get_cmap("tab10", 10).colors[:n_responses] - elif cmap == "tab20" and n_responses <= 20: - colors = plt.get_cmap("tab20", 20).colors[:n_responses] + colors = cmap.colors + elif "colors" in kwargs: + colors = mpl.colors.to_rgba(kwargs.pop("colors")) + colors = [colors for _ in range(n_responses)] else: + cmap = "tab10" if n_responses <= 10 else "gist_rainbow" colors = plt.get_cmap(cmap, n_responses).colors - elif isinstance(self.multiclass_colors, str): - colors = colors = plt.get_cmap( - self.multiclass_colors, n_responses - ).colors else: - colors = [mpl.colors.to_rgba(color) for color in self.multiclass_colors] + if "cmap" in kwargs: + warnings.warn("'cmap' is ignored when 'multiclass_colors' is set.") + del kwargs["cmap"] + if "colors" in kwargs: + warnings.warn( + "'colors' is ignored when 'multiclass_colors' is set." + ) + del kwargs["colors"] + if isinstance(self.multiclass_colors, str): + cmap = plt.get_cmap(self.multiclass_colors, n_responses) + if not hasattr(cmap, "colors"): + # For LinearSegmentedColormap + colors = cmap(np.linspace(0, 1, n_responses)) + else: + colors = cmap.colors + elif isinstance(self.multiclass_colors, list): + colors = [ + mpl.colors.to_rgba(color) for color in self.multiclass_colors + ] + else: + raise ValueError("'multiclass_colors' must be a list or a str.") self.multiclass_colors_ = colors - multiclass_cmaps = [ - mpl.colors.LinearSegmentedColormap.from_list( - f"colormap_{class_idx}", [(1.0, 1.0, 1.0, 1.0), (r, g, b, 1.0)] - ) - for class_idx, (r, g, b, _) in enumerate(colors) - ] - - self.surface_ = [] - for class_idx, cmap in enumerate(multiclass_cmaps): - response = np.ma.array( - self.response[:, :, class_idx], - mask=~(self.response.argmax(axis=2) == class_idx), + if plot_method == "contour": + # Plot only argmax map for contour + class_map = self.response.argmax(axis=2) + self.surface_ = plot_func( + self.xx0, self.xx1, class_map, colors=colors, **kwargs ) - # `cmap` should not be in kwargs - safe_kwargs = kwargs.copy() - if "cmap" in safe_kwargs: - del safe_kwargs["cmap"] - warnings.warn( - "Plotting max class of multiclass 'decision_function' or " - "'predict_proba', thus 'multiclass_colors' used and " - "'cmap' kwarg ignored." + else: + multiclass_cmaps = [ + mpl.colors.LinearSegmentedColormap.from_list( + f"colormap_{class_idx}", [(1.0, 1.0, 1.0, 1.0), (r, g, b, 1.0)] + ) + for class_idx, (r, g, b, _) in enumerate(colors) + ] + + self.surface_ = [] + for class_idx, cmap in enumerate(multiclass_cmaps): + response = np.ma.array( + self.response[:, :, class_idx], + mask=~(self.response.argmax(axis=2) == class_idx), + ) + self.surface_.append( + plot_func(self.xx0, self.xx1, response, cmap=cmap, **kwargs) ) - self.surface_.append( - plot_func(self.xx0, self.xx1, response, cmap=cmap, **safe_kwargs) - ) if xlabel is not None or not ax.get_xlabel(): xlabel = self.xlabel if xlabel is None else xlabel From a80ef4c62530229bc4686953a312ab3f4fd5e428 Mon Sep 17 00:00:00 2001 From: jshn9515 Date: Mon, 16 Jun 2025 13:21:40 +0800 Subject: [PATCH 02/17] Update corresponding tests --- .../tests/test_boundary_decision_display.py | 18 +++--------------- 1 file changed, 3 insertions(+), 15 deletions(-) diff --git a/sklearn/inspection/_plot/tests/test_boundary_decision_display.py b/sklearn/inspection/_plot/tests/test_boundary_decision_display.py index 3284f42241fa5..06943d41bd4ec 100644 --- a/sklearn/inspection/_plot/tests/test_boundary_decision_display.py +++ b/sklearn/inspection/_plot/tests/test_boundary_decision_display.py @@ -652,21 +652,9 @@ def test_multiclass_colors_cmap(pyplot, plot_method, multiclass_colors): for class_idx, (r, g, b, _) in enumerate(colors) ] - for idx, quad in enumerate(disp.surface_): - assert quad.cmap == cmaps[idx] - - -def test_multiclass_plot_max_class_cmap_kwarg(pyplot): - """Check `cmap` kwarg ignored when using plotting max multiclass class.""" - X, y = load_iris_2d_scaled() - clf = LogisticRegression().fit(X, y) - - msg = ( - "Plotting max class of multiclass 'decision_function' or 'predict_proba', " - "thus 'multiclass_colors' used and 'cmap' kwarg ignored." - ) - with pytest.warns(UserWarning, match=msg): - DecisionBoundaryDisplay.from_estimator(clf, X, cmap="viridis") + if plot_method != 'contour': + for idx, quad in enumerate(disp.surface_): + assert quad.cmap == cmaps[idx] def test_subclass_named_constructors_return_type_is_subclass(pyplot): From c72e425d3d3b0e6f4d4090dad4cf4df9080fcfd9 Mon Sep 17 00:00:00 2001 From: jshn9515 Date: Mon, 16 Jun 2025 13:55:36 +0800 Subject: [PATCH 03/17] Reformat with ruff --- .../inspection/_plot/tests/test_boundary_decision_display.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/inspection/_plot/tests/test_boundary_decision_display.py b/sklearn/inspection/_plot/tests/test_boundary_decision_display.py index 06943d41bd4ec..9e4b7bf45c12e 100644 --- a/sklearn/inspection/_plot/tests/test_boundary_decision_display.py +++ b/sklearn/inspection/_plot/tests/test_boundary_decision_display.py @@ -652,7 +652,7 @@ def test_multiclass_colors_cmap(pyplot, plot_method, multiclass_colors): for class_idx, (r, g, b, _) in enumerate(colors) ] - if plot_method != 'contour': + if plot_method != "contour": for idx, quad in enumerate(disp.surface_): assert quad.cmap == cmaps[idx] From 1be1a98845dd162853e9bb2183cb5d1b4438c274 Mon Sep 17 00:00:00 2001 From: jshn9515 Date: Sat, 21 Jun 2025 22:54:24 +0800 Subject: [PATCH 04/17] Fix: add handling logic when `colors` is a list --- sklearn/inspection/_plot/decision_boundary.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/sklearn/inspection/_plot/decision_boundary.py b/sklearn/inspection/_plot/decision_boundary.py index 7922907854a11..98c6d792622d3 100644 --- a/sklearn/inspection/_plot/decision_boundary.py +++ b/sklearn/inspection/_plot/decision_boundary.py @@ -235,8 +235,12 @@ def plot(self, plot_method="contourf", ax=None, xlabel=None, ylabel=None, **kwar else: colors = cmap.colors elif "colors" in kwargs: - colors = mpl.colors.to_rgba(kwargs.pop("colors")) - colors = [colors for _ in range(n_responses)] + if isinstance(kwargs["colors"], str): + colors = mpl.colors.to_rgba(kwargs.pop("colors")) + colors = [colors for _ in range(n_responses)] + else: + colors = kwargs.pop("colors") + colors = [mpl.colors.to_rgba(color) for color in colors] else: cmap = "tab10" if n_responses <= 10 else "gist_rainbow" colors = plt.get_cmap(cmap, n_responses).colors From 7ca868f67a6fb23f35d33554f177f98af80a4f53 Mon Sep 17 00:00:00 2001 From: jshn9515 Date: Sat, 21 Jun 2025 22:56:35 +0800 Subject: [PATCH 05/17] ENH: add tests for cmap and colors handling in DecisionBoundaryDisplay --- .../tests/test_boundary_decision_display.py | 121 +++++++++++++++++- 1 file changed, 116 insertions(+), 5 deletions(-) diff --git a/sklearn/inspection/_plot/tests/test_boundary_decision_display.py b/sklearn/inspection/_plot/tests/test_boundary_decision_display.py index 9e4b7bf45c12e..109b8a66e2114 100644 --- a/sklearn/inspection/_plot/tests/test_boundary_decision_display.py +++ b/sklearn/inspection/_plot/tests/test_boundary_decision_display.py @@ -645,16 +645,127 @@ def test_multiclass_colors_cmap(pyplot, plot_method, multiclass_colors): else: colors = [mpl.colors.to_rgba(color) for color in multiclass_colors] - cmaps = [ - mpl.colors.LinearSegmentedColormap.from_list( - f"colormap_{class_idx}", [(1.0, 1.0, 1.0, 1.0), (r, g, b, 1.0)] + if plot_method != "contour": + cmaps = [ + mpl.colors.LinearSegmentedColormap.from_list( + f"colormap_{class_idx}", [(1.0, 1.0, 1.0, 1.0), (r, g, b, 1.0)] + ) + for class_idx, (r, g, b, _) in enumerate(colors) + ] + for idx, quad in enumerate(disp.surface_): + assert quad.cmap == cmaps[idx] + else: + assert_allclose(disp.surface_.colors, colors) + + +def test_cmap_and_colors_logic(pyplot): + """Check the handling logic for `cmap` and `colors`.""" + X, y = load_iris_2d_scaled() + clf = LogisticRegression().fit(X, y) + + with pytest.raises( + ValueError, + match="Cannot specify both 'cmap' and 'colors' in kwargs.", + ): + DecisionBoundaryDisplay.from_estimator( + clf, + X, + colors="black", + cmap="Blues", + ) + + with pytest.warns( + UserWarning, + match="'cmap' is ignored when 'multiclass_colors' is set.", + ): + DecisionBoundaryDisplay.from_estimator( + clf, + X, + multiclass_colors="plasma", + cmap="Blues", ) - for class_idx, (r, g, b, _) in enumerate(colors) - ] + + with pytest.warns( + UserWarning, + match="'colors' is ignored when 'multiclass_colors' is set.", + ): + DecisionBoundaryDisplay.from_estimator( + clf, + X, + multiclass_colors="plasma", + colors="blue", + ) + + +@pytest.mark.parametrize("plot_method", ["contourf", "contour", "pcolormesh"]) +@pytest.mark.parametrize("kwargs", [{"cmap": "tab10"}, {"cmap": "Blues"}]) +def test_multiclass_cmap(pyplot, plot_method, kwargs): + """Check that `cmap` is correctly applied to DecisionBoundaryDisplay.""" + import matplotlib as mpl + import matplotlib.pyplot as plt + + X, y = load_iris_2d_scaled() + clf = LogisticRegression().fit(X, y) + + disp = DecisionBoundaryDisplay.from_estimator( + clf, + X, + cmap=kwargs["cmap"], + plot_method=plot_method, + ) + + cmap = plt.get_cmap(kwargs["cmap"], len(clf.classes_)) + if not hasattr(cmap, "colors"): + colors = cmap(np.linspace(0, 1, len(clf.classes_))) + else: + colors = cmap.colors if plot_method != "contour": + cmaps = [ + mpl.colors.LinearSegmentedColormap.from_list( + f"colormap_{class_idx}", [(1.0, 1.0, 1.0, 1.0), (r, g, b, 1.0)] + ) + for class_idx, (r, g, b, _) in enumerate(colors) + ] for idx, quad in enumerate(disp.surface_): assert quad.cmap == cmaps[idx] + else: + assert_allclose(disp.surface_.colors, colors) + + +@pytest.mark.parametrize("plot_method", ["contourf", "contour", "pcolormesh"]) +@pytest.mark.parametrize("kwargs", [{"colors": "black"}, {"colors": ["r", "g", "b"]}]) +def test_multiclass_colors(pyplot, plot_method, kwargs): + """Check that `cmap` is correctly applied to DecisionBoundaryDisplay.""" + import matplotlib as mpl + + X, y = load_iris_2d_scaled() + clf = LogisticRegression().fit(X, y) + + disp = DecisionBoundaryDisplay.from_estimator( + clf, + X, + colors=kwargs["colors"], + plot_method=plot_method, + ) + + if isinstance(kwargs["colors"], str): + colors = mpl.colors.to_rgba(kwargs["colors"]) + colors = [colors for _ in range(len(clf.classes_))] + else: + colors = [mpl.colors.to_rgba(color) for color in kwargs["colors"]] + + if plot_method != "contour": + cmaps = [ + mpl.colors.LinearSegmentedColormap.from_list( + f"colormap_{class_idx}", [(1.0, 1.0, 1.0, 1.0), (r, g, b, 1.0)] + ) + for class_idx, (r, g, b, _) in enumerate(colors) + ] + for idx, quad in enumerate(disp.surface_): + assert quad.cmap == cmaps[idx] + else: + assert_allclose(disp.surface_.colors, colors) def test_subclass_named_constructors_return_type_is_subclass(pyplot): From 04b3528ab06a02baf12cb554d512233a671eb331 Mon Sep 17 00:00:00 2001 From: jshn9515 Date: Thu, 10 Jul 2025 10:32:33 +0800 Subject: [PATCH 06/17] Remove conflicting check for 'cmap' and 'colors' --- sklearn/inspection/_plot/decision_boundary.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/sklearn/inspection/_plot/decision_boundary.py b/sklearn/inspection/_plot/decision_boundary.py index 98c6d792622d3..1cdc72c03a478 100644 --- a/sklearn/inspection/_plot/decision_boundary.py +++ b/sklearn/inspection/_plot/decision_boundary.py @@ -222,11 +222,6 @@ def plot(self, plot_method="contourf", ax=None, xlabel=None, ylabel=None, **kwar else: # self.response.ndim == 3 n_responses = self.response.shape[-1] if self.multiclass_colors is None: - if "cmap" in kwargs and "colors" in kwargs: - raise ValueError( - "Cannot specify both 'cmap' and 'colors' in kwargs. " - "Please use only one of them." - ) if "cmap" in kwargs: cmap = plt.get_cmap(kwargs.pop("cmap"), n_responses) if not hasattr(cmap, "colors"): From e97c8578c17b1bf4219f612ff68b647cb620d073 Mon Sep 17 00:00:00 2001 From: jshn9515 Date: Thu, 10 Jul 2025 10:52:23 +0800 Subject: [PATCH 07/17] Ignore 'cmap' and 'colors' arguments in multiclass case --- sklearn/inspection/_plot/decision_boundary.py | 33 +++++-------------- 1 file changed, 8 insertions(+), 25 deletions(-) diff --git a/sklearn/inspection/_plot/decision_boundary.py b/sklearn/inspection/_plot/decision_boundary.py index 1cdc72c03a478..41418bcbd472c 100644 --- a/sklearn/inspection/_plot/decision_boundary.py +++ b/sklearn/inspection/_plot/decision_boundary.py @@ -221,33 +221,16 @@ def plot(self, plot_method="contourf", ax=None, xlabel=None, ylabel=None, **kwar self.surface_ = plot_func(self.xx0, self.xx1, self.response, **kwargs) else: # self.response.ndim == 3 n_responses = self.response.shape[-1] + if "cmap" in kwargs: + warnings.warn("'cmap' is ignored in multiclass case.") + del kwargs["cmap"] + if "colors" in kwargs: + warnings.warn("'colors' is ignored in multiclass case.") + del kwargs["colors"] if self.multiclass_colors is None: - if "cmap" in kwargs: - cmap = plt.get_cmap(kwargs.pop("cmap"), n_responses) - if not hasattr(cmap, "colors"): - # For LinearSegmentedColormap - colors = cmap(np.linspace(0, 1, n_responses)) - else: - colors = cmap.colors - elif "colors" in kwargs: - if isinstance(kwargs["colors"], str): - colors = mpl.colors.to_rgba(kwargs.pop("colors")) - colors = [colors for _ in range(n_responses)] - else: - colors = kwargs.pop("colors") - colors = [mpl.colors.to_rgba(color) for color in colors] - else: - cmap = "tab10" if n_responses <= 10 else "gist_rainbow" - colors = plt.get_cmap(cmap, n_responses).colors + cmap = "tab10" if n_responses <= 10 else "gist_rainbow" + colors = plt.get_cmap(cmap, n_responses).colors else: - if "cmap" in kwargs: - warnings.warn("'cmap' is ignored when 'multiclass_colors' is set.") - del kwargs["cmap"] - if "colors" in kwargs: - warnings.warn( - "'colors' is ignored when 'multiclass_colors' is set." - ) - del kwargs["colors"] if isinstance(self.multiclass_colors, str): cmap = plt.get_cmap(self.multiclass_colors, n_responses) if not hasattr(cmap, "colors"): From 752df3b073c5f10b5eccd3679ff9abd396050681 Mon Sep 17 00:00:00 2001 From: jshn9515 Date: Thu, 10 Jul 2025 11:19:54 +0800 Subject: [PATCH 08/17] Update corresponding tests --- .../tests/test_boundary_decision_display.py | 94 ++----------------- 1 file changed, 10 insertions(+), 84 deletions(-) diff --git a/sklearn/inspection/_plot/tests/test_boundary_decision_display.py b/sklearn/inspection/_plot/tests/test_boundary_decision_display.py index 109b8a66e2114..bac47c8ac00a2 100644 --- a/sklearn/inspection/_plot/tests/test_boundary_decision_display.py +++ b/sklearn/inspection/_plot/tests/test_boundary_decision_display.py @@ -169,6 +169,10 @@ def test_input_validation_errors(pyplot, kwargs, error_msg, fitted_clf): @pytest.mark.parametrize( "kwargs, error_msg", [ + ( + {"multiclass_colors": {"dict": "not_list"}}, + "'multiclass_colors' must be a list or a str.", + ), ({"multiclass_colors": "not_cmap"}, "it must be a valid Matplotlib colormap"), ({"multiclass_colors": ["red", "green"]}, "it must be of the same length"), ( @@ -617,6 +621,7 @@ def test_multiclass_plot_max_class(pyplot, response_method): "multiclass_colors", [ "plasma", + "Blues", ["red", "green", "blue"], ], ) @@ -642,6 +647,9 @@ def test_multiclass_colors_cmap(pyplot, plot_method, multiclass_colors): if multiclass_colors == "plasma": colors = mpl.pyplot.get_cmap(multiclass_colors, len(clf.classes_)).colors + elif multiclass_colors == "Blues": + cmap = mpl.pyplot.get_cmap(multiclass_colors, len(clf.classes_)) + colors = cmap(np.linspace(0, 1, len(clf.classes_))) else: colors = [mpl.colors.to_rgba(color) for color in multiclass_colors] @@ -663,20 +671,9 @@ def test_cmap_and_colors_logic(pyplot): X, y = load_iris_2d_scaled() clf = LogisticRegression().fit(X, y) - with pytest.raises( - ValueError, - match="Cannot specify both 'cmap' and 'colors' in kwargs.", - ): - DecisionBoundaryDisplay.from_estimator( - clf, - X, - colors="black", - cmap="Blues", - ) - with pytest.warns( UserWarning, - match="'cmap' is ignored when 'multiclass_colors' is set.", + match="'cmap' is ignored in multiclass case.", ): DecisionBoundaryDisplay.from_estimator( clf, @@ -687,7 +684,7 @@ def test_cmap_and_colors_logic(pyplot): with pytest.warns( UserWarning, - match="'colors' is ignored when 'multiclass_colors' is set.", + match="'colors' is ignored in multiclass case.", ): DecisionBoundaryDisplay.from_estimator( clf, @@ -697,77 +694,6 @@ def test_cmap_and_colors_logic(pyplot): ) -@pytest.mark.parametrize("plot_method", ["contourf", "contour", "pcolormesh"]) -@pytest.mark.parametrize("kwargs", [{"cmap": "tab10"}, {"cmap": "Blues"}]) -def test_multiclass_cmap(pyplot, plot_method, kwargs): - """Check that `cmap` is correctly applied to DecisionBoundaryDisplay.""" - import matplotlib as mpl - import matplotlib.pyplot as plt - - X, y = load_iris_2d_scaled() - clf = LogisticRegression().fit(X, y) - - disp = DecisionBoundaryDisplay.from_estimator( - clf, - X, - cmap=kwargs["cmap"], - plot_method=plot_method, - ) - - cmap = plt.get_cmap(kwargs["cmap"], len(clf.classes_)) - if not hasattr(cmap, "colors"): - colors = cmap(np.linspace(0, 1, len(clf.classes_))) - else: - colors = cmap.colors - - if plot_method != "contour": - cmaps = [ - mpl.colors.LinearSegmentedColormap.from_list( - f"colormap_{class_idx}", [(1.0, 1.0, 1.0, 1.0), (r, g, b, 1.0)] - ) - for class_idx, (r, g, b, _) in enumerate(colors) - ] - for idx, quad in enumerate(disp.surface_): - assert quad.cmap == cmaps[idx] - else: - assert_allclose(disp.surface_.colors, colors) - - -@pytest.mark.parametrize("plot_method", ["contourf", "contour", "pcolormesh"]) -@pytest.mark.parametrize("kwargs", [{"colors": "black"}, {"colors": ["r", "g", "b"]}]) -def test_multiclass_colors(pyplot, plot_method, kwargs): - """Check that `cmap` is correctly applied to DecisionBoundaryDisplay.""" - import matplotlib as mpl - - X, y = load_iris_2d_scaled() - clf = LogisticRegression().fit(X, y) - - disp = DecisionBoundaryDisplay.from_estimator( - clf, - X, - colors=kwargs["colors"], - plot_method=plot_method, - ) - - if isinstance(kwargs["colors"], str): - colors = mpl.colors.to_rgba(kwargs["colors"]) - colors = [colors for _ in range(len(clf.classes_))] - else: - colors = [mpl.colors.to_rgba(color) for color in kwargs["colors"]] - - if plot_method != "contour": - cmaps = [ - mpl.colors.LinearSegmentedColormap.from_list( - f"colormap_{class_idx}", [(1.0, 1.0, 1.0, 1.0), (r, g, b, 1.0)] - ) - for class_idx, (r, g, b, _) in enumerate(colors) - ] - for idx, quad in enumerate(disp.surface_): - assert quad.cmap == cmaps[idx] - else: - assert_allclose(disp.surface_.colors, colors) - - def test_subclass_named_constructors_return_type_is_subclass(pyplot): """Check that named constructors return the correct type when subclassed. From 714d82420580a6d8b8881c83b83f38422453587a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=A9mie=20du=20Boisberranger?= Date: Thu, 10 Jul 2025 14:34:46 +0200 Subject: [PATCH 09/17] add back the tab10/tab20 logic --- sklearn/inspection/_plot/decision_boundary.py | 35 ++++++++++++------- 1 file changed, 23 insertions(+), 12 deletions(-) diff --git a/sklearn/inspection/_plot/decision_boundary.py b/sklearn/inspection/_plot/decision_boundary.py index 41418bcbd472c..132b3390c243a 100644 --- a/sklearn/inspection/_plot/decision_boundary.py +++ b/sklearn/inspection/_plot/decision_boundary.py @@ -227,23 +227,34 @@ def plot(self, plot_method="contourf", ax=None, xlabel=None, ylabel=None, **kwar if "colors" in kwargs: warnings.warn("'colors' is ignored in multiclass case.") del kwargs["colors"] - if self.multiclass_colors is None: - cmap = "tab10" if n_responses <= 10 else "gist_rainbow" - colors = plt.get_cmap(cmap, n_responses).colors - else: - if isinstance(self.multiclass_colors, str): - cmap = plt.get_cmap(self.multiclass_colors, n_responses) + + if self.multiclass_colors is None or isinstance( + self.multiclass_colors, str + ): + if self.multiclass_colors is None: + cmap = "tab10" if n_responses <= 10 else "gist_rainbow" + else: + cmap = self.multiclass_colors + + # Special case for the tab10 and tab20 colormaps that encode a + # discrete set of colors that are easily distinguishable + # contrary to other colormaps that are continuous. + if cmap == "tab10" and n_responses <= 10: + colors = plt.get_cmap("tab10", 10).colors[:n_responses] + elif cmap == "tab20" and n_responses <= 20: + colors = plt.get_cmap("tab20", 20).colors[:n_responses] + else: + cmap = plt.get_cmap(cmap, n_responses) if not hasattr(cmap, "colors"): # For LinearSegmentedColormap colors = cmap(np.linspace(0, 1, n_responses)) + print("linear segmented") else: colors = cmap.colors - elif isinstance(self.multiclass_colors, list): - colors = [ - mpl.colors.to_rgba(color) for color in self.multiclass_colors - ] - else: - raise ValueError("'multiclass_colors' must be a list or a str.") + elif isinstance(self.multiclass_colors, list): + colors = [mpl.colors.to_rgba(color) for color in self.multiclass_colors] + else: + raise ValueError("'multiclass_colors' must be a list or a str.") self.multiclass_colors_ = colors if plot_method == "contour": From 0b51a481a1e3ce498e48e6f2a9684335bcc4d443 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=A9mie=20du=20Boisberranger?= Date: Thu, 10 Jul 2025 16:39:54 +0200 Subject: [PATCH 10/17] Update sklearn/inspection/_plot/decision_boundary.py --- sklearn/inspection/_plot/decision_boundary.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sklearn/inspection/_plot/decision_boundary.py b/sklearn/inspection/_plot/decision_boundary.py index 132b3390c243a..c3166e6aafa0a 100644 --- a/sklearn/inspection/_plot/decision_boundary.py +++ b/sklearn/inspection/_plot/decision_boundary.py @@ -248,7 +248,6 @@ def plot(self, plot_method="contourf", ax=None, xlabel=None, ylabel=None, **kwar if not hasattr(cmap, "colors"): # For LinearSegmentedColormap colors = cmap(np.linspace(0, 1, n_responses)) - print("linear segmented") else: colors = cmap.colors elif isinstance(self.multiclass_colors, list): From e3c0d875ce678abe24490b4382687b5b9e7e62d3 Mon Sep 17 00:00:00 2001 From: jshn9515 Date: Thu, 10 Jul 2025 22:51:30 +0800 Subject: [PATCH 11/17] Update warnings for 'cmap' and 'colors' in multiclass case --- sklearn/inspection/_plot/decision_boundary.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/sklearn/inspection/_plot/decision_boundary.py b/sklearn/inspection/_plot/decision_boundary.py index c3166e6aafa0a..a596d4219bd54 100644 --- a/sklearn/inspection/_plot/decision_boundary.py +++ b/sklearn/inspection/_plot/decision_boundary.py @@ -222,10 +222,18 @@ def plot(self, plot_method="contourf", ax=None, xlabel=None, ylabel=None, **kwar else: # self.response.ndim == 3 n_responses = self.response.shape[-1] if "cmap" in kwargs: - warnings.warn("'cmap' is ignored in multiclass case.") + warnings.warn( + "'cmap' is ignored in favor of 'multiclass_colors' " + "in the multiclass case when the response method is " + "'decision_function' or 'predict_proba'." + ) del kwargs["cmap"] if "colors" in kwargs: - warnings.warn("'colors' is ignored in multiclass case.") + warnings.warn( + "'colors' is ignored in favor of 'multiclass_colors' " + "in the multiclass case when the response method is " + "'decision_function' or 'predict_proba'." + ) del kwargs["colors"] if self.multiclass_colors is None or isinstance( From 3351a8d1711d43beb776bcaac1b4110f8c22d252 Mon Sep 17 00:00:00 2001 From: jshn9515 Date: Thu, 10 Jul 2025 22:53:56 +0800 Subject: [PATCH 12/17] Update corresponding tests --- .../inspection/_plot/tests/test_boundary_decision_display.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/inspection/_plot/tests/test_boundary_decision_display.py b/sklearn/inspection/_plot/tests/test_boundary_decision_display.py index bac47c8ac00a2..f409a50ab58c0 100644 --- a/sklearn/inspection/_plot/tests/test_boundary_decision_display.py +++ b/sklearn/inspection/_plot/tests/test_boundary_decision_display.py @@ -673,7 +673,7 @@ def test_cmap_and_colors_logic(pyplot): with pytest.warns( UserWarning, - match="'cmap' is ignored in multiclass case.", + match="'cmap' is ignored in favor of 'multiclass_colors'", ): DecisionBoundaryDisplay.from_estimator( clf, @@ -684,7 +684,7 @@ def test_cmap_and_colors_logic(pyplot): with pytest.warns( UserWarning, - match="'colors' is ignored in multiclass case.", + match="'colors' is ignored in favor of 'multiclass_colors'", ): DecisionBoundaryDisplay.from_estimator( clf, From 61065bc9e6817529b517b6f0fd1f0b2f179f4ae0 Mon Sep 17 00:00:00 2001 From: jshn9515 Date: Thu, 10 Jul 2025 22:55:46 +0800 Subject: [PATCH 13/17] Update changelog --- .../upcoming_changes/sklearn.inspection/31553.fix.rst | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 doc/whats_new/upcoming_changes/sklearn.inspection/31553.fix.rst diff --git a/doc/whats_new/upcoming_changes/sklearn.inspection/31553.fix.rst b/doc/whats_new/upcoming_changes/sklearn.inspection/31553.fix.rst new file mode 100644 index 0000000000000..aed168e860644 --- /dev/null +++ b/doc/whats_new/upcoming_changes/sklearn.inspection/31553.fix.rst @@ -0,0 +1,5 @@ +- Fix multiple issues in :class:`inspection.DecisionBoundaryDisplay`: + - `contour` plotting now correctly shows the decision boundary. + - `cmap` and `colors` are now handled consistently in multiclass settings. + - Linear segmented colormaps are now fully supported. + By :user:`Yunjie Lin ` From c163647230f79d610db07d3a164c73eaf4a6ead6 Mon Sep 17 00:00:00 2001 From: jshn9515 Date: Fri, 11 Jul 2025 18:17:23 +0800 Subject: [PATCH 14/17] Reduce code duplication for 'cmap' and 'colors' MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Jérémie du Boisberranger --- sklearn/inspection/_plot/decision_boundary.py | 22 +++++++------------ 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/sklearn/inspection/_plot/decision_boundary.py b/sklearn/inspection/_plot/decision_boundary.py index a596d4219bd54..2ef8538058393 100644 --- a/sklearn/inspection/_plot/decision_boundary.py +++ b/sklearn/inspection/_plot/decision_boundary.py @@ -221,20 +221,14 @@ def plot(self, plot_method="contourf", ax=None, xlabel=None, ylabel=None, **kwar self.surface_ = plot_func(self.xx0, self.xx1, self.response, **kwargs) else: # self.response.ndim == 3 n_responses = self.response.shape[-1] - if "cmap" in kwargs: - warnings.warn( - "'cmap' is ignored in favor of 'multiclass_colors' " - "in the multiclass case when the response method is " - "'decision_function' or 'predict_proba'." - ) - del kwargs["cmap"] - if "colors" in kwargs: - warnings.warn( - "'colors' is ignored in favor of 'multiclass_colors' " - "in the multiclass case when the response method is " - "'decision_function' or 'predict_proba'." - ) - del kwargs["colors"] + for kwarg in ("cmap", "colors"): + if kwarg in kwargs: + warnings.warn( + f"'{kwarg}' is ignored in favor of 'multiclass_colors' " + "in the multiclass case when the response method is " + "'decision_function' or 'predict_proba'." + ) + del kwargs[kwarg] if self.multiclass_colors is None or isinstance( self.multiclass_colors, str From 192910369592eab59d7d62f5d0726c970481af67 Mon Sep 17 00:00:00 2001 From: jshn9515 Date: Fri, 11 Jul 2025 18:21:41 +0800 Subject: [PATCH 15/17] Clarify fixes apply to multiclass case in DecisionBoundaryDisplay MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Jérémie du Boisberranger --- doc/whats_new/upcoming_changes/sklearn.inspection/31553.fix.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/whats_new/upcoming_changes/sklearn.inspection/31553.fix.rst b/doc/whats_new/upcoming_changes/sklearn.inspection/31553.fix.rst index aed168e860644..e29e86161bae8 100644 --- a/doc/whats_new/upcoming_changes/sklearn.inspection/31553.fix.rst +++ b/doc/whats_new/upcoming_changes/sklearn.inspection/31553.fix.rst @@ -1,4 +1,4 @@ -- Fix multiple issues in :class:`inspection.DecisionBoundaryDisplay`: +- Fix multiple issues in the multiclass setting of :class:`inspection.DecisionBoundaryDisplay`: - `contour` plotting now correctly shows the decision boundary. - `cmap` and `colors` are now handled consistently in multiclass settings. - Linear segmented colormaps are now fully supported. From eac9900d96e5b6722344ce6f289105b3800ff407 Mon Sep 17 00:00:00 2001 From: jshn9515 Date: Fri, 11 Jul 2025 18:23:17 +0800 Subject: [PATCH 16/17] Clarify that 'cmap' and 'colors' are ignored in favor of 'multiclass_colors' MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Jérémie du Boisberranger --- doc/whats_new/upcoming_changes/sklearn.inspection/31553.fix.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/whats_new/upcoming_changes/sklearn.inspection/31553.fix.rst b/doc/whats_new/upcoming_changes/sklearn.inspection/31553.fix.rst index e29e86161bae8..19428a8dc432a 100644 --- a/doc/whats_new/upcoming_changes/sklearn.inspection/31553.fix.rst +++ b/doc/whats_new/upcoming_changes/sklearn.inspection/31553.fix.rst @@ -1,5 +1,5 @@ - Fix multiple issues in the multiclass setting of :class:`inspection.DecisionBoundaryDisplay`: - `contour` plotting now correctly shows the decision boundary. - - `cmap` and `colors` are now handled consistently in multiclass settings. + - `cmap` and `colors` are now properly ignored in favor of `multiclass_colors`. - Linear segmented colormaps are now fully supported. By :user:`Yunjie Lin ` From 94cbfdebb30374a3f2bf7e6f1222a1ed3d3a244a Mon Sep 17 00:00:00 2001 From: jshn9515 Date: Fri, 11 Jul 2025 18:43:40 +0800 Subject: [PATCH 17/17] Add blank lines in changelog --- doc/whats_new/upcoming_changes/sklearn.inspection/31553.fix.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/doc/whats_new/upcoming_changes/sklearn.inspection/31553.fix.rst b/doc/whats_new/upcoming_changes/sklearn.inspection/31553.fix.rst index 19428a8dc432a..bd9bb339bb68c 100644 --- a/doc/whats_new/upcoming_changes/sklearn.inspection/31553.fix.rst +++ b/doc/whats_new/upcoming_changes/sklearn.inspection/31553.fix.rst @@ -1,5 +1,7 @@ - Fix multiple issues in the multiclass setting of :class:`inspection.DecisionBoundaryDisplay`: + - `contour` plotting now correctly shows the decision boundary. - `cmap` and `colors` are now properly ignored in favor of `multiclass_colors`. - Linear segmented colormaps are now fully supported. + By :user:`Yunjie Lin `