diff --git a/doc/modules/classes.rst b/doc/modules/classes.rst index fecb9f0b95b60..0543407c75527 100644 --- a/doc/modules/classes.rst +++ b/doc/modules/classes.rst @@ -657,6 +657,7 @@ Plotting :toctree: generated/ :template: class.rst + inspection.DecisionBoundaryDisplay inspection.PartialDependenceDisplay .. autosummary:: diff --git a/doc/visualizations.rst b/doc/visualizations.rst index 31db6f3db03ce..0f0ec73549355 100644 --- a/doc/visualizations.rst +++ b/doc/visualizations.rst @@ -96,6 +96,7 @@ Display Objects calibration.CalibrationDisplay inspection.PartialDependenceDisplay + inspection.DecisionBoundaryDisplay metrics.ConfusionMatrixDisplay metrics.DetCurveDisplay metrics.PrecisionRecallDisplay diff --git a/doc/whats_new/v1.1.rst b/doc/whats_new/v1.1.rst index fab4ec850f63d..b171872000381 100644 --- a/doc/whats_new/v1.1.rst +++ b/doc/whats_new/v1.1.rst @@ -534,6 +534,10 @@ Changelog :mod:`sklearn.inspection` ......................... +- |Feature| Add a display to plot the boundary decision of a classifier by + using the method :func:`inspection.DecisionBoundaryDisplay.from_estimator`. + :pr:`16061` by `Thomas Fan`_. + - |Enhancement| In :meth:`~sklearn.inspection.PartialDependenceDisplay.from_estimator` and :meth:`~sklearn.inspection.PartialDependenceDisplay.from_predictions`, allow diff --git a/examples/classification/plot_classifier_comparison.py b/examples/classification/plot_classifier_comparison.py index 1c7112e5fa3d5..e4c52d9e2564a 100644 --- a/examples/classification/plot_classifier_comparison.py +++ b/examples/classification/plot_classifier_comparison.py @@ -40,8 +40,7 @@ from sklearn.ensemble import RandomForestClassifier, AdaBoostClassifier from sklearn.naive_bayes import GaussianNB from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis - -h = 0.02 # step size in the mesh +from sklearn.inspection import DecisionBoundaryDisplay names = [ "Nearest Neighbors", @@ -95,7 +94,6 @@ x_min, x_max = X[:, 0].min() - 0.5, X[:, 0].max() + 0.5 y_min, y_max = X[:, 1].min() - 0.5, X[:, 1].max() + 0.5 - xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h)) # just plot the dataset first cm = plt.cm.RdBu @@ -109,8 +107,8 @@ ax.scatter( X_test[:, 0], X_test[:, 1], c=y_test, cmap=cm_bright, alpha=0.6, edgecolors="k" ) - ax.set_xlim(xx.min(), xx.max()) - ax.set_ylim(yy.min(), yy.max()) + ax.set_xlim(x_min, x_max) + ax.set_ylim(y_min, y_max) ax.set_xticks(()) ax.set_yticks(()) i += 1 @@ -120,17 +118,9 @@ ax = plt.subplot(len(datasets), len(classifiers) + 1, i) clf.fit(X_train, y_train) score = clf.score(X_test, y_test) - - # Plot the decision boundary. For that, we will assign a color to each - # point in the mesh [x_min, x_max]x[y_min, y_max]. - if hasattr(clf, "decision_function"): - Z = clf.decision_function(np.c_[xx.ravel(), yy.ravel()]) - else: - Z = clf.predict_proba(np.c_[xx.ravel(), yy.ravel()])[:, 1] - - # Put the result into a color plot - Z = Z.reshape(xx.shape) - ax.contourf(xx, yy, Z, cmap=cm, alpha=0.8) + DecisionBoundaryDisplay.from_estimator( + clf, X, cmap=cm, alpha=0.8, ax=ax, eps=0.5 + ) # Plot the training points ax.scatter( @@ -146,15 +136,15 @@ alpha=0.6, ) - ax.set_xlim(xx.min(), xx.max()) - ax.set_ylim(yy.min(), yy.max()) + ax.set_xlim(x_min, x_max) + ax.set_ylim(y_min, y_max) ax.set_xticks(()) ax.set_yticks(()) if ds_cnt == 0: ax.set_title(name) ax.text( - xx.max() - 0.3, - yy.min() + 0.3, + x_max - 0.3, + y_min + 0.3, ("%.2f" % score).lstrip("0"), size=15, horizontalalignment="right", diff --git a/examples/cluster/plot_inductive_clustering.py b/examples/cluster/plot_inductive_clustering.py index fd8adb8df189e..e395571a1caad 100644 --- a/examples/cluster/plot_inductive_clustering.py +++ b/examples/cluster/plot_inductive_clustering.py @@ -23,12 +23,12 @@ # Authors: Chirag Nagpal # Christos Aridas -import numpy as np import matplotlib.pyplot as plt from sklearn.base import BaseEstimator, clone from sklearn.cluster import AgglomerativeClustering from sklearn.datasets import make_blobs from sklearn.ensemble import RandomForestClassifier +from sklearn.inspection import DecisionBoundaryDisplay from sklearn.utils.metaestimators import available_if from sklearn.utils.validation import check_is_fitted @@ -116,19 +116,14 @@ def plot_scatter(X, color, alpha=0.5): probable_clusters = inductive_learner.predict(X_new) -plt.subplot(133) +ax = plt.subplot(133) plot_scatter(X, cluster_labels) plot_scatter(X_new, probable_clusters) # Plotting decision regions -x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1 -y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1 -xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.1), np.arange(y_min, y_max, 0.1)) - -Z = inductive_learner.predict(np.c_[xx.ravel(), yy.ravel()]) -Z = Z.reshape(xx.shape) - -plt.contourf(xx, yy, Z, alpha=0.4) +DecisionBoundaryDisplay.from_estimator( + inductive_learner, X, response_method="predict", alpha=0.4, ax=ax +) plt.title("Classify unknown instances") plt.show() diff --git a/examples/ensemble/plot_adaboost_twoclass.py b/examples/ensemble/plot_adaboost_twoclass.py index 38e3e95ae96ef..19679c6285d3b 100644 --- a/examples/ensemble/plot_adaboost_twoclass.py +++ b/examples/ensemble/plot_adaboost_twoclass.py @@ -27,6 +27,7 @@ from sklearn.ensemble import AdaBoostClassifier from sklearn.tree import DecisionTreeClassifier from sklearn.datasets import make_gaussian_quantiles +from sklearn.inspection import DecisionBoundaryDisplay # Construct dataset @@ -53,16 +54,18 @@ plt.figure(figsize=(10, 5)) # Plot the decision boundaries -plt.subplot(121) -x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1 -y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1 -xx, yy = np.meshgrid( - np.arange(x_min, x_max, plot_step), np.arange(y_min, y_max, plot_step) +ax = plt.subplot(121) +disp = DecisionBoundaryDisplay.from_estimator( + bdt, + X, + cmap=plt.cm.Paired, + response_method="predict", + ax=ax, + xlabel="x", + ylabel="y", ) - -Z = bdt.predict(np.c_[xx.ravel(), yy.ravel()]) -Z = Z.reshape(xx.shape) -cs = plt.contourf(xx, yy, Z, cmap=plt.cm.Paired) +x_min, x_max = disp.xx0.min(), disp.xx0.max() +y_min, y_max = disp.xx1.min(), disp.xx1.max() plt.axis("tight") # Plot the training points @@ -80,8 +83,7 @@ plt.xlim(x_min, x_max) plt.ylim(y_min, y_max) plt.legend(loc="upper right") -plt.xlabel("x") -plt.ylabel("y") + plt.title("Decision Boundary") # Plot the two-class decision scores diff --git a/examples/ensemble/plot_voting_decision_regions.py b/examples/ensemble/plot_voting_decision_regions.py index 58bcd2dfc7404..e6dc68eeadf98 100644 --- a/examples/ensemble/plot_voting_decision_regions.py +++ b/examples/ensemble/plot_voting_decision_regions.py @@ -25,7 +25,6 @@ from itertools import product -import numpy as np import matplotlib.pyplot as plt from sklearn import datasets @@ -33,6 +32,7 @@ from sklearn.neighbors import KNeighborsClassifier from sklearn.svm import SVC from sklearn.ensemble import VotingClassifier +from sklearn.inspection import DecisionBoundaryDisplay # Loading some example data iris = datasets.load_iris() @@ -55,22 +55,15 @@ eclf.fit(X, y) # Plotting decision regions -x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1 -y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1 -xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.1), np.arange(y_min, y_max, 0.1)) - f, axarr = plt.subplots(2, 2, sharex="col", sharey="row", figsize=(10, 8)) - for idx, clf, tt in zip( product([0, 1], [0, 1]), [clf1, clf2, clf3, eclf], ["Decision Tree (depth=4)", "KNN (k=7)", "Kernel SVM", "Soft Voting"], ): - - Z = clf.predict(np.c_[xx.ravel(), yy.ravel()]) - Z = Z.reshape(xx.shape) - - axarr[idx[0], idx[1]].contourf(xx, yy, Z, alpha=0.4) + DecisionBoundaryDisplay.from_estimator( + clf, X, alpha=0.4, ax=axarr[idx[0], idx[1]], response_method="predict" + ) axarr[idx[0], idx[1]].scatter(X[:, 0], X[:, 1], c=y, s=20, edgecolor="k") axarr[idx[0], idx[1]].set_title(tt) diff --git a/examples/linear_model/plot_iris_logistic.py b/examples/linear_model/plot_iris_logistic.py index 88a1313662084..10a1f0f15ad79 100644 --- a/examples/linear_model/plot_iris_logistic.py +++ b/examples/linear_model/plot_iris_logistic.py @@ -15,10 +15,10 @@ # Modified for documentation by Jaques Grobler # License: BSD 3 clause -import numpy as np import matplotlib.pyplot as plt from sklearn.linear_model import LogisticRegression from sklearn import datasets +from sklearn.inspection import DecisionBoundaryDisplay # import some data to play with iris = datasets.load_iris() @@ -29,26 +29,24 @@ logreg = LogisticRegression(C=1e5) logreg.fit(X, Y) -# Plot the decision boundary. For that, we will assign a color to each -# point in the mesh [x_min, x_max]x[y_min, y_max]. -x_min, x_max = X[:, 0].min() - 0.5, X[:, 0].max() + 0.5 -y_min, y_max = X[:, 1].min() - 0.5, X[:, 1].max() + 0.5 -h = 0.02 # step size in the mesh -xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h)) -Z = logreg.predict(np.c_[xx.ravel(), yy.ravel()]) - -# Put the result into a color plot -Z = Z.reshape(xx.shape) -plt.figure(1, figsize=(4, 3)) -plt.pcolormesh(xx, yy, Z, cmap=plt.cm.Paired) +_, ax = plt.subplots(figsize=(4, 3)) +DecisionBoundaryDisplay.from_estimator( + logreg, + X, + cmap=plt.cm.Paired, + ax=ax, + response_method="predict", + plot_method="pcolormesh", + shading="auto", + xlabel="Sepal length", + ylabel="Sepal width", + eps=0.5, +) # Plot also the training points plt.scatter(X[:, 0], X[:, 1], c=Y, edgecolors="k", cmap=plt.cm.Paired) -plt.xlabel("Sepal length") -plt.ylabel("Sepal width") -plt.xlim(xx.min(), xx.max()) -plt.ylim(yy.min(), yy.max()) + plt.xticks(()) plt.yticks(()) diff --git a/examples/linear_model/plot_logistic_multinomial.py b/examples/linear_model/plot_logistic_multinomial.py index 143e946b76d58..814eeadaa68c4 100644 --- a/examples/linear_model/plot_logistic_multinomial.py +++ b/examples/linear_model/plot_logistic_multinomial.py @@ -16,6 +16,7 @@ import matplotlib.pyplot as plt from sklearn.datasets import make_blobs from sklearn.linear_model import LogisticRegression +from sklearn.inspection import DecisionBoundaryDisplay # make 3-class dataset for classification centers = [[-5, 0], [0, 1.5], [5, -1]] @@ -31,19 +32,10 @@ # print the training scores print("training score : %.3f (%s)" % (clf.score(X, y), multi_class)) - # create a mesh to plot in - h = 0.02 # step size in the mesh - x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1 - y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1 - xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h)) - - # Plot the decision boundary. For that, we will assign a color to each - # point in the mesh [x_min, x_max]x[y_min, y_max]. - Z = clf.predict(np.c_[xx.ravel(), yy.ravel()]) - # Put the result into a color plot - Z = Z.reshape(xx.shape) - plt.figure() - plt.contourf(xx, yy, Z, cmap=plt.cm.Paired) + _, ax = plt.subplots() + DecisionBoundaryDisplay.from_estimator( + clf, X, response_method="predict", cmap=plt.cm.Paired, ax=ax + ) plt.title("Decision surface of LogisticRegression (%s)" % multi_class) plt.axis("tight") diff --git a/examples/linear_model/plot_sgd_iris.py b/examples/linear_model/plot_sgd_iris.py index 0113c259d7afa..64dca07396d54 100644 --- a/examples/linear_model/plot_sgd_iris.py +++ b/examples/linear_model/plot_sgd_iris.py @@ -13,6 +13,7 @@ import matplotlib.pyplot as plt from sklearn import datasets from sklearn.linear_model import SGDClassifier +from sklearn.inspection import DecisionBoundaryDisplay # import some data to play with iris = datasets.load_iris() @@ -35,21 +36,17 @@ std = X.std(axis=0) X = (X - mean) / std -h = 0.02 # step size in the mesh - clf = SGDClassifier(alpha=0.001, max_iter=100).fit(X, y) - -# create a mesh to plot in -x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1 -y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1 -xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h)) - -# Plot the decision boundary. For that, we will assign a color to each -# point in the mesh [x_min, x_max]x[y_min, y_max]. -Z = clf.predict(np.c_[xx.ravel(), yy.ravel()]) -# Put the result into a color plot -Z = Z.reshape(xx.shape) -cs = plt.contourf(xx, yy, Z, cmap=plt.cm.Paired) +ax = plt.gca() +DecisionBoundaryDisplay.from_estimator( + clf, + X, + cmap=plt.cm.Paired, + ax=ax, + response_method="predict", + xlabel=iris.feature_names[0], + ylabel=iris.feature_names[1], +) plt.axis("tight") # Plot also the training points diff --git a/examples/neighbors/plot_classification.py b/examples/neighbors/plot_classification.py index 59d6285d9dcd0..cc4f0864ba926 100644 --- a/examples/neighbors/plot_classification.py +++ b/examples/neighbors/plot_classification.py @@ -8,11 +8,11 @@ """ -import numpy as np import matplotlib.pyplot as plt import seaborn as sns from matplotlib.colors import ListedColormap from sklearn import neighbors, datasets +from sklearn.inspection import DecisionBoundaryDisplay n_neighbors = 15 @@ -24,8 +24,6 @@ X = iris.data[:, :2] y = iris.target -h = 0.02 # step size in the mesh - # Create color maps cmap_light = ListedColormap(["orange", "cyan", "cornflowerblue"]) cmap_bold = ["darkorange", "c", "darkblue"] @@ -35,17 +33,18 @@ clf = neighbors.KNeighborsClassifier(n_neighbors, weights=weights) clf.fit(X, y) - # Plot the decision boundary. For that, we will assign a color to each - # point in the mesh [x_min, x_max]x[y_min, y_max]. - x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1 - y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1 - xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h)) - Z = clf.predict(np.c_[xx.ravel(), yy.ravel()]) - - # Put the result into a color plot - Z = Z.reshape(xx.shape) - plt.figure(figsize=(8, 6)) - plt.contourf(xx, yy, Z, cmap=cmap_light) + _, ax = plt.subplots() + DecisionBoundaryDisplay.from_estimator( + clf, + X, + cmap=cmap_light, + ax=ax, + response_method="predict", + plot_method="pcolormesh", + xlabel=iris.feature_names[0], + ylabel=iris.feature_names[1], + shading="auto", + ) # Plot also the training points sns.scatterplot( @@ -56,12 +55,8 @@ alpha=1.0, edgecolor="black", ) - plt.xlim(xx.min(), xx.max()) - plt.ylim(yy.min(), yy.max()) plt.title( "3-Class classification (k = %i, weights = '%s')" % (n_neighbors, weights) ) - plt.xlabel(iris.feature_names[0]) - plt.ylabel(iris.feature_names[1]) plt.show() diff --git a/examples/neighbors/plot_nca_classification.py b/examples/neighbors/plot_nca_classification.py index 45f25c37f235e..17e6a667fcb3b 100644 --- a/examples/neighbors/plot_nca_classification.py +++ b/examples/neighbors/plot_nca_classification.py @@ -17,7 +17,6 @@ # License: BSD 3 clause -import numpy as np import matplotlib.pyplot as plt from matplotlib.colors import ListedColormap from sklearn import datasets @@ -25,6 +24,7 @@ from sklearn.preprocessing import StandardScaler from sklearn.neighbors import KNeighborsClassifier, NeighborhoodComponentsAnalysis from sklearn.pipeline import Pipeline +from sklearn.inspection import DecisionBoundaryDisplay n_neighbors = 1 @@ -64,28 +64,25 @@ ), ] -x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1 -y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1 -xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h)) - for name, clf in zip(names, classifiers): clf.fit(X_train, y_train) score = clf.score(X_test, y_test) - # Plot the decision boundary. For that, we will assign a color to each - # point in the mesh [x_min, x_max]x[y_min, y_max]. - Z = clf.predict(np.c_[xx.ravel(), yy.ravel()]) - - # Put the result into a color plot - Z = Z.reshape(xx.shape) - plt.figure() - plt.pcolormesh(xx, yy, Z, cmap=cmap_light, alpha=0.8) + _, ax = plt.subplots() + DecisionBoundaryDisplay.from_estimator( + clf, + X, + cmap=cmap_light, + alpha=0.8, + ax=ax, + response_method="predict", + plot_method="pcolormesh", + shading="auto", + ) # Plot also the training and testing points plt.scatter(X[:, 0], X[:, 1], c=y, cmap=cmap_bold, edgecolor="k", s=20) - plt.xlim(xx.min(), xx.max()) - plt.ylim(yy.min(), yy.max()) plt.title("{} (k = {})".format(name, n_neighbors)) plt.text( 0.9, diff --git a/examples/neighbors/plot_nearest_centroid.py b/examples/neighbors/plot_nearest_centroid.py index a2d0bea5623d7..0ea3c0c6b1209 100644 --- a/examples/neighbors/plot_nearest_centroid.py +++ b/examples/neighbors/plot_nearest_centroid.py @@ -13,6 +13,7 @@ from matplotlib.colors import ListedColormap from sklearn import datasets from sklearn.neighbors import NearestCentroid +from sklearn.inspection import DecisionBoundaryDisplay n_neighbors = 15 @@ -23,8 +24,6 @@ X = iris.data[:, :2] y = iris.target -h = 0.02 # step size in the mesh - # Create color maps cmap_light = ListedColormap(["orange", "cyan", "cornflowerblue"]) cmap_bold = ListedColormap(["darkorange", "c", "darkblue"]) @@ -35,17 +34,11 @@ clf.fit(X, y) y_pred = clf.predict(X) print(shrinkage, np.mean(y == y_pred)) - # Plot the decision boundary. For that, we will assign a color to each - # point in the mesh [x_min, x_max]x[y_min, y_max]. - x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1 - y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1 - xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h)) - Z = clf.predict(np.c_[xx.ravel(), yy.ravel()]) - - # Put the result into a color plot - Z = Z.reshape(xx.shape) - plt.figure() - plt.pcolormesh(xx, yy, Z, cmap=cmap_light) + + _, ax = plt.subplots() + DecisionBoundaryDisplay.from_estimator( + clf, X, cmap=cmap_light, ax=ax, response_method="predict" + ) # Plot also the training points plt.scatter(X[:, 0], X[:, 1], c=y, cmap=cmap_bold, edgecolor="k", s=20) diff --git a/examples/svm/plot_custom_kernel.py b/examples/svm/plot_custom_kernel.py index 1dd0f2af6e145..c2c3bc6e6ba28 100644 --- a/examples/svm/plot_custom_kernel.py +++ b/examples/svm/plot_custom_kernel.py @@ -11,6 +11,7 @@ import numpy as np import matplotlib.pyplot as plt from sklearn import svm, datasets +from sklearn.inspection import DecisionBoundaryDisplay # import some data to play with iris = datasets.load_iris() @@ -37,16 +38,16 @@ def my_kernel(X, Y): clf = svm.SVC(kernel=my_kernel) clf.fit(X, Y) -# Plot the decision boundary. For that, we will assign a color to each -# point in the mesh [x_min, x_max]x[y_min, y_max]. -x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1 -y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1 -xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h)) -Z = clf.predict(np.c_[xx.ravel(), yy.ravel()]) - -# Put the result into a color plot -Z = Z.reshape(xx.shape) -plt.pcolormesh(xx, yy, Z, cmap=plt.cm.Paired) +ax = plt.gca() +DecisionBoundaryDisplay.from_estimator( + clf, + X, + cmap=plt.cm.Paired, + ax=ax, + response_method="predict", + plot_method="pcolormesh", + shading="auto", +) # Plot also the training points plt.scatter(X[:, 0], X[:, 1], c=Y, cmap=plt.cm.Paired, edgecolors="k") diff --git a/examples/svm/plot_iris_svc.py b/examples/svm/plot_iris_svc.py index c4e5dc8314784..5931ad57c263f 100644 --- a/examples/svm/plot_iris_svc.py +++ b/examples/svm/plot_iris_svc.py @@ -34,45 +34,9 @@ """ -import numpy as np import matplotlib.pyplot as plt from sklearn import svm, datasets - - -def make_meshgrid(x, y, h=0.02): - """Create a mesh of points to plot in - - Parameters - ---------- - x: data to base x-axis meshgrid on - y: data to base y-axis meshgrid on - h: stepsize for meshgrid, optional - - Returns - ------- - xx, yy : ndarray - """ - x_min, x_max = x.min() - 1, x.max() + 1 - y_min, y_max = y.min() - 1, y.max() + 1 - xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h)) - return xx, yy - - -def plot_contours(ax, clf, xx, yy, **params): - """Plot the decision boundaries for a classifier. - - Parameters - ---------- - ax: matplotlib axes object - clf: a classifier - xx: meshgrid ndarray - yy: meshgrid ndarray - params: dictionary of params to pass to contourf, optional - """ - Z = clf.predict(np.c_[xx.ravel(), yy.ravel()]) - Z = Z.reshape(xx.shape) - out = ax.contourf(xx, yy, Z, **params) - return out +from sklearn.inspection import DecisionBoundaryDisplay # import some data to play with @@ -105,15 +69,19 @@ def plot_contours(ax, clf, xx, yy, **params): plt.subplots_adjust(wspace=0.4, hspace=0.4) X0, X1 = X[:, 0], X[:, 1] -xx, yy = make_meshgrid(X0, X1) for clf, title, ax in zip(models, titles, sub.flatten()): - plot_contours(ax, clf, xx, yy, cmap=plt.cm.coolwarm, alpha=0.8) + disp = DecisionBoundaryDisplay.from_estimator( + clf, + X, + response_method="predict", + cmap=plt.cm.coolwarm, + alpha=0.8, + ax=ax, + xlabel=iris.feature_names[0], + ylabel=iris.feature_names[1], + ) ax.scatter(X0, X1, c=y, cmap=plt.cm.coolwarm, s=20, edgecolors="k") - ax.set_xlim(xx.min(), xx.max()) - ax.set_ylim(yy.min(), yy.max()) - ax.set_xlabel("Sepal length") - ax.set_ylabel("Sepal width") ax.set_xticks(()) ax.set_yticks(()) ax.set_title(title) diff --git a/examples/svm/plot_linearsvc_support_vectors.py b/examples/svm/plot_linearsvc_support_vectors.py index 298ec5e2419fb..7fdfea416013f 100644 --- a/examples/svm/plot_linearsvc_support_vectors.py +++ b/examples/svm/plot_linearsvc_support_vectors.py @@ -13,6 +13,7 @@ import matplotlib.pyplot as plt from sklearn.datasets import make_blobs from sklearn.svm import LinearSVC +from sklearn.inspection import DecisionBoundaryDisplay X, y = make_blobs(n_samples=40, centers=2, random_state=0) @@ -32,17 +33,12 @@ plt.subplot(1, 2, i + 1) plt.scatter(X[:, 0], X[:, 1], c=y, s=30, cmap=plt.cm.Paired) ax = plt.gca() - xlim = ax.get_xlim() - ylim = ax.get_ylim() - xx, yy = np.meshgrid( - np.linspace(xlim[0], xlim[1], 50), np.linspace(ylim[0], ylim[1], 50) - ) - Z = clf.decision_function(np.c_[xx.ravel(), yy.ravel()]) - Z = Z.reshape(xx.shape) - plt.contour( - xx, - yy, - Z, + DecisionBoundaryDisplay.from_estimator( + clf, + X, + ax=ax, + grid_resolution=50, + plot_method="contour", colors="k", levels=[-1, 0, 1], alpha=0.5, diff --git a/examples/svm/plot_separating_hyperplane.py b/examples/svm/plot_separating_hyperplane.py index 627aa3b8a2e72..45bacff6a2b97 100644 --- a/examples/svm/plot_separating_hyperplane.py +++ b/examples/svm/plot_separating_hyperplane.py @@ -9,10 +9,10 @@ """ -import numpy as np import matplotlib.pyplot as plt from sklearn import svm from sklearn.datasets import make_blobs +from sklearn.inspection import DecisionBoundaryDisplay # we create 40 separable points @@ -26,19 +26,15 @@ # plot the decision function ax = plt.gca() -xlim = ax.get_xlim() -ylim = ax.get_ylim() - -# create grid to evaluate model -xx = np.linspace(xlim[0], xlim[1], 30) -yy = np.linspace(ylim[0], ylim[1], 30) -YY, XX = np.meshgrid(yy, xx) -xy = np.vstack([XX.ravel(), YY.ravel()]).T -Z = clf.decision_function(xy).reshape(XX.shape) - -# plot decision boundary and margins -ax.contour( - XX, YY, Z, colors="k", levels=[-1, 0, 1], alpha=0.5, linestyles=["--", "-", "--"] +DecisionBoundaryDisplay.from_estimator( + clf, + X, + plot_method="contour", + colors="k", + levels=[-1, 0, 1], + alpha=0.5, + linestyles=["--", "-", "--"], + ax=ax, ) # plot support vectors ax.scatter( diff --git a/examples/svm/plot_separating_hyperplane_unbalanced.py b/examples/svm/plot_separating_hyperplane_unbalanced.py index b810016ffc74c..fe71420ffd0b3 100644 --- a/examples/svm/plot_separating_hyperplane_unbalanced.py +++ b/examples/svm/plot_separating_hyperplane_unbalanced.py @@ -25,10 +25,10 @@ """ -import numpy as np import matplotlib.pyplot as plt from sklearn import svm from sklearn.datasets import make_blobs +from sklearn.inspection import DecisionBoundaryDisplay # we create two clusters of random points n_samples_1 = 1000 @@ -56,29 +56,31 @@ # plot the decision functions for both classifiers ax = plt.gca() -xlim = ax.get_xlim() -ylim = ax.get_ylim() - -# create grid to evaluate model -xx = np.linspace(xlim[0], xlim[1], 30) -yy = np.linspace(ylim[0], ylim[1], 30) -YY, XX = np.meshgrid(yy, xx) -xy = np.vstack([XX.ravel(), YY.ravel()]).T - -# get the separating hyperplane -Z = clf.decision_function(xy).reshape(XX.shape) - -# plot decision boundary and margins -a = ax.contour(XX, YY, Z, colors="k", levels=[0], alpha=0.5, linestyles=["-"]) - -# get the separating hyperplane for weighted classes -Z = wclf.decision_function(xy).reshape(XX.shape) +disp = DecisionBoundaryDisplay.from_estimator( + clf, + X, + plot_method="contour", + colors="k", + levels=[0], + alpha=0.5, + linestyles=["-"], + ax=ax, +) # plot decision boundary and margins for weighted classes -b = ax.contour(XX, YY, Z, colors="r", levels=[0], alpha=0.5, linestyles=["-"]) +wdisp = DecisionBoundaryDisplay.from_estimator( + wclf, + X, + plot_method="contour", + colors="r", + levels=[0], + alpha=0.5, + linestyles=["-"], + ax=ax, +) plt.legend( - [a.collections[0], b.collections[0]], + [disp.surface_.collections[0], wdisp.surface_.collections[0]], ["non weighted", "weighted"], loc="upper right", ) diff --git a/examples/tree/plot_iris_dtc.py b/examples/tree/plot_iris_dtc.py index ef9ea0232a2ca..14f6506b5810f 100644 --- a/examples/tree/plot_iris_dtc.py +++ b/examples/tree/plot_iris_dtc.py @@ -25,7 +25,11 @@ # Display the decision functions of trees trained on all pairs of features. import numpy as np import matplotlib.pyplot as plt + +from sklearn.datasets import load_iris from sklearn.tree import DecisionTreeClassifier +from sklearn.inspection import DecisionBoundaryDisplay + # Parameters n_classes = 3 @@ -42,21 +46,17 @@ clf = DecisionTreeClassifier().fit(X, y) # Plot the decision boundary - plt.subplot(2, 3, pairidx + 1) - - x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1 - y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1 - xx, yy = np.meshgrid( - np.arange(x_min, x_max, plot_step), np.arange(y_min, y_max, plot_step) - ) + ax = plt.subplot(2, 3, pairidx + 1) plt.tight_layout(h_pad=0.5, w_pad=0.5, pad=2.5) - - Z = clf.predict(np.c_[xx.ravel(), yy.ravel()]) - Z = Z.reshape(xx.shape) - cs = plt.contourf(xx, yy, Z, cmap=plt.cm.RdYlBu) - - plt.xlabel(iris.feature_names[pair[0]]) - plt.ylabel(iris.feature_names[pair[1]]) + DecisionBoundaryDisplay.from_estimator( + clf, + X, + cmap=plt.cm.RdYlBu, + response_method="predict", + ax=ax, + xlabel=iris.feature_names[pair[0]], + ylabel=iris.feature_names[pair[1]], + ) # Plot the training points for i, color in zip(range(n_classes), plot_colors): diff --git a/sklearn/inspection/__init__.py b/sklearn/inspection/__init__.py index 70e6c48a2998b..76c44ea81bbbe 100644 --- a/sklearn/inspection/__init__.py +++ b/sklearn/inspection/__init__.py @@ -2,6 +2,7 @@ from ._permutation_importance import permutation_importance +from ._plot.decision_boundary import DecisionBoundaryDisplay from ._partial_dependence import partial_dependence from ._plot.partial_dependence import plot_partial_dependence @@ -13,4 +14,5 @@ "plot_partial_dependence", "permutation_importance", "PartialDependenceDisplay", + "DecisionBoundaryDisplay", ] diff --git a/sklearn/inspection/_plot/decision_boundary.py b/sklearn/inspection/_plot/decision_boundary.py new file mode 100644 index 0000000000000..78a8b16bd577a --- /dev/null +++ b/sklearn/inspection/_plot/decision_boundary.py @@ -0,0 +1,331 @@ +from functools import reduce + +import numpy as np + +from ...preprocessing import LabelEncoder +from ...utils import check_matplotlib_support +from ...utils import _safe_indexing +from ...base import is_regressor +from ...utils.validation import check_is_fitted, _is_arraylike_not_scalar + + +def _check_boundary_response_method(estimator, response_method): + """Return prediction method from the `response_method` for decision boundary. + + Parameters + ---------- + estimator : object + Fitted estimator to check. + + response_method : {'auto', 'predict_proba', 'decision_function', 'predict'} + Specifies whether to use :term:`predict_proba`, + :term:`decision_function`, :term:`predict` as the target response. + If set to 'auto', the response method is tried in the following order: + :term:`decision_function`, :term:`predict_proba`, :term:`predict`. + + Returns + ------- + prediction_method: callable + Prediction method of estimator. + """ + has_classes = hasattr(estimator, "classes_") + if has_classes and _is_arraylike_not_scalar(estimator.classes_[0]): + msg = "Multi-label and multi-output multi-class classifiers are not supported" + raise ValueError(msg) + + if has_classes and len(estimator.classes_) > 2: + if response_method not in {"auto", "predict"}: + msg = ( + "Multiclass classifiers are only supported when response_method is" + " 'predict' or 'auto'" + ) + raise ValueError(msg) + methods_list = ["predict"] + elif response_method == "auto": + methods_list = ["decision_function", "predict_proba", "predict"] + else: + methods_list = [response_method] + + prediction_method = [getattr(estimator, method, None) for method in methods_list] + prediction_method = reduce(lambda x, y: x or y, prediction_method) + if prediction_method is None: + raise ValueError( + f"{estimator.__class__.__name__} has none of the following attributes: " + f"{', '.join(methods_list)}." + ) + + return prediction_method + + +class DecisionBoundaryDisplay: + """Decisions boundary visualization. + + It is recommended to use + :func:`~sklearn.inspection.DecisionBoundaryDisplay.from_estimator` + to create a :class:`DecisionBoundaryDisplay`. All parameters are stored as + attributes. + + Read more in the :ref:`User Guide `. + + .. versionadded:: 1.1 + + Parameters + ---------- + xx0 : ndarray of shape (grid_resolution, grid_resolution) + First output of :func:`meshgrid `. + + xx1 : ndarray of shape (grid_resolution, grid_resolution) + Second output of :func:`meshgrid `. + + response : ndarray of shape (grid_resolution, grid_resolution) + Values of the response function. + + xlabel : str, default=None + Default label to place on x axis. + + ylabel : str, default=None + Default label to place on y axis. + + Attributes + ---------- + surface_ : matplotlib `QuadContourSet` or `QuadMesh` + If `plot_method` is 'contour' or 'contourf', `surface_` is a + :class:`QuadContourSet `. If + `plot_method is `pcolormesh`, `surface_` is a + :class:`QuadMesh `. + + ax_ : matplotlib Axes + Axes with confusion matrix. + + figure_ : matplotlib Figure + Figure containing the confusion matrix. + """ + + def __init__(self, *, xx0, xx1, response, xlabel=None, ylabel=None): + self.xx0 = xx0 + self.xx1 = xx1 + self.response = response + self.xlabel = xlabel + self.ylabel = ylabel + + def plot(self, plot_method="contourf", ax=None, xlabel=None, ylabel=None, **kwargs): + """Plot visualization. + + Parameters + ---------- + plot_method : {'contourf', 'contour', 'pcolormesh'}, default='contourf' + Plotting method to call when plotting the response. Please refer + to the following matplotlib documentation for details: + :func:`contourf `, + :func:`contour `, + :func:`pcolomesh `. + + ax : Matplotlib axes, default=None + Axes object to plot on. If `None`, a new figure and axes is + created. + + xlabel : str, default=None + Overwrite the x-axis label. + + ylabel : str, default=None + Overwrite the y-axis label. + + **kwargs : dict + Additional keyword arguments to be passed to the `plot_method`. + + Returns + ------- + display: :class:`~sklearn.inspection.DecisionBoundaryDisplay` + """ + check_matplotlib_support("DecisionBoundaryDisplay.plot") + import matplotlib.pyplot as plt # noqa + + if plot_method not in ("contourf", "contour", "pcolormesh"): + raise ValueError( + "plot_method must be 'contourf', 'contour', or 'pcolormesh'" + ) + + if ax is None: + _, ax = plt.subplots() + + plot_func = getattr(ax, plot_method) + self.surface_ = plot_func(self.xx0, self.xx1, self.response, **kwargs) + + if xlabel is not None or not ax.get_xlabel(): + xlabel = self.xlabel if xlabel is None else xlabel + ax.set_xlabel(xlabel) + if ylabel is not None or not ax.get_ylabel(): + ylabel = self.ylabel if ylabel is None else ylabel + ax.set_ylabel(ylabel) + + self.ax_ = ax + self.figure_ = ax.figure + return self + + @classmethod + def from_estimator( + cls, + estimator, + X, + *, + grid_resolution=100, + eps=1.0, + plot_method="contourf", + response_method="auto", + xlabel=None, + ylabel=None, + ax=None, + **kwargs, + ): + """Plot decision boundary given an estimator. + + Read more in the :ref:`User Guide `. + + Parameters + ---------- + estimator : object + Trained estimator used to plot the decision boundary. + + X : {array-like, sparse matrix, dataframe} of shape (n_samples, 2) + Input data that should be only 2-dimensional. + + grid_resolution : int, default=100 + Number of grid points to use for plotting decision boundary. + Higher values will make the plot look nicer but be slower to + render. + + eps : float, default=1.0 + Extends the minimum and maximum values of X for evaluating the + response function. + + plot_method : {'contourf', 'contour', 'pcolormesh'}, default='contourf' + Plotting method to call when plotting the response. Please refer + to the following matplotlib documentation for details: + :func:`contourf `, + :func:`contour `, + :func:`pcolomesh `. + + response_method : {'auto', 'predict_proba', 'decision_function', \ + 'predict'}, default='auto' + Specifies whether to use :term:`predict_proba`, + :term:`decision_function`, :term:`predict` as the target response. + If set to 'auto', the response method is tried in the following order: + :term:`decision_function`, :term:`predict_proba`, :term:`predict`. + For multiclass problems, :term:`predict` is selected when + `response_method="auto"`. + + xlabel : str, default=None + The label used for the x-axis. If `None`, an attempt is made to + extract a label from `X` if it is a dataframe, otherwise an empty + string is used. + + ylabel : str, default=None + The label used for the y-axis. If `None`, an attempt is made to + extract a label from `X` if it is a dataframe, otherwise an empty + string is used. + + ax : Matplotlib axes, default=None + Axes object to plot on. If `None`, a new figure and axes is + created. + + **kwargs : dict + Additional keyword arguments to be passed to the + `plot_method`. + + Returns + ------- + display : :class:`~sklearn.inspection.DecisionBoundaryDisplay` + Object that stores the result. + + See Also + -------- + DecisionBoundaryDisplay : Decision boundary visualization. + ConfusionMatrixDisplay.from_estimator : Plot the confusion matrix + given an estimator, the data, and the label. + ConfusionMatrixDisplay.from_predictions : Plot the confusion matrix + given the true and predicted labels. + + Examples + -------- + >>> import matplotlib.pyplot as plt + >>> from sklearn.datasets import load_iris + >>> from sklearn.linear_model import LogisticRegression + >>> from sklearn.inspection import DecisionBoundaryDisplay + >>> iris = load_iris() + >>> X = iris.data[:, :2] + >>> classifier = LogisticRegression().fit(X, iris.target) + >>> disp = DecisionBoundaryDisplay.from_estimator( + ... classifier, X, response_method="predict", + ... xlabel=iris.feature_names[0], ylabel=iris.feature_names[1], + ... alpha=0.5, + ... ) + >>> disp.ax_.scatter(X[:, 0], X[:, 1], c=iris.target, edgecolor="k") + <...> + >>> plt.show() + """ + check_matplotlib_support(f"{cls.__name__}.from_estimator") + check_is_fitted(estimator) + + if not grid_resolution > 1: + raise ValueError( + "grid_resolution must be greater than 1. Got" + f" {grid_resolution} instead." + ) + + if not eps >= 0: + raise ValueError( + f"eps must be greater than or equal to 0. Got {eps} instead." + ) + + possible_plot_methods = ("contourf", "contour", "pcolormesh") + if plot_method not in possible_plot_methods: + available_methods = ", ".join(possible_plot_methods) + raise ValueError( + f"plot_method must be one of {available_methods}. " + f"Got {plot_method} instead." + ) + + x0, x1 = _safe_indexing(X, 0, axis=1), _safe_indexing(X, 1, axis=1) + + x0_min, x0_max = x0.min() - eps, x0.max() + eps + x1_min, x1_max = x1.min() - eps, x1.max() + eps + + xx0, xx1 = np.meshgrid( + np.linspace(x0_min, x0_max, grid_resolution), + np.linspace(x1_min, x1_max, grid_resolution), + ) + + pred_func = _check_boundary_response_method(estimator, response_method) + response = pred_func(np.c_[xx0.ravel(), xx1.ravel()]) + + # convert classes predictions into integers + if pred_func.__name__ == "predict" and hasattr(estimator, "classes_"): + encoder = LabelEncoder() + encoder.classes_ = estimator.classes_ + response = encoder.transform(response) + + if response.ndim != 1: + if is_regressor(estimator): + raise ValueError("Multi-output regressors are not supported") + + # TODO: Support pos_label + response = response[:, 1] + + if xlabel is not None: + xlabel = xlabel + else: + xlabel = X.columns[0] if hasattr(X, "columns") else "" + + if ylabel is not None: + ylabel = ylabel + else: + ylabel = X.columns[1] if hasattr(X, "columns") else "" + + display = DecisionBoundaryDisplay( + xx0=xx0, + xx1=xx1, + response=response.reshape(xx0.shape), + xlabel=xlabel, + ylabel=ylabel, + ) + return display.plot(ax=ax, plot_method=plot_method, **kwargs) diff --git a/sklearn/inspection/_plot/tests/test_boundary_decision_display.py b/sklearn/inspection/_plot/tests/test_boundary_decision_display.py new file mode 100644 index 0000000000000..955deb33331d6 --- /dev/null +++ b/sklearn/inspection/_plot/tests/test_boundary_decision_display.py @@ -0,0 +1,321 @@ +import pytest +import numpy as np +from numpy.testing import assert_allclose + +from sklearn.base import BaseEstimator +from sklearn.base import ClassifierMixin +from sklearn.datasets import make_classification +from sklearn.linear_model import LogisticRegression +from sklearn.datasets import load_iris +from sklearn.datasets import make_multilabel_classification +from sklearn.tree import DecisionTreeRegressor +from sklearn.tree import DecisionTreeClassifier + +from sklearn.inspection import DecisionBoundaryDisplay +from sklearn.inspection._plot.decision_boundary import _check_boundary_response_method + + +# TODO: Remove when https://github.com/numpy/numpy/issues/14397 is resolved +pytestmark = pytest.mark.filterwarnings( + "ignore:In future, it will be an error for 'np.bool_':DeprecationWarning:" + "matplotlib.*" +) + + +X, y = make_classification( + n_informative=1, + n_redundant=1, + n_clusters_per_class=1, + n_features=2, + random_state=42, +) + + +@pytest.fixture(scope="module") +def fitted_clf(): + return LogisticRegression().fit(X, y) + + +def test_check_boundary_response_method_auto(): + """Check _check_boundary_response_method behavior with 'auto'.""" + + class A: + def decision_function(self): + pass + + a_inst = A() + method = _check_boundary_response_method(a_inst, "auto") + assert method == a_inst.decision_function + + class B: + def predict_proba(self): + pass + + b_inst = B() + method = _check_boundary_response_method(b_inst, "auto") + assert method == b_inst.predict_proba + + class C: + def predict_proba(self): + pass + + def decision_function(self): + pass + + c_inst = C() + method = _check_boundary_response_method(c_inst, "auto") + assert method == c_inst.decision_function + + class D: + def predict(self): + pass + + d_inst = D() + method = _check_boundary_response_method(d_inst, "auto") + assert method == d_inst.predict + + +@pytest.mark.parametrize("response_method", ["predict_proba", "decision_function"]) +def test_multiclass_error(pyplot, response_method): + """Check multiclass errors.""" + X, y = make_classification(n_classes=3, n_informative=3, random_state=0) + X = X[:, [0, 1]] + lr = LogisticRegression().fit(X, y) + + msg = ( + "Multiclass classifiers are only supported when response_method is 'predict' or" + " 'auto'" + ) + with pytest.raises(ValueError, match=msg): + DecisionBoundaryDisplay.from_estimator(lr, X, response_method=response_method) + + +@pytest.mark.parametrize("response_method", ["auto", "predict"]) +def test_multiclass(pyplot, response_method): + """Check multiclass gives expected results.""" + grid_resolution = 10 + eps = 1.0 + X, y = make_classification(n_classes=3, n_informative=3, random_state=0) + X = X[:, [0, 1]] + lr = LogisticRegression(random_state=0).fit(X, y) + + disp = DecisionBoundaryDisplay.from_estimator( + lr, X, response_method=response_method, grid_resolution=grid_resolution, eps=1.0 + ) + + x0_min, x0_max = X[:, 0].min() - eps, X[:, 0].max() + eps + x1_min, x1_max = X[:, 1].min() - eps, X[:, 1].max() + eps + xx0, xx1 = np.meshgrid( + np.linspace(x0_min, x0_max, grid_resolution), + np.linspace(x1_min, x1_max, grid_resolution), + ) + response = lr.predict(np.c_[xx0.ravel(), xx1.ravel()]) + assert_allclose(disp.response, response.reshape(xx0.shape)) + assert_allclose(disp.xx0, xx0) + assert_allclose(disp.xx1, xx1) + + +@pytest.mark.parametrize( + "kwargs, error_msg", + [ + ( + {"plot_method": "hello_world"}, + r"plot_method must be one of contourf, contour, pcolormesh. Got hello_world" + r" instead.", + ), + ( + {"grid_resolution": 1}, + r"grid_resolution must be greater than 1. Got 1 instead", + ), + ( + {"grid_resolution": -1}, + r"grid_resolution must be greater than 1. Got -1 instead", + ), + ({"eps": -1.1}, r"eps must be greater than or equal to 0. Got -1.1 instead"), + ], +) +def test_input_validation_errors(pyplot, kwargs, error_msg, fitted_clf): + """Check input validation from_estimator.""" + with pytest.raises(ValueError, match=error_msg): + DecisionBoundaryDisplay.from_estimator(fitted_clf, X, **kwargs) + + +def test_display_plot_input_error(pyplot, fitted_clf): + """Check input validation for `plot`.""" + disp = DecisionBoundaryDisplay.from_estimator(fitted_clf, X, grid_resolution=5) + + with pytest.raises(ValueError, match="plot_method must be 'contourf'"): + disp.plot(plot_method="hello_world") + + +@pytest.mark.parametrize( + "response_method", ["auto", "predict", "predict_proba", "decision_function"] +) +@pytest.mark.parametrize("plot_method", ["contourf", "contour"]) +def test_decision_boundary_display(pyplot, fitted_clf, response_method, plot_method): + """Check that decision boundary is correct.""" + fig, ax = pyplot.subplots() + eps = 2.0 + disp = DecisionBoundaryDisplay.from_estimator( + fitted_clf, + X, + grid_resolution=5, + response_method=response_method, + plot_method=plot_method, + eps=eps, + ax=ax, + ) + assert isinstance(disp.surface_, pyplot.matplotlib.contour.QuadContourSet) + assert disp.ax_ == ax + assert disp.figure_ == fig + + x0, x1 = X[:, 0], X[:, 1] + + x0_min, x0_max = x0.min() - eps, x0.max() + eps + x1_min, x1_max = x1.min() - eps, x1.max() + eps + + assert disp.xx0.min() == pytest.approx(x0_min) + assert disp.xx0.max() == pytest.approx(x0_max) + assert disp.xx1.min() == pytest.approx(x1_min) + assert disp.xx1.max() == pytest.approx(x1_max) + + fig2, ax2 = pyplot.subplots() + # change plotting method for second plot + disp.plot(plot_method="pcolormesh", ax=ax2, shading="auto") + assert isinstance(disp.surface_, pyplot.matplotlib.collections.QuadMesh) + assert disp.ax_ == ax2 + assert disp.figure_ == fig2 + + +@pytest.mark.parametrize( + "response_method, msg", + [ + ( + "predict_proba", + "MyClassifier has none of the following attributes: predict_proba", + ), + ( + "decision_function", + "MyClassifier has none of the following attributes: decision_function", + ), + ( + "auto", + "MyClassifier has none of the following attributes: decision_function, " + "predict_proba, predict", + ), + ( + "bad_method", + "MyClassifier has none of the following attributes: bad_method", + ), + ], +) +def test_error_bad_response(pyplot, response_method, msg): + """Check errors for bad response.""" + + class MyClassifier(BaseEstimator, ClassifierMixin): + def fit(self, X, y): + self.fitted_ = True + self.classes_ = [0, 1] + return self + + clf = MyClassifier().fit(X, y) + + with pytest.raises(ValueError, match=msg): + DecisionBoundaryDisplay.from_estimator(clf, X, response_method=response_method) + + +@pytest.mark.parametrize("response_method", ["auto", "predict", "predict_proba"]) +def test_multilabel_classifier_error(pyplot, response_method): + """Check that multilabel classifier raises correct error.""" + X, y = make_multilabel_classification(random_state=0) + X = X[:, :2] + tree = DecisionTreeClassifier().fit(X, y) + + msg = "Multi-label and multi-output multi-class classifiers are not supported" + with pytest.raises(ValueError, match=msg): + DecisionBoundaryDisplay.from_estimator( + tree, + X, + response_method=response_method, + ) + + +@pytest.mark.parametrize("response_method", ["auto", "predict", "predict_proba"]) +def test_multi_output_multi_class_classifier_error(pyplot, response_method): + """Check that multi-output multi-class classifier raises correct error.""" + X = np.asarray([[0, 1], [1, 2]]) + y = np.asarray([["tree", "cat"], ["cat", "tree"]]) + tree = DecisionTreeClassifier().fit(X, y) + + msg = "Multi-label and multi-output multi-class classifiers are not supported" + with pytest.raises(ValueError, match=msg): + DecisionBoundaryDisplay.from_estimator( + tree, + X, + response_method=response_method, + ) + + +def test_multioutput_regressor_error(pyplot): + """Check that multioutput regressor raises correct error.""" + X = np.asarray([[0, 1], [1, 2]]) + y = np.asarray([[0, 1], [4, 1]]) + tree = DecisionTreeRegressor().fit(X, y) + with pytest.raises(ValueError, match="Multi-output regressors are not supported"): + DecisionBoundaryDisplay.from_estimator(tree, X) + + +def test_dataframe_labels_used(pyplot, fitted_clf): + """Check that column names are used for pandas.""" + pd = pytest.importorskip("pandas") + df = pd.DataFrame(X, columns=["col_x", "col_y"]) + + # pandas column names are used by default + _, ax = pyplot.subplots() + disp = DecisionBoundaryDisplay.from_estimator(fitted_clf, df, ax=ax) + assert ax.get_xlabel() == "col_x" + assert ax.get_ylabel() == "col_y" + + # second call to plot will have the names + fig, ax = pyplot.subplots() + disp.plot(ax=ax) + assert ax.get_xlabel() == "col_x" + assert ax.get_ylabel() == "col_y" + + # axes with a label will not get overridden + fig, ax = pyplot.subplots() + ax.set(xlabel="hello", ylabel="world") + disp.plot(ax=ax) + assert ax.get_xlabel() == "hello" + assert ax.get_ylabel() == "world" + + # labels get overriden only if provided to the `plot` method + disp.plot(ax=ax, xlabel="overwritten_x", ylabel="overwritten_y") + assert ax.get_xlabel() == "overwritten_x" + assert ax.get_ylabel() == "overwritten_y" + + # labels do not get inferred if provided to `from_estimator` + _, ax = pyplot.subplots() + disp = DecisionBoundaryDisplay.from_estimator( + fitted_clf, df, ax=ax, xlabel="overwritten_x", ylabel="overwritten_y" + ) + assert ax.get_xlabel() == "overwritten_x" + assert ax.get_ylabel() == "overwritten_y" + + +def test_string_target(pyplot): + """Check that decision boundary works with classifiers trained on string labels.""" + iris = load_iris() + X = iris.data[:, [0, 1]] + + # Use strings as target + y = iris.target_names[iris.target] + log_reg = LogisticRegression().fit(X, y) + + # Does not raise + DecisionBoundaryDisplay.from_estimator( + log_reg, + X, + grid_resolution=5, + response_method="predict", + )