From 66eb147cd5560691b8bfb386a7a097d1cef56750 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 4 Sep 2023 21:12:13 +0200 Subject: [PATCH 01/18] ENH accept in --- doc/whats_new/v1.4.rst | 9 ++ sklearn/inspection/_plot/decision_boundary.py | 67 +++++++++----- .../tests/test_boundary_decision_display.py | 40 +++----- sklearn/utils/_response.py | 18 +++- sklearn/utils/tests/test_response.py | 91 ++++++++++++++----- 5 files changed, 153 insertions(+), 72 deletions(-) diff --git a/doc/whats_new/v1.4.rst b/doc/whats_new/v1.4.rst index c13922c6cb22e..09abf9485f249 100644 --- a/doc/whats_new/v1.4.rst +++ b/doc/whats_new/v1.4.rst @@ -157,6 +157,15 @@ Changelog - |Fix| :func:`feature_selection.mutual_info_regression` now correctly computes the result when `X` is of integer dtype. :pr:`26748` by :user:`Yao Xiao `. +:mod:`sklearn.inspection` +......................... + +- |Enhancement| :class:`inspection.DecisionBoundaryDisplay` now accepts a parameter + `pos_label` to select the class of interest when plotting the response provided by + `response_method="predict_proba"` or `response_method="decision_function"`. It allows + to plot the decision boundary for both binary and multiclass classifiers. + :pr:`xxxx` by :user:`Guillaume Lemaitre `. + :mod:`sklearn.linear_model` ........................... diff --git a/sklearn/inspection/_plot/decision_boundary.py b/sklearn/inspection/_plot/decision_boundary.py index c9d2a52b6e9ab..9e0a4e0d6415c 100644 --- a/sklearn/inspection/_plot/decision_boundary.py +++ b/sklearn/inspection/_plot/decision_boundary.py @@ -1,10 +1,9 @@ -from functools import reduce - import numpy as np from ...base import is_regressor from ...preprocessing import LabelEncoder from ...utils import _safe_indexing, check_matplotlib_support +from ...utils._response import _get_response_values from ...utils.validation import ( _is_arraylike_not_scalar, _num_features, @@ -12,8 +11,8 @@ ) -def _check_boundary_response_method(estimator, response_method): - """Return prediction method from the `response_method` for decision boundary. +def _check_boundary_response_method(estimator, response_method, pos_label): + """Validate the response methods to be used with the fitted estimator. Parameters ---------- @@ -26,10 +25,17 @@ def _check_boundary_response_method(estimator, response_method): If set to 'auto', the response method is tried in the following order: :term:`decision_function`, :term:`predict_proba`, :term:`predict`. + pos_label : int, float, bool or str + The class considered as the positive class when plotting the decision. + If the label is specified, it then possible to plot the decision boundary in + multiclass settings. + + .. versionadded:: 1.4 + Returns ------- - prediction_method: callable - Prediction method of estimator. + prediction_method : list of str or str + The name of the response methods to use or a list of such names. """ has_classes = hasattr(estimator, "classes_") if has_classes and _is_arraylike_not_scalar(estimator.classes_[0]): @@ -37,25 +43,18 @@ def _check_boundary_response_method(estimator, response_method): raise ValueError(msg) if has_classes and len(estimator.classes_) > 2: - if response_method not in {"auto", "predict"}: + if response_method not in {"auto", "predict"} and pos_label is None: msg = ( "Multiclass classifiers are only supported when response_method is" - " 'predict' or 'auto'" + " 'predict' or 'auto', or you must provide `pos_label` to select a" + " specific class to plot the decision boundary." ) raise ValueError(msg) - methods_list = ["predict"] + prediction_method = "predict" if response_method == "auto" else response_method elif response_method == "auto": - methods_list = ["decision_function", "predict_proba", "predict"] + prediction_method = ["decision_function", "predict_proba", "predict"] else: - methods_list = [response_method] - - prediction_method = [getattr(estimator, method, None) for method in methods_list] - prediction_method = reduce(lambda x, y: x or y, prediction_method) - if prediction_method is None: - raise ValueError( - f"{estimator.__class__.__name__} has none of the following attributes: " - f"{', '.join(methods_list)}." - ) + prediction_method = response_method return prediction_method @@ -206,6 +205,7 @@ def from_estimator( eps=1.0, plot_method="contourf", response_method="auto", + pos_label=None, xlabel=None, ylabel=None, ax=None, @@ -248,6 +248,12 @@ def from_estimator( For multiclass problems, :term:`predict` is selected when `response_method="auto"`. + pos_label : int, float, bool or str, default=None + The class considered as the positive class when plotting the decision. + By default, `estimators.classes_[1]` is considered as the positive class. + + .. versionadded:: 1.4 + xlabel : str, default=None The label used for the x-axis. If `None`, an attempt is made to extract a label from `X` if it is a dataframe, otherwise an empty @@ -342,11 +348,19 @@ def from_estimator( else: X_grid = np.c_[xx0.ravel(), xx1.ravel()] - pred_func = _check_boundary_response_method(estimator, response_method) - response = pred_func(X_grid) + prediction_method = _check_boundary_response_method( + estimator, response_method, pos_label + ) + response, _, response_method_used = _get_response_values( + estimator, + X_grid, + response_method=prediction_method, + pos_label=pos_label, + return_response_method_used=True, + ) # convert classes predictions into integers - if pred_func.__name__ == "predict" and hasattr(estimator, "classes_"): + if response_method_used == "predict" and hasattr(estimator, "classes_"): encoder = LabelEncoder() encoder.classes_ = estimator.classes_ response = encoder.transform(response) @@ -355,8 +369,13 @@ def from_estimator( if is_regressor(estimator): raise ValueError("Multi-output regressors are not supported") - # TODO: Support pos_label - response = response[:, 1] + # For the multiclass case, `_get_response_values` returns the response + # as-is. Thus, we have a column per class and we need to select the column + # corresponding to the positive class. + if pos_label is None: + pos_label = estimator.classes_[1] + col_idx = np.flatnonzero(estimator.classes_ == pos_label)[0] + response = response[:, col_idx] if xlabel is None: xlabel = X.columns[0] if hasattr(X, "columns") else "" diff --git a/sklearn/inspection/_plot/tests/test_boundary_decision_display.py b/sklearn/inspection/_plot/tests/test_boundary_decision_display.py index 47c21e4521c35..33eaab674e879 100644 --- a/sklearn/inspection/_plot/tests/test_boundary_decision_display.py +++ b/sklearn/inspection/_plot/tests/test_boundary_decision_display.py @@ -49,40 +49,26 @@ def test_input_data_dimension(pyplot): def test_check_boundary_response_method_auto(): """Check _check_boundary_response_method behavior with 'auto'.""" + expected_methods = ["decision_function", "predict_proba", "predict"] + class A: def decision_function(self): pass - a_inst = A() - method = _check_boundary_response_method(a_inst, "auto") - assert method == a_inst.decision_function - class B: def predict_proba(self): pass - b_inst = B() - method = _check_boundary_response_method(b_inst, "auto") - assert method == b_inst.predict_proba - - class C: - def predict_proba(self): - pass - - def decision_function(self): - pass - - c_inst = C() - method = _check_boundary_response_method(c_inst, "auto") - assert method == c_inst.decision_function + class C(A, B): + pass class D: def predict(self): pass - d_inst = D() - method = _check_boundary_response_method(d_inst, "auto") - assert method == d_inst.predict + for Klass in [A, B, C, D]: + methods = _check_boundary_response_method(Klass(), "auto", None) + assert methods == expected_methods @pytest.mark.parametrize("response_method", ["predict_proba", "decision_function"]) @@ -198,18 +184,21 @@ def test_decision_boundary_display(pyplot, fitted_clf, response_method, plot_met @pytest.mark.parametrize( - "response_method, msg", + "response_method, type_err, msg", [ ( "predict_proba", + AttributeError, "MyClassifier has none of the following attributes: predict_proba", ), ( "decision_function", + AttributeError, "MyClassifier has none of the following attributes: decision_function", ), ( "auto", + AttributeError, ( "MyClassifier has none of the following attributes: decision_function, " "predict_proba, predict" @@ -217,11 +206,12 @@ def test_decision_boundary_display(pyplot, fitted_clf, response_method, plot_met ), ( "bad_method", + AttributeError, "MyClassifier has none of the following attributes: bad_method", ), ], ) -def test_error_bad_response(pyplot, response_method, msg): +def test_error_bad_response(pyplot, response_method, type_err, msg): """Check errors for bad response.""" class MyClassifier(BaseEstimator, ClassifierMixin): @@ -232,7 +222,7 @@ def fit(self, X, y): clf = MyClassifier().fit(X, y) - with pytest.raises(ValueError, match=msg): + with pytest.raises(type_err, match=msg): DecisionBoundaryDisplay.from_estimator(clf, X, response_method=response_method) @@ -274,7 +264,7 @@ def test_multioutput_regressor_error(pyplot): y = np.asarray([[0, 1], [4, 1]]) tree = DecisionTreeRegressor().fit(X, y) with pytest.raises(ValueError, match="Multi-output regressors are not supported"): - DecisionBoundaryDisplay.from_estimator(tree, X) + DecisionBoundaryDisplay.from_estimator(tree, X, response_method="predict") @pytest.mark.filterwarnings( diff --git a/sklearn/utils/_response.py b/sklearn/utils/_response.py index e753ced045e1e..046f71287f4d3 100644 --- a/sklearn/utils/_response.py +++ b/sklearn/utils/_response.py @@ -13,6 +13,7 @@ def _get_response_values( X, response_method, pos_label=None, + return_response_method_used=False, ): """Compute the response values of a classifier or a regressor. @@ -49,6 +50,12 @@ def _get_response_values( the metrics. By default, `estimators.classes_[1]` is considered as the positive class. + return_response_method_used : bool, default=False + Whether to return the response method used to compute the response + values. + + .. versionadded:: 1.4 + Returns ------- y_pred : ndarray of shape (n_samples,) @@ -59,6 +66,12 @@ def _get_response_values( The class considered as the positive class when computing the metrics. Returns `None` if `estimator` is a regressor. + response_method_used : str + The response method used to compute the response values. Only returned + if `return_response_method_used` is `True`. + + .. versionadded:: 1.4 + Raises ------ ValueError @@ -106,8 +119,11 @@ def _get_response_values( "should be 'predict'. Got a regressor with response_method=" f"{response_method} instead." ) - y_pred, pos_label = estimator.predict(X), None + prediction_method = estimator.predict + y_pred, pos_label = prediction_method(X), None + if return_response_method_used: + return y_pred, pos_label, prediction_method.__name__ return y_pred, pos_label diff --git a/sklearn/utils/tests/test_response.py b/sklearn/utils/tests/test_response.py index a67346e5697ec..8eefa11fdea66 100644 --- a/sklearn/utils/tests/test_response.py +++ b/sklearn/utils/tests/test_response.py @@ -29,17 +29,21 @@ def test_get_response_values_regressor_error(response_method): _get_response_values(my_estimator, X, response_method=response_method) -def test_get_response_values_regressor(): +@pytest.mark.parametrize("return_response_method_used", [True, False]) +def test_get_response_values_regressor(return_response_method_used): """Check the behaviour of `_get_response_values` with regressor.""" X, y = make_regression(n_samples=10, random_state=0) regressor = LinearRegression().fit(X, y) - y_pred, pos_label = _get_response_values( + results = _get_response_values( regressor, X, response_method="predict", + return_response_method_used=return_response_method_used, ) - assert_array_equal(y_pred, regressor.predict(X)) - assert pos_label is None + assert_array_equal(results[0], regressor.predict(X)) + assert results[1] is None + if return_response_method_used: + assert results[2] == "predict" @pytest.mark.parametrize( @@ -78,7 +82,10 @@ def test_get_response_values_classifier_inconsistent_y_pred_for_binary_proba(): _get_response_values(classifier, X, response_method="predict_proba") -def test_get_response_values_binary_classifier_decision_function(): +@pytest.mark.parametrize("return_response_method_used", [True, False]) +def test_get_response_values_binary_classifier_decision_function( + return_response_method_used, +): """Check the behaviour of `_get_response_values` with `decision_function` and binary classifier.""" X, y = make_classification( @@ -91,27 +98,36 @@ def test_get_response_values_binary_classifier_decision_function(): response_method = "decision_function" # default `pos_label` - y_pred, pos_label = _get_response_values( + results = _get_response_values( classifier, X, response_method=response_method, pos_label=None, + return_response_method_used=return_response_method_used, ) - assert_allclose(y_pred, classifier.decision_function(X)) - assert pos_label == 1 + assert_allclose(results[0], classifier.decision_function(X)) + assert results[1] == 1 + if return_response_method_used: + assert results[2] == "decision_function" # when forcing `pos_label=classifier.classes_[0]` - y_pred, pos_label = _get_response_values( + results = _get_response_values( classifier, X, response_method=response_method, pos_label=classifier.classes_[0], + return_response_method_used=return_response_method_used, ) - assert_allclose(y_pred, classifier.decision_function(X) * -1) - assert pos_label == 0 + assert_allclose(results[0], classifier.decision_function(X) * -1) + assert results[1] == 0 + if return_response_method_used: + assert results[2] == "decision_function" -def test_get_response_values_binary_classifier_predict_proba(): +@pytest.mark.parametrize("return_response_method_used", [True, False]) +def test_get_response_values_binary_classifier_predict_proba( + return_response_method_used, +): """Check that `_get_response_values` with `predict_proba` and binary classifier.""" X, y = make_classification( @@ -124,24 +140,28 @@ def test_get_response_values_binary_classifier_predict_proba(): response_method = "predict_proba" # default `pos_label` - y_pred, pos_label = _get_response_values( + results = _get_response_values( classifier, X, response_method=response_method, pos_label=None, + return_response_method_used=return_response_method_used, ) - assert_allclose(y_pred, classifier.predict_proba(X)[:, 1]) - assert pos_label == 1 + assert_allclose(results[0], classifier.predict_proba(X)[:, 1]) + assert results[1] == 1 + if return_response_method_used: + assert results[2] == "predict_proba" # when forcing `pos_label=classifier.classes_[0]` - y_pred, pos_label = _get_response_values( + results = _get_response_values( classifier, X, response_method=response_method, pos_label=classifier.classes_[0], + return_response_method_used=return_response_method_used, ) - assert_allclose(y_pred, classifier.predict_proba(X)[:, 0]) - assert pos_label == 0 + assert_allclose(results[0], classifier.predict_proba(X)[:, 0]) + assert results[1] == 0 @pytest.mark.parametrize( @@ -184,13 +204,13 @@ def test_get_response_predict_proba(): y_proba, pos_label = _get_response_values_binary( classifier, X_binary, response_method="predict_proba" ) - np.testing.assert_allclose(y_proba, classifier.predict_proba(X_binary)[:, 1]) + assert_allclose(y_proba, classifier.predict_proba(X_binary)[:, 1]) assert pos_label == 1 y_proba, pos_label = _get_response_values_binary( classifier, X_binary, response_method="predict_proba", pos_label=0 ) - np.testing.assert_allclose(y_proba, classifier.predict_proba(X_binary)[:, 0]) + assert_allclose(y_proba, classifier.predict_proba(X_binary)[:, 0]) assert pos_label == 0 @@ -200,13 +220,13 @@ def test_get_response_decision_function(): y_score, pos_label = _get_response_values_binary( classifier, X_binary, response_method="decision_function" ) - np.testing.assert_allclose(y_score, classifier.decision_function(X_binary)) + assert_allclose(y_score, classifier.decision_function(X_binary)) assert pos_label == 1 y_score, pos_label = _get_response_values_binary( classifier, X_binary, response_method="decision_function", pos_label=0 ) - np.testing.assert_allclose(y_score, classifier.decision_function(X_binary) * -1) + assert_allclose(y_score, classifier.decision_function(X_binary) * -1) assert pos_label == 0 @@ -230,3 +250,30 @@ def test_get_response_values_multiclass(estimator, response_method): assert predictions.shape == (X.shape[0], len(estimator.classes_)) if response_method == "predict_proba": assert np.logical_and(predictions >= 0, predictions <= 1).all() + + +def test_get_response_values_with_response_list(): + """Check the behaviour of passing a list of responses to `_get_response_values`.""" + classifier = LogisticRegression().fit(X_binary, y_binary) + + # it should use `predict_proba` + y_pred, pos_label, response_method = _get_response_values( + classifier, + X_binary, + response_method=["predict_proba", "decision_function"], + return_response_method_used=True, + ) + assert_allclose(y_pred, classifier.predict_proba(X_binary)[:, 1]) + assert pos_label == 1 + assert response_method == "predict_proba" + + # it should use `decision_function` + y_pred, pos_label, response_method = _get_response_values( + classifier, + X_binary, + response_method=["decision_function", "predict_proba"], + return_response_method_used=True, + ) + assert_allclose(y_pred, classifier.decision_function(X_binary)) + assert pos_label == 1 + assert response_method == "decision_function" From fad3a1c310e65ed803eeb07cb7d0671e30bfa553 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 4 Sep 2023 21:15:46 +0200 Subject: [PATCH 02/18] change pr number --- doc/whats_new/v1.4.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/whats_new/v1.4.rst b/doc/whats_new/v1.4.rst index 09abf9485f249..be88cbbe8ecbd 100644 --- a/doc/whats_new/v1.4.rst +++ b/doc/whats_new/v1.4.rst @@ -164,7 +164,7 @@ Changelog `pos_label` to select the class of interest when plotting the response provided by `response_method="predict_proba"` or `response_method="decision_function"`. It allows to plot the decision boundary for both binary and multiclass classifiers. - :pr:`xxxx` by :user:`Guillaume Lemaitre `. + :pr:`27291` by :user:`Guillaume Lemaitre `. :mod:`sklearn.linear_model` ........................... From 82a41501e7eb6341dae802cf9cd4b2beb6d1f255 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 4 Sep 2023 21:17:23 +0200 Subject: [PATCH 03/18] less diff --- .../_plot/tests/test_boundary_decision_display.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/sklearn/inspection/_plot/tests/test_boundary_decision_display.py b/sklearn/inspection/_plot/tests/test_boundary_decision_display.py index 33eaab674e879..cffa3ff175ecf 100644 --- a/sklearn/inspection/_plot/tests/test_boundary_decision_display.py +++ b/sklearn/inspection/_plot/tests/test_boundary_decision_display.py @@ -184,21 +184,18 @@ def test_decision_boundary_display(pyplot, fitted_clf, response_method, plot_met @pytest.mark.parametrize( - "response_method, type_err, msg", + "response_method, msg", [ ( "predict_proba", - AttributeError, "MyClassifier has none of the following attributes: predict_proba", ), ( "decision_function", - AttributeError, "MyClassifier has none of the following attributes: decision_function", ), ( "auto", - AttributeError, ( "MyClassifier has none of the following attributes: decision_function, " "predict_proba, predict" @@ -206,12 +203,11 @@ def test_decision_boundary_display(pyplot, fitted_clf, response_method, plot_met ), ( "bad_method", - AttributeError, "MyClassifier has none of the following attributes: bad_method", ), ], ) -def test_error_bad_response(pyplot, response_method, type_err, msg): +def test_error_bad_response(pyplot, response_method, msg): """Check errors for bad response.""" class MyClassifier(BaseEstimator, ClassifierMixin): @@ -222,7 +218,7 @@ def fit(self, X, y): clf = MyClassifier().fit(X, y) - with pytest.raises(type_err, match=msg): + with pytest.raises(AttributeError, match=msg): DecisionBoundaryDisplay.from_estimator(clf, X, response_method=response_method) From 20461deb910fbeafe77b9f8c2cd8ee4ce516e017 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 4 Sep 2023 21:46:07 +0200 Subject: [PATCH 04/18] modify example where the feature is useful --- .../plot_classification_probability.py | 61 +++++++++---------- 1 file changed, 30 insertions(+), 31 deletions(-) diff --git a/examples/classification/plot_classification_probability.py b/examples/classification/plot_classification_probability.py index ec5887b63914d..9191510143659 100644 --- a/examples/classification/plot_classification_probability.py +++ b/examples/classification/plot_classification_probability.py @@ -22,10 +22,12 @@ import matplotlib.pyplot as plt import numpy as np +from matplotlib import cm from sklearn import datasets from sklearn.gaussian_process import GaussianProcessClassifier from sklearn.gaussian_process.kernels import RBF +from sklearn.inspection import DecisionBoundaryDisplay from sklearn.linear_model import LogisticRegression from sklearn.metrics import accuracy_score from sklearn.svm import SVC @@ -56,40 +58,37 @@ n_classifiers = len(classifiers) -plt.figure(figsize=(3 * 2, n_classifiers * 2)) -plt.subplots_adjust(bottom=0.2, top=0.95) - -xx = np.linspace(3, 9, 100) -yy = np.linspace(1, 5, 100).T -xx, yy = np.meshgrid(xx, yy) -Xfull = np.c_[xx.ravel(), yy.ravel()] - -for index, (name, classifier) in enumerate(classifiers.items()): - classifier.fit(X, y) - - y_pred = classifier.predict(X) +fig, axes = plt.subplots( + nrows=n_classifiers, ncols=len(iris.target_names), figsize=(10, 16) +) +for classifier_idx, (name, classifier) in enumerate(classifiers.items()): + y_pred = classifier.fit(X, y).predict(X) accuracy = accuracy_score(y, y_pred) - print("Accuracy (train) for %s: %0.1f%% " % (name, accuracy * 100)) - - # View probabilities: - probas = classifier.predict_proba(Xfull) - n_classes = np.unique(y_pred).size - for k in range(n_classes): - plt.subplot(n_classifiers, n_classes, index * n_classes + k + 1) - plt.title("Class %d" % k) - if k == 0: - plt.ylabel(name) - imshow_handle = plt.imshow( - probas[:, k].reshape((100, 100)), extent=(3, 9, 1, 5), origin="lower" + print(f"Accuracy (train) for {name}: {accuracy:0.1%}") + for label in np.unique(y): + # plot the probability estimate provided by the classifier + disp = DecisionBoundaryDisplay.from_estimator( + classifier, + X, + response_method="predict_proba", + pos_label=label, + ax=axes[classifier_idx, label], + vmin=0, + vmax=1, + ) + axes[classifier_idx, label].set_title(f"Class {label}") + # plot the data that are predict to belong to the class + mask_y_pred = y_pred == label + axes[classifier_idx, label].scatter( + X[mask_y_pred, 0], X[mask_y_pred, 1], marker="o", c="w", edgecolor="k" ) - plt.xticks(()) - plt.yticks(()) - idx = y_pred == k - if idx.any(): - plt.scatter(X[idx, 0], X[idx, 1], marker="o", c="w", edgecolor="k") + axes[classifier_idx, label].set(xticks=(), yticks=()) + axes[classifier_idx, 0].set_ylabel(name) -ax = plt.axes([0.15, 0.04, 0.7, 0.05]) +ax = plt.axes([0.15, 0.04, 0.7, 0.02]) plt.title("Probability") -plt.colorbar(imshow_handle, cax=ax, orientation="horizontal") +_ = plt.colorbar( + cm.ScalarMappable(norm=None, cmap="viridis"), cax=ax, orientation="horizontal" +) plt.show() From a08fbf71ebff4dca61316c79f9500248c5f6a4ad Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 5 Sep 2023 14:22:47 +0200 Subject: [PATCH 05/18] TST add dedicated test --- .../tests/test_boundary_decision_display.py | 45 +++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/sklearn/inspection/_plot/tests/test_boundary_decision_display.py b/sklearn/inspection/_plot/tests/test_boundary_decision_display.py index cffa3ff175ecf..7683b930dfffb 100644 --- a/sklearn/inspection/_plot/tests/test_boundary_decision_display.py +++ b/sklearn/inspection/_plot/tests/test_boundary_decision_display.py @@ -339,3 +339,48 @@ def test_dataframe_support(pyplot): # no warnings linked to feature names validation should be raised warnings.simplefilter("error", UserWarning) DecisionBoundaryDisplay.from_estimator(estimator, df, response_method="predict") + + +@pytest.mark.parametrize("response_method", ["predict_proba", "decision_function"]) +def test_pos_label(pyplot, response_method): + """Check the behaviour of passing `pos_label` for plotting the output of + `predict_proba` and `decision_function`. + """ + iris = load_iris() + X = iris.data[:, :2] + y = iris.target # the target are numerical labels + pos_label_idx = 2 + + estimator = LogisticRegression().fit(X, y) + disp = DecisionBoundaryDisplay.from_estimator( + estimator, X, response_method=response_method, pos_label=pos_label_idx + ) + + # we will check that we plot the expected values as response + grid = np.concatenate([disp.xx0.reshape(-1, 1), disp.xx1.reshape(-1, 1)], axis=1) + response = getattr(estimator, response_method)(grid)[:, pos_label_idx] + assert_allclose(response.reshape(*disp.response.shape), disp.response) + + # make the same test but this time using target as strings + y = iris.target_names[iris.target] + estimator = LogisticRegression().fit(X, y) + + disp = DecisionBoundaryDisplay.from_estimator( + estimator, + X, + response_method=response_method, + pos_label=iris.target_names[pos_label_idx], + ) + + grid = np.concatenate([disp.xx0.reshape(-1, 1), disp.xx1.reshape(-1, 1)], axis=1) + response = getattr(estimator, response_method)(grid)[:, pos_label_idx] + assert_allclose(response.reshape(*disp.response.shape), disp.response) + + # check that we raise an error for unknown labels + # this test should already be handled in `_get_response_values` but we can have this + # test here as well + err_msg = "pos_label=2 is not a valid label: It should be one of" + with pytest.raises(ValueError, match=err_msg): + DecisionBoundaryDisplay.from_estimator( + estimator, X, response_method=response_method, pos_label=pos_label_idx + ) From c825ac44059d95eaa79fe9a67b6c34f16628f9e6 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 6 Sep 2023 12:04:35 +0200 Subject: [PATCH 06/18] rename pos_label to class_of_interest --- doc/whats_new/v1.4.rst | 7 ++-- .../plot_classification_probability.py | 2 +- sklearn/inspection/_plot/decision_boundary.py | 37 ++++++++++--------- .../tests/test_boundary_decision_display.py | 22 +++++++---- 4 files changed, 39 insertions(+), 29 deletions(-) diff --git a/doc/whats_new/v1.4.rst b/doc/whats_new/v1.4.rst index be88cbbe8ecbd..70b9d7f81f346 100644 --- a/doc/whats_new/v1.4.rst +++ b/doc/whats_new/v1.4.rst @@ -161,9 +161,10 @@ Changelog ......................... - |Enhancement| :class:`inspection.DecisionBoundaryDisplay` now accepts a parameter - `pos_label` to select the class of interest when plotting the response provided by - `response_method="predict_proba"` or `response_method="decision_function"`. It allows - to plot the decision boundary for both binary and multiclass classifiers. + `class_of_interest` to select the class of interest when plotting the response + provided by `response_method="predict_proba"` or + `response_method="decision_function"`. It allows to plot the decision boundary for + both binary and multiclass classifiers. :pr:`27291` by :user:`Guillaume Lemaitre `. :mod:`sklearn.linear_model` diff --git a/examples/classification/plot_classification_probability.py b/examples/classification/plot_classification_probability.py index 9191510143659..0499f49ac4855 100644 --- a/examples/classification/plot_classification_probability.py +++ b/examples/classification/plot_classification_probability.py @@ -71,7 +71,7 @@ classifier, X, response_method="predict_proba", - pos_label=label, + class_of_interest=label, ax=axes[classifier_idx, label], vmin=0, vmax=1, diff --git a/sklearn/inspection/_plot/decision_boundary.py b/sklearn/inspection/_plot/decision_boundary.py index 9e0a4e0d6415c..fe773cc2217c2 100644 --- a/sklearn/inspection/_plot/decision_boundary.py +++ b/sklearn/inspection/_plot/decision_boundary.py @@ -11,7 +11,7 @@ ) -def _check_boundary_response_method(estimator, response_method, pos_label): +def _check_boundary_response_method(estimator, response_method, class_of_interest): """Validate the response methods to be used with the fitted estimator. Parameters @@ -25,10 +25,9 @@ def _check_boundary_response_method(estimator, response_method, pos_label): If set to 'auto', the response method is tried in the following order: :term:`decision_function`, :term:`predict_proba`, :term:`predict`. - pos_label : int, float, bool or str - The class considered as the positive class when plotting the decision. - If the label is specified, it then possible to plot the decision boundary in - multiclass settings. + class_of_interest : int, float, bool or str + The class considered when plotting the decision. If the label is specified, it + then possible to plot the decision boundary in multiclass settings. .. versionadded:: 1.4 @@ -43,11 +42,11 @@ def _check_boundary_response_method(estimator, response_method, pos_label): raise ValueError(msg) if has_classes and len(estimator.classes_) > 2: - if response_method not in {"auto", "predict"} and pos_label is None: + if response_method not in {"auto", "predict"} and class_of_interest is None: msg = ( "Multiclass classifiers are only supported when response_method is" - " 'predict' or 'auto', or you must provide `pos_label` to select a" - " specific class to plot the decision boundary." + " 'predict' or 'auto', or you must provide `class_of_interest` to " + " select a specific class to plot the decision boundary." ) raise ValueError(msg) prediction_method = "predict" if response_method == "auto" else response_method @@ -205,7 +204,7 @@ def from_estimator( eps=1.0, plot_method="contourf", response_method="auto", - pos_label=None, + class_of_interest=None, xlabel=None, ylabel=None, ax=None, @@ -248,9 +247,9 @@ def from_estimator( For multiclass problems, :term:`predict` is selected when `response_method="auto"`. - pos_label : int, float, bool or str, default=None - The class considered as the positive class when plotting the decision. - By default, `estimators.classes_[1]` is considered as the positive class. + class_of_interest : int, float, bool or str, default=None + The class considered when plotting the decision. By default, + `estimators.classes_[1]` is considered as the positive class. .. versionadded:: 1.4 @@ -349,13 +348,13 @@ def from_estimator( X_grid = np.c_[xx0.ravel(), xx1.ravel()] prediction_method = _check_boundary_response_method( - estimator, response_method, pos_label + estimator, response_method, class_of_interest ) response, _, response_method_used = _get_response_values( estimator, X_grid, response_method=prediction_method, - pos_label=pos_label, + pos_label=class_of_interest, return_response_method_used=True, ) @@ -372,9 +371,13 @@ def from_estimator( # For the multiclass case, `_get_response_values` returns the response # as-is. Thus, we have a column per class and we need to select the column # corresponding to the positive class. - if pos_label is None: - pos_label = estimator.classes_[1] - col_idx = np.flatnonzero(estimator.classes_ == pos_label)[0] + if class_of_interest is None and len(estimator.classes_) > 2: + raise ValueError( + "With multiclass classification, you must specify the class of " + "interest, via the `class_of_interest` parameter, to plot the " + "decision boundary." + ) + col_idx = np.flatnonzero(estimator.classes_ == class_of_interest)[0] response = response[:, col_idx] if xlabel is None: diff --git a/sklearn/inspection/_plot/tests/test_boundary_decision_display.py b/sklearn/inspection/_plot/tests/test_boundary_decision_display.py index 7683b930dfffb..84136dbc02d21 100644 --- a/sklearn/inspection/_plot/tests/test_boundary_decision_display.py +++ b/sklearn/inspection/_plot/tests/test_boundary_decision_display.py @@ -342,23 +342,26 @@ def test_dataframe_support(pyplot): @pytest.mark.parametrize("response_method", ["predict_proba", "decision_function"]) -def test_pos_label(pyplot, response_method): - """Check the behaviour of passing `pos_label` for plotting the output of +def test_class_of_interest(pyplot, response_method): + """Check the behaviour of passing `class_of_interest` for plotting the output of `predict_proba` and `decision_function`. """ iris = load_iris() X = iris.data[:, :2] y = iris.target # the target are numerical labels - pos_label_idx = 2 + class_of_interest_idx = 2 estimator = LogisticRegression().fit(X, y) disp = DecisionBoundaryDisplay.from_estimator( - estimator, X, response_method=response_method, pos_label=pos_label_idx + estimator, + X, + response_method=response_method, + class_of_interest=class_of_interest_idx, ) # we will check that we plot the expected values as response grid = np.concatenate([disp.xx0.reshape(-1, 1), disp.xx1.reshape(-1, 1)], axis=1) - response = getattr(estimator, response_method)(grid)[:, pos_label_idx] + response = getattr(estimator, response_method)(grid)[:, class_of_interest_idx] assert_allclose(response.reshape(*disp.response.shape), disp.response) # make the same test but this time using target as strings @@ -369,11 +372,11 @@ def test_pos_label(pyplot, response_method): estimator, X, response_method=response_method, - pos_label=iris.target_names[pos_label_idx], + class_of_interest=iris.target_names[class_of_interest_idx], ) grid = np.concatenate([disp.xx0.reshape(-1, 1), disp.xx1.reshape(-1, 1)], axis=1) - response = getattr(estimator, response_method)(grid)[:, pos_label_idx] + response = getattr(estimator, response_method)(grid)[:, class_of_interest_idx] assert_allclose(response.reshape(*disp.response.shape), disp.response) # check that we raise an error for unknown labels @@ -382,5 +385,8 @@ def test_pos_label(pyplot, response_method): err_msg = "pos_label=2 is not a valid label: It should be one of" with pytest.raises(ValueError, match=err_msg): DecisionBoundaryDisplay.from_estimator( - estimator, X, response_method=response_method, pos_label=pos_label_idx + estimator, + X, + response_method=response_method, + class_of_interest=class_of_interest_idx, ) From 6cb3a55fc4cd619b57e3edb3b704cc13cb3bd0de Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 6 Sep 2023 14:04:13 +0200 Subject: [PATCH 07/18] better error message --- sklearn/inspection/_plot/decision_boundary.py | 31 +++++++++++-------- .../tests/test_boundary_decision_display.py | 2 +- 2 files changed, 19 insertions(+), 14 deletions(-) diff --git a/sklearn/inspection/_plot/decision_boundary.py b/sklearn/inspection/_plot/decision_boundary.py index fe773cc2217c2..6ac2816946669 100644 --- a/sklearn/inspection/_plot/decision_boundary.py +++ b/sklearn/inspection/_plot/decision_boundary.py @@ -350,13 +350,24 @@ def from_estimator( prediction_method = _check_boundary_response_method( estimator, response_method, class_of_interest ) - response, _, response_method_used = _get_response_values( - estimator, - X_grid, - response_method=prediction_method, - pos_label=class_of_interest, - return_response_method_used=True, - ) + try: + response, _, response_method_used = _get_response_values( + estimator, + X_grid, + response_method=prediction_method, + pos_label=class_of_interest, + return_response_method_used=True, + ) + except ValueError as exc: + if "is not a valid label" in str(exc): + # re-raise a more informative error message since `pos_label` is unknown + # to our user when interacting with + # `DecisionBoundaryDisplay.from_estimator` + raise ValueError( + f"class_of_interest={class_of_interest} is not a valid label: It " + f"should be one of {estimator.classes_}" + ) from exc + raise exc # convert classes predictions into integers if response_method_used == "predict" and hasattr(estimator, "classes_"): @@ -371,12 +382,6 @@ def from_estimator( # For the multiclass case, `_get_response_values` returns the response # as-is. Thus, we have a column per class and we need to select the column # corresponding to the positive class. - if class_of_interest is None and len(estimator.classes_) > 2: - raise ValueError( - "With multiclass classification, you must specify the class of " - "interest, via the `class_of_interest` parameter, to plot the " - "decision boundary." - ) col_idx = np.flatnonzero(estimator.classes_ == class_of_interest)[0] response = response[:, col_idx] diff --git a/sklearn/inspection/_plot/tests/test_boundary_decision_display.py b/sklearn/inspection/_plot/tests/test_boundary_decision_display.py index 84136dbc02d21..f32fd32896d5b 100644 --- a/sklearn/inspection/_plot/tests/test_boundary_decision_display.py +++ b/sklearn/inspection/_plot/tests/test_boundary_decision_display.py @@ -382,7 +382,7 @@ def test_class_of_interest(pyplot, response_method): # check that we raise an error for unknown labels # this test should already be handled in `_get_response_values` but we can have this # test here as well - err_msg = "pos_label=2 is not a valid label: It should be one of" + err_msg = "class_of_interest=2 is not a valid label: It should be one of" with pytest.raises(ValueError, match=err_msg): DecisionBoundaryDisplay.from_estimator( estimator, From 687b098d5257ace335e961aecf81c0e4d16da96a Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 6 Sep 2023 15:04:05 +0200 Subject: [PATCH 08/18] revert to original size --- examples/classification/plot_classification_probability.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/classification/plot_classification_probability.py b/examples/classification/plot_classification_probability.py index 0499f49ac4855..d6f49295c8607 100644 --- a/examples/classification/plot_classification_probability.py +++ b/examples/classification/plot_classification_probability.py @@ -59,7 +59,9 @@ n_classifiers = len(classifiers) fig, axes = plt.subplots( - nrows=n_classifiers, ncols=len(iris.target_names), figsize=(10, 16) + nrows=n_classifiers, + ncols=len(iris.target_names), + figsize=(3 * 2, n_classifiers * 2), ) for classifier_idx, (name, classifier) in enumerate(classifiers.items()): y_pred = classifier.fit(X, y).predict(X) From bcf6b529b32ee89af4a9ba22b5f11e7fff956565 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Thu, 7 Sep 2023 10:47:21 +0200 Subject: [PATCH 09/18] TST add separate test for binary and multiclass --- sklearn/inspection/_plot/decision_boundary.py | 4 +- .../tests/test_boundary_decision_display.py | 59 ++++++++++++++++++- 2 files changed, 60 insertions(+), 3 deletions(-) diff --git a/sklearn/inspection/_plot/decision_boundary.py b/sklearn/inspection/_plot/decision_boundary.py index 6ac2816946669..9a74d275c5b5d 100644 --- a/sklearn/inspection/_plot/decision_boundary.py +++ b/sklearn/inspection/_plot/decision_boundary.py @@ -249,7 +249,9 @@ def from_estimator( class_of_interest : int, float, bool or str, default=None The class considered when plotting the decision. By default, - `estimators.classes_[1]` is considered as the positive class. + `estimators.classes_[1]` is considered as the positive class + for binary classifiers. For multiclass classifier, passing + and explicit value for `class_of_interest` is mandatory. .. versionadded:: 1.4 diff --git a/sklearn/inspection/_plot/tests/test_boundary_decision_display.py b/sklearn/inspection/_plot/tests/test_boundary_decision_display.py index f32fd32896d5b..e7cd4d1f7e492 100644 --- a/sklearn/inspection/_plot/tests/test_boundary_decision_display.py +++ b/sklearn/inspection/_plot/tests/test_boundary_decision_display.py @@ -342,9 +342,52 @@ def test_dataframe_support(pyplot): @pytest.mark.parametrize("response_method", ["predict_proba", "decision_function"]) -def test_class_of_interest(pyplot, response_method): +def test_class_of_interest_binary(pyplot, response_method): """Check the behaviour of passing `class_of_interest` for plotting the output of - `predict_proba` and `decision_function`. + `predict_proba` and `decision_function` in the binary case. + """ + iris = load_iris() + X = iris.data[:100, :2] + y = iris.target[:100] # the target are numerical labels + + estimator = LogisticRegression().fit(X, y) + # We will check that `class_of_interest=None` is equivalent to + # `class_of_interest=estimator.classes_[1]` + disp_default = DecisionBoundaryDisplay.from_estimator( + estimator, + X, + response_method=response_method, + class_of_interest=None, + ) + disp_class_1 = DecisionBoundaryDisplay.from_estimator( + estimator, + X, + response_method=response_method, + class_of_interest=estimator.classes_[1], + ) + + assert_allclose(disp_default.response, disp_class_1.response) + + # we can check that `_get_response_values` modifies the response when targeting + # the other class, i.e. 1 - p(y=1|x) for `predict_proba` and -decision_function + # for `decision_function`. + disp_class_0 = DecisionBoundaryDisplay.from_estimator( + estimator, + X, + response_method=response_method, + class_of_interest=estimator.classes_[0], + ) + + if response_method == "predict_proba": + assert_allclose(disp_default.response, 1 - disp_class_0.response) + else: + assert_allclose(disp_default.response, -disp_class_0.response) + + +@pytest.mark.parametrize("response_method", ["predict_proba", "decision_function"]) +def test_class_of_interest_multiclass(pyplot, response_method): + """Check the behaviour of passing `class_of_interest` for plotting the output of + `predict_proba` and `decision_function` in the multiclass case. """ iris = load_iris() X = iris.data[:, :2] @@ -390,3 +433,15 @@ def test_class_of_interest(pyplot, response_method): response_method=response_method, class_of_interest=class_of_interest_idx, ) + + # TODO: remove this test when we handle multiclass with class_of_interest=None + # by showing the max of the decision function or the max of the predicted + # probabilities. + err_msg = "Multiclass classifiers are only supported" + with pytest.raises(ValueError, match=err_msg): + DecisionBoundaryDisplay.from_estimator( + estimator, + X, + response_method=response_method, + class_of_interest=None, + ) From 010a292dbb3c8f890bfbe7eba6f95279b6c85021 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 8 Sep 2023 17:35:01 +0200 Subject: [PATCH 10/18] Apply suggestions from code review Co-authored-by: Olivier Grisel --- sklearn/inspection/_plot/decision_boundary.py | 4 ++-- .../_plot/tests/test_boundary_decision_display.py | 4 +++- sklearn/utils/tests/test_response.py | 9 ++++++--- 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/sklearn/inspection/_plot/decision_boundary.py b/sklearn/inspection/_plot/decision_boundary.py index 9a74d275c5b5d..8e6c8698942a5 100644 --- a/sklearn/inspection/_plot/decision_boundary.py +++ b/sklearn/inspection/_plot/decision_boundary.py @@ -25,7 +25,7 @@ def _check_boundary_response_method(estimator, response_method, class_of_interes If set to 'auto', the response method is tried in the following order: :term:`decision_function`, :term:`predict_proba`, :term:`predict`. - class_of_interest : int, float, bool or str + class_of_interest : int, float, bool, str or None The class considered when plotting the decision. If the label is specified, it then possible to plot the decision boundary in multiclass settings. @@ -369,7 +369,7 @@ def from_estimator( f"class_of_interest={class_of_interest} is not a valid label: It " f"should be one of {estimator.classes_}" ) from exc - raise exc + raise # convert classes predictions into integers if response_method_used == "predict" and hasattr(estimator, "classes_"): diff --git a/sklearn/inspection/_plot/tests/test_boundary_decision_display.py b/sklearn/inspection/_plot/tests/test_boundary_decision_display.py index e7cd4d1f7e492..d061af91862d5 100644 --- a/sklearn/inspection/_plot/tests/test_boundary_decision_display.py +++ b/sklearn/inspection/_plot/tests/test_boundary_decision_display.py @@ -348,7 +348,8 @@ def test_class_of_interest_binary(pyplot, response_method): """ iris = load_iris() X = iris.data[:100, :2] - y = iris.target[:100] # the target are numerical labels + y = iris.target[:100] + assert_array_equal(np.unique(y), [0, 1]) estimator = LogisticRegression().fit(X, y) # We will check that `class_of_interest=None` is equivalent to @@ -381,6 +382,7 @@ def test_class_of_interest_binary(pyplot, response_method): if response_method == "predict_proba": assert_allclose(disp_default.response, 1 - disp_class_0.response) else: + assert response_method == "decision_function" assert_allclose(disp_default.response, -disp_class_0.response) diff --git a/sklearn/utils/tests/test_response.py b/sklearn/utils/tests/test_response.py index 8eefa11fdea66..4e220c23a4e79 100644 --- a/sklearn/utils/tests/test_response.py +++ b/sklearn/utils/tests/test_response.py @@ -150,18 +150,21 @@ def test_get_response_values_binary_classifier_predict_proba( assert_allclose(results[0], classifier.predict_proba(X)[:, 1]) assert results[1] == 1 if return_response_method_used: + assert len(results) == 3 assert results[2] == "predict_proba" + else: + assert len(results) == 2 # when forcing `pos_label=classifier.classes_[0]` - results = _get_response_values( + y_pred, pos_label, *_ = _get_response_values( classifier, X, response_method=response_method, pos_label=classifier.classes_[0], return_response_method_used=return_response_method_used, ) - assert_allclose(results[0], classifier.predict_proba(X)[:, 0]) - assert results[1] == 0 + assert_allclose(y_pred, classifier.predict_proba(X)[:, 0]) + assert pos_label == 0 @pytest.mark.parametrize( From 38e76287e8549ab5cc299a9694b8e192306852da Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 8 Sep 2023 17:36:30 +0200 Subject: [PATCH 11/18] iter --- .../inspection/_plot/tests/test_boundary_decision_display.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sklearn/inspection/_plot/tests/test_boundary_decision_display.py b/sklearn/inspection/_plot/tests/test_boundary_decision_display.py index d061af91862d5..0d0c8021a08ed 100644 --- a/sklearn/inspection/_plot/tests/test_boundary_decision_display.py +++ b/sklearn/inspection/_plot/tests/test_boundary_decision_display.py @@ -2,7 +2,6 @@ import numpy as np import pytest -from numpy.testing import assert_allclose from sklearn.base import BaseEstimator, ClassifierMixin from sklearn.datasets import ( @@ -14,6 +13,10 @@ from sklearn.inspection._plot.decision_boundary import _check_boundary_response_method from sklearn.linear_model import LogisticRegression from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor +from sklearn.utils._testing import ( + assert_allclose, + assert_array_equal, +) # TODO: Remove when https://github.com/numpy/numpy/issues/14397 is resolved pytestmark = pytest.mark.filterwarnings( From 637757388e531189f7cea7e5d53bd8b50f15ef66 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 8 Sep 2023 18:15:50 +0200 Subject: [PATCH 12/18] iter --- sklearn/inspection/_plot/decision_boundary.py | 5 +++- .../tests/test_boundary_decision_display.py | 24 +++++++++++++++++++ 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/sklearn/inspection/_plot/decision_boundary.py b/sklearn/inspection/_plot/decision_boundary.py index 8e6c8698942a5..96a3f75656961 100644 --- a/sklearn/inspection/_plot/decision_boundary.py +++ b/sklearn/inspection/_plot/decision_boundary.py @@ -51,7 +51,10 @@ def _check_boundary_response_method(estimator, response_method, class_of_interes raise ValueError(msg) prediction_method = "predict" if response_method == "auto" else response_method elif response_method == "auto": - prediction_method = ["decision_function", "predict_proba", "predict"] + if is_regressor(estimator): + prediction_method = "predict" + else: + prediction_method = ["decision_function", "predict_proba", "predict"] else: prediction_method = response_method diff --git a/sklearn/inspection/_plot/tests/test_boundary_decision_display.py b/sklearn/inspection/_plot/tests/test_boundary_decision_display.py index 0d0c8021a08ed..c8534e05371ae 100644 --- a/sklearn/inspection/_plot/tests/test_boundary_decision_display.py +++ b/sklearn/inspection/_plot/tests/test_boundary_decision_display.py @@ -5,6 +5,7 @@ from sklearn.base import BaseEstimator, ClassifierMixin from sklearn.datasets import ( + load_diabetes, load_iris, make_classification, make_multilabel_classification, @@ -266,6 +267,29 @@ def test_multioutput_regressor_error(pyplot): DecisionBoundaryDisplay.from_estimator(tree, X, response_method="predict") +@pytest.mark.parametrize("response_method", ["auto", "predict"]) +def test_regressor(pyplot, response_method): + """Check that we can display the decision boundary for a regressor.""" + X, y = load_diabetes(return_X_y=True) + X = X[:, :2] + tree = DecisionTreeRegressor().fit(X, y) + DecisionBoundaryDisplay.from_estimator(tree, X, response_method=response_method) + + +@pytest.mark.parametrize( + "response_method", + ["predict_proba", "decision_function", ["predict_proba", "predict"]], +) +def test_regressor_unsupported_response(pyplot, response_method): + """Check that we can display the decision boundary for a regressor.""" + X, y = load_diabetes(return_X_y=True) + X = X[:, :2] + tree = DecisionTreeRegressor().fit(X, y) + err_msg = "should either be a classifier to be used with response_method" + with pytest.raises(ValueError, match=err_msg): + DecisionBoundaryDisplay.from_estimator(tree, X, response_method=response_method) + + @pytest.mark.filterwarnings( # We expect to raise the following warning because the classifier is fit on a # NumPy array From 2ce357faaaf8b0b83179dd615b2ea32cfb81b404 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 8 Sep 2023 18:29:25 +0200 Subject: [PATCH 13/18] iter --- .../tests/test_boundary_decision_display.py | 52 +++++++++++++++---- 1 file changed, 42 insertions(+), 10 deletions(-) diff --git a/sklearn/inspection/_plot/tests/test_boundary_decision_display.py b/sklearn/inspection/_plot/tests/test_boundary_decision_display.py index c8534e05371ae..530f2a61dc5d7 100644 --- a/sklearn/inspection/_plot/tests/test_boundary_decision_display.py +++ b/sklearn/inspection/_plot/tests/test_boundary_decision_display.py @@ -152,7 +152,9 @@ def test_display_plot_input_error(pyplot, fitted_clf): "response_method", ["auto", "predict", "predict_proba", "decision_function"] ) @pytest.mark.parametrize("plot_method", ["contourf", "contour"]) -def test_decision_boundary_display(pyplot, fitted_clf, response_method, plot_method): +def test_decision_boundary_display_classifier( + pyplot, fitted_clf, response_method, plot_method +): """Check that decision boundary is correct.""" fig, ax = pyplot.subplots() eps = 2.0 @@ -187,6 +189,45 @@ def test_decision_boundary_display(pyplot, fitted_clf, response_method, plot_met assert disp.figure_ == fig2 +@pytest.mark.parametrize("response_method", ["auto", "predict"]) +@pytest.mark.parametrize("plot_method", ["contourf", "contour"]) +def test_decision_boundary_display_regressor(pyplot, response_method, plot_method): + """Check that we can display the decision boundary for a regressor.""" + X, y = load_diabetes(return_X_y=True) + X = X[:, :2] + tree = DecisionTreeRegressor().fit(X, y) + fig, ax = pyplot.subplots() + eps = 2.0 + disp = DecisionBoundaryDisplay.from_estimator( + tree, + X, + response_method=response_method, + ax=ax, + eps=eps, + plot_method=plot_method, + ) + assert isinstance(disp.surface_, pyplot.matplotlib.contour.QuadContourSet) + assert disp.ax_ == ax + assert disp.figure_ == fig + + x0, x1 = X[:, 0], X[:, 1] + + x0_min, x0_max = x0.min() - eps, x0.max() + eps + x1_min, x1_max = x1.min() - eps, x1.max() + eps + + assert disp.xx0.min() == pytest.approx(x0_min) + assert disp.xx0.max() == pytest.approx(x0_max) + assert disp.xx1.min() == pytest.approx(x1_min) + assert disp.xx1.max() == pytest.approx(x1_max) + + fig2, ax2 = pyplot.subplots() + # change plotting method for second plot + disp.plot(plot_method="pcolormesh", ax=ax2, shading="auto") + assert isinstance(disp.surface_, pyplot.matplotlib.collections.QuadMesh) + assert disp.ax_ == ax2 + assert disp.figure_ == fig2 + + @pytest.mark.parametrize( "response_method, msg", [ @@ -267,15 +308,6 @@ def test_multioutput_regressor_error(pyplot): DecisionBoundaryDisplay.from_estimator(tree, X, response_method="predict") -@pytest.mark.parametrize("response_method", ["auto", "predict"]) -def test_regressor(pyplot, response_method): - """Check that we can display the decision boundary for a regressor.""" - X, y = load_diabetes(return_X_y=True) - X = X[:, :2] - tree = DecisionTreeRegressor().fit(X, y) - DecisionBoundaryDisplay.from_estimator(tree, X, response_method=response_method) - - @pytest.mark.parametrize( "response_method", ["predict_proba", "decision_function", ["predict_proba", "predict"]], From 07463259a6881cb0ac34786c5c55f0c53fcb8e57 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 8 Sep 2023 20:10:14 +0200 Subject: [PATCH 14/18] TST properly tests _check_boundary_decision_response_method --- .../tests/test_boundary_decision_display.py | 85 +++++++++++++++---- 1 file changed, 68 insertions(+), 17 deletions(-) diff --git a/sklearn/inspection/_plot/tests/test_boundary_decision_display.py b/sklearn/inspection/_plot/tests/test_boundary_decision_display.py index 530f2a61dc5d7..2e2a4d10246e2 100644 --- a/sklearn/inspection/_plot/tests/test_boundary_decision_display.py +++ b/sklearn/inspection/_plot/tests/test_boundary_decision_display.py @@ -13,6 +13,7 @@ from sklearn.inspection import DecisionBoundaryDisplay from sklearn.inspection._plot.decision_boundary import _check_boundary_response_method from sklearn.linear_model import LogisticRegression +from sklearn.preprocessing import scale from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor from sklearn.utils._testing import ( assert_allclose, @@ -35,6 +36,12 @@ ) +def load_iris_scaled(): + X, y = load_iris(return_X_y=True) + X = scale(X)[:, :2] + return X, y + + @pytest.fixture(scope="module") def fitted_clf(): return LogisticRegression().fit(X, y) @@ -50,29 +57,73 @@ def test_input_data_dimension(pyplot): DecisionBoundaryDisplay.from_estimator(estimator=clf, X=X) -def test_check_boundary_response_method_auto(): - """Check _check_boundary_response_method behavior with 'auto'.""" +def test_check_boundary_response_method_error(): + """Check that we raise an error for the cases not supported by + `_check_boundary_response_method`. + """ - expected_methods = ["decision_function", "predict_proba", "predict"] + class MultiLabelClassifier: + classes_ = [np.array([0, 1]), np.array([0, 1])] - class A: - def decision_function(self): - pass + err_msg = "Multi-label and multi-output multi-class classifiers are not supported" + with pytest.raises(ValueError, match=err_msg): + _check_boundary_response_method(MultiLabelClassifier(), "predict", None) - class B: - def predict_proba(self): - pass + class MulticlassClassifier: + classes_ = [0, 1, 2] - class C(A, B): - pass + err_msg = "Multiclass classifiers are only supported when response_method is" + for response_method in ("predict_proba", "decision_function"): + with pytest.raises(ValueError, match=err_msg): + _check_boundary_response_method( + MulticlassClassifier(), response_method, None + ) - class D: - def predict(self): - pass - for Klass in [A, B, C, D]: - methods = _check_boundary_response_method(Klass(), "auto", None) - assert methods == expected_methods +@pytest.mark.parametrize( + "estimator, response_method, class_of_interest, expected_prediction_method", + [ + (DecisionTreeRegressor(), "predict", None, "predict"), + (DecisionTreeRegressor(), "auto", None, "predict"), + (LogisticRegression().fit(*load_iris_scaled()), "predict", None, "predict"), + (LogisticRegression().fit(*load_iris_scaled()), "auto", None, "predict"), + ( + LogisticRegression().fit(*load_iris_scaled()), + "predict_proba", + 0, + "predict_proba", + ), + ( + LogisticRegression().fit(*load_iris_scaled()), + "decision_function", + 0, + "decision_function", + ), + ( + LogisticRegression().fit(X, y), + "auto", + None, + ["decision_function", "predict_proba", "predict"], + ), + (LogisticRegression().fit(X, y), "predict", None, "predict"), + ( + LogisticRegression().fit(X, y), + ["predict_proba", "decision_function"], + None, + ["predict_proba", "decision_function"], + ), + ], +) +def test_check_boundary_response_method( + estimator, response_method, class_of_interest, expected_prediction_method +): + """Check the behaviour of `_check_boundary_response_method` for the supported + cases. + """ + prediction_method = _check_boundary_response_method( + estimator, response_method, class_of_interest + ) + assert prediction_method == expected_prediction_method @pytest.mark.parametrize("response_method", ["predict_proba", "decision_function"]) From 52d877feb4ea799ee86af11678300ce80c34fc05 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 25 Sep 2023 11:49:13 +0200 Subject: [PATCH 15/18] address ogrisel comment --- sklearn/inspection/_plot/decision_boundary.py | 3 +++ .../_plot/tests/test_boundary_decision_display.py | 12 ++++++------ 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/sklearn/inspection/_plot/decision_boundary.py b/sklearn/inspection/_plot/decision_boundary.py index 96a3f75656961..f669b11614ef5 100644 --- a/sklearn/inspection/_plot/decision_boundary.py +++ b/sklearn/inspection/_plot/decision_boundary.py @@ -373,6 +373,9 @@ def from_estimator( f"should be one of {estimator.classes_}" ) from exc raise + except AttributeError as exc: + # re-raise the AttributeError as a ValueError for backward compatibility + raise ValueError(str(exc)) from exc # convert classes predictions into integers if response_method_used == "predict" and hasattr(estimator, "classes_"): diff --git a/sklearn/inspection/_plot/tests/test_boundary_decision_display.py b/sklearn/inspection/_plot/tests/test_boundary_decision_display.py index 2e2a4d10246e2..2308be3ad4848 100644 --- a/sklearn/inspection/_plot/tests/test_boundary_decision_display.py +++ b/sklearn/inspection/_plot/tests/test_boundary_decision_display.py @@ -36,7 +36,7 @@ ) -def load_iris_scaled(): +def load_iris_2d_scaled(): X, y = load_iris(return_X_y=True) X = scale(X)[:, :2] return X, y @@ -85,16 +85,16 @@ class MulticlassClassifier: [ (DecisionTreeRegressor(), "predict", None, "predict"), (DecisionTreeRegressor(), "auto", None, "predict"), - (LogisticRegression().fit(*load_iris_scaled()), "predict", None, "predict"), - (LogisticRegression().fit(*load_iris_scaled()), "auto", None, "predict"), + (LogisticRegression().fit(*load_iris_2d_scaled()), "predict", None, "predict"), + (LogisticRegression().fit(*load_iris_2d_scaled()), "auto", None, "predict"), ( - LogisticRegression().fit(*load_iris_scaled()), + LogisticRegression().fit(*load_iris_2d_scaled()), "predict_proba", 0, "predict_proba", ), ( - LogisticRegression().fit(*load_iris_scaled()), + LogisticRegression().fit(*load_iris_2d_scaled()), "decision_function", 0, "decision_function", @@ -314,7 +314,7 @@ def fit(self, X, y): clf = MyClassifier().fit(X, y) - with pytest.raises(AttributeError, match=msg): + with pytest.raises(ValueError, match=msg): DecisionBoundaryDisplay.from_estimator(clf, X, response_method=response_method) From d986e6ab8a8e1e99c200477ddffd7e5a93b53e59 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 10 Oct 2023 10:20:08 +0200 Subject: [PATCH 16/18] Apply suggestions from code review Co-authored-by: Arturo Amor <86408019+ArturoAmorQ@users.noreply.github.com> --- .../plot_classification_probability.py | 2 +- sklearn/inspection/_plot/decision_boundary.py | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/examples/classification/plot_classification_probability.py b/examples/classification/plot_classification_probability.py index d6f49295c8607..4e8f0763d3b47 100644 --- a/examples/classification/plot_classification_probability.py +++ b/examples/classification/plot_classification_probability.py @@ -79,7 +79,7 @@ vmax=1, ) axes[classifier_idx, label].set_title(f"Class {label}") - # plot the data that are predict to belong to the class + # plot data predicted to belong to given class mask_y_pred = y_pred == label axes[classifier_idx, label].scatter( X[mask_y_pred, 0], X[mask_y_pred, 1], marker="o", c="w", edgecolor="k" diff --git a/sklearn/inspection/_plot/decision_boundary.py b/sklearn/inspection/_plot/decision_boundary.py index f669b11614ef5..6e00579bed08b 100644 --- a/sklearn/inspection/_plot/decision_boundary.py +++ b/sklearn/inspection/_plot/decision_boundary.py @@ -27,14 +27,14 @@ def _check_boundary_response_method(estimator, response_method, class_of_interes class_of_interest : int, float, bool, str or None The class considered when plotting the decision. If the label is specified, it - then possible to plot the decision boundary in multiclass settings. + is then possible to plot the decision boundary in multiclass settings. .. versionadded:: 1.4 Returns ------- prediction_method : list of str or str - The name of the response methods to use or a list of such names. + The name or list of names of the response methods to use. """ has_classes = hasattr(estimator, "classes_") if has_classes and _is_arraylike_not_scalar(estimator.classes_[0]): @@ -251,10 +251,10 @@ def from_estimator( `response_method="auto"`. class_of_interest : int, float, bool or str, default=None - The class considered when plotting the decision. By default, - `estimators.classes_[1]` is considered as the positive class - for binary classifiers. For multiclass classifier, passing - and explicit value for `class_of_interest` is mandatory. + The class considered when plotting the decision. If None, + `estimator.classes_[1]` is considered as the positive class + for binary classifiers. For multiclass classifiers, passing + an explicit value for `class_of_interest` is mandatory. .. versionadded:: 1.4 From 7084b23c5443aee6ef1159693a9034788097f77d Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 10 Oct 2023 10:20:22 +0200 Subject: [PATCH 17/18] add api change --- doc/whats_new/v1.4.rst | 4 ++++ sklearn/inspection/_plot/decision_boundary.py | 3 --- .../inspection/_plot/tests/test_boundary_decision_display.py | 2 +- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/doc/whats_new/v1.4.rst b/doc/whats_new/v1.4.rst index 6f043ca411c19..c927204cd7fc1 100644 --- a/doc/whats_new/v1.4.rst +++ b/doc/whats_new/v1.4.rst @@ -208,6 +208,10 @@ Changelog both binary and multiclass classifiers. :pr:`27291` by :user:`Guillaume Lemaitre `. +- |API| :class:`inspection.DecisionBoundaryDisplay` raise an `AttributeError` instead + of a `ValueError` when an estimator does not implement the requested response method. + :pr:`27291` by :user:`Guillaume Lemaitre `. + :mod:`sklearn.linear_model` ........................... diff --git a/sklearn/inspection/_plot/decision_boundary.py b/sklearn/inspection/_plot/decision_boundary.py index 6e00579bed08b..b2e801c946b5f 100644 --- a/sklearn/inspection/_plot/decision_boundary.py +++ b/sklearn/inspection/_plot/decision_boundary.py @@ -373,9 +373,6 @@ def from_estimator( f"should be one of {estimator.classes_}" ) from exc raise - except AttributeError as exc: - # re-raise the AttributeError as a ValueError for backward compatibility - raise ValueError(str(exc)) from exc # convert classes predictions into integers if response_method_used == "predict" and hasattr(estimator, "classes_"): diff --git a/sklearn/inspection/_plot/tests/test_boundary_decision_display.py b/sklearn/inspection/_plot/tests/test_boundary_decision_display.py index 2308be3ad4848..fa1a7733e5979 100644 --- a/sklearn/inspection/_plot/tests/test_boundary_decision_display.py +++ b/sklearn/inspection/_plot/tests/test_boundary_decision_display.py @@ -314,7 +314,7 @@ def fit(self, X, y): clf = MyClassifier().fit(X, y) - with pytest.raises(ValueError, match=msg): + with pytest.raises(AttributeError, match=msg): DecisionBoundaryDisplay.from_estimator(clf, X, response_method=response_method) From c99001472a20a29d9b8e395a876e862213b0acb5 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 10 Oct 2023 10:24:05 +0200 Subject: [PATCH 18/18] change error message --- sklearn/inspection/_plot/decision_boundary.py | 6 +++--- .../_plot/tests/test_boundary_decision_display.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/sklearn/inspection/_plot/decision_boundary.py b/sklearn/inspection/_plot/decision_boundary.py index 60cedce21d437..a42e744261e0b 100644 --- a/sklearn/inspection/_plot/decision_boundary.py +++ b/sklearn/inspection/_plot/decision_boundary.py @@ -44,9 +44,9 @@ def _check_boundary_response_method(estimator, response_method, class_of_interes if has_classes and len(estimator.classes_) > 2: if response_method not in {"auto", "predict"} and class_of_interest is None: msg = ( - "Multiclass classifiers are only supported when response_method is" - " 'predict' or 'auto', or you must provide `class_of_interest` to " - " select a specific class to plot the decision boundary." + "Multiclass classifiers are only supported when `response_method` is " + "'predict' or 'auto'. Else you must provide `class_of_interest` to " + "plot the decision boundary of a specific class." ) raise ValueError(msg) prediction_method = "predict" if response_method == "auto" else response_method diff --git a/sklearn/inspection/_plot/tests/test_boundary_decision_display.py b/sklearn/inspection/_plot/tests/test_boundary_decision_display.py index fa1a7733e5979..e93534b3b9e13 100644 --- a/sklearn/inspection/_plot/tests/test_boundary_decision_display.py +++ b/sklearn/inspection/_plot/tests/test_boundary_decision_display.py @@ -72,7 +72,7 @@ class MultiLabelClassifier: class MulticlassClassifier: classes_ = [0, 1, 2] - err_msg = "Multiclass classifiers are only supported when response_method is" + err_msg = "Multiclass classifiers are only supported when `response_method` is" for response_method in ("predict_proba", "decision_function"): with pytest.raises(ValueError, match=err_msg): _check_boundary_response_method( @@ -134,8 +134,8 @@ def test_multiclass_error(pyplot, response_method): lr = LogisticRegression().fit(X, y) msg = ( - "Multiclass classifiers are only supported when response_method is 'predict' or" - " 'auto'" + "Multiclass classifiers are only supported when `response_method` is 'predict'" + " or 'auto'" ) with pytest.raises(ValueError, match=msg): DecisionBoundaryDisplay.from_estimator(lr, X, response_method=response_method)