Skip to content

ENH accept class_of_interest in DecisionBoundaryDisplay to inspect multiclass classifiers #27291

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
Oct 11, 2023

Conversation

glemaitre
Copy link
Member

@glemaitre glemaitre commented Sep 4, 2023

This PR proposes an improvement to the DecisionBoundaryDisplay to address a TODO from the code.

We expose a class_of_interest parameter allowing us to plot the output of predict_proba or decision_function with binary and multiclass classifiers.

@github-actions
Copy link

github-actions bot commented Sep 4, 2023

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: c990014. Link to the linter CI: here

Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have no strong opinion as about which strategy between this PR and the one of #26995 is the best. However I think we choose the one that makes implementing what I suggest in the following the most natural:

@ogrisel
Copy link
Member

ogrisel commented Sep 5, 2023

BTW, I am ok to implement the max case in a follow-up PR but I would like to make sure that the design decisions made in this PR will not prevent a natural implementation of that case.

@glemaitre
Copy link
Member Author

@ogrisel I quickly implemented your suggestion and it provides the following diff:

diff --git a/sklearn/inspection/_plot/decision_boundary.py b/sklearn/inspection/_plot/decision_boundary.py
index 6ac2816946..3a8b044bfd 100644
--- a/sklearn/inspection/_plot/decision_boundary.py
+++ b/sklearn/inspection/_plot/decision_boundary.py
@@ -41,16 +41,7 @@ def _check_boundary_response_method(estimator, response_method, class_of_interes
         msg = "Multi-label and multi-output multi-class classifiers are not supported"
         raise ValueError(msg)
 
-    if has_classes and len(estimator.classes_) > 2:
-        if response_method not in {"auto", "predict"} and class_of_interest is None:
-            msg = (
-                "Multiclass classifiers are only supported when response_method is"
-                " 'predict' or 'auto', or you must provide `class_of_interest` to "
-                " select a specific class to plot the decision boundary."
-            )
-            raise ValueError(msg)
-        prediction_method = "predict" if response_method == "auto" else response_method
-    elif response_method == "auto":
+    if response_method == "auto":
         prediction_method = ["decision_function", "predict_proba", "predict"]
     else:
         prediction_method = response_method
@@ -78,7 +69,8 @@ class DecisionBoundaryDisplay:
     xx1 : ndarray of shape (grid_resolution, grid_resolution)
         Second output of :func:`meshgrid <numpy.meshgrid>`.
 
-    response : ndarray of shape (grid_resolution, grid_resolution)
+    response : ndarray of shape (grid_resolution, grid_resolution) or \
+            (grid_resolution, grid_resolution, n_classes)
         Values of the response function.
 
     xlabel : str, default=None
@@ -89,7 +81,7 @@ class DecisionBoundaryDisplay:
 
     Attributes
     ----------
-    surface_ : matplotlib `QuadContourSet` or `QuadMesh`
+    surface_ : matplotlib `QuadContourSet` or `QuadMesh` or list of such objects
         If `plot_method` is 'contour' or 'contourf', `surface_` is a
         :class:`QuadContourSet <matplotlib.contour.QuadContourSet>`. If
         `plot_method` is 'pcolormesh', `surface_` is a
@@ -170,6 +162,7 @@ class DecisionBoundaryDisplay:
             Object that stores computed values.
         """
         check_matplotlib_support("DecisionBoundaryDisplay.plot")
+        import matplotlib as mpl
         import matplotlib.pyplot as plt  # noqa
 
         if plot_method not in ("contourf", "contour", "pcolormesh"):
@@ -181,7 +174,26 @@ class DecisionBoundaryDisplay:
             _, ax = plt.subplots()
 
         plot_func = getattr(ax, plot_method)
-        self.surface_ = plot_func(self.xx0, self.xx1, self.response, **kwargs)
+
+        if self.response.ndim == 2:
+            self.surface_ = plot_func(self.xx0, self.xx1, self.response, **kwargs)
+        else:  # self.response.ndim == 3
+            # create the colormap for each class
+            viridis = mpl.colormaps["viridis"].resampled(self.response.shape[-1])
+
+            self.surface_ = []
+            for class_idx, primary_color in enumerate(viridis.colors):
+                r, g, b, _ = primary_color
+                cmap = mpl.colors.LinearSegmentedColormap.from_list(
+                    f"colormap_{class_idx}", [(1.0, 1.0, 1.0, 1.0), (r, g, b, 1.0)]
+                )
+                response = np.ma.array(
+                    self.response[:, :, class_idx],
+                    mask=~(self.response.argmax(axis=2) == class_idx),
+                )
+                self.surface_.append(
+                    plot_func(self.xx0, self.xx1, response, cmap=cmap, **kwargs)
+                )
 
         if xlabel is not None or not ax.get_xlabel():
             xlabel = self.xlabel if xlabel is None else xlabel
@@ -379,11 +391,16 @@ class DecisionBoundaryDisplay:
             if is_regressor(estimator):
                 raise ValueError("Multi-output regressors are not supported")
 
-            # For the multiclass case, `_get_response_values` returns the response
-            # as-is. Thus, we have a column per class and we need to select the column
-            # corresponding to the positive class.
-            col_idx = np.flatnonzero(estimator.classes_ == class_of_interest)[0]
-            response = response[:, col_idx]
+            if class_of_interest is not None:
+                # For the multiclass case, `_get_response_values` returns the response
+                # as-is. Thus, we have a column per class and we need to select the column
+                # corresponding to the positive class.
+                col_idx = np.flatnonzero(estimator.classes_ == class_of_interest)[0]
+                response = response[:, col_idx].reshape(*xx0.shape)
+            else:
+                response = response.reshape(*xx0.shape, response.shape[-1])
+        else:
+            response = response.reshape(*xx0.shape)
 
         if xlabel is None:
             xlabel = X.columns[0] if hasattr(X, "columns") else ""
@@ -394,7 +411,7 @@ class DecisionBoundaryDisplay:
         display = DecisionBoundaryDisplay(
             xx0=xx0,
             xx1=xx1,
-            response=response.reshape(xx0.shape),
+            response=response,
             xlabel=xlabel,
             ylabel=ylabel,
         )

I feed that the changes are quite reasonable and good news is that it works for "contour", "contourf", and "pcolormesh":

image

image

image

While the changes are minor, having it in an additional PR would be better. We need to think about exposing parameters for the colormap and it might not be straightforward. However, I am pretty happy with the default that I implemented (resampling viridis that is the default colormap used in a scatter plot)

@ogrisel
Copy link
Member

ogrisel commented Sep 6, 2023

This looks great, ok for a follow-up PR then.

@ogrisel ogrisel changed the title ENH accept pos_label in DecisionBoundaryDisplay ENH accept class_of_interest in DecisionBoundaryDisplay to inspect multiclass classifiers Sep 7, 2023
Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another pass of feedback. Besides, LGTM.

Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM once the following points are resolved.

@glemaitre
Copy link
Member Author

@adrinjalali Would you mind to have a look at this PR. It similar to a previous PR that we open in August after a PyLadies Berlin. As previously stated, it does yet handle one of the case that I will implement later.

@@ -248,6 +250,14 @@ def from_estimator(
For multiclass problems, :term:`predict` is selected when
`response_method="auto"`.

class_of_interest : int, float, bool or str, default=None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what float and bool here mean? do we accept floats for classification targets? is bool accepted when target is boolean?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we accept floats for classification targets?

(unfortunatelly) yes we do. The type reported here are the same than the one accepted as pos_label in thresholded metric.

Comment on lines 376 to 378
except AttributeError as exc:
# re-raise the AttributeError as a ValueError for backward compatibility
raise ValueError(str(exc)) from exc
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems bad to me. Attribute error should be an attribute error, not converted to a ValueError. I would just change this. These displays are used in interactive mode, the type of exception could be changed here IMO.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both ValueError and AttributeError are valid. AttributeError is valid since the estimator does not expose the attribute but the ValueError is valid because this is a parameter value passed by the user.

Since we are already raising a ValueError that is not semantically wrong, I agree with @ogrisel that this is a pity to break the backward compatibility.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dislike the kind of code which is only there because somebody did something in the past and it's there for historical reasons. Makes code less maintainable over time. And I don't consider this a major backward compatibility issue since it's not like we were working and we don't now. We were failing, and we still fail, but a different error type. If you really want, we can add a message here that the type of the error will be change to an AttributeError in 1.6, but I don't think we should keep such code in the long term in the code base.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So let's take the second option of the @ogrisel comment: #27291 (comment)

Let's acknowledge the change in the changelog only.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I buy the fact that the display are usually used in a more interactive mode and this is I don't foresee a use case where someone will catch the error to do something else.

Copy link
Member

@ArturoAmorQ ArturoAmorQ left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some comments about the docstrings. Otherwise LGTM :)

Comment on lines 47 to 49
"Multiclass classifiers are only supported when response_method is"
" 'predict' or 'auto'"
" 'predict' or 'auto', or you must provide `class_of_interest` to "
" select a specific class to plot the decision boundary."
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

            "Multiclass classifiers are only supported when `response_method`"
            " is 'predict' or 'auto'. Else you must provide `class_of_interest`"
            " to plot the decision boundary of a specific class."

@adrinjalali adrinjalali merged commit 3597f0b into scikit-learn:main Oct 11, 2023
glemaitre added a commit to glemaitre/scikit-learn that referenced this pull request Oct 31, 2023
…lticlass classifiers (scikit-learn#27291)

Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
Co-authored-by: Arturo Amor <86408019+ArturoAmorQ@users.noreply.github.com>
REDVM pushed a commit to REDVM/scikit-learn that referenced this pull request Nov 16, 2023
…lticlass classifiers (scikit-learn#27291)

Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
Co-authored-by: Arturo Amor <86408019+ArturoAmorQ@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants