-
-
Notifications
You must be signed in to change notification settings - Fork 25.9k
ENH accept class_of_interest in DecisionBoundaryDisplay to inspect multiclass classifiers #27291
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
66eb147
fad3a1c
82a4150
20461de
a08fbf7
c825ac4
6cb3a55
687b098
bcf6b52
010a292
38e7628
6377573
2ce357f
0746325
6b43da5
52d877f
d3a9bc4
d986e6a
7084b23
cac74dd
c990014
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,19 +1,18 @@ | ||
from functools import reduce | ||
|
||
import numpy as np | ||
|
||
from ...base import is_regressor | ||
from ...preprocessing import LabelEncoder | ||
from ...utils import _safe_indexing, check_matplotlib_support | ||
from ...utils._response import _get_response_values | ||
from ...utils.validation import ( | ||
_is_arraylike_not_scalar, | ||
_num_features, | ||
check_is_fitted, | ||
) | ||
|
||
|
||
def _check_boundary_response_method(estimator, response_method): | ||
"""Return prediction method from the `response_method` for decision boundary. | ||
def _check_boundary_response_method(estimator, response_method, class_of_interest): | ||
"""Validate the response methods to be used with the fitted estimator. | ||
|
||
Parameters | ||
---------- | ||
|
@@ -26,36 +25,38 @@ def _check_boundary_response_method(estimator, response_method): | |
If set to 'auto', the response method is tried in the following order: | ||
:term:`decision_function`, :term:`predict_proba`, :term:`predict`. | ||
|
||
class_of_interest : int, float, bool, str or None | ||
The class considered when plotting the decision. If the label is specified, it | ||
is then possible to plot the decision boundary in multiclass settings. | ||
|
||
.. versionadded:: 1.4 | ||
|
||
Returns | ||
------- | ||
prediction_method: callable | ||
Prediction method of estimator. | ||
prediction_method : list of str or str | ||
The name or list of names of the response methods to use. | ||
""" | ||
has_classes = hasattr(estimator, "classes_") | ||
if has_classes and _is_arraylike_not_scalar(estimator.classes_[0]): | ||
msg = "Multi-label and multi-output multi-class classifiers are not supported" | ||
raise ValueError(msg) | ||
|
||
if has_classes and len(estimator.classes_) > 2: | ||
if response_method not in {"auto", "predict"}: | ||
if response_method not in {"auto", "predict"} and class_of_interest is None: | ||
msg = ( | ||
"Multiclass classifiers are only supported when response_method is" | ||
" 'predict' or 'auto'" | ||
"Multiclass classifiers are only supported when `response_method` is " | ||
"'predict' or 'auto'. Else you must provide `class_of_interest` to " | ||
"plot the decision boundary of a specific class." | ||
) | ||
raise ValueError(msg) | ||
methods_list = ["predict"] | ||
prediction_method = "predict" if response_method == "auto" else response_method | ||
ogrisel marked this conversation as resolved.
Show resolved
Hide resolved
|
||
elif response_method == "auto": | ||
methods_list = ["decision_function", "predict_proba", "predict"] | ||
if is_regressor(estimator): | ||
prediction_method = "predict" | ||
else: | ||
prediction_method = ["decision_function", "predict_proba", "predict"] | ||
else: | ||
methods_list = [response_method] | ||
|
||
prediction_method = [getattr(estimator, method, None) for method in methods_list] | ||
prediction_method = reduce(lambda x, y: x or y, prediction_method) | ||
if prediction_method is None: | ||
raise ValueError( | ||
f"{estimator.__class__.__name__} has none of the following attributes: " | ||
f"{', '.join(methods_list)}." | ||
) | ||
prediction_method = response_method | ||
|
||
return prediction_method | ||
|
||
|
@@ -206,6 +207,7 @@ def from_estimator( | |
eps=1.0, | ||
plot_method="contourf", | ||
response_method="auto", | ||
class_of_interest=None, | ||
xlabel=None, | ||
ylabel=None, | ||
ax=None, | ||
|
@@ -248,6 +250,14 @@ def from_estimator( | |
For multiclass problems, :term:`predict` is selected when | ||
`response_method="auto"`. | ||
|
||
class_of_interest : int, float, bool or str, default=None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure what There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
(unfortunatelly) yes we do. The type reported here are the same than the one accepted as |
||
The class considered when plotting the decision. If None, | ||
`estimator.classes_[1]` is considered as the positive class | ||
for binary classifiers. For multiclass classifiers, passing | ||
an explicit value for `class_of_interest` is mandatory. | ||
|
||
.. versionadded:: 1.4 | ||
|
||
xlabel : str, default=None | ||
The label used for the x-axis. If `None`, an attempt is made to | ||
extract a label from `X` if it is a dataframe, otherwise an empty | ||
|
@@ -342,11 +352,30 @@ def from_estimator( | |
else: | ||
X_grid = np.c_[xx0.ravel(), xx1.ravel()] | ||
|
||
pred_func = _check_boundary_response_method(estimator, response_method) | ||
response = pred_func(X_grid) | ||
prediction_method = _check_boundary_response_method( | ||
estimator, response_method, class_of_interest | ||
) | ||
try: | ||
response, _, response_method_used = _get_response_values( | ||
estimator, | ||
X_grid, | ||
response_method=prediction_method, | ||
pos_label=class_of_interest, | ||
return_response_method_used=True, | ||
) | ||
except ValueError as exc: | ||
if "is not a valid label" in str(exc): | ||
# re-raise a more informative error message since `pos_label` is unknown | ||
# to our user when interacting with | ||
# `DecisionBoundaryDisplay.from_estimator` | ||
raise ValueError( | ||
f"class_of_interest={class_of_interest} is not a valid label: It " | ||
f"should be one of {estimator.classes_}" | ||
) from exc | ||
raise | ||
|
||
# convert classes predictions into integers | ||
if pred_func.__name__ == "predict" and hasattr(estimator, "classes_"): | ||
if response_method_used == "predict" and hasattr(estimator, "classes_"): | ||
encoder = LabelEncoder() | ||
encoder.classes_ = estimator.classes_ | ||
response = encoder.transform(response) | ||
|
@@ -355,8 +384,11 @@ def from_estimator( | |
if is_regressor(estimator): | ||
raise ValueError("Multi-output regressors are not supported") | ||
|
||
# TODO: Support pos_label | ||
response = response[:, 1] | ||
# For the multiclass case, `_get_response_values` returns the response | ||
# as-is. Thus, we have a column per class and we need to select the column | ||
# corresponding to the positive class. | ||
col_idx = np.flatnonzero(estimator.classes_ == class_of_interest)[0] | ||
response = response[:, col_idx] | ||
|
||
if xlabel is None: | ||
xlabel = X.columns[0] if hasattr(X, "columns") else "" | ||
|
Uh oh!
There was an error while loading. Please reload this page.