From 1e0f324f6968dee798dce020792f5c535c32ec1e Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 2 Aug 2023 15:11:27 +0200 Subject: [PATCH 1/5] ENH handle mutliclass with scores and probailities in DecisionBoundaryDisplay --- doc/whats_new/v1.4.rst | 8 +++ .../plot_classification_probability.py | 49 ++++++++++--------- sklearn/inspection/_plot/decision_boundary.py | 34 ++++++++++--- .../tests/test_boundary_decision_display.py | 46 ++++++++++++++--- 4 files changed, 97 insertions(+), 40 deletions(-) diff --git a/doc/whats_new/v1.4.rst b/doc/whats_new/v1.4.rst index 72fd30eb1050a..762ffb0159021 100644 --- a/doc/whats_new/v1.4.rst +++ b/doc/whats_new/v1.4.rst @@ -131,6 +131,14 @@ Changelog - |Fix| :func:`feature_selection.mutual_info_regression` now correctly computes the result when `X` is of integer dtype. :pr:`26748` by :user:`Yao Xiao `. +:mod:`sklearn.inspection` +......................... + +- |Enhancement| :class:`inspection.DecisionBoundaryDisplay` can be used with + `response_method` set to `"predict_proba"` or `"decision_function"` for multiclass + problem and by setting `class_label` to select the class to plot. + :pr:`xxx` by :user:`Guillaume Lemaitre `. + :mod:`sklearn.linear_model` ........................... diff --git a/examples/classification/plot_classification_probability.py b/examples/classification/plot_classification_probability.py index ec5887b63914d..bb84ee1b7f143 100644 --- a/examples/classification/plot_classification_probability.py +++ b/examples/classification/plot_classification_probability.py @@ -22,10 +22,12 @@ import matplotlib.pyplot as plt import numpy as np +from matplotlib import cm from sklearn import datasets from sklearn.gaussian_process import GaussianProcessClassifier from sklearn.gaussian_process.kernels import RBF +from sklearn.inspection import DecisionBoundaryDisplay from sklearn.linear_model import LogisticRegression from sklearn.metrics import accuracy_score from sklearn.svm import SVC @@ -35,6 +37,7 @@ y = iris.target n_features = X.shape[1] +n_classes = len(np.unique(y)) C = 10 kernel = 1.0 * RBF([1.0, 1.0]) # for GPC @@ -56,13 +59,7 @@ n_classifiers = len(classifiers) -plt.figure(figsize=(3 * 2, n_classifiers * 2)) -plt.subplots_adjust(bottom=0.2, top=0.95) - -xx = np.linspace(3, 9, 100) -yy = np.linspace(1, 5, 100).T -xx, yy = np.meshgrid(xx, yy) -Xfull = np.c_[xx.ravel(), yy.ravel()] +fig, axs = plt.subplots(nrows=n_classifiers, ncols=n_classes, figsize=(6, 14)) for index, (name, classifier) in enumerate(classifiers.items()): classifier.fit(X, y) @@ -71,25 +68,29 @@ accuracy = accuracy_score(y, y_pred) print("Accuracy (train) for %s: %0.1f%% " % (name, accuracy * 100)) - # View probabilities: - probas = classifier.predict_proba(Xfull) - n_classes = np.unique(y_pred).size - for k in range(n_classes): - plt.subplot(n_classifiers, n_classes, index * n_classes + k + 1) - plt.title("Class %d" % k) - if k == 0: - plt.ylabel(name) - imshow_handle = plt.imshow( - probas[:, k].reshape((100, 100)), extent=(3, 9, 1, 5), origin="lower" + for k in classifier.classes_: + disp = DecisionBoundaryDisplay.from_estimator( + classifier, + X, + plot_method="pcolormesh", + response_method="predict_proba", + class_label=k, + ax=axs[index, k], + ) + axs[index, k].set( + xticks=(), yticks=(), ylabel=name if k == 0 else None, title=f"Class #{k}" ) - plt.xticks(()) - plt.yticks(()) idx = y_pred == k if idx.any(): - plt.scatter(X[idx, 0], X[idx, 1], marker="o", c="w", edgecolor="k") - -ax = plt.axes([0.15, 0.04, 0.7, 0.05]) -plt.title("Probability") -plt.colorbar(imshow_handle, cax=ax, orientation="horizontal") + axs[index, k].scatter( + X[idx, 0], X[idx, 1], marker="o", c="w", edgecolor="k" + ) + +fig.colorbar( + cm.ScalarMappable(norm=None, cmap="viridis"), + ax=axs, + orientation="horizontal", + label="Probability", +) plt.show() diff --git a/sklearn/inspection/_plot/decision_boundary.py b/sklearn/inspection/_plot/decision_boundary.py index c9d2a52b6e9ab..739afd119312d 100644 --- a/sklearn/inspection/_plot/decision_boundary.py +++ b/sklearn/inspection/_plot/decision_boundary.py @@ -12,7 +12,7 @@ ) -def _check_boundary_response_method(estimator, response_method): +def _check_boundary_response_method(estimator, response_method, class_label): """Return prediction method from the `response_method` for decision boundary. Parameters @@ -37,13 +37,18 @@ def _check_boundary_response_method(estimator, response_method): raise ValueError(msg) if has_classes and len(estimator.classes_) > 2: - if response_method not in {"auto", "predict"}: + if response_method in {"predict_proba", "decision_function"} and ( + class_label is None or class_label not in estimator.classes_ + ): msg = ( - "Multiclass classifiers are only supported when response_method is" - " 'predict' or 'auto'" + "When `response_method` is set to 'predict_proba' or " + "'decision_function' and the target is multiclass, you must define " + "the class label to be selected as class of interest. Got " + f"class_label={class_label} instead. Potential choices are: " + f"{estimator.classes_}." ) raise ValueError(msg) - methods_list = ["predict"] + methods_list = ["predict"] if response_method == "auto" else [response_method] elif response_method == "auto": methods_list = ["decision_function", "predict_proba", "predict"] else: @@ -206,6 +211,7 @@ def from_estimator( eps=1.0, plot_method="contourf", response_method="auto", + class_label=None, xlabel=None, ylabel=None, ax=None, @@ -248,6 +254,13 @@ def from_estimator( For multiclass problems, :term:`predict` is selected when `response_method="auto"`. + class_label : int, float or str, default=None + When `response_method` return several columns (i.e. `"predict_proba"` and + `"decision_function"`), `class_label` specifies which column to use for + plotting. + + .. versionadded:: 1.4 + xlabel : str, default=None The label used for the x-axis. If `None`, an attempt is made to extract a label from `X` if it is a dataframe, otherwise an empty @@ -342,7 +355,9 @@ def from_estimator( else: X_grid = np.c_[xx0.ravel(), xx1.ravel()] - pred_func = _check_boundary_response_method(estimator, response_method) + pred_func = _check_boundary_response_method( + estimator, response_method, class_label + ) response = pred_func(X_grid) # convert classes predictions into integers @@ -355,8 +370,11 @@ def from_estimator( if is_regressor(estimator): raise ValueError("Multi-output regressors are not supported") - # TODO: Support pos_label - response = response[:, 1] + if class_label is None: + response = response[:, 1] + else: + target_index = np.flatnonzero(estimator.classes_ == class_label)[0] + response = response[:, target_index] if xlabel is None: xlabel = X.columns[0] if hasattr(X, "columns") else "" diff --git a/sklearn/inspection/_plot/tests/test_boundary_decision_display.py b/sklearn/inspection/_plot/tests/test_boundary_decision_display.py index 47c21e4521c35..0c5e4fff7cf28 100644 --- a/sklearn/inspection/_plot/tests/test_boundary_decision_display.py +++ b/sklearn/inspection/_plot/tests/test_boundary_decision_display.py @@ -54,7 +54,7 @@ def decision_function(self): pass a_inst = A() - method = _check_boundary_response_method(a_inst, "auto") + method = _check_boundary_response_method(a_inst, "auto", None) assert method == a_inst.decision_function class B: @@ -62,7 +62,7 @@ def predict_proba(self): pass b_inst = B() - method = _check_boundary_response_method(b_inst, "auto") + method = _check_boundary_response_method(b_inst, "auto", None) assert method == b_inst.predict_proba class C: @@ -73,7 +73,7 @@ def decision_function(self): pass c_inst = C() - method = _check_boundary_response_method(c_inst, "auto") + method = _check_boundary_response_method(c_inst, "auto", None) assert method == c_inst.decision_function class D: @@ -81,7 +81,7 @@ def predict(self): pass d_inst = D() - method = _check_boundary_response_method(d_inst, "auto") + method = _check_boundary_response_method(d_inst, "auto", None) assert method == d_inst.predict @@ -92,10 +92,7 @@ def test_multiclass_error(pyplot, response_method): X = X[:, [0, 1]] lr = LogisticRegression().fit(X, y) - msg = ( - "Multiclass classifiers are only supported when response_method is 'predict' or" - " 'auto'" - ) + msg = "you must define the class label to be selected" with pytest.raises(ValueError, match=msg): DecisionBoundaryDisplay.from_estimator(lr, X, response_method=response_method) @@ -125,6 +122,39 @@ def test_multiclass(pyplot, response_method): assert_allclose(disp.xx1, xx1) +@pytest.mark.parametrize("response_method", ["predict_proba", "decision_function"]) +@pytest.mark.parametrize("class_label", [0, 1, 2]) +def test_multiclass_class_label(pyplot, response_method, class_label): + """Check multiclass with decision function and probabilities provide the expected + results.""" + grid_resolution = 10 + eps = 1.0 + X, y = make_classification(n_classes=3, n_informative=3, random_state=0) + X = X[:, [0, 1]] + lr = LogisticRegression(random_state=0).fit(X, y) + + disp = DecisionBoundaryDisplay.from_estimator( + lr, + X, + response_method=response_method, + class_label=class_label, + grid_resolution=grid_resolution, + eps=1.0, + ) + + x0_min, x0_max = X[:, 0].min() - eps, X[:, 0].max() + eps + x1_min, x1_max = X[:, 1].min() - eps, X[:, 1].max() + eps + xx0, xx1 = np.meshgrid( + np.linspace(x0_min, x0_max, grid_resolution), + np.linspace(x1_min, x1_max, grid_resolution), + ) + response = getattr(lr, response_method)(np.c_[xx0.ravel(), xx1.ravel()]) + response = response[:, class_label] + assert_allclose(disp.response, response.reshape(xx0.shape)) + assert_allclose(disp.xx0, xx0) + assert_allclose(disp.xx1, xx1) + + @pytest.mark.parametrize( "kwargs, error_msg", [ From 901bba90da8c8ae3d9068d94f55db51ab4b5b04e Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 2 Aug 2023 15:14:14 +0200 Subject: [PATCH 2/5] change the pr number --- doc/whats_new/v1.4.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/whats_new/v1.4.rst b/doc/whats_new/v1.4.rst index 762ffb0159021..6fed4729ec706 100644 --- a/doc/whats_new/v1.4.rst +++ b/doc/whats_new/v1.4.rst @@ -137,7 +137,7 @@ Changelog - |Enhancement| :class:`inspection.DecisionBoundaryDisplay` can be used with `response_method` set to `"predict_proba"` or `"decision_function"` for multiclass problem and by setting `class_label` to select the class to plot. - :pr:`xxx` by :user:`Guillaume Lemaitre `. + :pr:`26995` by :user:`Guillaume Lemaitre `. :mod:`sklearn.linear_model` ........................... From b0fd28d1112a2d19ae29d5000539895916f16c3d Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Thu, 3 Aug 2023 13:43:03 +0200 Subject: [PATCH 3/5] iter --- sklearn/inspection/_plot/decision_boundary.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/sklearn/inspection/_plot/decision_boundary.py b/sklearn/inspection/_plot/decision_boundary.py index 739afd119312d..f540b05c11d97 100644 --- a/sklearn/inspection/_plot/decision_boundary.py +++ b/sklearn/inspection/_plot/decision_boundary.py @@ -255,9 +255,14 @@ def from_estimator( `response_method="auto"`. class_label : int, float or str, default=None - When `response_method` return several columns (i.e. `"predict_proba"` and - `"decision_function"`), `class_label` specifies which column to use for - plotting. + When dealing with a multiclass problem, you can visualize one class against + the other classes by specifying `class_label`. This can be used in + combination with `"predict_proba"` and `"decision_function"` passed + ass `response_method`. + + See the example entitle + :ref:`sphx_glr_auto_examples_classification_plot_classification_probability.py.py` + that shows how to use this parameter. .. versionadded:: 1.4 From e612aec64cc25e2b90892eb8e6321fe3fc26a563 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Thu, 3 Aug 2023 13:45:15 +0200 Subject: [PATCH 4/5] fix test --- .../_plot/tests/test_boundary_decision_display.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/sklearn/inspection/_plot/tests/test_boundary_decision_display.py b/sklearn/inspection/_plot/tests/test_boundary_decision_display.py index 0c5e4fff7cf28..184c6cc52238c 100644 --- a/sklearn/inspection/_plot/tests/test_boundary_decision_display.py +++ b/sklearn/inspection/_plot/tests/test_boundary_decision_display.py @@ -129,8 +129,15 @@ def test_multiclass_class_label(pyplot, response_method, class_label): results.""" grid_resolution = 10 eps = 1.0 - X, y = make_classification(n_classes=3, n_informative=3, random_state=0) - X = X[:, [0, 1]] + X, y = make_classification( + n_features=2, + n_classes=3, + n_informative=2, + n_redundant=0, + n_repeated=0, + n_clusters_per_class=1, + random_state=0, + ) lr = LogisticRegression(random_state=0).fit(X, y) disp = DecisionBoundaryDisplay.from_estimator( From ecbc5497bb18415f82251b1584b622b398097bca Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Thu, 3 Aug 2023 14:55:34 +0200 Subject: [PATCH 5/5] iter --- .../plot_classification_probability.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/examples/classification/plot_classification_probability.py b/examples/classification/plot_classification_probability.py index bb84ee1b7f143..c17b407fe3bc9 100644 --- a/examples/classification/plot_classification_probability.py +++ b/examples/classification/plot_classification_probability.py @@ -76,18 +76,25 @@ response_method="predict_proba", class_label=k, ax=axs[index, k], + alpha=0.5, + cmap="RdBu", ) axs[index, k].set( xticks=(), yticks=(), ylabel=name if k == 0 else None, title=f"Class #{k}" ) - idx = y_pred == k - if idx.any(): - axs[index, k].scatter( - X[idx, 0], X[idx, 1], marker="o", c="w", edgecolor="k" - ) + scatter = axs[index, k].scatter( + X[:, 0], X[:, 1], marker="o", c=y_pred, edgecolor="k", alpha=0.7 + ) + +axs[4, 1].legend( + scatter.legend_elements()[0], + iris.target_names, + bbox_to_anchor=(1.03, -0.1), + title="Predicted classes", +) fig.colorbar( - cm.ScalarMappable(norm=None, cmap="viridis"), + cm.ScalarMappable(norm=None, cmap="RdBu"), ax=axs, orientation="horizontal", label="Probability",