-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
ENH accept class_of_interest in DecisionBoundaryDisplay to inspect multiclass classifiers #27291
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ENH accept class_of_interest in DecisionBoundaryDisplay to inspect multiclass classifiers #27291
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have no strong opinion as about which strategy between this PR and the one of #26995 is the best. However I think we choose the one that makes implementing what I suggest in the following the most natural:
BTW, I am ok to implement the max case in a follow-up PR but I would like to make sure that the design decisions made in this PR will not prevent a natural implementation of that case. |
@ogrisel I quickly implemented your suggestion and it provides the following diff: diff --git a/sklearn/inspection/_plot/decision_boundary.py b/sklearn/inspection/_plot/decision_boundary.py
index 6ac2816946..3a8b044bfd 100644
--- a/sklearn/inspection/_plot/decision_boundary.py
+++ b/sklearn/inspection/_plot/decision_boundary.py
@@ -41,16 +41,7 @@ def _check_boundary_response_method(estimator, response_method, class_of_interes
msg = "Multi-label and multi-output multi-class classifiers are not supported"
raise ValueError(msg)
- if has_classes and len(estimator.classes_) > 2:
- if response_method not in {"auto", "predict"} and class_of_interest is None:
- msg = (
- "Multiclass classifiers are only supported when response_method is"
- " 'predict' or 'auto', or you must provide `class_of_interest` to "
- " select a specific class to plot the decision boundary."
- )
- raise ValueError(msg)
- prediction_method = "predict" if response_method == "auto" else response_method
- elif response_method == "auto":
+ if response_method == "auto":
prediction_method = ["decision_function", "predict_proba", "predict"]
else:
prediction_method = response_method
@@ -78,7 +69,8 @@ class DecisionBoundaryDisplay:
xx1 : ndarray of shape (grid_resolution, grid_resolution)
Second output of :func:`meshgrid <numpy.meshgrid>`.
- response : ndarray of shape (grid_resolution, grid_resolution)
+ response : ndarray of shape (grid_resolution, grid_resolution) or \
+ (grid_resolution, grid_resolution, n_classes)
Values of the response function.
xlabel : str, default=None
@@ -89,7 +81,7 @@ class DecisionBoundaryDisplay:
Attributes
----------
- surface_ : matplotlib `QuadContourSet` or `QuadMesh`
+ surface_ : matplotlib `QuadContourSet` or `QuadMesh` or list of such objects
If `plot_method` is 'contour' or 'contourf', `surface_` is a
:class:`QuadContourSet <matplotlib.contour.QuadContourSet>`. If
`plot_method` is 'pcolormesh', `surface_` is a
@@ -170,6 +162,7 @@ class DecisionBoundaryDisplay:
Object that stores computed values.
"""
check_matplotlib_support("DecisionBoundaryDisplay.plot")
+ import matplotlib as mpl
import matplotlib.pyplot as plt # noqa
if plot_method not in ("contourf", "contour", "pcolormesh"):
@@ -181,7 +174,26 @@ class DecisionBoundaryDisplay:
_, ax = plt.subplots()
plot_func = getattr(ax, plot_method)
- self.surface_ = plot_func(self.xx0, self.xx1, self.response, **kwargs)
+
+ if self.response.ndim == 2:
+ self.surface_ = plot_func(self.xx0, self.xx1, self.response, **kwargs)
+ else: # self.response.ndim == 3
+ # create the colormap for each class
+ viridis = mpl.colormaps["viridis"].resampled(self.response.shape[-1])
+
+ self.surface_ = []
+ for class_idx, primary_color in enumerate(viridis.colors):
+ r, g, b, _ = primary_color
+ cmap = mpl.colors.LinearSegmentedColormap.from_list(
+ f"colormap_{class_idx}", [(1.0, 1.0, 1.0, 1.0), (r, g, b, 1.0)]
+ )
+ response = np.ma.array(
+ self.response[:, :, class_idx],
+ mask=~(self.response.argmax(axis=2) == class_idx),
+ )
+ self.surface_.append(
+ plot_func(self.xx0, self.xx1, response, cmap=cmap, **kwargs)
+ )
if xlabel is not None or not ax.get_xlabel():
xlabel = self.xlabel if xlabel is None else xlabel
@@ -379,11 +391,16 @@ class DecisionBoundaryDisplay:
if is_regressor(estimator):
raise ValueError("Multi-output regressors are not supported")
- # For the multiclass case, `_get_response_values` returns the response
- # as-is. Thus, we have a column per class and we need to select the column
- # corresponding to the positive class.
- col_idx = np.flatnonzero(estimator.classes_ == class_of_interest)[0]
- response = response[:, col_idx]
+ if class_of_interest is not None:
+ # For the multiclass case, `_get_response_values` returns the response
+ # as-is. Thus, we have a column per class and we need to select the column
+ # corresponding to the positive class.
+ col_idx = np.flatnonzero(estimator.classes_ == class_of_interest)[0]
+ response = response[:, col_idx].reshape(*xx0.shape)
+ else:
+ response = response.reshape(*xx0.shape, response.shape[-1])
+ else:
+ response = response.reshape(*xx0.shape)
if xlabel is None:
xlabel = X.columns[0] if hasattr(X, "columns") else ""
@@ -394,7 +411,7 @@ class DecisionBoundaryDisplay:
display = DecisionBoundaryDisplay(
xx0=xx0,
xx1=xx1,
- response=response.reshape(xx0.shape),
+ response=response,
xlabel=xlabel,
ylabel=ylabel,
) I feed that the changes are quite reasonable and good news is that it works for "contour", "contourf", and "pcolormesh": While the changes are minor, having it in an additional PR would be better. We need to think about exposing parameters for the colormap and it might not be straightforward. However, I am pretty happy with the default that I implemented (resampling viridis that is the default colormap used in a scatter plot) |
This looks great, ok for a follow-up PR then. |
sklearn/inspection/_plot/tests/test_boundary_decision_display.py
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another pass of feedback. Besides, LGTM.
sklearn/inspection/_plot/tests/test_boundary_decision_display.py
Outdated
Show resolved
Hide resolved
sklearn/inspection/_plot/tests/test_boundary_decision_display.py
Outdated
Show resolved
Hide resolved
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM once the following points are resolved.
sklearn/inspection/_plot/tests/test_boundary_decision_display.py
Outdated
Show resolved
Hide resolved
@adrinjalali Would you mind to have a look at this PR. It similar to a previous PR that we open in August after a PyLadies Berlin. As previously stated, it does yet handle one of the case that I will implement later. |
@@ -248,6 +250,14 @@ def from_estimator( | |||
For multiclass problems, :term:`predict` is selected when | |||
`response_method="auto"`. | |||
|
|||
class_of_interest : int, float, bool or str, default=None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure what float
and bool
here mean? do we accept floats for classification targets? is bool
accepted when target is boolean?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we accept floats for classification targets?
(unfortunatelly) yes we do. The type reported here are the same than the one accepted as pos_label
in thresholded metric.
except AttributeError as exc: | ||
# re-raise the AttributeError as a ValueError for backward compatibility | ||
raise ValueError(str(exc)) from exc |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this seems bad to me. Attribute error should be an attribute error, not converted to a ValueError. I would just change this. These displays are used in interactive mode, the type of exception could be changed here IMO.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Both ValueError
and AttributeError
are valid. AttributeError
is valid since the estimator does not expose the attribute but the ValueError
is valid because this is a parameter value passed by the user.
Since we are already raising a ValueError
that is not semantically wrong, I agree with @ogrisel that this is a pity to break the backward compatibility.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I dislike the kind of code which is only there because somebody did something in the past and it's there for historical reasons. Makes code less maintainable over time. And I don't consider this a major backward compatibility issue since it's not like we were working and we don't now. We were failing, and we still fail, but a different error type. If you really want, we can add a message here that the type of the error will be change to an AttributeError
in 1.6, but I don't think we should keep such code in the long term in the code base.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So let's take the second option of the @ogrisel comment: #27291 (comment)
Let's acknowledge the change in the changelog only.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I buy the fact that the display are usually used in a more interactive mode and this is I don't foresee a use case where someone will catch the error to do something else.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some comments about the docstrings. Otherwise LGTM :)
"Multiclass classifiers are only supported when response_method is" | ||
" 'predict' or 'auto'" | ||
" 'predict' or 'auto', or you must provide `class_of_interest` to " | ||
" select a specific class to plot the decision boundary." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"Multiclass classifiers are only supported when `response_method`"
" is 'predict' or 'auto'. Else you must provide `class_of_interest`"
" to plot the decision boundary of a specific class."
Co-authored-by: Arturo Amor <86408019+ArturoAmorQ@users.noreply.github.com>
…lticlass classifiers (scikit-learn#27291) Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org> Co-authored-by: Arturo Amor <86408019+ArturoAmorQ@users.noreply.github.com>
…lticlass classifiers (scikit-learn#27291) Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org> Co-authored-by: Arturo Amor <86408019+ArturoAmorQ@users.noreply.github.com>
This PR proposes an improvement to the
DecisionBoundaryDisplay
to address a TODO from the code.We expose a
class_of_interest
parameter allowing us to plot the output ofpredict_proba
ordecision_function
with binary and multiclass classifiers.