Skip to content

ENH handle mutliclass with scores and probailities in DecisionBoundaryDisplay #26995

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions doc/whats_new/v1.4.rst
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,14 @@ Changelog
- |Fix| :func:`feature_selection.mutual_info_regression` now correctly computes the
result when `X` is of integer dtype. :pr:`26748` by :user:`Yao Xiao <Charlie-XIAO>`.

:mod:`sklearn.inspection`
.........................

- |Enhancement| :class:`inspection.DecisionBoundaryDisplay` can be used with
`response_method` set to `"predict_proba"` or `"decision_function"` for multiclass
problem and by setting `class_label` to select the class to plot.
:pr:`26995` by :user:`Guillaume Lemaitre <glemaitre>`.

:mod:`sklearn.linear_model`
...........................

Expand Down
60 changes: 34 additions & 26 deletions examples/classification/plot_classification_probability.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,12 @@

import matplotlib.pyplot as plt
import numpy as np
from matplotlib import cm

from sklearn import datasets
from sklearn.gaussian_process import GaussianProcessClassifier
from sklearn.gaussian_process.kernels import RBF
from sklearn.inspection import DecisionBoundaryDisplay
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from sklearn.svm import SVC
Expand All @@ -35,6 +37,7 @@
y = iris.target

n_features = X.shape[1]
n_classes = len(np.unique(y))

C = 10
kernel = 1.0 * RBF([1.0, 1.0]) # for GPC
Expand All @@ -56,13 +59,7 @@

n_classifiers = len(classifiers)

plt.figure(figsize=(3 * 2, n_classifiers * 2))
plt.subplots_adjust(bottom=0.2, top=0.95)

xx = np.linspace(3, 9, 100)
yy = np.linspace(1, 5, 100).T
xx, yy = np.meshgrid(xx, yy)
Xfull = np.c_[xx.ravel(), yy.ravel()]
fig, axs = plt.subplots(nrows=n_classifiers, ncols=n_classes, figsize=(6, 14))

for index, (name, classifier) in enumerate(classifiers.items()):
classifier.fit(X, y)
Expand All @@ -71,25 +68,36 @@
accuracy = accuracy_score(y, y_pred)
print("Accuracy (train) for %s: %0.1f%% " % (name, accuracy * 100))

# View probabilities:
probas = classifier.predict_proba(Xfull)
n_classes = np.unique(y_pred).size
for k in range(n_classes):
plt.subplot(n_classifiers, n_classes, index * n_classes + k + 1)
plt.title("Class %d" % k)
if k == 0:
plt.ylabel(name)
imshow_handle = plt.imshow(
probas[:, k].reshape((100, 100)), extent=(3, 9, 1, 5), origin="lower"
for k in classifier.classes_:
disp = DecisionBoundaryDisplay.from_estimator(
classifier,
X,
plot_method="pcolormesh",
response_method="predict_proba",
class_label=k,
ax=axs[index, k],
alpha=0.5,
cmap="RdBu",
)
axs[index, k].set(
xticks=(), yticks=(), ylabel=name if k == 0 else None, title=f"Class #{k}"
)
plt.xticks(())
plt.yticks(())
idx = y_pred == k
if idx.any():
plt.scatter(X[idx, 0], X[idx, 1], marker="o", c="w", edgecolor="k")

ax = plt.axes([0.15, 0.04, 0.7, 0.05])
plt.title("Probability")
plt.colorbar(imshow_handle, cax=ax, orientation="horizontal")
scatter = axs[index, k].scatter(
X[:, 0], X[:, 1], marker="o", c=y_pred, edgecolor="k", alpha=0.7
)

axs[4, 1].legend(
scatter.legend_elements()[0],
iris.target_names,
bbox_to_anchor=(1.03, -0.1),
title="Predicted classes",
)

fig.colorbar(
cm.ScalarMappable(norm=None, cmap="RdBu"),
ax=axs,
orientation="horizontal",
label="Probability",
)

plt.show()
39 changes: 31 additions & 8 deletions sklearn/inspection/_plot/decision_boundary.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
)


def _check_boundary_response_method(estimator, response_method):
def _check_boundary_response_method(estimator, response_method, class_label):
"""Return prediction method from the `response_method` for decision boundary.

Parameters
Expand All @@ -37,13 +37,18 @@ def _check_boundary_response_method(estimator, response_method):
raise ValueError(msg)

if has_classes and len(estimator.classes_) > 2:
if response_method not in {"auto", "predict"}:
if response_method in {"predict_proba", "decision_function"} and (
class_label is None or class_label not in estimator.classes_
):
msg = (
"Multiclass classifiers are only supported when response_method is"
" 'predict' or 'auto'"
"When `response_method` is set to 'predict_proba' or "
"'decision_function' and the target is multiclass, you must define "
"the class label to be selected as class of interest. Got "
f"class_label={class_label} instead. Potential choices are: "
f"{estimator.classes_}."
)
raise ValueError(msg)
methods_list = ["predict"]
methods_list = ["predict"] if response_method == "auto" else [response_method]
elif response_method == "auto":
methods_list = ["decision_function", "predict_proba", "predict"]
else:
Expand Down Expand Up @@ -206,6 +211,7 @@ def from_estimator(
eps=1.0,
plot_method="contourf",
response_method="auto",
class_label=None,
xlabel=None,
ylabel=None,
ax=None,
Expand Down Expand Up @@ -248,6 +254,18 @@ def from_estimator(
For multiclass problems, :term:`predict` is selected when
`response_method="auto"`.

class_label : int, float or str, default=None
When dealing with a multiclass problem, you can visualize one class against
the other classes by specifying `class_label`. This can be used in
combination with `"predict_proba"` and `"decision_function"` passed
ass `response_method`.

See the example entitle
:ref:`sphx_glr_auto_examples_classification_plot_classification_probability.py.py`
that shows how to use this parameter.

.. versionadded:: 1.4

xlabel : str, default=None
The label used for the x-axis. If `None`, an attempt is made to
extract a label from `X` if it is a dataframe, otherwise an empty
Expand Down Expand Up @@ -342,7 +360,9 @@ def from_estimator(
else:
X_grid = np.c_[xx0.ravel(), xx1.ravel()]

pred_func = _check_boundary_response_method(estimator, response_method)
pred_func = _check_boundary_response_method(
estimator, response_method, class_label
)
response = pred_func(X_grid)

# convert classes predictions into integers
Expand All @@ -355,8 +375,11 @@ def from_estimator(
if is_regressor(estimator):
raise ValueError("Multi-output regressors are not supported")

# TODO: Support pos_label
response = response[:, 1]
if class_label is None:
response = response[:, 1]
else:
target_index = np.flatnonzero(estimator.classes_ == class_label)[0]
response = response[:, target_index]

if xlabel is None:
xlabel = X.columns[0] if hasattr(X, "columns") else ""
Expand Down
53 changes: 45 additions & 8 deletions sklearn/inspection/_plot/tests/test_boundary_decision_display.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,15 @@ def decision_function(self):
pass

a_inst = A()
method = _check_boundary_response_method(a_inst, "auto")
method = _check_boundary_response_method(a_inst, "auto", None)
assert method == a_inst.decision_function

class B:
def predict_proba(self):
pass

b_inst = B()
method = _check_boundary_response_method(b_inst, "auto")
method = _check_boundary_response_method(b_inst, "auto", None)
assert method == b_inst.predict_proba

class C:
Expand All @@ -73,15 +73,15 @@ def decision_function(self):
pass

c_inst = C()
method = _check_boundary_response_method(c_inst, "auto")
method = _check_boundary_response_method(c_inst, "auto", None)
assert method == c_inst.decision_function

class D:
def predict(self):
pass

d_inst = D()
method = _check_boundary_response_method(d_inst, "auto")
method = _check_boundary_response_method(d_inst, "auto", None)
assert method == d_inst.predict


Expand All @@ -92,10 +92,7 @@ def test_multiclass_error(pyplot, response_method):
X = X[:, [0, 1]]
lr = LogisticRegression().fit(X, y)

msg = (
"Multiclass classifiers are only supported when response_method is 'predict' or"
" 'auto'"
)
msg = "you must define the class label to be selected"
with pytest.raises(ValueError, match=msg):
DecisionBoundaryDisplay.from_estimator(lr, X, response_method=response_method)

Expand Down Expand Up @@ -125,6 +122,46 @@ def test_multiclass(pyplot, response_method):
assert_allclose(disp.xx1, xx1)


@pytest.mark.parametrize("response_method", ["predict_proba", "decision_function"])
@pytest.mark.parametrize("class_label", [0, 1, 2])
def test_multiclass_class_label(pyplot, response_method, class_label):
"""Check multiclass with decision function and probabilities provide the expected
results."""
grid_resolution = 10
eps = 1.0
X, y = make_classification(
n_features=2,
n_classes=3,
n_informative=2,
n_redundant=0,
n_repeated=0,
n_clusters_per_class=1,
random_state=0,
)
lr = LogisticRegression(random_state=0).fit(X, y)

disp = DecisionBoundaryDisplay.from_estimator(
lr,
X,
response_method=response_method,
class_label=class_label,
grid_resolution=grid_resolution,
eps=1.0,
)

x0_min, x0_max = X[:, 0].min() - eps, X[:, 0].max() + eps
x1_min, x1_max = X[:, 1].min() - eps, X[:, 1].max() + eps
xx0, xx1 = np.meshgrid(
np.linspace(x0_min, x0_max, grid_resolution),
np.linspace(x1_min, x1_max, grid_resolution),
)
response = getattr(lr, response_method)(np.c_[xx0.ravel(), xx1.ravel()])
response = response[:, class_label]
assert_allclose(disp.response, response.reshape(xx0.shape))
assert_allclose(disp.xx0, xx0)
assert_allclose(disp.xx1, xx1)


@pytest.mark.parametrize(
"kwargs, error_msg",
[
Expand Down