From c5b9aa0fcdc15ba70938f9387dd7102884ab098c Mon Sep 17 00:00:00 2001 From: Thomas J Fan Date: Mon, 2 Dec 2019 12:18:03 -0500 Subject: [PATCH 01/48] WIP --- examples/svm/plot_iris_svc.py | 46 ++---------------- sklearn/inspection/__init__.py | 4 +- sklearn/inspection/_plot/__init__.py | 0 sklearn/inspection/_plot/decision_boundary.py | 47 +++++++++++++++++++ sklearn/inspection/_plot/tests/__init__.py | 0 .../tests/test_plot_decision_boundary.py | 0 sklearn/setup.py | 2 + 7 files changed, 55 insertions(+), 44 deletions(-) create mode 100644 sklearn/inspection/_plot/__init__.py create mode 100644 sklearn/inspection/_plot/decision_boundary.py create mode 100644 sklearn/inspection/_plot/tests/__init__.py create mode 100644 sklearn/inspection/_plot/tests/test_plot_decision_boundary.py diff --git a/examples/svm/plot_iris_svc.py b/examples/svm/plot_iris_svc.py index ab7860296985c..fa230c7f9ca88 100644 --- a/examples/svm/plot_iris_svc.py +++ b/examples/svm/plot_iris_svc.py @@ -35,46 +35,9 @@ """ print(__doc__) -import numpy as np import matplotlib.pyplot as plt from sklearn import svm, datasets - - -def make_meshgrid(x, y, h=.02): - """Create a mesh of points to plot in - - Parameters - ---------- - x: data to base x-axis meshgrid on - y: data to base y-axis meshgrid on - h: stepsize for meshgrid, optional - - Returns - ------- - xx, yy : ndarray - """ - x_min, x_max = x.min() - 1, x.max() + 1 - y_min, y_max = y.min() - 1, y.max() + 1 - xx, yy = np.meshgrid(np.arange(x_min, x_max, h), - np.arange(y_min, y_max, h)) - return xx, yy - - -def plot_contours(ax, clf, xx, yy, **params): - """Plot the decision boundaries for a classifier. - - Parameters - ---------- - ax: matplotlib axes object - clf: a classifier - xx: meshgrid ndarray - yy: meshgrid ndarray - params: dictionary of params to pass to contourf, optional - """ - Z = clf.predict(np.c_[xx.ravel(), yy.ravel()]) - Z = Z.reshape(xx.shape) - out = ax.contourf(xx, yy, Z, **params) - return out +from sklearn.inspection import plot_decision_boundary # import some data to play with @@ -103,14 +66,11 @@ def plot_contours(ax, clf, xx, yy, **params): plt.subplots_adjust(wspace=0.4, hspace=0.4) X0, X1 = X[:, 0], X[:, 1] -xx, yy = make_meshgrid(X0, X1) for clf, title, ax in zip(models, titles, sub.flatten()): - plot_contours(ax, clf, xx, yy, - cmap=plt.cm.coolwarm, alpha=0.8) + disp = plot_decision_boundary(clf, X, response_method='predict', + cmap=plt.cm.coolwarm, alpha=0.8, ax=ax) ax.scatter(X0, X1, c=y, cmap=plt.cm.coolwarm, s=20, edgecolors='k') - ax.set_xlim(xx.min(), xx.max()) - ax.set_ylim(yy.min(), yy.max()) ax.set_xlabel('Sepal length') ax.set_ylabel('Sepal width') ax.set_xticks(()) diff --git a/sklearn/inspection/__init__.py b/sklearn/inspection/__init__.py index 04d9d84ecaf02..dfe046a8a24bb 100644 --- a/sklearn/inspection/__init__.py +++ b/sklearn/inspection/__init__.py @@ -3,10 +3,12 @@ from ._partial_dependence import plot_partial_dependence from ._partial_dependence import PartialDependenceDisplay from ._permutation_importance import permutation_importance +from ._plot.decision_boundary import plot_decision_boundary __all__ = [ 'partial_dependence', 'plot_partial_dependence', 'permutation_importance', - 'PartialDependenceDisplay' + 'PartialDependenceDisplay', + 'plot_decision_boundary' ] diff --git a/sklearn/inspection/_plot/__init__.py b/sklearn/inspection/_plot/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sklearn/inspection/_plot/decision_boundary.py b/sklearn/inspection/_plot/decision_boundary.py new file mode 100644 index 0000000000000..83acd4511bb6f --- /dev/null +++ b/sklearn/inspection/_plot/decision_boundary.py @@ -0,0 +1,47 @@ +import numpy as np +from ...utils import check_matplotlib_support + + +class DecisionBoundaryDisplay: + def __init__(self, xx0, xx1, response): + self.xx0 = xx0 + self.xx1 = xx1 + self.response = response + + def plot(self, ax=None, plot_method='contourf', **kwargs): + check_matplotlib_support('DecisionBoundaryDisplay.plot') + import matplotlib.pyplot as plt # noqa + + if ax is None: + _, ax = plt.subplots() + + plot_func = getattr(ax, plot_method) + self.surface_ = plot_func(self.xx0, self.xx1, + self.response, **kwargs) + self.ax_ = ax + self.figure_ = ax.figure + return self + + +def plot_decision_boundary(est, X, grid_resolution=100, + features=(0, 1), + response_method='decision_function', + plot_method='contourf', + ax=None, + **kwargs): + check_matplotlib_support('plot_decision_boundary') + + x0, x1 = X[:, features[0]], X[:, features[1]] + + x0_min, x0_max = x0.min() - 1, x0.max() + 1 + x1_min, x1_max = x1.min() - 1, x1.max() + 1 + + xx0, xx1 = np.meshgrid(np.linspace(x0_min, x0_max, grid_resolution), + np.linspace(x1_min, x1_max, grid_resolution)) + + response_func = getattr(est, response_method) + response = response_func(np.c_[xx0.ravel(), xx1.ravel()]) + display = DecisionBoundaryDisplay(xx0=xx0, xx1=xx1, + response=response.reshape(xx0.shape)) + + return display.plot(ax=ax, plot_method=plot_method, **kwargs) diff --git a/sklearn/inspection/_plot/tests/__init__.py b/sklearn/inspection/_plot/tests/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sklearn/inspection/_plot/tests/test_plot_decision_boundary.py b/sklearn/inspection/_plot/tests/test_plot_decision_boundary.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sklearn/setup.py b/sklearn/setup.py index cc257c30e6f43..7f97e515f0a1d 100644 --- a/sklearn/setup.py +++ b/sklearn/setup.py @@ -39,6 +39,8 @@ def configuration(parent_package='', top_path=None): config.add_subpackage('impute/tests') config.add_subpackage('inspection') config.add_subpackage('inspection/tests') + config.add_subpackage('inspection/_plot') + config.add_subpackage('inspection/_plot/tests') config.add_subpackage('mixture') config.add_subpackage('mixture/tests') config.add_subpackage('model_selection') From 6dd153bb9aacbd6ff46156840e4c866b92abc7cc Mon Sep 17 00:00:00 2001 From: Thomas J Fan Date: Mon, 6 Jan 2020 15:03:27 -0500 Subject: [PATCH 02/48] ENH Completely adds decision boundary --- doc/modules/classes.rst | 2 + doc/visualizations.rst | 2 + .../plot_classifier_comparison.py | 27 +-- examples/cluster/plot_inductive_clustering.py | 15 +- examples/ensemble/plot_adaboost_twoclass.py | 15 +- .../ensemble/plot_voting_decision_regions.py | 15 +- examples/linear_model/plot_iris_logistic.py | 19 +- .../linear_model/plot_logistic_multinomial.py | 18 +- examples/linear_model/plot_sgd_iris.py | 18 +- examples/neighbors/plot_classification.py | 21 +-- examples/neighbors/plot_nca_classification.py | 19 +- examples/neighbors/plot_nearest_centroid.py | 19 +- .../plot_label_propagation_versus_svm_iris.py | 18 +- examples/svm/plot_custom_kernel.py | 14 +- .../svm/plot_linearsvc_support_vectors.py | 13 +- examples/svm/plot_separating_hyperplane.py | 18 +- .../plot_separating_hyperplane_unbalanced.py | 27 +-- examples/tree/plot_iris_dtc.py | 14 +- sklearn/inspection/__init__.py | 4 +- sklearn/inspection/_plot/decision_boundary.py | 174 ++++++++++++++++-- .../tests/test_plot_decision_boundary.py | 119 ++++++++++++ 21 files changed, 358 insertions(+), 233 deletions(-) diff --git a/doc/modules/classes.rst b/doc/modules/classes.rst index f8e5195cc9174..5f39c11bcf29c 100644 --- a/doc/modules/classes.rst +++ b/doc/modules/classes.rst @@ -653,12 +653,14 @@ Plotting :toctree: generated/ :template: class.rst + inspection.DecisionBoundaryDisplay inspection.PartialDependenceDisplay .. autosummary:: :toctree: generated/ :template: function.rst + inspection.plot_decision_boundary inspection.plot_partial_dependence .. _isotonic_ref: diff --git a/doc/visualizations.rst b/doc/visualizations.rst index 47d826602b62f..f6e41632b12c1 100644 --- a/doc/visualizations.rst +++ b/doc/visualizations.rst @@ -72,6 +72,7 @@ Functions .. autosummary:: inspection.plot_partial_dependence + inspection.plot_decision_boundary metrics.plot_confusion_matrix metrics.plot_precision_recall_curve metrics.plot_roc_curve @@ -85,6 +86,7 @@ Display Objects .. autosummary:: inspection.PartialDependenceDisplay + inspection.DecisionBoundaryDisplay metrics.ConfusionMatrixDisplay metrics.PrecisionRecallDisplay metrics.RocCurveDisplay diff --git a/examples/classification/plot_classifier_comparison.py b/examples/classification/plot_classifier_comparison.py index 83019e821dae5..a6ea51a484658 100644 --- a/examples/classification/plot_classifier_comparison.py +++ b/examples/classification/plot_classifier_comparison.py @@ -43,8 +43,7 @@ from sklearn.ensemble import RandomForestClassifier, AdaBoostClassifier from sklearn.naive_bayes import GaussianNB from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis - -h = .02 # step size in the mesh +from sklearn.inspection import plot_decision_boundary names = ["Nearest Neighbors", "Linear SVM", "RBF SVM", "Gaussian Process", "Decision Tree", "Random Forest", "Neural Net", "AdaBoost", @@ -85,8 +84,6 @@ x_min, x_max = X[:, 0].min() - .5, X[:, 0].max() + .5 y_min, y_max = X[:, 1].min() - .5, X[:, 1].max() + .5 - xx, yy = np.meshgrid(np.arange(x_min, x_max, h), - np.arange(y_min, y_max, h)) # just plot the dataset first cm = plt.cm.RdBu @@ -100,8 +97,8 @@ # Plot the testing points ax.scatter(X_test[:, 0], X_test[:, 1], c=y_test, cmap=cm_bright, alpha=0.6, edgecolors='k') - ax.set_xlim(xx.min(), xx.max()) - ax.set_ylim(yy.min(), yy.max()) + ax.set_xlim(x_min, x_max) + ax.set_ylim(y_min, y_max) ax.set_xticks(()) ax.set_yticks(()) i += 1 @@ -111,17 +108,7 @@ ax = plt.subplot(len(datasets), len(classifiers) + 1, i) clf.fit(X_train, y_train) score = clf.score(X_test, y_test) - - # Plot the decision boundary. For that, we will assign a color to each - # point in the mesh [x_min, x_max]x[y_min, y_max]. - if hasattr(clf, "decision_function"): - Z = clf.decision_function(np.c_[xx.ravel(), yy.ravel()]) - else: - Z = clf.predict_proba(np.c_[xx.ravel(), yy.ravel()])[:, 1] - - # Put the result into a color plot - Z = Z.reshape(xx.shape) - ax.contourf(xx, yy, Z, cmap=cm, alpha=.8) + plot_decision_boundary(clf, X, cmap=cm, alpha=0.8, ax=ax, eps=0.5) # Plot the training points ax.scatter(X_train[:, 0], X_train[:, 1], c=y_train, cmap=cm_bright, @@ -130,13 +117,13 @@ ax.scatter(X_test[:, 0], X_test[:, 1], c=y_test, cmap=cm_bright, edgecolors='k', alpha=0.6) - ax.set_xlim(xx.min(), xx.max()) - ax.set_ylim(yy.min(), yy.max()) + ax.set_xlim(x_min, x_max) + ax.set_ylim(y_min, y_max) ax.set_xticks(()) ax.set_yticks(()) if ds_cnt == 0: ax.set_title(name) - ax.text(xx.max() - .3, yy.min() + .3, ('%.2f' % score).lstrip('0'), + ax.text(x_max - .3, y_min + .3, ('%.2f' % score).lstrip('0'), size=15, horizontalalignment='right') i += 1 diff --git a/examples/cluster/plot_inductive_clustering.py b/examples/cluster/plot_inductive_clustering.py index c5a51db5ef577..623b8da3144ce 100644 --- a/examples/cluster/plot_inductive_clustering.py +++ b/examples/cluster/plot_inductive_clustering.py @@ -29,7 +29,7 @@ from sklearn.datasets import make_blobs from sklearn.ensemble import RandomForestClassifier from sklearn.utils.metaestimators import if_delegate_has_method - +from sklearn.inspection import plot_decision_boundary N_SAMPLES = 5000 RANDOM_STATE = 42 @@ -101,20 +101,13 @@ def plot_scatter(X, color, alpha=0.5): probable_clusters = inductive_learner.predict(X_new) -plt.subplot(133) +ax = plt.subplot(133) plot_scatter(X, cluster_labels) plot_scatter(X_new, probable_clusters) # Plotting decision regions -x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1 -y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1 -xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.1), - np.arange(y_min, y_max, 0.1)) - -Z = inductive_learner.predict(np.c_[xx.ravel(), yy.ravel()]) -Z = Z.reshape(xx.shape) - -plt.contourf(xx, yy, Z, alpha=0.4) +plot_decision_boundary(inductive_learner, X, response_method='predict', + alpha=0.4, ax=ax) plt.title("Classify unknown instances") plt.show() diff --git a/examples/ensemble/plot_adaboost_twoclass.py b/examples/ensemble/plot_adaboost_twoclass.py index edb4cbb1a97b3..296390abb76d7 100644 --- a/examples/ensemble/plot_adaboost_twoclass.py +++ b/examples/ensemble/plot_adaboost_twoclass.py @@ -28,6 +28,7 @@ from sklearn.ensemble import AdaBoostClassifier from sklearn.tree import DecisionTreeClassifier from sklearn.datasets import make_gaussian_quantiles +from sklearn.inspection import plot_decision_boundary # Construct dataset @@ -54,15 +55,11 @@ plt.figure(figsize=(10, 5)) # Plot the decision boundaries -plt.subplot(121) -x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1 -y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1 -xx, yy = np.meshgrid(np.arange(x_min, x_max, plot_step), - np.arange(y_min, y_max, plot_step)) - -Z = bdt.predict(np.c_[xx.ravel(), yy.ravel()]) -Z = Z.reshape(xx.shape) -cs = plt.contourf(xx, yy, Z, cmap=plt.cm.Paired) +ax = plt.subplot(121) +disp = plot_decision_boundary(bdt, X, cmap=plt.cm.Paired, + response_method='predict', ax=ax) +x_min, x_max = disp.xx0.min(), disp.xx0.max() +y_min, y_max = disp.xx1.min(), disp.xx1.max() plt.axis("tight") # Plot the training points diff --git a/examples/ensemble/plot_voting_decision_regions.py b/examples/ensemble/plot_voting_decision_regions.py index fdfda74947f5f..8684494d2d763 100644 --- a/examples/ensemble/plot_voting_decision_regions.py +++ b/examples/ensemble/plot_voting_decision_regions.py @@ -26,7 +26,6 @@ from itertools import product -import numpy as np import matplotlib.pyplot as plt from sklearn import datasets @@ -34,6 +33,7 @@ from sklearn.neighbors import KNeighborsClassifier from sklearn.svm import SVC from sklearn.ensemble import VotingClassifier +from sklearn.inspection import plot_decision_boundary # Loading some example data iris = datasets.load_iris() @@ -54,22 +54,13 @@ eclf.fit(X, y) # Plotting decision regions -x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1 -y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1 -xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.1), - np.arange(y_min, y_max, 0.1)) - f, axarr = plt.subplots(2, 2, sharex='col', sharey='row', figsize=(10, 8)) - for idx, clf, tt in zip(product([0, 1], [0, 1]), [clf1, clf2, clf3, eclf], ['Decision Tree (depth=4)', 'KNN (k=7)', 'Kernel SVM', 'Soft Voting']): - - Z = clf.predict(np.c_[xx.ravel(), yy.ravel()]) - Z = Z.reshape(xx.shape) - - axarr[idx[0], idx[1]].contourf(xx, yy, Z, alpha=0.4) + plot_decision_boundary(clf, X, alpha=0.4, ax=axarr[idx[0], idx[1]], + response_method='predict') axarr[idx[0], idx[1]].scatter(X[:, 0], X[:, 1], c=y, s=20, edgecolor='k') axarr[idx[0], idx[1]].set_title(tt) diff --git a/examples/linear_model/plot_iris_logistic.py b/examples/linear_model/plot_iris_logistic.py index 2c4fd48d62ff3..2c73fc21303de 100644 --- a/examples/linear_model/plot_iris_logistic.py +++ b/examples/linear_model/plot_iris_logistic.py @@ -18,10 +18,10 @@ # Modified for documentation by Jaques Grobler # License: BSD 3 clause -import numpy as np import matplotlib.pyplot as plt from sklearn.linear_model import LogisticRegression from sklearn import datasets +from sklearn.inspection import plot_decision_boundary # import some data to play with iris = datasets.load_iris() @@ -33,26 +33,15 @@ # Create an instance of Logistic Regression Classifier and fit the data. logreg.fit(X, Y) -# Plot the decision boundary. For that, we will assign a color to each -# point in the mesh [x_min, x_max]x[y_min, y_max]. -x_min, x_max = X[:, 0].min() - .5, X[:, 0].max() + .5 -y_min, y_max = X[:, 1].min() - .5, X[:, 1].max() + .5 -h = .02 # step size in the mesh -xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h)) -Z = logreg.predict(np.c_[xx.ravel(), yy.ravel()]) - -# Put the result into a color plot -Z = Z.reshape(xx.shape) -plt.figure(1, figsize=(4, 3)) -plt.pcolormesh(xx, yy, Z, cmap=plt.cm.Paired) +_, ax = plt.subplots(figsize=(4, 3)) +plot_decision_boundary(logreg, X, cmap=plt.cm.Paired, ax=ax, + response_method='predict', plot_method='pcolormesh') # Plot also the training points plt.scatter(X[:, 0], X[:, 1], c=Y, edgecolors='k', cmap=plt.cm.Paired) plt.xlabel('Sepal length') plt.ylabel('Sepal width') -plt.xlim(xx.min(), xx.max()) -plt.ylim(yy.min(), yy.max()) plt.xticks(()) plt.yticks(()) diff --git a/examples/linear_model/plot_logistic_multinomial.py b/examples/linear_model/plot_logistic_multinomial.py index 518a2aeade61c..6ecc23fc73c0c 100644 --- a/examples/linear_model/plot_logistic_multinomial.py +++ b/examples/linear_model/plot_logistic_multinomial.py @@ -15,6 +15,7 @@ import matplotlib.pyplot as plt from sklearn.datasets import make_blobs from sklearn.linear_model import LogisticRegression +from sklearn.inspection import plot_decision_boundary # make 3-class dataset for classification centers = [[-5, 0], [0, 1.5], [5, -1]] @@ -29,20 +30,9 @@ # print the training scores print("training score : %.3f (%s)" % (clf.score(X, y), multi_class)) - # create a mesh to plot in - h = .02 # step size in the mesh - x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1 - y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1 - xx, yy = np.meshgrid(np.arange(x_min, x_max, h), - np.arange(y_min, y_max, h)) - - # Plot the decision boundary. For that, we will assign a color to each - # point in the mesh [x_min, x_max]x[y_min, y_max]. - Z = clf.predict(np.c_[xx.ravel(), yy.ravel()]) - # Put the result into a color plot - Z = Z.reshape(xx.shape) - plt.figure() - plt.contourf(xx, yy, Z, cmap=plt.cm.Paired) + _, ax = plt.subplots() + plot_decision_boundary(clf, X, response_method='predict', + cmap=plt.cm.Paired, ax=ax) plt.title("Decision surface of LogisticRegression (%s)" % multi_class) plt.axis('tight') diff --git a/examples/linear_model/plot_sgd_iris.py b/examples/linear_model/plot_sgd_iris.py index 0dddf7475728d..e5c3e5b8b7aab 100644 --- a/examples/linear_model/plot_sgd_iris.py +++ b/examples/linear_model/plot_sgd_iris.py @@ -14,6 +14,7 @@ import matplotlib.pyplot as plt from sklearn import datasets from sklearn.linear_model import SGDClassifier +from sklearn.inspection import plot_decision_boundary # import some data to play with iris = datasets.load_iris() @@ -36,22 +37,9 @@ std = X.std(axis=0) X = (X - mean) / std -h = .02 # step size in the mesh - clf = SGDClassifier(alpha=0.001, max_iter=100).fit(X, y) - -# create a mesh to plot in -x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1 -y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1 -xx, yy = np.meshgrid(np.arange(x_min, x_max, h), - np.arange(y_min, y_max, h)) - -# Plot the decision boundary. For that, we will assign a color to each -# point in the mesh [x_min, x_max]x[y_min, y_max]. -Z = clf.predict(np.c_[xx.ravel(), yy.ravel()]) -# Put the result into a color plot -Z = Z.reshape(xx.shape) -cs = plt.contourf(xx, yy, Z, cmap=plt.cm.Paired) +ax = plt.gca() +plot_decision_boundary(clf, X, cmap=plt.cm.Paired, ax=ax) plt.axis('tight') # Plot also the training points diff --git a/examples/neighbors/plot_classification.py b/examples/neighbors/plot_classification.py index 14cbb732df1a7..aba6583df894f 100644 --- a/examples/neighbors/plot_classification.py +++ b/examples/neighbors/plot_classification.py @@ -8,10 +8,10 @@ """ print(__doc__) -import numpy as np import matplotlib.pyplot as plt from matplotlib.colors import ListedColormap from sklearn import neighbors, datasets +from sklearn.inspection import plot_decision_boundary n_neighbors = 15 @@ -23,8 +23,6 @@ X = iris.data[:, :2] y = iris.target -h = .02 # step size in the mesh - # Create color maps cmap_light = ListedColormap(['orange', 'cyan', 'cornflowerblue']) cmap_bold = ListedColormap(['darkorange', 'c', 'darkblue']) @@ -34,24 +32,13 @@ clf = neighbors.KNeighborsClassifier(n_neighbors, weights=weights) clf.fit(X, y) - # Plot the decision boundary. For that, we will assign a color to each - # point in the mesh [x_min, x_max]x[y_min, y_max]. - x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1 - y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1 - xx, yy = np.meshgrid(np.arange(x_min, x_max, h), - np.arange(y_min, y_max, h)) - Z = clf.predict(np.c_[xx.ravel(), yy.ravel()]) - - # Put the result into a color plot - Z = Z.reshape(xx.shape) - plt.figure() - plt.pcolormesh(xx, yy, Z, cmap=cmap_light) + _, ax = plt.subplots() + plot_decision_boundary(clf, X, cmap=cmap_light, ax=ax, + response_method='predict', plot_method='pcolormesh') # Plot also the training points plt.scatter(X[:, 0], X[:, 1], c=y, cmap=cmap_bold, edgecolor='k', s=20) - plt.xlim(xx.min(), xx.max()) - plt.ylim(yy.min(), yy.max()) plt.title("3-Class classification (k = %i, weights = '%s')" % (n_neighbors, weights)) diff --git a/examples/neighbors/plot_nca_classification.py b/examples/neighbors/plot_nca_classification.py index 5536e8eb69e89..58380ba4839d7 100644 --- a/examples/neighbors/plot_nca_classification.py +++ b/examples/neighbors/plot_nca_classification.py @@ -25,6 +25,7 @@ from sklearn.neighbors import (KNeighborsClassifier, NeighborhoodComponentsAnalysis) from sklearn.pipeline import Pipeline +from sklearn.inspection import plot_decision_boundary print(__doc__) @@ -58,29 +59,17 @@ ]) ] -x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1 -y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1 -xx, yy = np.meshgrid(np.arange(x_min, x_max, h), - np.arange(y_min, y_max, h)) - for name, clf in zip(names, classifiers): clf.fit(X_train, y_train) score = clf.score(X_test, y_test) - # Plot the decision boundary. For that, we will assign a color to each - # point in the mesh [x_min, x_max]x[y_min, y_max]. - Z = clf.predict(np.c_[xx.ravel(), yy.ravel()]) - - # Put the result into a color plot - Z = Z.reshape(xx.shape) - plt.figure() - plt.pcolormesh(xx, yy, Z, cmap=cmap_light, alpha=.8) + _, ax = plt.subplots() + plot_decision_boundary(clf, X, cmap=cmap_light, alpha=0.8, ax=ax, + response_method='predict', plot_method='pcolormesh') # Plot also the training and testing points plt.scatter(X[:, 0], X[:, 1], c=y, cmap=cmap_bold, edgecolor='k', s=20) - plt.xlim(xx.min(), xx.max()) - plt.ylim(yy.min(), yy.max()) plt.title("{} (k = {})".format(name, n_neighbors)) plt.text(0.9, 0.1, '{:.2f}'.format(score), size=15, ha='center', va='center', transform=plt.gca().transAxes) diff --git a/examples/neighbors/plot_nearest_centroid.py b/examples/neighbors/plot_nearest_centroid.py index 04a105c0e07fd..b21cdfda7da00 100644 --- a/examples/neighbors/plot_nearest_centroid.py +++ b/examples/neighbors/plot_nearest_centroid.py @@ -13,6 +13,7 @@ from matplotlib.colors import ListedColormap from sklearn import datasets from sklearn.neighbors import NearestCentroid +from sklearn.inspection import plot_decision_boundary n_neighbors = 15 @@ -23,8 +24,6 @@ X = iris.data[:, :2] y = iris.target -h = .02 # step size in the mesh - # Create color maps cmap_light = ListedColormap(['orange', 'cyan', 'cornflowerblue']) cmap_bold = ListedColormap(['darkorange', 'c', 'darkblue']) @@ -35,18 +34,10 @@ clf.fit(X, y) y_pred = clf.predict(X) print(shrinkage, np.mean(y == y_pred)) - # Plot the decision boundary. For that, we will assign a color to each - # point in the mesh [x_min, x_max]x[y_min, y_max]. - x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1 - y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1 - xx, yy = np.meshgrid(np.arange(x_min, x_max, h), - np.arange(y_min, y_max, h)) - Z = clf.predict(np.c_[xx.ravel(), yy.ravel()]) - - # Put the result into a color plot - Z = Z.reshape(xx.shape) - plt.figure() - plt.pcolormesh(xx, yy, Z, cmap=cmap_light) + + _, ax = plt.subplots() + plot_decision_boundary(clf, X, cmap=cmap_light, ax=ax, + response_method='predict') # Plot also the training points plt.scatter(X[:, 0], X[:, 1], c=y, cmap=cmap_bold, diff --git a/examples/semi_supervised/plot_label_propagation_versus_svm_iris.py b/examples/semi_supervised/plot_label_propagation_versus_svm_iris.py index caabc46cb0cc1..e7af77b332333 100644 --- a/examples/semi_supervised/plot_label_propagation_versus_svm_iris.py +++ b/examples/semi_supervised/plot_label_propagation_versus_svm_iris.py @@ -20,6 +20,7 @@ from sklearn import datasets from sklearn import svm from sklearn.semi_supervised import LabelSpreading +from sklearn.inspection import plot_decision_boundary rng = np.random.RandomState(0) @@ -28,9 +29,6 @@ X = iris.data[:, :2] y = iris.target -# step size in the mesh -h = .02 - y_30 = np.copy(y) y_30[rng.rand(len(y)) < 0.3] = -1 y_50 = np.copy(y) @@ -42,12 +40,6 @@ ls100 = (LabelSpreading().fit(X, y), y) rbf_svc = (svm.SVC(kernel='rbf', gamma=.5).fit(X, y), y) -# create a mesh to plot in -x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1 -y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1 -xx, yy = np.meshgrid(np.arange(x_min, x_max, h), - np.arange(y_min, y_max, h)) - # title for the plots titles = ['Label Spreading 30% data', 'Label Spreading 50% data', @@ -59,12 +51,10 @@ for i, (clf, y_train) in enumerate((ls30, ls50, ls100, rbf_svc)): # Plot the decision boundary. For that, we will assign a color to each # point in the mesh [x_min, x_max]x[y_min, y_max]. - plt.subplot(2, 2, i + 1) - Z = clf.predict(np.c_[xx.ravel(), yy.ravel()]) + ax = plt.subplot(2, 2, i + 1) + plot_decision_boundary(clf, X, cmap=plt.cm.Paired, ax=ax, + response_method='predict') - # Put the result into a color plot - Z = Z.reshape(xx.shape) - plt.contourf(xx, yy, Z, cmap=plt.cm.Paired) plt.axis('off') # Plot also the training points diff --git a/examples/svm/plot_custom_kernel.py b/examples/svm/plot_custom_kernel.py index 28641cd35f8cb..f5aaaec1eb7ac 100644 --- a/examples/svm/plot_custom_kernel.py +++ b/examples/svm/plot_custom_kernel.py @@ -12,6 +12,7 @@ import numpy as np import matplotlib.pyplot as plt from sklearn import svm, datasets +from sklearn.inspection import plot_decision_boundary # import some data to play with iris = datasets.load_iris() @@ -38,16 +39,9 @@ def my_kernel(X, Y): clf = svm.SVC(kernel=my_kernel) clf.fit(X, Y) -# Plot the decision boundary. For that, we will assign a color to each -# point in the mesh [x_min, x_max]x[y_min, y_max]. -x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1 -y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1 -xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h)) -Z = clf.predict(np.c_[xx.ravel(), yy.ravel()]) - -# Put the result into a color plot -Z = Z.reshape(xx.shape) -plt.pcolormesh(xx, yy, Z, cmap=plt.cm.Paired) +ax = plt.gca() +plot_decision_boundary(clf, X, cmap=plt.cm.Paired, ax=ax, + response_method='predict', plot_method='pcolormesh') # Plot also the training points plt.scatter(X[:, 0], X[:, 1], c=Y, cmap=plt.cm.Paired, edgecolors='k') diff --git a/examples/svm/plot_linearsvc_support_vectors.py b/examples/svm/plot_linearsvc_support_vectors.py index e2737d47033e6..c8ab2a21f62de 100644 --- a/examples/svm/plot_linearsvc_support_vectors.py +++ b/examples/svm/plot_linearsvc_support_vectors.py @@ -13,6 +13,7 @@ import matplotlib.pyplot as plt from sklearn.datasets import make_blobs from sklearn.svm import LinearSVC +from sklearn.inspection import plot_decision_boundary X, y = make_blobs(n_samples=40, centers=2, random_state=0) @@ -30,14 +31,10 @@ plt.subplot(1, 2, i + 1) plt.scatter(X[:, 0], X[:, 1], c=y, s=30, cmap=plt.cm.Paired) ax = plt.gca() - xlim = ax.get_xlim() - ylim = ax.get_ylim() - xx, yy = np.meshgrid(np.linspace(xlim[0], xlim[1], 50), - np.linspace(ylim[0], ylim[1], 50)) - Z = clf.decision_function(np.c_[xx.ravel(), yy.ravel()]) - Z = Z.reshape(xx.shape) - plt.contour(xx, yy, Z, colors='k', levels=[-1, 0, 1], alpha=0.5, - linestyles=['--', '-', '--']) + plot_decision_boundary(clf, X, ax=ax, grid_resolution=50, + plot_method='contour', + colors='k', levels=[-1, 0, 1], alpha=0.5, + linestyles=['--', '-', '--']) plt.scatter(support_vectors[:, 0], support_vectors[:, 1], s=100, linewidth=1, facecolors='none', edgecolors='k') plt.title("C=" + str(C)) diff --git a/examples/svm/plot_separating_hyperplane.py b/examples/svm/plot_separating_hyperplane.py index cbd61abad53e6..becb5599a9c63 100644 --- a/examples/svm/plot_separating_hyperplane.py +++ b/examples/svm/plot_separating_hyperplane.py @@ -9,10 +9,10 @@ """ print(__doc__) -import numpy as np import matplotlib.pyplot as plt from sklearn import svm from sklearn.datasets import make_blobs +from sklearn.inspection import plot_decision_boundary # we create 40 separable points @@ -26,19 +26,9 @@ # plot the decision function ax = plt.gca() -xlim = ax.get_xlim() -ylim = ax.get_ylim() - -# create grid to evaluate model -xx = np.linspace(xlim[0], xlim[1], 30) -yy = np.linspace(ylim[0], ylim[1], 30) -YY, XX = np.meshgrid(yy, xx) -xy = np.vstack([XX.ravel(), YY.ravel()]).T -Z = clf.decision_function(xy).reshape(XX.shape) - -# plot decision boundary and margins -ax.contour(XX, YY, Z, colors='k', levels=[-1, 0, 1], alpha=0.5, - linestyles=['--', '-', '--']) +plot_decision_boundary(clf, X, plot_method='contour', + colors='k', levels=[-1, 0, 1], alpha=0.5, + linestyles=['--', '-', '--'], ax=ax) # plot support vectors ax.scatter(clf.support_vectors_[:, 0], clf.support_vectors_[:, 1], s=100, linewidth=1, facecolors='none', edgecolors='k') diff --git a/examples/svm/plot_separating_hyperplane_unbalanced.py b/examples/svm/plot_separating_hyperplane_unbalanced.py index 2a0540fead310..039fc5d7341c5 100644 --- a/examples/svm/plot_separating_hyperplane_unbalanced.py +++ b/examples/svm/plot_separating_hyperplane_unbalanced.py @@ -26,10 +26,10 @@ """ print(__doc__) -import numpy as np import matplotlib.pyplot as plt from sklearn import svm from sklearn.datasets import make_blobs +from sklearn.inspection import plot_decision_boundary # we create two clusters of random points n_samples_1 = 1000 @@ -54,27 +54,14 @@ # plot the decision functions for both classifiers ax = plt.gca() -xlim = ax.get_xlim() -ylim = ax.get_ylim() - -# create grid to evaluate model -xx = np.linspace(xlim[0], xlim[1], 30) -yy = np.linspace(ylim[0], ylim[1], 30) -YY, XX = np.meshgrid(yy, xx) -xy = np.vstack([XX.ravel(), YY.ravel()]).T - -# get the separating hyperplane -Z = clf.decision_function(xy).reshape(XX.shape) - -# plot decision boundary and margins -a = ax.contour(XX, YY, Z, colors='k', levels=[0], alpha=0.5, linestyles=['-']) - -# get the separating hyperplane for weighted classes -Z = wclf.decision_function(xy).reshape(XX.shape) +disp = plot_decision_boundary(clf, X, plot_method='contour', colors='k', + levels=[0], alpha=0.5, linestyles=['-'], ax=ax) # plot decision boundary and margins for weighted classes -b = ax.contour(XX, YY, Z, colors='r', levels=[0], alpha=0.5, linestyles=['-']) +wdisp = plot_decision_boundary(wclf, X, plot_method='contour', colors='r', + levels=[0], alpha=0.5, linestyles=['-'], ax=ax) -plt.legend([a.collections[0], b.collections[0]], ["non weighted", "weighted"], +plt.legend([disp.surface_.collections[0], wdisp.surface_.collections[0]], + ["non weighted", "weighted"], loc="upper right") plt.show() diff --git a/examples/tree/plot_iris_dtc.py b/examples/tree/plot_iris_dtc.py index 60328c4f90d4f..1cf9f718e9413 100644 --- a/examples/tree/plot_iris_dtc.py +++ b/examples/tree/plot_iris_dtc.py @@ -21,6 +21,7 @@ from sklearn.datasets import load_iris from sklearn.tree import DecisionTreeClassifier, plot_tree +from sklearn.inspection import plot_decision_boundary # Parameters n_classes = 3 @@ -40,17 +41,10 @@ clf = DecisionTreeClassifier().fit(X, y) # Plot the decision boundary - plt.subplot(2, 3, pairidx + 1) - - x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1 - y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1 - xx, yy = np.meshgrid(np.arange(x_min, x_max, plot_step), - np.arange(y_min, y_max, plot_step)) + ax = plt.subplot(2, 3, pairidx + 1) plt.tight_layout(h_pad=0.5, w_pad=0.5, pad=2.5) - - Z = clf.predict(np.c_[xx.ravel(), yy.ravel()]) - Z = Z.reshape(xx.shape) - cs = plt.contourf(xx, yy, Z, cmap=plt.cm.RdYlBu) + plot_decision_boundary(clf, X, cmap=plt.cm.RdYlBu, + response_method='predict', ax=ax) plt.xlabel(iris.feature_names[pair[0]]) plt.ylabel(iris.feature_names[pair[1]]) diff --git a/sklearn/inspection/__init__.py b/sklearn/inspection/__init__.py index d843ce18dcbfe..19982e56ed953 100644 --- a/sklearn/inspection/__init__.py +++ b/sklearn/inspection/__init__.py @@ -15,6 +15,7 @@ from ._partial_dependence import plot_partial_dependence # noqa from ._partial_dependence import PartialDependenceDisplay # noqa from ._permutation_importance import permutation_importance # noqa +from ._plot.decision_boundary import DecisionBoundaryDisplay # noqa from ._plot.decision_boundary import plot_decision_boundary # noqa @@ -23,5 +24,6 @@ 'plot_partial_dependence', 'permutation_importance', 'PartialDependenceDisplay', - 'plot_decision_boundary' + 'plot_decision_boundary', + 'DecisionBoundaryDisplay' ] diff --git a/sklearn/inspection/_plot/decision_boundary.py b/sklearn/inspection/_plot/decision_boundary.py index 83acd4511bb6f..9d7222314b83c 100644 --- a/sklearn/inspection/_plot/decision_boundary.py +++ b/sklearn/inspection/_plot/decision_boundary.py @@ -2,46 +2,192 @@ from ...utils import check_matplotlib_support +def _check_boundary_response_method(estimator, response_method): + """Return prediction method from the response_method for decision boundary + + Parameters + ---------- + estimator: object + Estimator to check + + response_method: {'auto', 'predict_proba', 'decision_function', 'predict'} + Specifies whether to use :term:`predict_proba` or + :term:`decision_function` as the target response. If set to 'auto', + :term:`predict_proba` is tried first and if it does not exist + :term:`decision_function` is tried next. + + Returns + ------- + prediction_method: callable + prediction method of estimator + """ + + if response_method not in ("predict_proba", "decision_function", + "auto", "predict"): + raise ValueError("response_method must be 'predict_proba', " + "'decision_function', 'predict', or 'auto'") + + error_msg = "response method {} is not defined in {}" + if response_method != "auto": + if not hasattr(estimator, response_method): + raise ValueError(error_msg.format(response_method, + estimator.__class__.__name__)) + return getattr(estimator, response_method) + elif hasattr(estimator, 'predict_proba'): + return getattr(estimator, 'predict_proba') + elif hasattr(estimator, 'decision_function'): + return getattr(estimator, 'decision_function') + + raise ValueError(error_msg.format( + "decision_function or predict_proba", + estimator.__class__.__name__)) + + class DecisionBoundaryDisplay: + """Decisions Boundary visualization. + + It is recommend to use :func:`~sklearn.inspection.plot_decision_boundary` to + create a :class:`DecisionBoundaryDisplay`. All parameters are stored as + attributes. + + Read more in the :ref:`User Guide `. + + Parameters + ---------- + xx0 : ndarray of shape (grid_resolution, grid_resolution) + First output of meshgrid. + + xx1 : ndarray of shape (grid_resolution, grid_resolution) + Second output of meshgrid. + + response : ndarray of shape (grid_resolution, grid_resolution) + Values of the response function. + + Attributes + ---------- + surface_ : matplotlib `QuadContourSet` or `QuadMesh` + If `plot_method` is 'contour' or 'contourf', `surface_` is a + `QuadContourSet`. If `plot_method is `pcolormesh`, `surface_` is a + `QuadMesh`. + + ax_ : matplotlib Axes + Axes with confusion matrix. + + figure_ : matplotlib Figure + Figure containing the confusion matrix. + """ def __init__(self, xx0, xx1, response): self.xx0 = xx0 self.xx1 = xx1 self.response = response - def plot(self, ax=None, plot_method='contourf', **kwargs): + def plot(self, plot_method='contourf', ax=None, **kwargs): + """Plot visualization. + + Parameters + ---------- + plot_method : {'contourf', 'contour', 'pcolormesh'}, default='contourf' + Plotting method to call when plotting the response. + + ax : Matplotlib axes, default=None + Axes object to plot on. If `None`, a new figure and axes is + created. + + **kwargs : dict + Additional keyword arguments to be pased to the `plot_method`. + + Returns + ------- + display: :class:`~sklearn.inspection.DecisionBoundaryDisplay` + """ check_matplotlib_support('DecisionBoundaryDisplay.plot') import matplotlib.pyplot as plt # noqa + if plot_method not in ('contourf', 'contour', 'pcolormesh'): + raise ValueError("plot_method must be 'contourf', 'contour', or " + "'pcolormesh'") + if ax is None: _, ax = plt.subplots() plot_func = getattr(ax, plot_method) - self.surface_ = plot_func(self.xx0, self.xx1, - self.response, **kwargs) + self.surface_ = plot_func(self.xx0, self.xx1, self.response, **kwargs) self.ax_ = ax self.figure_ = ax.figure return self -def plot_decision_boundary(est, X, grid_resolution=100, - features=(0, 1), - response_method='decision_function', +def plot_decision_boundary(estimator, X, grid_resolution=100, eps=1.0, plot_method='contourf', - ax=None, - **kwargs): + response_method='auto', + ax=None, **kwargs): + """Plot Decision Boundary. + + Parameters + ---------- + estimator : estimator instance + Trained estimator. + + X : array-like of shape (n_samples, 2) + Input values. + + grid_resolution : int, default=100 + The number of equally spaced points to evaluate the response function. + + eps : float, default=1.0 + Extends the minimum and maximum values of X for evaluating the + response function. + + plot_method : {'contourf', 'contour', 'pcolormesh'}, default='contourf' + Plotting method to call when plotting the response. + + response_method : {'auto', 'predict_proba', 'decision_function', \ + 'predict'}, defaul='auto' + Specifies whether to use :term:`predict_proba` or + :term:`decision_function` as the target response. If set to 'auto', + :term:`predict_proba` is tried first and if it does not exist + :term:`decision_function` is tried next. + + ax : Matplotlib axes, default=None + Axes object to plot on. If `None`, a new figure and axes is + created. + + **kwargs : dict + Additional keyword arguments to be pased to the `plot_method`. + + Returns + ------- + display: :class:`~sklearn.inspection.DecisionBoundaryDisplay` + """ check_matplotlib_support('plot_decision_boundary') - x0, x1 = X[:, features[0]], X[:, features[1]] + if not grid_resolution > 1: + raise ValueError("grid_resolution must be greater than 1") + + if not eps >= 0: + raise ValueError("eps must be greater than or equal to 0") + + if plot_method not in ('contourf', 'contour', 'pcolormesh'): + raise ValueError("plot_method must be 'contourf', 'contour', or " + "'pcolormesh'") - x0_min, x0_max = x0.min() - 1, x0.max() + 1 - x1_min, x1_max = x1.min() - 1, x1.max() + 1 + pred_func = _check_boundary_response_method(estimator, response_method) + + x0, x1 = X[:, 0], X[:, 1] + + x0_min, x0_max = x0.min() - eps, x0.max() + eps + x1_min, x1_max = x1.min() - eps, x1.max() + eps xx0, xx1 = np.meshgrid(np.linspace(x0_min, x0_max, grid_resolution), np.linspace(x1_min, x1_max, grid_resolution)) + response = pred_func(np.c_[xx0.ravel(), xx1.ravel()]) + + if response.ndim != 1: + if response.shape[1] != 2: + raise ValueError("multiclass classifers are only supported when " + "response_method='predict'") + response = response[:, 1] - response_func = getattr(est, response_method) - response = response_func(np.c_[xx0.ravel(), xx1.ravel()]) display = DecisionBoundaryDisplay(xx0=xx0, xx1=xx1, response=response.reshape(xx0.shape)) - return display.plot(ax=ax, plot_method=plot_method, **kwargs) diff --git a/sklearn/inspection/_plot/tests/test_plot_decision_boundary.py b/sklearn/inspection/_plot/tests/test_plot_decision_boundary.py index e69de29bb2d1d..9e36151203ea9 100644 --- a/sklearn/inspection/_plot/tests/test_plot_decision_boundary.py +++ b/sklearn/inspection/_plot/tests/test_plot_decision_boundary.py @@ -0,0 +1,119 @@ +import pytest + +from sklearn.base import BaseEstimator +from sklearn.base import ClassifierMixin +from sklearn.inspection import plot_decision_boundary +from sklearn.datasets import make_classification +from sklearn.linear_model import LogisticRegression + + +# TODO: Remove when https://github.com/numpy/numpy/issues/14397 is resolved +pytestmark = pytest.mark.filterwarnings( + "ignore:In future, it will be an error for 'np.bool_':DeprecationWarning:" + "matplotlib.*") + + +@pytest.fixture(scope="module") +def data(): + X, y = make_classification(n_informative=1, n_redundant=1, + n_clusters_per_class=1, n_features=2) + return X, y + + +@pytest.fixture(scope="module") +def fitted_clf(data): + return LogisticRegression().fit(*data) + + +@pytest.mark.parametrize("response_method", + ['auto', 'predict_proba', 'decision_function']) +def test_multiclass_error(pyplot, response_method): + X, y = make_classification(n_classes=3, n_informative=3, random_state=0) + X = X[:, [0, 1]] + lr = LogisticRegression().fit(X, y) + + msg = ("multiclass classifers are only supported when " + "response_method='predict'") + with pytest.raises(ValueError, match=msg): + plot_decision_boundary(lr, X, response_method=response_method) + + +@pytest.mark.parametrize("kwargs, error_msg", [ + ({"plot_method": "hello_world"}, + r"plot_method must be 'contourf',"), + ({"grid_resolution": 1}, + r"grid_resolution must be greater than 1"), + ({"grid_resolution": -1}, + r"grid_resolution must be greater than 1"), + ({"eps": -1.1}, + r"eps must be greater than or equal to 0") +]) +def test_input_validation_errors(pyplot, kwargs, error_msg, fitted_clf, data): + X, _ = data + with pytest.raises(ValueError, match=error_msg): + plot_decision_boundary(fitted_clf, X, **kwargs) + + +def test_display_plot_input_error(pyplot, fitted_clf, data): + X, y = data + disp = plot_decision_boundary(fitted_clf, X, grid_resolution=5) + + with pytest.raises(ValueError, match="plot_method must be 'contourf'"): + disp.plot(plot_method="hello_world") + + +@pytest.mark.parametrize("response_method", ['auto', 'predict', + 'predict_proba', + 'decision_function']) +@pytest.mark.parametrize("plot_method", ['contourf', 'contour']) +def test_plot_decision_boundary(pyplot, fitted_clf, data, + response_method, plot_method): + fig, ax = pyplot.subplots() + eps = 2.0 + X, y = data + disp = plot_decision_boundary(fitted_clf, X, grid_resolution=5, + response_method=response_method, + plot_method=plot_method, + eps=eps, ax=ax) + assert isinstance(disp.surface_, pyplot.matplotlib.contour.QuadContourSet) + assert disp.ax_ == ax + assert disp.figure_ == fig + + x0, x1 = X[:, 0], X[:, 1] + + x0_min, x0_max = x0.min() - eps, x0.max() + eps + x1_min, x1_max = x1.min() - eps, x1.max() + eps + + assert disp.xx0.min() == pytest.approx(x0_min) + assert disp.xx0.max() == pytest.approx(x0_max) + assert disp.xx1.min() == pytest.approx(x1_min) + assert disp.xx1.max() == pytest.approx(x1_max) + + # change plotting method for second plot + disp.plot(plot_method='pcolormesh') + assert isinstance(disp.surface_, pyplot.matplotlib.collections.QuadMesh) + + +@pytest.mark.parametrize( + "response_method, msg", + [("predict_proba", "response method predict_proba is not defined in " + "MyClassifier"), + ("decision_function", "response method decision_function is not defined " + "in MyClassifier"), + ("auto", "response method decision_function or predict_proba is not " + "defined in MyClassifier"), + ("bad_method", "response_method must be 'predict_proba', " + "'decision_function', 'predict', or 'auto'")]) +def test_error_bad_response(pyplot, response_method, msg): + X, y = make_classification(n_classes=2, n_samples=50, random_state=0) + + class MyClassifier(BaseEstimator, ClassifierMixin): + def fit(self, X, y): + self.fitted_ = True + self.classes_ = [0, 1] + return self + + clf = MyClassifier().fit(X, y) + + with pytest.raises(ValueError, match=msg): + plot_decision_boundary(clf, X, response_method=response_method) From ed1cae079cee16490cfdacff4803f63a5a83db26 Mon Sep 17 00:00:00 2001 From: Thomas J Fan Date: Mon, 6 Jan 2020 15:25:15 -0500 Subject: [PATCH 03/48] TST Adds more tests --- .../inspection/_plot/tests/test_plot_decision_boundary.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sklearn/inspection/_plot/tests/test_plot_decision_boundary.py b/sklearn/inspection/_plot/tests/test_plot_decision_boundary.py index 9e36151203ea9..b7dc26a3626be 100644 --- a/sklearn/inspection/_plot/tests/test_plot_decision_boundary.py +++ b/sklearn/inspection/_plot/tests/test_plot_decision_boundary.py @@ -89,9 +89,12 @@ def test_plot_decision_boundary(pyplot, fitted_clf, data, assert disp.xx1.min() == pytest.approx(x1_min) assert disp.xx1.max() == pytest.approx(x1_max) + fig2, ax2 = pyplot.subplots() # change plotting method for second plot - disp.plot(plot_method='pcolormesh') + disp.plot(plot_method='pcolormesh', ax=ax2) assert isinstance(disp.surface_, pyplot.matplotlib.collections.QuadMesh) + assert disp.ax_ == ax2 + assert disp.figure_ == fig2 @pytest.mark.parametrize( From c7ee84e984c629f71532f66f2ef8c7209d2258e8 Mon Sep 17 00:00:00 2001 From: Thomas J Fan Date: Mon, 6 Jan 2020 17:20:37 -0500 Subject: [PATCH 04/48] STY Linting --- sklearn/inspection/_plot/decision_boundary.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/inspection/_plot/decision_boundary.py b/sklearn/inspection/_plot/decision_boundary.py index 9d7222314b83c..613769556c046 100644 --- a/sklearn/inspection/_plot/decision_boundary.py +++ b/sklearn/inspection/_plot/decision_boundary.py @@ -46,8 +46,8 @@ def _check_boundary_response_method(estimator, response_method): class DecisionBoundaryDisplay: """Decisions Boundary visualization. - It is recommend to use :func:`~sklearn.inspection.plot_decision_boundary` to - create a :class:`DecisionBoundaryDisplay`. All parameters are stored as + It is recommend to use :func:`~sklearn.inspection.plot_decision_boundary` + to create a :class:`DecisionBoundaryDisplay`. All parameters are stored as attributes. Read more in the :ref:`User Guide `. From a8d4e1162aa67fe6fd4fd470ec684622276dff39 Mon Sep 17 00:00:00 2001 From: Thomas J Fan Date: Wed, 8 Jan 2020 16:46:09 -0500 Subject: [PATCH 05/48] BUG Fix --- examples/linear_model/plot_sgd_iris.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/linear_model/plot_sgd_iris.py b/examples/linear_model/plot_sgd_iris.py index e5c3e5b8b7aab..35206d83b74d5 100644 --- a/examples/linear_model/plot_sgd_iris.py +++ b/examples/linear_model/plot_sgd_iris.py @@ -39,7 +39,8 @@ clf = SGDClassifier(alpha=0.001, max_iter=100).fit(X, y) ax = plt.gca() -plot_decision_boundary(clf, X, cmap=plt.cm.Paired, ax=ax) +plot_decision_boundary(clf, X, cmap=plt.cm.Paired, ax=ax, + response_method='predict') plt.axis('tight') # Plot also the training points From 3c633d6a7f81ba7b0443a4dcce3ef469782130ea Mon Sep 17 00:00:00 2001 From: Thomas J Fan Date: Fri, 24 Jan 2020 10:36:35 -0500 Subject: [PATCH 06/48] CLN Update response_method order --- sklearn/inspection/_plot/decision_boundary.py | 24 ++++++----- .../tests/test_plot_decision_boundary.py | 40 ++++++++++++++++++- 2 files changed, 51 insertions(+), 13 deletions(-) diff --git a/sklearn/inspection/_plot/decision_boundary.py b/sklearn/inspection/_plot/decision_boundary.py index 613769556c046..608a99ffd0594 100644 --- a/sklearn/inspection/_plot/decision_boundary.py +++ b/sklearn/inspection/_plot/decision_boundary.py @@ -11,10 +11,10 @@ def _check_boundary_response_method(estimator, response_method): Estimator to check response_method: {'auto', 'predict_proba', 'decision_function', 'predict'} - Specifies whether to use :term:`predict_proba` or - :term:`decision_function` as the target response. If set to 'auto', - :term:`predict_proba` is tried first and if it does not exist - :term:`decision_function` is tried next. + Specifies whether to use :term:`predict_proba`, + :term:`decision_function`, :term:`predict` as the target response. + If set to 'auto', the response method is tried in the following order: + :term:`predict_proba`, :term:`decision_function`, :term:`predict`. Returns ------- @@ -33,13 +33,15 @@ def _check_boundary_response_method(estimator, response_method): raise ValueError(error_msg.format(response_method, estimator.__class__.__name__)) return getattr(estimator, response_method) - elif hasattr(estimator, 'predict_proba'): - return getattr(estimator, 'predict_proba') elif hasattr(estimator, 'decision_function'): return getattr(estimator, 'decision_function') + elif hasattr(estimator, 'predict_proba'): + return getattr(estimator, 'predict_proba') + elif hasattr(estimator, 'predict'): + return getattr(estimator, 'predict') raise ValueError(error_msg.format( - "decision_function or predict_proba", + "decision_function, predict_proba, or predict", estimator.__class__.__name__)) @@ -143,10 +145,10 @@ def plot_decision_boundary(estimator, X, grid_resolution=100, eps=1.0, response_method : {'auto', 'predict_proba', 'decision_function', \ 'predict'}, defaul='auto' - Specifies whether to use :term:`predict_proba` or - :term:`decision_function` as the target response. If set to 'auto', - :term:`predict_proba` is tried first and if it does not exist - :term:`decision_function` is tried next. + Specifies whether to use :term:`predict_proba`, + :term:`decision_function`, :term:`predict` as the target response. + If set to 'auto', the response method is tried in the following order: + :term:`predict_proba`, :term:`decision_function`, :term:`predict`. ax : Matplotlib axes, default=None Axes object to plot on. If `None`, a new figure and axes is diff --git a/sklearn/inspection/_plot/tests/test_plot_decision_boundary.py b/sklearn/inspection/_plot/tests/test_plot_decision_boundary.py index b7dc26a3626be..4772cbc842b19 100644 --- a/sklearn/inspection/_plot/tests/test_plot_decision_boundary.py +++ b/sklearn/inspection/_plot/tests/test_plot_decision_boundary.py @@ -5,6 +5,8 @@ from sklearn.inspection import plot_decision_boundary from sklearn.datasets import make_classification from sklearn.linear_model import LogisticRegression +from sklearn.inspection._plot.decision_boundary import ( + _check_boundary_response_method) # TODO: Remove when https://github.com/numpy/numpy/issues/14397 is resolved @@ -25,6 +27,40 @@ def fitted_clf(data): return LogisticRegression().fit(*data) +def test_check_boundary_response_method_auto(): + class A: + def decision_function(self): + pass + + a_inst = A() + method = _check_boundary_response_method(a_inst, 'auto') + assert method == a_inst.decision_function + + class B: + def predict_proba(self): + pass + b_inst = B() + method = _check_boundary_response_method(b_inst, 'auto') + assert method == b_inst.predict_proba + + class C: + def predict_proba(self): + pass + + def decision_function(self): + pass + c_inst = C() + method = _check_boundary_response_method(c_inst, 'auto') + assert method == c_inst.decision_function + + class D: + def predict(self): + pass + d_inst = D() + method = _check_boundary_response_method(d_inst, 'auto') + assert method == d_inst.predict + + @pytest.mark.parametrize("response_method", ['auto', 'predict_proba', 'decision_function']) def test_multiclass_error(pyplot, response_method): @@ -103,8 +139,8 @@ def test_plot_decision_boundary(pyplot, fitted_clf, data, "MyClassifier"), ("decision_function", "response method decision_function is not defined " "in MyClassifier"), - ("auto", "response method decision_function or predict_proba is not " - "defined in MyClassifier"), + ("auto", "response method decision_function, predict_proba, or predict " + "is not defined in MyClassifier"), ("bad_method", "response_method must be 'predict_proba', " "'decision_function', 'predict', or 'auto'")]) def test_error_bad_response(pyplot, response_method, msg): From ae93c3fa18a2142d707a9d2c4a7a217dc825c619 Mon Sep 17 00:00:00 2001 From: Thomas J Fan Date: Fri, 24 Jan 2020 13:22:06 -0500 Subject: [PATCH 07/48] CLN Adds links to external libraries --- sklearn/inspection/_plot/decision_boundary.py | 23 +++++++++++++------ .../tests/test_plot_decision_boundary.py | 2 +- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/sklearn/inspection/_plot/decision_boundary.py b/sklearn/inspection/_plot/decision_boundary.py index 608a99ffd0594..10d43ebbd03d7 100644 --- a/sklearn/inspection/_plot/decision_boundary.py +++ b/sklearn/inspection/_plot/decision_boundary.py @@ -57,10 +57,10 @@ class DecisionBoundaryDisplay: Parameters ---------- xx0 : ndarray of shape (grid_resolution, grid_resolution) - First output of meshgrid. + First output of :func:`meshgrid `. xx1 : ndarray of shape (grid_resolution, grid_resolution) - Second output of meshgrid. + Second output of :func:`meshgrid `. response : ndarray of shape (grid_resolution, grid_resolution) Values of the response function. @@ -89,7 +89,11 @@ def plot(self, plot_method='contourf', ax=None, **kwargs): Parameters ---------- plot_method : {'contourf', 'contour', 'pcolormesh'}, default='contourf' - Plotting method to call when plotting the response. + Plotting method to call when plotting the response. Please refer + to the following matplotlib documentation for details: + :func:`contourf `, + :func:`contour `, + :func:`pcolomesh `. ax : Matplotlib axes, default=None Axes object to plot on. If `None`, a new figure and axes is @@ -120,11 +124,12 @@ def plot(self, plot_method='contourf', ax=None, **kwargs): def plot_decision_boundary(estimator, X, grid_resolution=100, eps=1.0, - plot_method='contourf', - response_method='auto', + plot_method='contourf', response_method='auto', ax=None, **kwargs): """Plot Decision Boundary. + Please see examples below for usage. + Parameters ---------- estimator : estimator instance @@ -141,7 +146,11 @@ def plot_decision_boundary(estimator, X, grid_resolution=100, eps=1.0, response function. plot_method : {'contourf', 'contour', 'pcolormesh'}, default='contourf' - Plotting method to call when plotting the response. + Plotting method to call when plotting the response. Please refer + to the following matplotlib documentation for details: + :func:`contourf `, + :func:`contour `, + :func:`pcolomesh `. response_method : {'auto', 'predict_proba', 'decision_function', \ 'predict'}, defaul='auto' @@ -186,7 +195,7 @@ def plot_decision_boundary(estimator, X, grid_resolution=100, eps=1.0, if response.ndim != 1: if response.shape[1] != 2: - raise ValueError("multiclass classifers are only supported when " + raise ValueError("multiclass classifiers are only supported when " "response_method='predict'") response = response[:, 1] diff --git a/sklearn/inspection/_plot/tests/test_plot_decision_boundary.py b/sklearn/inspection/_plot/tests/test_plot_decision_boundary.py index 4772cbc842b19..fa6a7e1ac0f67 100644 --- a/sklearn/inspection/_plot/tests/test_plot_decision_boundary.py +++ b/sklearn/inspection/_plot/tests/test_plot_decision_boundary.py @@ -68,7 +68,7 @@ def test_multiclass_error(pyplot, response_method): X = X[:, [0, 1]] lr = LogisticRegression().fit(X, y) - msg = ("multiclass classifers are only supported when " + msg = ("multiclass classifiers are only supported when " "response_method='predict'") with pytest.raises(ValueError, match=msg): plot_decision_boundary(lr, X, response_method=response_method) From 67149a4b3177a3ff73b04d76cef0b7045c0ee813 Mon Sep 17 00:00:00 2001 From: Thomas J Fan Date: Fri, 24 Jan 2020 13:25:07 -0500 Subject: [PATCH 08/48] CLN Adds reference to quad* --- sklearn/inspection/_plot/decision_boundary.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sklearn/inspection/_plot/decision_boundary.py b/sklearn/inspection/_plot/decision_boundary.py index 10d43ebbd03d7..50e1002810eb7 100644 --- a/sklearn/inspection/_plot/decision_boundary.py +++ b/sklearn/inspection/_plot/decision_boundary.py @@ -69,8 +69,9 @@ class DecisionBoundaryDisplay: ---------- surface_ : matplotlib `QuadContourSet` or `QuadMesh` If `plot_method` is 'contour' or 'contourf', `surface_` is a - `QuadContourSet`. If `plot_method is `pcolormesh`, `surface_` is a - `QuadMesh`. + :class:`QuadContourSet `. If + `plot_method is `pcolormesh`, `surface_` is a + :class:`QuadMesh `. ax_ : matplotlib Axes Axes with confusion matrix. From b525638425015003697473fbc3fba1d17546c7b2 Mon Sep 17 00:00:00 2001 From: Thomas J Fan Date: Thu, 20 Feb 2020 14:04:44 -0500 Subject: [PATCH 09/48] CLN Address comments --- sklearn/inspection/_plot/tests/test_plot_decision_boundary.py | 3 ++- sklearn/setup.py | 2 -- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/sklearn/inspection/_plot/tests/test_plot_decision_boundary.py b/sklearn/inspection/_plot/tests/test_plot_decision_boundary.py index fa6a7e1ac0f67..7a7890ce43af8 100644 --- a/sklearn/inspection/_plot/tests/test_plot_decision_boundary.py +++ b/sklearn/inspection/_plot/tests/test_plot_decision_boundary.py @@ -18,7 +18,8 @@ @pytest.fixture(scope="module") def data(): X, y = make_classification(n_informative=1, n_redundant=1, - n_clusters_per_class=1, n_features=2) + n_clusters_per_class=1, n_features=2, + random_state=42) return X, y diff --git a/sklearn/setup.py b/sklearn/setup.py index 7f97e515f0a1d..cc257c30e6f43 100644 --- a/sklearn/setup.py +++ b/sklearn/setup.py @@ -39,8 +39,6 @@ def configuration(parent_package='', top_path=None): config.add_subpackage('impute/tests') config.add_subpackage('inspection') config.add_subpackage('inspection/tests') - config.add_subpackage('inspection/_plot') - config.add_subpackage('inspection/_plot/tests') config.add_subpackage('mixture') config.add_subpackage('mixture/tests') config.add_subpackage('model_selection') From 69fbabfa1035b4bcbbbf35f5053fe6c0c791fa85 Mon Sep 17 00:00:00 2001 From: Thomas J Fan Date: Thu, 20 Feb 2020 16:42:41 -0500 Subject: [PATCH 10/48] CLN Address comments --- sklearn/inspection/_plot/decision_boundary.py | 29 ++++++++++++++++--- .../tests/test_plot_decision_boundary.py | 28 ++++++++++++++++-- 2 files changed, 51 insertions(+), 6 deletions(-) diff --git a/sklearn/inspection/_plot/decision_boundary.py b/sklearn/inspection/_plot/decision_boundary.py index 50e1002810eb7..7b85bb6c9046a 100644 --- a/sklearn/inspection/_plot/decision_boundary.py +++ b/sklearn/inspection/_plot/decision_boundary.py @@ -1,5 +1,6 @@ import numpy as np from ...utils import check_matplotlib_support +from ...utils import _safe_indexing def _check_boundary_response_method(estimator, response_method): @@ -65,6 +66,12 @@ class DecisionBoundaryDisplay: response : ndarray of shape (grid_resolution, grid_resolution) Values of the response function. + xlabel : str, default="" + Default label to place on x axis. + + ylabel : str, default="" + DEfault label to place on y axis. + Attributes ---------- surface_ : matplotlib `QuadContourSet` or `QuadMesh` @@ -79,10 +86,12 @@ class DecisionBoundaryDisplay: figure_ : matplotlib Figure Figure containing the confusion matrix. """ - def __init__(self, xx0, xx1, response): + def __init__(self, xx0, xx1, response, xlabel=None, ylabel=None): self.xx0 = xx0 self.xx1 = xx1 self.response = response + self.xlabel = xlabel + self.ylabel = ylabel def plot(self, plot_method='contourf', ax=None, **kwargs): """Plot visualization. @@ -119,6 +128,12 @@ def plot(self, plot_method='contourf', ax=None, **kwargs): plot_func = getattr(ax, plot_method) self.surface_ = plot_func(self.xx0, self.xx1, self.response, **kwargs) + + if not ax.get_xlabel(): + ax.set_xlabel(self.xlabel) + if not ax.get_ylabel(): + ax.set_ylabel(self.ylabel) + self.ax_ = ax self.figure_ = ax.figure return self @@ -136,7 +151,7 @@ def plot_decision_boundary(estimator, X, grid_resolution=100, eps=1.0, estimator : estimator instance Trained estimator. - X : array-like of shape (n_samples, 2) + X : ndarray or pandas dataframe of shape (n_samples, 2) Input values. grid_resolution : int, default=100 @@ -185,7 +200,7 @@ def plot_decision_boundary(estimator, X, grid_resolution=100, eps=1.0, pred_func = _check_boundary_response_method(estimator, response_method) - x0, x1 = X[:, 0], X[:, 1] + x0, x1 = _safe_indexing(X, 0), _safe_indexing(X, 1) x0_min, x0_max = x0.min() - eps, x0.max() + eps x1_min, x1_max = x1.min() - eps, x1.max() + eps @@ -200,6 +215,12 @@ def plot_decision_boundary(estimator, X, grid_resolution=100, eps=1.0, "response_method='predict'") response = response[:, 1] + if hasattr(X, "columns"): + xlabel, ylabel = X.columns[0], X.columns[1] + else: + xlabel, ylabel = "", "" + display = DecisionBoundaryDisplay(xx0=xx0, xx1=xx1, - response=response.reshape(xx0.shape)) + response=response.reshape(xx0.shape), + xlabel=xlabel, ylabel=ylabel) return display.plot(ax=ax, plot_method=plot_method, **kwargs) diff --git a/sklearn/inspection/_plot/tests/test_plot_decision_boundary.py b/sklearn/inspection/_plot/tests/test_plot_decision_boundary.py index 7a7890ce43af8..b118f520ca1ca 100644 --- a/sklearn/inspection/_plot/tests/test_plot_decision_boundary.py +++ b/sklearn/inspection/_plot/tests/test_plot_decision_boundary.py @@ -144,8 +144,8 @@ def test_plot_decision_boundary(pyplot, fitted_clf, data, "is not defined in MyClassifier"), ("bad_method", "response_method must be 'predict_proba', " "'decision_function', 'predict', or 'auto'")]) -def test_error_bad_response(pyplot, response_method, msg): - X, y = make_classification(n_classes=2, n_samples=50, random_state=0) +def test_error_bad_response(pyplot, response_method, msg, data): + X, y = data class MyClassifier(BaseEstimator, ClassifierMixin): def fit(self, X, y): @@ -157,3 +157,27 @@ def fit(self, X, y): with pytest.raises(ValueError, match=msg): plot_decision_boundary(clf, X, response_method=response_method) + + +def test_dataframe_labels_used(pyplot, data, fitted_clf): + pd = pytest.importorskip("pandas") + df = pd.DataFrame(data[0], columns=['col_x', 'col_y']) + + # pandas column names are used by default + _, ax = pyplot.subplots() + disp = plot_decision_boundary(fitted_clf, df, ax=ax) + assert ax.get_xlabel() == "col_x" + assert ax.get_ylabel() == "col_y" + + # second call to plot will have the names + fig, ax = pyplot.subplots() + disp.plot(ax=ax) + assert ax.get_xlabel() == "col_x" + assert ax.get_ylabel() == "col_y" + + # axes with a label will not get overridden + fig, ax = pyplot.subplots() + ax.set(xlabel="hello", ylabel="world") + disp.plot(ax=ax) + assert ax.get_xlabel() == "hello" + assert ax.get_ylabel() == "world" From a0addef73619e303ce294637b5dcf1f3892d708e Mon Sep 17 00:00:00 2001 From: Thomas J Fan Date: Thu, 20 Feb 2020 17:17:21 -0500 Subject: [PATCH 11/48] BUG Fix --- sklearn/inspection/_plot/decision_boundary.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/inspection/_plot/decision_boundary.py b/sklearn/inspection/_plot/decision_boundary.py index 7b85bb6c9046a..abe0196d2d11c 100644 --- a/sklearn/inspection/_plot/decision_boundary.py +++ b/sklearn/inspection/_plot/decision_boundary.py @@ -200,7 +200,7 @@ def plot_decision_boundary(estimator, X, grid_resolution=100, eps=1.0, pred_func = _check_boundary_response_method(estimator, response_method) - x0, x1 = _safe_indexing(X, 0), _safe_indexing(X, 1) + x0, x1 = _safe_indexing(X, 0, axis=1), _safe_indexing(X, 1, axis=1) x0_min, x0_max = x0.min() - eps, x0.max() + eps x1_min, x1_max = x1.min() - eps, x1.max() + eps From 731bc233440560c6979434b244af91b7fa6f9c70 Mon Sep 17 00:00:00 2001 From: Thomas J Fan Date: Wed, 22 Apr 2020 11:48:35 -0400 Subject: [PATCH 12/48] CLN Move to utils --- doc/visualizations.rst | 4 ++-- examples/classification/plot_classifier_comparison.py | 2 +- examples/cluster/plot_inductive_clustering.py | 2 +- examples/ensemble/plot_adaboost_twoclass.py | 2 +- examples/ensemble/plot_voting_decision_regions.py | 2 +- examples/linear_model/plot_iris_logistic.py | 2 +- examples/linear_model/plot_logistic_multinomial.py | 2 +- examples/linear_model/plot_sgd_iris.py | 2 +- examples/neighbors/plot_classification.py | 2 +- examples/neighbors/plot_nca_classification.py | 2 +- examples/neighbors/plot_nearest_centroid.py | 2 +- .../plot_label_propagation_versus_svm_iris.py | 2 +- examples/svm/plot_custom_kernel.py | 2 +- examples/svm/plot_iris_svc.py | 2 +- examples/svm/plot_linearsvc_support_vectors.py | 2 +- examples/svm/plot_separating_hyperplane.py | 2 +- examples/svm/plot_separating_hyperplane_unbalanced.py | 2 +- examples/tree/plot_iris_dtc.py | 2 +- sklearn/inspection/__init__.py | 4 ---- sklearn/utils/__init__.py | 5 ++++- sklearn/utils/_plot/__init__.py | 0 sklearn/{inspection => utils}/_plot/decision_boundary.py | 6 +++--- .../_plot/tests/test_plot_decision_boundary.py | 4 ++-- sklearn/utils/setup.py | 2 ++ 24 files changed, 30 insertions(+), 29 deletions(-) create mode 100644 sklearn/utils/_plot/__init__.py rename sklearn/{inspection => utils}/_plot/decision_boundary.py (98%) rename sklearn/{inspection => utils}/_plot/tests/test_plot_decision_boundary.py (98%) diff --git a/doc/visualizations.rst b/doc/visualizations.rst index f6e41632b12c1..3caa2ba258d7b 100644 --- a/doc/visualizations.rst +++ b/doc/visualizations.rst @@ -72,7 +72,7 @@ Functions .. autosummary:: inspection.plot_partial_dependence - inspection.plot_decision_boundary + utils.plot_decision_boundary metrics.plot_confusion_matrix metrics.plot_precision_recall_curve metrics.plot_roc_curve @@ -86,7 +86,7 @@ Display Objects .. autosummary:: inspection.PartialDependenceDisplay - inspection.DecisionBoundaryDisplay + utils.DecisionBoundaryDisplay metrics.ConfusionMatrixDisplay metrics.PrecisionRecallDisplay metrics.RocCurveDisplay diff --git a/examples/classification/plot_classifier_comparison.py b/examples/classification/plot_classifier_comparison.py index a6ea51a484658..30a05dde27d4e 100644 --- a/examples/classification/plot_classifier_comparison.py +++ b/examples/classification/plot_classifier_comparison.py @@ -43,7 +43,7 @@ from sklearn.ensemble import RandomForestClassifier, AdaBoostClassifier from sklearn.naive_bayes import GaussianNB from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis -from sklearn.inspection import plot_decision_boundary +from sklearn.utils import plot_decision_boundary names = ["Nearest Neighbors", "Linear SVM", "RBF SVM", "Gaussian Process", "Decision Tree", "Random Forest", "Neural Net", "AdaBoost", diff --git a/examples/cluster/plot_inductive_clustering.py b/examples/cluster/plot_inductive_clustering.py index 623b8da3144ce..840b59880925d 100644 --- a/examples/cluster/plot_inductive_clustering.py +++ b/examples/cluster/plot_inductive_clustering.py @@ -29,7 +29,7 @@ from sklearn.datasets import make_blobs from sklearn.ensemble import RandomForestClassifier from sklearn.utils.metaestimators import if_delegate_has_method -from sklearn.inspection import plot_decision_boundary +from sklearn.utils import plot_decision_boundary N_SAMPLES = 5000 RANDOM_STATE = 42 diff --git a/examples/ensemble/plot_adaboost_twoclass.py b/examples/ensemble/plot_adaboost_twoclass.py index 296390abb76d7..d5ecfee009444 100644 --- a/examples/ensemble/plot_adaboost_twoclass.py +++ b/examples/ensemble/plot_adaboost_twoclass.py @@ -28,7 +28,7 @@ from sklearn.ensemble import AdaBoostClassifier from sklearn.tree import DecisionTreeClassifier from sklearn.datasets import make_gaussian_quantiles -from sklearn.inspection import plot_decision_boundary +from sklearn.utils import plot_decision_boundary # Construct dataset diff --git a/examples/ensemble/plot_voting_decision_regions.py b/examples/ensemble/plot_voting_decision_regions.py index 8684494d2d763..c74ed82556b77 100644 --- a/examples/ensemble/plot_voting_decision_regions.py +++ b/examples/ensemble/plot_voting_decision_regions.py @@ -33,7 +33,7 @@ from sklearn.neighbors import KNeighborsClassifier from sklearn.svm import SVC from sklearn.ensemble import VotingClassifier -from sklearn.inspection import plot_decision_boundary +from sklearn.utils import plot_decision_boundary # Loading some example data iris = datasets.load_iris() diff --git a/examples/linear_model/plot_iris_logistic.py b/examples/linear_model/plot_iris_logistic.py index 2c73fc21303de..af34acaec320c 100644 --- a/examples/linear_model/plot_iris_logistic.py +++ b/examples/linear_model/plot_iris_logistic.py @@ -21,7 +21,7 @@ import matplotlib.pyplot as plt from sklearn.linear_model import LogisticRegression from sklearn import datasets -from sklearn.inspection import plot_decision_boundary +from sklearn.utils import plot_decision_boundary # import some data to play with iris = datasets.load_iris() diff --git a/examples/linear_model/plot_logistic_multinomial.py b/examples/linear_model/plot_logistic_multinomial.py index 6ecc23fc73c0c..0f8c027421d0b 100644 --- a/examples/linear_model/plot_logistic_multinomial.py +++ b/examples/linear_model/plot_logistic_multinomial.py @@ -15,7 +15,7 @@ import matplotlib.pyplot as plt from sklearn.datasets import make_blobs from sklearn.linear_model import LogisticRegression -from sklearn.inspection import plot_decision_boundary +from sklearn.utils import plot_decision_boundary # make 3-class dataset for classification centers = [[-5, 0], [0, 1.5], [5, -1]] diff --git a/examples/linear_model/plot_sgd_iris.py b/examples/linear_model/plot_sgd_iris.py index 35206d83b74d5..1ee5b97ea585c 100644 --- a/examples/linear_model/plot_sgd_iris.py +++ b/examples/linear_model/plot_sgd_iris.py @@ -14,7 +14,7 @@ import matplotlib.pyplot as plt from sklearn import datasets from sklearn.linear_model import SGDClassifier -from sklearn.inspection import plot_decision_boundary +from sklearn.utils import plot_decision_boundary # import some data to play with iris = datasets.load_iris() diff --git a/examples/neighbors/plot_classification.py b/examples/neighbors/plot_classification.py index aba6583df894f..49adca7d0f8a3 100644 --- a/examples/neighbors/plot_classification.py +++ b/examples/neighbors/plot_classification.py @@ -11,7 +11,7 @@ import matplotlib.pyplot as plt from matplotlib.colors import ListedColormap from sklearn import neighbors, datasets -from sklearn.inspection import plot_decision_boundary +from sklearn.utils import plot_decision_boundary n_neighbors = 15 diff --git a/examples/neighbors/plot_nca_classification.py b/examples/neighbors/plot_nca_classification.py index 58380ba4839d7..b0ad2b018b5e6 100644 --- a/examples/neighbors/plot_nca_classification.py +++ b/examples/neighbors/plot_nca_classification.py @@ -25,7 +25,7 @@ from sklearn.neighbors import (KNeighborsClassifier, NeighborhoodComponentsAnalysis) from sklearn.pipeline import Pipeline -from sklearn.inspection import plot_decision_boundary +from sklearn.utils import plot_decision_boundary print(__doc__) diff --git a/examples/neighbors/plot_nearest_centroid.py b/examples/neighbors/plot_nearest_centroid.py index b21cdfda7da00..853a7a01b2ab1 100644 --- a/examples/neighbors/plot_nearest_centroid.py +++ b/examples/neighbors/plot_nearest_centroid.py @@ -13,7 +13,7 @@ from matplotlib.colors import ListedColormap from sklearn import datasets from sklearn.neighbors import NearestCentroid -from sklearn.inspection import plot_decision_boundary +from sklearn.utils import plot_decision_boundary n_neighbors = 15 diff --git a/examples/semi_supervised/plot_label_propagation_versus_svm_iris.py b/examples/semi_supervised/plot_label_propagation_versus_svm_iris.py index e7af77b332333..012487b357777 100644 --- a/examples/semi_supervised/plot_label_propagation_versus_svm_iris.py +++ b/examples/semi_supervised/plot_label_propagation_versus_svm_iris.py @@ -20,7 +20,7 @@ from sklearn import datasets from sklearn import svm from sklearn.semi_supervised import LabelSpreading -from sklearn.inspection import plot_decision_boundary +from sklearn.utils import plot_decision_boundary rng = np.random.RandomState(0) diff --git a/examples/svm/plot_custom_kernel.py b/examples/svm/plot_custom_kernel.py index f5aaaec1eb7ac..40d45c8260482 100644 --- a/examples/svm/plot_custom_kernel.py +++ b/examples/svm/plot_custom_kernel.py @@ -12,7 +12,7 @@ import numpy as np import matplotlib.pyplot as plt from sklearn import svm, datasets -from sklearn.inspection import plot_decision_boundary +from sklearn.utils import plot_decision_boundary # import some data to play with iris = datasets.load_iris() diff --git a/examples/svm/plot_iris_svc.py b/examples/svm/plot_iris_svc.py index fa230c7f9ca88..ab28038ac0ebf 100644 --- a/examples/svm/plot_iris_svc.py +++ b/examples/svm/plot_iris_svc.py @@ -37,7 +37,7 @@ import matplotlib.pyplot as plt from sklearn import svm, datasets -from sklearn.inspection import plot_decision_boundary +from sklearn.utils import plot_decision_boundary # import some data to play with diff --git a/examples/svm/plot_linearsvc_support_vectors.py b/examples/svm/plot_linearsvc_support_vectors.py index c8ab2a21f62de..c5f7cac85440d 100644 --- a/examples/svm/plot_linearsvc_support_vectors.py +++ b/examples/svm/plot_linearsvc_support_vectors.py @@ -13,7 +13,7 @@ import matplotlib.pyplot as plt from sklearn.datasets import make_blobs from sklearn.svm import LinearSVC -from sklearn.inspection import plot_decision_boundary +from sklearn.utils import plot_decision_boundary X, y = make_blobs(n_samples=40, centers=2, random_state=0) diff --git a/examples/svm/plot_separating_hyperplane.py b/examples/svm/plot_separating_hyperplane.py index becb5599a9c63..a41f3a430cdde 100644 --- a/examples/svm/plot_separating_hyperplane.py +++ b/examples/svm/plot_separating_hyperplane.py @@ -12,7 +12,7 @@ import matplotlib.pyplot as plt from sklearn import svm from sklearn.datasets import make_blobs -from sklearn.inspection import plot_decision_boundary +from sklearn.utils import plot_decision_boundary # we create 40 separable points diff --git a/examples/svm/plot_separating_hyperplane_unbalanced.py b/examples/svm/plot_separating_hyperplane_unbalanced.py index 039fc5d7341c5..9aba49a92be13 100644 --- a/examples/svm/plot_separating_hyperplane_unbalanced.py +++ b/examples/svm/plot_separating_hyperplane_unbalanced.py @@ -29,7 +29,7 @@ import matplotlib.pyplot as plt from sklearn import svm from sklearn.datasets import make_blobs -from sklearn.inspection import plot_decision_boundary +from sklearn.utils import plot_decision_boundary # we create two clusters of random points n_samples_1 = 1000 diff --git a/examples/tree/plot_iris_dtc.py b/examples/tree/plot_iris_dtc.py index 1cf9f718e9413..b9420f5acc06d 100644 --- a/examples/tree/plot_iris_dtc.py +++ b/examples/tree/plot_iris_dtc.py @@ -21,7 +21,7 @@ from sklearn.datasets import load_iris from sklearn.tree import DecisionTreeClassifier, plot_tree -from sklearn.inspection import plot_decision_boundary +from sklearn.utils import plot_decision_boundary # Parameters n_classes = 3 diff --git a/sklearn/inspection/__init__.py b/sklearn/inspection/__init__.py index 5e5479a2a8737..bfa28f2b3a4f8 100644 --- a/sklearn/inspection/__init__.py +++ b/sklearn/inspection/__init__.py @@ -13,8 +13,6 @@ from .partial_dependence import partial_dependence from ._permutation_importance import permutation_importance # noqa -from ._plot.decision_boundary import DecisionBoundaryDisplay # noqa -from ._plot.decision_boundary import plot_decision_boundary # noqa from ._plot.partial_dependence import plot_partial_dependence # noqa from ._plot.partial_dependence import PartialDependenceDisplay # noqa @@ -25,6 +23,4 @@ 'plot_partial_dependence', 'permutation_importance', 'PartialDependenceDisplay', - 'plot_decision_boundary', - 'DecisionBoundaryDisplay' ] diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index aac6e292a198a..5d4493f141ed9 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -30,6 +30,8 @@ check_random_state, column_or_1d, check_array, check_consistent_length, check_X_y, indexable, check_symmetric, check_scalar) +from ._plot.decision_boundary import plot_decision_boundary +from ._plot.decision_boundary import DecisionBoundaryDisplay from .. import get_config @@ -51,7 +53,8 @@ "check_symmetric", "indices_to_mask", "deprecated", "parallel_backend", "register_parallel_backend", "resample", "shuffle", "check_matplotlib_support", "all_estimators", - "DataConversionWarning" + "DataConversionWarning", "plot_decision_boundary", + "DecisionBoundaryDisplay" ] IS_PYPY = platform.python_implementation() == 'PyPy' diff --git a/sklearn/utils/_plot/__init__.py b/sklearn/utils/_plot/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sklearn/inspection/_plot/decision_boundary.py b/sklearn/utils/_plot/decision_boundary.py similarity index 98% rename from sklearn/inspection/_plot/decision_boundary.py rename to sklearn/utils/_plot/decision_boundary.py index abe0196d2d11c..65338c79558cc 100644 --- a/sklearn/inspection/_plot/decision_boundary.py +++ b/sklearn/utils/_plot/decision_boundary.py @@ -1,6 +1,6 @@ import numpy as np -from ...utils import check_matplotlib_support -from ...utils import _safe_indexing +from .utils import check_matplotlib_support +from .utils import _safe_indexing def _check_boundary_response_method(estimator, response_method): @@ -49,7 +49,7 @@ def _check_boundary_response_method(estimator, response_method): class DecisionBoundaryDisplay: """Decisions Boundary visualization. - It is recommend to use :func:`~sklearn.inspection.plot_decision_boundary` + It is recommend to use :func:`~sklearn.utils.plot_decision_boundary` to create a :class:`DecisionBoundaryDisplay`. All parameters are stored as attributes. diff --git a/sklearn/inspection/_plot/tests/test_plot_decision_boundary.py b/sklearn/utils/_plot/tests/test_plot_decision_boundary.py similarity index 98% rename from sklearn/inspection/_plot/tests/test_plot_decision_boundary.py rename to sklearn/utils/_plot/tests/test_plot_decision_boundary.py index b118f520ca1ca..1de56ffd46690 100644 --- a/sklearn/inspection/_plot/tests/test_plot_decision_boundary.py +++ b/sklearn/utils/_plot/tests/test_plot_decision_boundary.py @@ -2,10 +2,10 @@ from sklearn.base import BaseEstimator from sklearn.base import ClassifierMixin -from sklearn.inspection import plot_decision_boundary +from sklearn.utils import plot_decision_boundary from sklearn.datasets import make_classification from sklearn.linear_model import LogisticRegression -from sklearn.inspection._plot.decision_boundary import ( +from sklearn.utils._plot.decision_boundary import ( _check_boundary_response_method) diff --git a/sklearn/utils/setup.py b/sklearn/utils/setup.py index 098adeeccab09..c1e65eae20e73 100644 --- a/sklearn/utils/setup.py +++ b/sklearn/utils/setup.py @@ -70,6 +70,8 @@ def configuration(parent_package='', top_path=None): include_dirs=[numpy.get_include()], libraries=libraries) + config.add_subpackage("_plot") + config.add_subpackage("_plot/tests") config.add_subpackage('tests') return config From e1e3df2f27238d9ca9c98981dd8f3b1defc7e77f Mon Sep 17 00:00:00 2001 From: Thomas J Fan Date: Wed, 22 Apr 2020 16:48:55 -0400 Subject: [PATCH 13/48] BUG Fix --- examples/cluster/plot_inductive_clustering.py | 1 - examples/neighbors/plot_nca_classification.py | 1 - 2 files changed, 2 deletions(-) diff --git a/examples/cluster/plot_inductive_clustering.py b/examples/cluster/plot_inductive_clustering.py index 840b59880925d..9fa6855b08d12 100644 --- a/examples/cluster/plot_inductive_clustering.py +++ b/examples/cluster/plot_inductive_clustering.py @@ -22,7 +22,6 @@ # Christos Aridas print(__doc__) -import numpy as np import matplotlib.pyplot as plt from sklearn.base import BaseEstimator, clone from sklearn.cluster import AgglomerativeClustering diff --git a/examples/neighbors/plot_nca_classification.py b/examples/neighbors/plot_nca_classification.py index b0ad2b018b5e6..0610ba574948b 100644 --- a/examples/neighbors/plot_nca_classification.py +++ b/examples/neighbors/plot_nca_classification.py @@ -16,7 +16,6 @@ # License: BSD 3 clause -import numpy as np import matplotlib.pyplot as plt from matplotlib.colors import ListedColormap from sklearn import datasets From afb8594097c0a62e1effa517b9bdfea6006bdc9d Mon Sep 17 00:00:00 2001 From: Thomas J Fan Date: Wed, 22 Apr 2020 17:00:36 -0400 Subject: [PATCH 14/48] ENH Move to utils --- examples/classification/plot_classifier_comparison.py | 2 +- examples/cluster/plot_inductive_clustering.py | 2 +- examples/ensemble/plot_adaboost_twoclass.py | 2 +- examples/ensemble/plot_voting_decision_regions.py | 2 +- examples/linear_model/plot_iris_logistic.py | 2 +- examples/linear_model/plot_logistic_multinomial.py | 2 +- examples/linear_model/plot_sgd_iris.py | 2 +- examples/neighbors/plot_classification.py | 2 +- examples/neighbors/plot_nca_classification.py | 2 +- examples/neighbors/plot_nearest_centroid.py | 2 +- .../plot_label_propagation_versus_svm_iris.py | 2 +- examples/svm/plot_custom_kernel.py | 2 +- examples/svm/plot_iris_svc.py | 2 +- examples/svm/plot_linearsvc_support_vectors.py | 2 +- examples/svm/plot_separating_hyperplane.py | 2 +- examples/svm/plot_separating_hyperplane_unbalanced.py | 2 +- examples/tree/plot_iris_dtc.py | 2 +- sklearn/utils/__init__.py | 5 +---- sklearn/utils/_plot/__init__.py | 0 sklearn/utils/plot/__init__.py | 7 +++++++ .../decision_boundary.py => plot/_decision_boundary.py} | 4 ++-- .../{_plot => plot}/tests/test_plot_decision_boundary.py | 4 ++-- 22 files changed, 29 insertions(+), 25 deletions(-) delete mode 100644 sklearn/utils/_plot/__init__.py create mode 100644 sklearn/utils/plot/__init__.py rename sklearn/utils/{_plot/decision_boundary.py => plot/_decision_boundary.py} (99%) rename sklearn/utils/{_plot => plot}/tests/test_plot_decision_boundary.py (98%) diff --git a/examples/classification/plot_classifier_comparison.py b/examples/classification/plot_classifier_comparison.py index 30a05dde27d4e..f0dd1f9ae654c 100644 --- a/examples/classification/plot_classifier_comparison.py +++ b/examples/classification/plot_classifier_comparison.py @@ -43,7 +43,7 @@ from sklearn.ensemble import RandomForestClassifier, AdaBoostClassifier from sklearn.naive_bayes import GaussianNB from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis -from sklearn.utils import plot_decision_boundary +from sklearn.utils.plot import plot_decision_boundary names = ["Nearest Neighbors", "Linear SVM", "RBF SVM", "Gaussian Process", "Decision Tree", "Random Forest", "Neural Net", "AdaBoost", diff --git a/examples/cluster/plot_inductive_clustering.py b/examples/cluster/plot_inductive_clustering.py index 9fa6855b08d12..bd0ada47cd675 100644 --- a/examples/cluster/plot_inductive_clustering.py +++ b/examples/cluster/plot_inductive_clustering.py @@ -28,7 +28,7 @@ from sklearn.datasets import make_blobs from sklearn.ensemble import RandomForestClassifier from sklearn.utils.metaestimators import if_delegate_has_method -from sklearn.utils import plot_decision_boundary +from sklearn.utils.plot import plot_decision_boundary N_SAMPLES = 5000 RANDOM_STATE = 42 diff --git a/examples/ensemble/plot_adaboost_twoclass.py b/examples/ensemble/plot_adaboost_twoclass.py index d5ecfee009444..f9688e055874d 100644 --- a/examples/ensemble/plot_adaboost_twoclass.py +++ b/examples/ensemble/plot_adaboost_twoclass.py @@ -28,7 +28,7 @@ from sklearn.ensemble import AdaBoostClassifier from sklearn.tree import DecisionTreeClassifier from sklearn.datasets import make_gaussian_quantiles -from sklearn.utils import plot_decision_boundary +from sklearn.utils.plot import plot_decision_boundary # Construct dataset diff --git a/examples/ensemble/plot_voting_decision_regions.py b/examples/ensemble/plot_voting_decision_regions.py index c74ed82556b77..be48a2047b43f 100644 --- a/examples/ensemble/plot_voting_decision_regions.py +++ b/examples/ensemble/plot_voting_decision_regions.py @@ -33,7 +33,7 @@ from sklearn.neighbors import KNeighborsClassifier from sklearn.svm import SVC from sklearn.ensemble import VotingClassifier -from sklearn.utils import plot_decision_boundary +from sklearn.utils.plot import plot_decision_boundary # Loading some example data iris = datasets.load_iris() diff --git a/examples/linear_model/plot_iris_logistic.py b/examples/linear_model/plot_iris_logistic.py index af34acaec320c..69a327ab52e37 100644 --- a/examples/linear_model/plot_iris_logistic.py +++ b/examples/linear_model/plot_iris_logistic.py @@ -21,7 +21,7 @@ import matplotlib.pyplot as plt from sklearn.linear_model import LogisticRegression from sklearn import datasets -from sklearn.utils import plot_decision_boundary +from sklearn.utils.plot import plot_decision_boundary # import some data to play with iris = datasets.load_iris() diff --git a/examples/linear_model/plot_logistic_multinomial.py b/examples/linear_model/plot_logistic_multinomial.py index 0f8c027421d0b..485887f509e33 100644 --- a/examples/linear_model/plot_logistic_multinomial.py +++ b/examples/linear_model/plot_logistic_multinomial.py @@ -15,7 +15,7 @@ import matplotlib.pyplot as plt from sklearn.datasets import make_blobs from sklearn.linear_model import LogisticRegression -from sklearn.utils import plot_decision_boundary +from sklearn.utils.plot import plot_decision_boundary # make 3-class dataset for classification centers = [[-5, 0], [0, 1.5], [5, -1]] diff --git a/examples/linear_model/plot_sgd_iris.py b/examples/linear_model/plot_sgd_iris.py index 1ee5b97ea585c..eba48879c0247 100644 --- a/examples/linear_model/plot_sgd_iris.py +++ b/examples/linear_model/plot_sgd_iris.py @@ -14,7 +14,7 @@ import matplotlib.pyplot as plt from sklearn import datasets from sklearn.linear_model import SGDClassifier -from sklearn.utils import plot_decision_boundary +from sklearn.utils.plot import plot_decision_boundary # import some data to play with iris = datasets.load_iris() diff --git a/examples/neighbors/plot_classification.py b/examples/neighbors/plot_classification.py index 49adca7d0f8a3..c6952838a83e8 100644 --- a/examples/neighbors/plot_classification.py +++ b/examples/neighbors/plot_classification.py @@ -11,7 +11,7 @@ import matplotlib.pyplot as plt from matplotlib.colors import ListedColormap from sklearn import neighbors, datasets -from sklearn.utils import plot_decision_boundary +from sklearn.utils.plot import plot_decision_boundary n_neighbors = 15 diff --git a/examples/neighbors/plot_nca_classification.py b/examples/neighbors/plot_nca_classification.py index 0610ba574948b..c495bb5dd5469 100644 --- a/examples/neighbors/plot_nca_classification.py +++ b/examples/neighbors/plot_nca_classification.py @@ -24,7 +24,7 @@ from sklearn.neighbors import (KNeighborsClassifier, NeighborhoodComponentsAnalysis) from sklearn.pipeline import Pipeline -from sklearn.utils import plot_decision_boundary +from sklearn.utils.plot import plot_decision_boundary print(__doc__) diff --git a/examples/neighbors/plot_nearest_centroid.py b/examples/neighbors/plot_nearest_centroid.py index 853a7a01b2ab1..7bcb2fa6e836a 100644 --- a/examples/neighbors/plot_nearest_centroid.py +++ b/examples/neighbors/plot_nearest_centroid.py @@ -13,7 +13,7 @@ from matplotlib.colors import ListedColormap from sklearn import datasets from sklearn.neighbors import NearestCentroid -from sklearn.utils import plot_decision_boundary +from sklearn.utils.plot import plot_decision_boundary n_neighbors = 15 diff --git a/examples/semi_supervised/plot_label_propagation_versus_svm_iris.py b/examples/semi_supervised/plot_label_propagation_versus_svm_iris.py index 012487b357777..7d0803df2605b 100644 --- a/examples/semi_supervised/plot_label_propagation_versus_svm_iris.py +++ b/examples/semi_supervised/plot_label_propagation_versus_svm_iris.py @@ -20,7 +20,7 @@ from sklearn import datasets from sklearn import svm from sklearn.semi_supervised import LabelSpreading -from sklearn.utils import plot_decision_boundary +from sklearn.utils.plot import plot_decision_boundary rng = np.random.RandomState(0) diff --git a/examples/svm/plot_custom_kernel.py b/examples/svm/plot_custom_kernel.py index 40d45c8260482..cb5f72fefaaaa 100644 --- a/examples/svm/plot_custom_kernel.py +++ b/examples/svm/plot_custom_kernel.py @@ -12,7 +12,7 @@ import numpy as np import matplotlib.pyplot as plt from sklearn import svm, datasets -from sklearn.utils import plot_decision_boundary +from sklearn.utils.plot import plot_decision_boundary # import some data to play with iris = datasets.load_iris() diff --git a/examples/svm/plot_iris_svc.py b/examples/svm/plot_iris_svc.py index ab28038ac0ebf..2aa9ccd2478e8 100644 --- a/examples/svm/plot_iris_svc.py +++ b/examples/svm/plot_iris_svc.py @@ -37,7 +37,7 @@ import matplotlib.pyplot as plt from sklearn import svm, datasets -from sklearn.utils import plot_decision_boundary +from sklearn.utils.plot import plot_decision_boundary # import some data to play with diff --git a/examples/svm/plot_linearsvc_support_vectors.py b/examples/svm/plot_linearsvc_support_vectors.py index c5f7cac85440d..5d7a03a7aec46 100644 --- a/examples/svm/plot_linearsvc_support_vectors.py +++ b/examples/svm/plot_linearsvc_support_vectors.py @@ -13,7 +13,7 @@ import matplotlib.pyplot as plt from sklearn.datasets import make_blobs from sklearn.svm import LinearSVC -from sklearn.utils import plot_decision_boundary +from sklearn.utils.plot import plot_decision_boundary X, y = make_blobs(n_samples=40, centers=2, random_state=0) diff --git a/examples/svm/plot_separating_hyperplane.py b/examples/svm/plot_separating_hyperplane.py index a41f3a430cdde..e699e0a5989d0 100644 --- a/examples/svm/plot_separating_hyperplane.py +++ b/examples/svm/plot_separating_hyperplane.py @@ -12,7 +12,7 @@ import matplotlib.pyplot as plt from sklearn import svm from sklearn.datasets import make_blobs -from sklearn.utils import plot_decision_boundary +from sklearn.utils.plot import plot_decision_boundary # we create 40 separable points diff --git a/examples/svm/plot_separating_hyperplane_unbalanced.py b/examples/svm/plot_separating_hyperplane_unbalanced.py index 9aba49a92be13..ff52e6b1981f1 100644 --- a/examples/svm/plot_separating_hyperplane_unbalanced.py +++ b/examples/svm/plot_separating_hyperplane_unbalanced.py @@ -29,7 +29,7 @@ import matplotlib.pyplot as plt from sklearn import svm from sklearn.datasets import make_blobs -from sklearn.utils import plot_decision_boundary +from sklearn.utils.plot import plot_decision_boundary # we create two clusters of random points n_samples_1 = 1000 diff --git a/examples/tree/plot_iris_dtc.py b/examples/tree/plot_iris_dtc.py index b9420f5acc06d..e4f23be6bf288 100644 --- a/examples/tree/plot_iris_dtc.py +++ b/examples/tree/plot_iris_dtc.py @@ -21,7 +21,7 @@ from sklearn.datasets import load_iris from sklearn.tree import DecisionTreeClassifier, plot_tree -from sklearn.utils import plot_decision_boundary +from sklearn.utils.plot import plot_decision_boundary # Parameters n_classes = 3 diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index 5d4493f141ed9..9834f4624bf76 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -30,8 +30,6 @@ check_random_state, column_or_1d, check_array, check_consistent_length, check_X_y, indexable, check_symmetric, check_scalar) -from ._plot.decision_boundary import plot_decision_boundary -from ._plot.decision_boundary import DecisionBoundaryDisplay from .. import get_config @@ -53,8 +51,7 @@ "check_symmetric", "indices_to_mask", "deprecated", "parallel_backend", "register_parallel_backend", "resample", "shuffle", "check_matplotlib_support", "all_estimators", - "DataConversionWarning", "plot_decision_boundary", - "DecisionBoundaryDisplay" + "DataConversionWarning", ] IS_PYPY = platform.python_implementation() == 'PyPy' diff --git a/sklearn/utils/_plot/__init__.py b/sklearn/utils/_plot/__init__.py deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/sklearn/utils/plot/__init__.py b/sklearn/utils/plot/__init__.py new file mode 100644 index 0000000000000..735c8e1ca2493 --- /dev/null +++ b/sklearn/utils/plot/__init__.py @@ -0,0 +1,7 @@ +from ._decision_boundary import plot_decision_boundary +from ._decision_boundary import DecisionBoundaryDisplay + +__all__ = [ + "plot_decision_boundary", + "DecisionBoundaryDisplay", +] diff --git a/sklearn/utils/_plot/decision_boundary.py b/sklearn/utils/plot/_decision_boundary.py similarity index 99% rename from sklearn/utils/_plot/decision_boundary.py rename to sklearn/utils/plot/_decision_boundary.py index 65338c79558cc..552d05a52f432 100644 --- a/sklearn/utils/_plot/decision_boundary.py +++ b/sklearn/utils/plot/_decision_boundary.py @@ -1,6 +1,6 @@ import numpy as np -from .utils import check_matplotlib_support -from .utils import _safe_indexing +from .. import check_matplotlib_support +from .. import _safe_indexing def _check_boundary_response_method(estimator, response_method): diff --git a/sklearn/utils/_plot/tests/test_plot_decision_boundary.py b/sklearn/utils/plot/tests/test_plot_decision_boundary.py similarity index 98% rename from sklearn/utils/_plot/tests/test_plot_decision_boundary.py rename to sklearn/utils/plot/tests/test_plot_decision_boundary.py index 1de56ffd46690..1a91a60833cc7 100644 --- a/sklearn/utils/_plot/tests/test_plot_decision_boundary.py +++ b/sklearn/utils/plot/tests/test_plot_decision_boundary.py @@ -2,10 +2,10 @@ from sklearn.base import BaseEstimator from sklearn.base import ClassifierMixin -from sklearn.utils import plot_decision_boundary +from sklearn.utils.plot import plot_decision_boundary from sklearn.datasets import make_classification from sklearn.linear_model import LogisticRegression -from sklearn.utils._plot.decision_boundary import ( +from sklearn.utils.plot._decision_boundary import ( _check_boundary_response_method) From 4d4ffe7879dc81a40aba5c13e7bd8854fcbc6e6e Mon Sep 17 00:00:00 2001 From: Thomas J Fan Date: Wed, 22 Apr 2020 17:02:43 -0400 Subject: [PATCH 15/48] BLD Fixes build error --- sklearn/utils/setup.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/utils/setup.py b/sklearn/utils/setup.py index c1e65eae20e73..91898e152a1d0 100644 --- a/sklearn/utils/setup.py +++ b/sklearn/utils/setup.py @@ -70,8 +70,8 @@ def configuration(parent_package='', top_path=None): include_dirs=[numpy.get_include()], libraries=libraries) - config.add_subpackage("_plot") - config.add_subpackage("_plot/tests") + config.add_subpackage("plot") + config.add_subpackage("plot/tests") config.add_subpackage('tests') return config From 571047342f2f090d2ae78dd1a8440f90d4738e95 Mon Sep 17 00:00:00 2001 From: Thomas J Fan Date: Wed, 22 Apr 2020 18:30:21 -0400 Subject: [PATCH 16/48] FIX Bug --- sklearn/utils/plot/tests/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 sklearn/utils/plot/tests/__init__.py diff --git a/sklearn/utils/plot/tests/__init__.py b/sklearn/utils/plot/tests/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d From 1bfb6cf0abec2e1c6af62936fe5f943f07601546 Mon Sep 17 00:00:00 2001 From: Thomas J Fan Date: Thu, 23 Apr 2020 10:55:38 -0400 Subject: [PATCH 17/48] CLN Move back to inspection --- doc/visualizations.rst | 4 ++-- examples/classification/plot_classifier_comparison.py | 2 +- examples/cluster/plot_inductive_clustering.py | 2 +- examples/ensemble/plot_adaboost_twoclass.py | 2 +- examples/ensemble/plot_voting_decision_regions.py | 2 +- examples/linear_model/plot_iris_logistic.py | 2 +- examples/linear_model/plot_logistic_multinomial.py | 2 +- examples/linear_model/plot_sgd_iris.py | 2 +- examples/neighbors/plot_classification.py | 2 +- examples/neighbors/plot_nca_classification.py | 2 +- examples/neighbors/plot_nearest_centroid.py | 2 +- .../plot_label_propagation_versus_svm_iris.py | 2 +- examples/svm/plot_custom_kernel.py | 2 +- examples/svm/plot_iris_svc.py | 2 +- examples/svm/plot_linearsvc_support_vectors.py | 2 +- examples/svm/plot_separating_hyperplane.py | 2 +- examples/svm/plot_separating_hyperplane_unbalanced.py | 2 +- examples/tree/plot_iris_dtc.py | 2 +- sklearn/inspection/__init__.py | 4 ++++ .../_plot/decision_boundary.py} | 6 +++--- .../_plot}/tests/test_plot_decision_boundary.py | 4 ++-- sklearn/utils/__init__.py | 2 +- sklearn/utils/plot/__init__.py | 7 ------- sklearn/utils/plot/tests/__init__.py | 0 sklearn/utils/setup.py | 2 -- 25 files changed, 29 insertions(+), 34 deletions(-) rename sklearn/{utils/plot/_decision_boundary.py => inspection/_plot/decision_boundary.py} (98%) rename sklearn/{utils/plot => inspection/_plot}/tests/test_plot_decision_boundary.py (98%) delete mode 100644 sklearn/utils/plot/__init__.py delete mode 100644 sklearn/utils/plot/tests/__init__.py diff --git a/doc/visualizations.rst b/doc/visualizations.rst index 3caa2ba258d7b..f6e41632b12c1 100644 --- a/doc/visualizations.rst +++ b/doc/visualizations.rst @@ -72,7 +72,7 @@ Functions .. autosummary:: inspection.plot_partial_dependence - utils.plot_decision_boundary + inspection.plot_decision_boundary metrics.plot_confusion_matrix metrics.plot_precision_recall_curve metrics.plot_roc_curve @@ -86,7 +86,7 @@ Display Objects .. autosummary:: inspection.PartialDependenceDisplay - utils.DecisionBoundaryDisplay + inspection.DecisionBoundaryDisplay metrics.ConfusionMatrixDisplay metrics.PrecisionRecallDisplay metrics.RocCurveDisplay diff --git a/examples/classification/plot_classifier_comparison.py b/examples/classification/plot_classifier_comparison.py index f0dd1f9ae654c..a6ea51a484658 100644 --- a/examples/classification/plot_classifier_comparison.py +++ b/examples/classification/plot_classifier_comparison.py @@ -43,7 +43,7 @@ from sklearn.ensemble import RandomForestClassifier, AdaBoostClassifier from sklearn.naive_bayes import GaussianNB from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis -from sklearn.utils.plot import plot_decision_boundary +from sklearn.inspection import plot_decision_boundary names = ["Nearest Neighbors", "Linear SVM", "RBF SVM", "Gaussian Process", "Decision Tree", "Random Forest", "Neural Net", "AdaBoost", diff --git a/examples/cluster/plot_inductive_clustering.py b/examples/cluster/plot_inductive_clustering.py index bd0ada47cd675..ba7df596aff59 100644 --- a/examples/cluster/plot_inductive_clustering.py +++ b/examples/cluster/plot_inductive_clustering.py @@ -28,7 +28,7 @@ from sklearn.datasets import make_blobs from sklearn.ensemble import RandomForestClassifier from sklearn.utils.metaestimators import if_delegate_has_method -from sklearn.utils.plot import plot_decision_boundary +from sklearn.inspection import plot_decision_boundary N_SAMPLES = 5000 RANDOM_STATE = 42 diff --git a/examples/ensemble/plot_adaboost_twoclass.py b/examples/ensemble/plot_adaboost_twoclass.py index f9688e055874d..296390abb76d7 100644 --- a/examples/ensemble/plot_adaboost_twoclass.py +++ b/examples/ensemble/plot_adaboost_twoclass.py @@ -28,7 +28,7 @@ from sklearn.ensemble import AdaBoostClassifier from sklearn.tree import DecisionTreeClassifier from sklearn.datasets import make_gaussian_quantiles -from sklearn.utils.plot import plot_decision_boundary +from sklearn.inspection import plot_decision_boundary # Construct dataset diff --git a/examples/ensemble/plot_voting_decision_regions.py b/examples/ensemble/plot_voting_decision_regions.py index be48a2047b43f..8684494d2d763 100644 --- a/examples/ensemble/plot_voting_decision_regions.py +++ b/examples/ensemble/plot_voting_decision_regions.py @@ -33,7 +33,7 @@ from sklearn.neighbors import KNeighborsClassifier from sklearn.svm import SVC from sklearn.ensemble import VotingClassifier -from sklearn.utils.plot import plot_decision_boundary +from sklearn.inspection import plot_decision_boundary # Loading some example data iris = datasets.load_iris() diff --git a/examples/linear_model/plot_iris_logistic.py b/examples/linear_model/plot_iris_logistic.py index 69a327ab52e37..2c73fc21303de 100644 --- a/examples/linear_model/plot_iris_logistic.py +++ b/examples/linear_model/plot_iris_logistic.py @@ -21,7 +21,7 @@ import matplotlib.pyplot as plt from sklearn.linear_model import LogisticRegression from sklearn import datasets -from sklearn.utils.plot import plot_decision_boundary +from sklearn.inspection import plot_decision_boundary # import some data to play with iris = datasets.load_iris() diff --git a/examples/linear_model/plot_logistic_multinomial.py b/examples/linear_model/plot_logistic_multinomial.py index 485887f509e33..6ecc23fc73c0c 100644 --- a/examples/linear_model/plot_logistic_multinomial.py +++ b/examples/linear_model/plot_logistic_multinomial.py @@ -15,7 +15,7 @@ import matplotlib.pyplot as plt from sklearn.datasets import make_blobs from sklearn.linear_model import LogisticRegression -from sklearn.utils.plot import plot_decision_boundary +from sklearn.inspection import plot_decision_boundary # make 3-class dataset for classification centers = [[-5, 0], [0, 1.5], [5, -1]] diff --git a/examples/linear_model/plot_sgd_iris.py b/examples/linear_model/plot_sgd_iris.py index eba48879c0247..35206d83b74d5 100644 --- a/examples/linear_model/plot_sgd_iris.py +++ b/examples/linear_model/plot_sgd_iris.py @@ -14,7 +14,7 @@ import matplotlib.pyplot as plt from sklearn import datasets from sklearn.linear_model import SGDClassifier -from sklearn.utils.plot import plot_decision_boundary +from sklearn.inspection import plot_decision_boundary # import some data to play with iris = datasets.load_iris() diff --git a/examples/neighbors/plot_classification.py b/examples/neighbors/plot_classification.py index c6952838a83e8..aba6583df894f 100644 --- a/examples/neighbors/plot_classification.py +++ b/examples/neighbors/plot_classification.py @@ -11,7 +11,7 @@ import matplotlib.pyplot as plt from matplotlib.colors import ListedColormap from sklearn import neighbors, datasets -from sklearn.utils.plot import plot_decision_boundary +from sklearn.inspection import plot_decision_boundary n_neighbors = 15 diff --git a/examples/neighbors/plot_nca_classification.py b/examples/neighbors/plot_nca_classification.py index c495bb5dd5469..694a865bfd1e9 100644 --- a/examples/neighbors/plot_nca_classification.py +++ b/examples/neighbors/plot_nca_classification.py @@ -24,7 +24,7 @@ from sklearn.neighbors import (KNeighborsClassifier, NeighborhoodComponentsAnalysis) from sklearn.pipeline import Pipeline -from sklearn.utils.plot import plot_decision_boundary +from sklearn.inspection import plot_decision_boundary print(__doc__) diff --git a/examples/neighbors/plot_nearest_centroid.py b/examples/neighbors/plot_nearest_centroid.py index 7bcb2fa6e836a..b21cdfda7da00 100644 --- a/examples/neighbors/plot_nearest_centroid.py +++ b/examples/neighbors/plot_nearest_centroid.py @@ -13,7 +13,7 @@ from matplotlib.colors import ListedColormap from sklearn import datasets from sklearn.neighbors import NearestCentroid -from sklearn.utils.plot import plot_decision_boundary +from sklearn.inspection import plot_decision_boundary n_neighbors = 15 diff --git a/examples/semi_supervised/plot_label_propagation_versus_svm_iris.py b/examples/semi_supervised/plot_label_propagation_versus_svm_iris.py index 7d0803df2605b..e7af77b332333 100644 --- a/examples/semi_supervised/plot_label_propagation_versus_svm_iris.py +++ b/examples/semi_supervised/plot_label_propagation_versus_svm_iris.py @@ -20,7 +20,7 @@ from sklearn import datasets from sklearn import svm from sklearn.semi_supervised import LabelSpreading -from sklearn.utils.plot import plot_decision_boundary +from sklearn.inspection import plot_decision_boundary rng = np.random.RandomState(0) diff --git a/examples/svm/plot_custom_kernel.py b/examples/svm/plot_custom_kernel.py index cb5f72fefaaaa..f5aaaec1eb7ac 100644 --- a/examples/svm/plot_custom_kernel.py +++ b/examples/svm/plot_custom_kernel.py @@ -12,7 +12,7 @@ import numpy as np import matplotlib.pyplot as plt from sklearn import svm, datasets -from sklearn.utils.plot import plot_decision_boundary +from sklearn.inspection import plot_decision_boundary # import some data to play with iris = datasets.load_iris() diff --git a/examples/svm/plot_iris_svc.py b/examples/svm/plot_iris_svc.py index 2aa9ccd2478e8..fa230c7f9ca88 100644 --- a/examples/svm/plot_iris_svc.py +++ b/examples/svm/plot_iris_svc.py @@ -37,7 +37,7 @@ import matplotlib.pyplot as plt from sklearn import svm, datasets -from sklearn.utils.plot import plot_decision_boundary +from sklearn.inspection import plot_decision_boundary # import some data to play with diff --git a/examples/svm/plot_linearsvc_support_vectors.py b/examples/svm/plot_linearsvc_support_vectors.py index 5d7a03a7aec46..c8ab2a21f62de 100644 --- a/examples/svm/plot_linearsvc_support_vectors.py +++ b/examples/svm/plot_linearsvc_support_vectors.py @@ -13,7 +13,7 @@ import matplotlib.pyplot as plt from sklearn.datasets import make_blobs from sklearn.svm import LinearSVC -from sklearn.utils.plot import plot_decision_boundary +from sklearn.inspection import plot_decision_boundary X, y = make_blobs(n_samples=40, centers=2, random_state=0) diff --git a/examples/svm/plot_separating_hyperplane.py b/examples/svm/plot_separating_hyperplane.py index e699e0a5989d0..becb5599a9c63 100644 --- a/examples/svm/plot_separating_hyperplane.py +++ b/examples/svm/plot_separating_hyperplane.py @@ -12,7 +12,7 @@ import matplotlib.pyplot as plt from sklearn import svm from sklearn.datasets import make_blobs -from sklearn.utils.plot import plot_decision_boundary +from sklearn.inspection import plot_decision_boundary # we create 40 separable points diff --git a/examples/svm/plot_separating_hyperplane_unbalanced.py b/examples/svm/plot_separating_hyperplane_unbalanced.py index ff52e6b1981f1..039fc5d7341c5 100644 --- a/examples/svm/plot_separating_hyperplane_unbalanced.py +++ b/examples/svm/plot_separating_hyperplane_unbalanced.py @@ -29,7 +29,7 @@ import matplotlib.pyplot as plt from sklearn import svm from sklearn.datasets import make_blobs -from sklearn.utils.plot import plot_decision_boundary +from sklearn.inspection import plot_decision_boundary # we create two clusters of random points n_samples_1 = 1000 diff --git a/examples/tree/plot_iris_dtc.py b/examples/tree/plot_iris_dtc.py index e4f23be6bf288..1cf9f718e9413 100644 --- a/examples/tree/plot_iris_dtc.py +++ b/examples/tree/plot_iris_dtc.py @@ -21,7 +21,7 @@ from sklearn.datasets import load_iris from sklearn.tree import DecisionTreeClassifier, plot_tree -from sklearn.utils.plot import plot_decision_boundary +from sklearn.inspection import plot_decision_boundary # Parameters n_classes = 3 diff --git a/sklearn/inspection/__init__.py b/sklearn/inspection/__init__.py index bfa28f2b3a4f8..5e5479a2a8737 100644 --- a/sklearn/inspection/__init__.py +++ b/sklearn/inspection/__init__.py @@ -13,6 +13,8 @@ from .partial_dependence import partial_dependence from ._permutation_importance import permutation_importance # noqa +from ._plot.decision_boundary import DecisionBoundaryDisplay # noqa +from ._plot.decision_boundary import plot_decision_boundary # noqa from ._plot.partial_dependence import plot_partial_dependence # noqa from ._plot.partial_dependence import PartialDependenceDisplay # noqa @@ -23,4 +25,6 @@ 'plot_partial_dependence', 'permutation_importance', 'PartialDependenceDisplay', + 'plot_decision_boundary', + 'DecisionBoundaryDisplay' ] diff --git a/sklearn/utils/plot/_decision_boundary.py b/sklearn/inspection/_plot/decision_boundary.py similarity index 98% rename from sklearn/utils/plot/_decision_boundary.py rename to sklearn/inspection/_plot/decision_boundary.py index 552d05a52f432..abe0196d2d11c 100644 --- a/sklearn/utils/plot/_decision_boundary.py +++ b/sklearn/inspection/_plot/decision_boundary.py @@ -1,6 +1,6 @@ import numpy as np -from .. import check_matplotlib_support -from .. import _safe_indexing +from ...utils import check_matplotlib_support +from ...utils import _safe_indexing def _check_boundary_response_method(estimator, response_method): @@ -49,7 +49,7 @@ def _check_boundary_response_method(estimator, response_method): class DecisionBoundaryDisplay: """Decisions Boundary visualization. - It is recommend to use :func:`~sklearn.utils.plot_decision_boundary` + It is recommend to use :func:`~sklearn.inspection.plot_decision_boundary` to create a :class:`DecisionBoundaryDisplay`. All parameters are stored as attributes. diff --git a/sklearn/utils/plot/tests/test_plot_decision_boundary.py b/sklearn/inspection/_plot/tests/test_plot_decision_boundary.py similarity index 98% rename from sklearn/utils/plot/tests/test_plot_decision_boundary.py rename to sklearn/inspection/_plot/tests/test_plot_decision_boundary.py index 1a91a60833cc7..b118f520ca1ca 100644 --- a/sklearn/utils/plot/tests/test_plot_decision_boundary.py +++ b/sklearn/inspection/_plot/tests/test_plot_decision_boundary.py @@ -2,10 +2,10 @@ from sklearn.base import BaseEstimator from sklearn.base import ClassifierMixin -from sklearn.utils.plot import plot_decision_boundary +from sklearn.inspection import plot_decision_boundary from sklearn.datasets import make_classification from sklearn.linear_model import LogisticRegression -from sklearn.utils.plot._decision_boundary import ( +from sklearn.inspection._plot.decision_boundary import ( _check_boundary_response_method) diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index 9834f4624bf76..aac6e292a198a 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -51,7 +51,7 @@ "check_symmetric", "indices_to_mask", "deprecated", "parallel_backend", "register_parallel_backend", "resample", "shuffle", "check_matplotlib_support", "all_estimators", - "DataConversionWarning", + "DataConversionWarning" ] IS_PYPY = platform.python_implementation() == 'PyPy' diff --git a/sklearn/utils/plot/__init__.py b/sklearn/utils/plot/__init__.py deleted file mode 100644 index 735c8e1ca2493..0000000000000 --- a/sklearn/utils/plot/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from ._decision_boundary import plot_decision_boundary -from ._decision_boundary import DecisionBoundaryDisplay - -__all__ = [ - "plot_decision_boundary", - "DecisionBoundaryDisplay", -] diff --git a/sklearn/utils/plot/tests/__init__.py b/sklearn/utils/plot/tests/__init__.py deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/sklearn/utils/setup.py b/sklearn/utils/setup.py index 91898e152a1d0..098adeeccab09 100644 --- a/sklearn/utils/setup.py +++ b/sklearn/utils/setup.py @@ -70,8 +70,6 @@ def configuration(parent_package='', top_path=None): include_dirs=[numpy.get_include()], libraries=libraries) - config.add_subpackage("plot") - config.add_subpackage("plot/tests") config.add_subpackage('tests') return config From eb044be550f57e23849aff6dd0173c42a1dbdcac Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Thu, 9 Jul 2020 14:04:41 -0400 Subject: [PATCH 18/48] DOC Adds whats new --- doc/whats_new/v0.24.rst | 3 +++ 1 file changed, 3 insertions(+) diff --git a/doc/whats_new/v0.24.rst b/doc/whats_new/v0.24.rst index 85b5a12e7b20a..7ca5f59e885c9 100644 --- a/doc/whats_new/v0.24.rst +++ b/doc/whats_new/v0.24.rst @@ -165,6 +165,9 @@ Changelog ``kind`` parameter. :pr:`16619` by :user:`Madhura Jayratne `. +- |Feature| Adds :func:`inspection.plot_decision_boundary` for plotting + decision boundaries. :pr:`16061` by `Thomas Fan`. + :mod:`sklearn.isotonic` ....................... From d504414ef884d7ad29290466c5dd5bce32e911f1 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 6 Aug 2021 15:34:48 +0200 Subject: [PATCH 19/48] API move to new plotting API --- .pre-commit-config.yaml | 1 + doc/modules/classes.rst | 1 - doc/whats_new/v1.0.rst | 4 +- .../plot_classifier_comparison.py | 6 +- examples/cluster/plot_inductive_clustering.py | 7 +- examples/ensemble/plot_adaboost_twoclass.py | 7 +- .../ensemble/plot_voting_decision_regions.py | 7 +- examples/linear_model/plot_iris_logistic.py | 4 +- .../linear_model/plot_logistic_multinomial.py | 7 +- examples/linear_model/plot_sgd_iris.py | 7 +- examples/neighbors/plot_classification.py | 12 +- examples/neighbors/plot_nca_classification.py | 13 +- examples/neighbors/plot_nearest_centroid.py | 7 +- examples/svm/plot_custom_kernel.py | 4 +- examples/svm/plot_iris_svc.py | 7 +- .../svm/plot_linearsvc_support_vectors.py | 10 +- examples/svm/plot_separating_hyperplane.py | 9 +- .../plot_separating_hyperplane_unbalanced.py | 14 +- examples/tree/plot_iris_dtc.py | 7 +- sklearn/inspection/__init__.py | 2 - sklearn/inspection/_plot/decision_boundary.py | 247 +++++++++++------- .../tests/test_plot_decision_boundary.py | 134 ++++++---- 22 files changed, 310 insertions(+), 207 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index db2c5084cbbb2..54c1349263bf8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -9,6 +9,7 @@ repos: rev: 21.6b0 hooks: - id: black + exclude: ^examples - repo: https://gitlab.com/pycqa/flake8 rev: 3.9.2 hooks: diff --git a/doc/modules/classes.rst b/doc/modules/classes.rst index ecc0d2b0d8c62..2da0b6c6a0362 100644 --- a/doc/modules/classes.rst +++ b/doc/modules/classes.rst @@ -664,7 +664,6 @@ Plotting :toctree: generated/ :template: function.rst - inspection.plot_decision_boundary inspection.plot_partial_dependence .. _isotonic_ref: diff --git a/doc/whats_new/v1.0.rst b/doc/whats_new/v1.0.rst index d03dd3e2ebeec..730aa98b64176 100644 --- a/doc/whats_new/v1.0.rst +++ b/doc/whats_new/v1.0.rst @@ -390,8 +390,8 @@ Changelog :mod:`sklearn.inspection` ......................... -- |Feature| Adds :func:`inspection.plot_decision_boundary` for plotting - decision boundaries. +- |Feature| Add a display to plot the boundary decision of a classifier by + using the method :func:`inspection.DecisionBoundaryDisplay.from_estimator`. :pr:`16061` by `Thomas Fan`_. - |Fix| Allow multiple scorers input to diff --git a/examples/classification/plot_classifier_comparison.py b/examples/classification/plot_classifier_comparison.py index a6ea51a484658..18a5c693b1840 100644 --- a/examples/classification/plot_classifier_comparison.py +++ b/examples/classification/plot_classifier_comparison.py @@ -43,7 +43,7 @@ from sklearn.ensemble import RandomForestClassifier, AdaBoostClassifier from sklearn.naive_bayes import GaussianNB from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis -from sklearn.inspection import plot_decision_boundary +from sklearn.inspection import DecisionBoundaryDisplay names = ["Nearest Neighbors", "Linear SVM", "RBF SVM", "Gaussian Process", "Decision Tree", "Random Forest", "Neural Net", "AdaBoost", @@ -108,7 +108,9 @@ ax = plt.subplot(len(datasets), len(classifiers) + 1, i) clf.fit(X_train, y_train) score = clf.score(X_test, y_test) - plot_decision_boundary(clf, X, cmap=cm, alpha=0.8, ax=ax, eps=0.5) + DecisionBoundaryDisplay.from_estimator( + clf, X, cmap=cm, alpha=0.8, ax=ax, eps=0.5 + ) # Plot the training points ax.scatter(X_train[:, 0], X_train[:, 1], c=y_train, cmap=cm_bright, diff --git a/examples/cluster/plot_inductive_clustering.py b/examples/cluster/plot_inductive_clustering.py index ba7df596aff59..6a15073931232 100644 --- a/examples/cluster/plot_inductive_clustering.py +++ b/examples/cluster/plot_inductive_clustering.py @@ -28,7 +28,7 @@ from sklearn.datasets import make_blobs from sklearn.ensemble import RandomForestClassifier from sklearn.utils.metaestimators import if_delegate_has_method -from sklearn.inspection import plot_decision_boundary +from sklearn.inspection import DecisionBoundaryDisplay N_SAMPLES = 5000 RANDOM_STATE = 42 @@ -105,8 +105,9 @@ def plot_scatter(X, color, alpha=0.5): plot_scatter(X_new, probable_clusters) # Plotting decision regions -plot_decision_boundary(inductive_learner, X, response_method='predict', - alpha=0.4, ax=ax) +DecisionBoundaryDisplay.from_estimator( + inductive_learner, X, response_method='predict', alpha=0.4, ax=ax +) plt.title("Classify unknown instances") plt.show() diff --git a/examples/ensemble/plot_adaboost_twoclass.py b/examples/ensemble/plot_adaboost_twoclass.py index 296390abb76d7..9f8a277bd1ef3 100644 --- a/examples/ensemble/plot_adaboost_twoclass.py +++ b/examples/ensemble/plot_adaboost_twoclass.py @@ -28,7 +28,7 @@ from sklearn.ensemble import AdaBoostClassifier from sklearn.tree import DecisionTreeClassifier from sklearn.datasets import make_gaussian_quantiles -from sklearn.inspection import plot_decision_boundary +from sklearn.inspection import DecisionBoundaryDisplay # Construct dataset @@ -56,8 +56,9 @@ # Plot the decision boundaries ax = plt.subplot(121) -disp = plot_decision_boundary(bdt, X, cmap=plt.cm.Paired, - response_method='predict', ax=ax) +disp = DecisionBoundaryDisplay.from_estimator( + bdt, X, cmap=plt.cm.Paired, response_method='predict', ax=ax +) x_min, x_max = disp.xx0.min(), disp.xx0.max() y_min, y_max = disp.xx1.min(), disp.xx1.max() plt.axis("tight") diff --git a/examples/ensemble/plot_voting_decision_regions.py b/examples/ensemble/plot_voting_decision_regions.py index 8684494d2d763..eba3eafe838c4 100644 --- a/examples/ensemble/plot_voting_decision_regions.py +++ b/examples/ensemble/plot_voting_decision_regions.py @@ -33,7 +33,7 @@ from sklearn.neighbors import KNeighborsClassifier from sklearn.svm import SVC from sklearn.ensemble import VotingClassifier -from sklearn.inspection import plot_decision_boundary +from sklearn.inspection import DecisionBoundaryDisplay # Loading some example data iris = datasets.load_iris() @@ -59,8 +59,9 @@ [clf1, clf2, clf3, eclf], ['Decision Tree (depth=4)', 'KNN (k=7)', 'Kernel SVM', 'Soft Voting']): - plot_decision_boundary(clf, X, alpha=0.4, ax=axarr[idx[0], idx[1]], - response_method='predict') + DecisionBoundaryDisplay.from_estimator( + clf, X, alpha=0.4, ax=axarr[idx[0], idx[1]], response_method='predict' + ) axarr[idx[0], idx[1]].scatter(X[:, 0], X[:, 1], c=y, s=20, edgecolor='k') axarr[idx[0], idx[1]].set_title(tt) diff --git a/examples/linear_model/plot_iris_logistic.py b/examples/linear_model/plot_iris_logistic.py index 9ff6a80e037f2..ab7e41195b21f 100644 --- a/examples/linear_model/plot_iris_logistic.py +++ b/examples/linear_model/plot_iris_logistic.py @@ -21,7 +21,7 @@ import matplotlib.pyplot as plt from sklearn.linear_model import LogisticRegression from sklearn import datasets -from sklearn.inspection import plot_decision_boundary +from sklearn.inspection import DecisionBoundaryDisplay # import some data to play with iris = datasets.load_iris() @@ -33,7 +33,7 @@ logreg.fit(X, Y) _, ax = plt.subplots(figsize=(4, 3)) -plot_decision_boundary( +DecisionBoundaryDisplay.from_estimator( logreg, X, cmap=plt.cm.Paired, diff --git a/examples/linear_model/plot_logistic_multinomial.py b/examples/linear_model/plot_logistic_multinomial.py index 6ecc23fc73c0c..56f2f85c9c0e7 100644 --- a/examples/linear_model/plot_logistic_multinomial.py +++ b/examples/linear_model/plot_logistic_multinomial.py @@ -15,7 +15,7 @@ import matplotlib.pyplot as plt from sklearn.datasets import make_blobs from sklearn.linear_model import LogisticRegression -from sklearn.inspection import plot_decision_boundary +from sklearn.inspection import DecisionBoundaryDisplay # make 3-class dataset for classification centers = [[-5, 0], [0, 1.5], [5, -1]] @@ -31,8 +31,9 @@ print("training score : %.3f (%s)" % (clf.score(X, y), multi_class)) _, ax = plt.subplots() - plot_decision_boundary(clf, X, response_method='predict', - cmap=plt.cm.Paired, ax=ax) + DecisionBoundaryDisplay.from_estimator( + clf, X, response_method='predict', cmap=plt.cm.Paired, ax=ax + ) plt.title("Decision surface of LogisticRegression (%s)" % multi_class) plt.axis('tight') diff --git a/examples/linear_model/plot_sgd_iris.py b/examples/linear_model/plot_sgd_iris.py index 35206d83b74d5..1a4a13d86a6eb 100644 --- a/examples/linear_model/plot_sgd_iris.py +++ b/examples/linear_model/plot_sgd_iris.py @@ -14,7 +14,7 @@ import matplotlib.pyplot as plt from sklearn import datasets from sklearn.linear_model import SGDClassifier -from sklearn.inspection import plot_decision_boundary +from sklearn.inspection import DecisionBoundaryDisplay # import some data to play with iris = datasets.load_iris() @@ -39,8 +39,9 @@ clf = SGDClassifier(alpha=0.001, max_iter=100).fit(X, y) ax = plt.gca() -plot_decision_boundary(clf, X, cmap=plt.cm.Paired, ax=ax, - response_method='predict') +DecisionBoundaryDisplay.from_estimator( + clf, X, cmap=plt.cm.Paired, ax=ax, response_method='predict' +) plt.axis('tight') # Plot also the training points diff --git a/examples/neighbors/plot_classification.py b/examples/neighbors/plot_classification.py index 45792741b3904..d7c4f7fbe3df0 100644 --- a/examples/neighbors/plot_classification.py +++ b/examples/neighbors/plot_classification.py @@ -12,7 +12,7 @@ import seaborn as sns from matplotlib.colors import ListedColormap from sklearn import neighbors, datasets -from sklearn.inspection import plot_decision_boundary +from sklearn.inspection import DecisionBoundaryDisplay n_neighbors = 15 @@ -34,8 +34,14 @@ clf.fit(X, y) _, ax = plt.subplots() - plot_decision_boundary(clf, X, cmap=cmap_light, ax=ax, - response_method='predict', plot_method='pcolormesh') + DecisionBoundaryDisplay.from_estimator( + clf, + X, + cmap=cmap_light, + ax=ax, + response_method='predict', + plot_method='pcolormesh' + ) # Plot also the training points sns.scatterplot(x=X[:, 0], y=X[:, 1], hue=iris.target_names[y], diff --git a/examples/neighbors/plot_nca_classification.py b/examples/neighbors/plot_nca_classification.py index 694a865bfd1e9..5f8ad51e7045a 100644 --- a/examples/neighbors/plot_nca_classification.py +++ b/examples/neighbors/plot_nca_classification.py @@ -24,7 +24,7 @@ from sklearn.neighbors import (KNeighborsClassifier, NeighborhoodComponentsAnalysis) from sklearn.pipeline import Pipeline -from sklearn.inspection import plot_decision_boundary +from sklearn.inspection import DecisionBoundaryDisplay print(__doc__) @@ -64,8 +64,15 @@ score = clf.score(X_test, y_test) _, ax = plt.subplots() - plot_decision_boundary(clf, X, cmap=cmap_light, alpha=0.8, ax=ax, - response_method='predict', plot_method='pcolormesh') + DecisionBoundaryDisplay.from_estimator( + clf, + X, + cmap=cmap_light, + alpha=0.8, + ax=ax, + response_method='predict', + plot_method='pcolormesh', + ) # Plot also the training and testing points plt.scatter(X[:, 0], X[:, 1], c=y, cmap=cmap_bold, edgecolor='k', s=20) diff --git a/examples/neighbors/plot_nearest_centroid.py b/examples/neighbors/plot_nearest_centroid.py index b21cdfda7da00..71e7646e31a44 100644 --- a/examples/neighbors/plot_nearest_centroid.py +++ b/examples/neighbors/plot_nearest_centroid.py @@ -13,7 +13,7 @@ from matplotlib.colors import ListedColormap from sklearn import datasets from sklearn.neighbors import NearestCentroid -from sklearn.inspection import plot_decision_boundary +from sklearn.inspection import DecisionBoundaryDisplay n_neighbors = 15 @@ -36,8 +36,9 @@ print(shrinkage, np.mean(y == y_pred)) _, ax = plt.subplots() - plot_decision_boundary(clf, X, cmap=cmap_light, ax=ax, - response_method='predict') + DecisionBoundaryDisplay.from_estimator( + clf, X, cmap=cmap_light, ax=ax, response_method='predict' + ) # Plot also the training points plt.scatter(X[:, 0], X[:, 1], c=y, cmap=cmap_bold, diff --git a/examples/svm/plot_custom_kernel.py b/examples/svm/plot_custom_kernel.py index 6e1a16fecedfd..f66e746fa0605 100644 --- a/examples/svm/plot_custom_kernel.py +++ b/examples/svm/plot_custom_kernel.py @@ -12,7 +12,7 @@ import numpy as np import matplotlib.pyplot as plt from sklearn import svm, datasets -from sklearn.inspection import plot_decision_boundary +from sklearn.inspection import DecisionBoundaryDisplay # import some data to play with iris = datasets.load_iris() @@ -40,7 +40,7 @@ def my_kernel(X, Y): clf.fit(X, Y) ax = plt.gca() -plot_decision_boundary( +DecisionBoundaryDisplay.from_estimator( clf, X, cmap=plt.cm.Paired, diff --git a/examples/svm/plot_iris_svc.py b/examples/svm/plot_iris_svc.py index fa230c7f9ca88..08c43bab098d4 100644 --- a/examples/svm/plot_iris_svc.py +++ b/examples/svm/plot_iris_svc.py @@ -37,7 +37,7 @@ import matplotlib.pyplot as plt from sklearn import svm, datasets -from sklearn.inspection import plot_decision_boundary +from sklearn.inspection import DecisionBoundaryDisplay # import some data to play with @@ -68,8 +68,9 @@ X0, X1 = X[:, 0], X[:, 1] for clf, title, ax in zip(models, titles, sub.flatten()): - disp = plot_decision_boundary(clf, X, response_method='predict', - cmap=plt.cm.coolwarm, alpha=0.8, ax=ax) + disp = DecisionBoundaryDisplay.from_estimator( + clf, X, response_method='predict', cmap=plt.cm.coolwarm, alpha=0.8, ax=ax + ) ax.scatter(X0, X1, c=y, cmap=plt.cm.coolwarm, s=20, edgecolors='k') ax.set_xlabel('Sepal length') ax.set_ylabel('Sepal width') diff --git a/examples/svm/plot_linearsvc_support_vectors.py b/examples/svm/plot_linearsvc_support_vectors.py index da885683e42d3..a4311f302e061 100644 --- a/examples/svm/plot_linearsvc_support_vectors.py +++ b/examples/svm/plot_linearsvc_support_vectors.py @@ -13,7 +13,7 @@ import matplotlib.pyplot as plt from sklearn.datasets import make_blobs from sklearn.svm import LinearSVC -from sklearn.inspection import plot_decision_boundary +from sklearn.inspection import DecisionBoundaryDisplay X, y = make_blobs(n_samples=40, centers=2, random_state=0) @@ -34,10 +34,10 @@ plt.subplot(1, 2, i + 1) plt.scatter(X[:, 0], X[:, 1], c=y, s=30, cmap=plt.cm.Paired) ax = plt.gca() - plot_decision_boundary(clf, X, ax=ax, grid_resolution=50, - plot_method='contour', - colors='k', levels=[-1, 0, 1], alpha=0.5, - linestyles=['--', '-', '--']) + DecisionBoundaryDisplay.from_estimator( + clf, X, ax=ax, grid_resolution=50, plot_method='contour', colors='k', + levels=[-1, 0, 1], alpha=0.5, linestyles=['--', '-', '--'] + ) plt.scatter(support_vectors[:, 0], support_vectors[:, 1], s=100, linewidth=1, facecolors='none', edgecolors='k') plt.title("C=" + str(C)) diff --git a/examples/svm/plot_separating_hyperplane.py b/examples/svm/plot_separating_hyperplane.py index becb5599a9c63..ab6ba2a2e6b29 100644 --- a/examples/svm/plot_separating_hyperplane.py +++ b/examples/svm/plot_separating_hyperplane.py @@ -12,7 +12,7 @@ import matplotlib.pyplot as plt from sklearn import svm from sklearn.datasets import make_blobs -from sklearn.inspection import plot_decision_boundary +from sklearn.inspection import DecisionBoundaryDisplay # we create 40 separable points @@ -26,9 +26,10 @@ # plot the decision function ax = plt.gca() -plot_decision_boundary(clf, X, plot_method='contour', - colors='k', levels=[-1, 0, 1], alpha=0.5, - linestyles=['--', '-', '--'], ax=ax) +DecisionBoundaryDisplay.from_estimator( + clf, X, plot_method='contour', colors='k', levels=[-1, 0, 1], alpha=0.5, + linestyles=['--', '-', '--'], ax=ax +) # plot support vectors ax.scatter(clf.support_vectors_[:, 0], clf.support_vectors_[:, 1], s=100, linewidth=1, facecolors='none', edgecolors='k') diff --git a/examples/svm/plot_separating_hyperplane_unbalanced.py b/examples/svm/plot_separating_hyperplane_unbalanced.py index 039fc5d7341c5..4096410ad1de6 100644 --- a/examples/svm/plot_separating_hyperplane_unbalanced.py +++ b/examples/svm/plot_separating_hyperplane_unbalanced.py @@ -29,7 +29,7 @@ import matplotlib.pyplot as plt from sklearn import svm from sklearn.datasets import make_blobs -from sklearn.inspection import plot_decision_boundary +from sklearn.inspection import DecisionBoundaryDisplay # we create two clusters of random points n_samples_1 = 1000 @@ -54,12 +54,16 @@ # plot the decision functions for both classifiers ax = plt.gca() -disp = plot_decision_boundary(clf, X, plot_method='contour', colors='k', - levels=[0], alpha=0.5, linestyles=['-'], ax=ax) +disp = DecisionBoundaryDisplay.from_estimator( + clf, X, plot_method='contour', colors='k', levels=[0], alpha=0.5, + linestyles=['-'], ax=ax +) # plot decision boundary and margins for weighted classes -wdisp = plot_decision_boundary(wclf, X, plot_method='contour', colors='r', - levels=[0], alpha=0.5, linestyles=['-'], ax=ax) +wdisp = DecisionBoundaryDisplay( + wclf, X, plot_method='contour', colors='r', levels=[0], alpha=0.5, + linestyles=['-'], ax=ax +) plt.legend([disp.surface_.collections[0], wdisp.surface_.collections[0]], ["non weighted", "weighted"], diff --git a/examples/tree/plot_iris_dtc.py b/examples/tree/plot_iris_dtc.py index 1cf9f718e9413..d65eb959123a3 100644 --- a/examples/tree/plot_iris_dtc.py +++ b/examples/tree/plot_iris_dtc.py @@ -21,7 +21,7 @@ from sklearn.datasets import load_iris from sklearn.tree import DecisionTreeClassifier, plot_tree -from sklearn.inspection import plot_decision_boundary +from sklearn.inspection import DecisionBoundaryDisplay # Parameters n_classes = 3 @@ -43,8 +43,9 @@ # Plot the decision boundary ax = plt.subplot(2, 3, pairidx + 1) plt.tight_layout(h_pad=0.5, w_pad=0.5, pad=2.5) - plot_decision_boundary(clf, X, cmap=plt.cm.RdYlBu, - response_method='predict', ax=ax) + DecisionBoundaryDisplay.from_estimator( + clf, X, cmap=plt.cm.RdYlBu, response_method='predict', ax=ax + ) plt.xlabel(iris.feature_names[pair[0]]) plt.ylabel(iris.feature_names[pair[1]]) diff --git a/sklearn/inspection/__init__.py b/sklearn/inspection/__init__.py index 8068c59e8d74f..76c44ea81bbbe 100644 --- a/sklearn/inspection/__init__.py +++ b/sklearn/inspection/__init__.py @@ -3,7 +3,6 @@ from ._permutation_importance import permutation_importance from ._plot.decision_boundary import DecisionBoundaryDisplay -from ._plot.decision_boundary import plot_decision_boundary from ._partial_dependence import partial_dependence from ._plot.partial_dependence import plot_partial_dependence @@ -15,6 +14,5 @@ "plot_partial_dependence", "permutation_importance", "PartialDependenceDisplay", - "plot_decision_boundary", "DecisionBoundaryDisplay", ] diff --git a/sklearn/inspection/_plot/decision_boundary.py b/sklearn/inspection/_plot/decision_boundary.py index abe0196d2d11c..7b95609fe3e33 100644 --- a/sklearn/inspection/_plot/decision_boundary.py +++ b/sklearn/inspection/_plot/decision_boundary.py @@ -1,15 +1,16 @@ import numpy as np + from ...utils import check_matplotlib_support from ...utils import _safe_indexing def _check_boundary_response_method(estimator, response_method): - """Return prediction method from the response_method for decision boundary + """Return prediction method from the response_method for decision boundary. Parameters ---------- estimator: object - Estimator to check + Estimator to check. response_method: {'auto', 'predict_proba', 'decision_function', 'predict'} Specifies whether to use :term:`predict_proba`, @@ -20,41 +21,53 @@ def _check_boundary_response_method(estimator, response_method): Returns ------- prediction_method: callable - prediction method of estimator + Prediction method of estimator. """ - if response_method not in ("predict_proba", "decision_function", - "auto", "predict"): - raise ValueError("response_method must be 'predict_proba', " - "'decision_function', 'predict', or 'auto'") + possible_response_methods = ( + "predict_proba", + "decision_function", + "auto", + "predict", + ) + if response_method not in possible_response_methods: + raise ValueError( + f"response_method must be one of {', '.join(possible_response_methods)}" + ) error_msg = "response method {} is not defined in {}" if response_method != "auto": if not hasattr(estimator, response_method): - raise ValueError(error_msg.format(response_method, - estimator.__class__.__name__)) + raise ValueError( + error_msg.format(response_method, estimator.__class__.__name__) + ) return getattr(estimator, response_method) - elif hasattr(estimator, 'decision_function'): - return getattr(estimator, 'decision_function') - elif hasattr(estimator, 'predict_proba'): - return getattr(estimator, 'predict_proba') - elif hasattr(estimator, 'predict'): - return getattr(estimator, 'predict') + elif hasattr(estimator, "decision_function"): + return getattr(estimator, "decision_function") + elif hasattr(estimator, "predict_proba"): + return getattr(estimator, "predict_proba") + elif hasattr(estimator, "predict"): + return getattr(estimator, "predict") - raise ValueError(error_msg.format( - "decision_function, predict_proba, or predict", - estimator.__class__.__name__)) + raise ValueError( + error_msg.format( + "decision_function, predict_proba, or predict", estimator.__class__.__name__ + ) + ) class DecisionBoundaryDisplay: - """Decisions Boundary visualization. + """Decisions boundary visualization. - It is recommend to use :func:`~sklearn.inspection.plot_decision_boundary` + It is recommend to use + :func:`~sklearn.inspection.DecisionBoundaryDisplay.from_estimator` to create a :class:`DecisionBoundaryDisplay`. All parameters are stored as attributes. Read more in the :ref:`User Guide `. + .. versionadded:: 1.0 + Parameters ---------- xx0 : ndarray of shape (grid_resolution, grid_resolution) @@ -86,6 +99,7 @@ class DecisionBoundaryDisplay: figure_ : matplotlib Figure Figure containing the confusion matrix. """ + def __init__(self, xx0, xx1, response, xlabel=None, ylabel=None): self.xx0 = xx0 self.xx1 = xx1 @@ -93,7 +107,7 @@ def __init__(self, xx0, xx1, response, xlabel=None, ylabel=None): self.xlabel = xlabel self.ylabel = ylabel - def plot(self, plot_method='contourf', ax=None, **kwargs): + def plot(self, plot_method="contourf", ax=None, **kwargs): """Plot visualization. Parameters @@ -116,12 +130,13 @@ def plot(self, plot_method='contourf', ax=None, **kwargs): ------- display: :class:`~sklearn.inspection.DecisionBoundaryDisplay` """ - check_matplotlib_support('DecisionBoundaryDisplay.plot') + check_matplotlib_support("DecisionBoundaryDisplay.plot") import matplotlib.pyplot as plt # noqa - if plot_method not in ('contourf', 'contour', 'pcolormesh'): - raise ValueError("plot_method must be 'contourf', 'contour', or " - "'pcolormesh'") + if plot_method not in ("contourf", "contour", "pcolormesh"): + raise ValueError( + "plot_method must be 'contourf', 'contour', or 'pcolormesh'" + ) if ax is None: _, ax = plt.subplots() @@ -138,89 +153,121 @@ def plot(self, plot_method='contourf', ax=None, **kwargs): self.figure_ = ax.figure return self + @classmethod + def from_estimator( + cls, + estimator, + X, + *, + grid_resolution=100, + eps=1.0, + plot_method="contourf", + response_method="auto", + ax=None, + **kwargs, + ): + """Plot decision boundary given an estimator. + + Read more in the :ref:`User Guide `. + + .. versionadded:: 1.0 -def plot_decision_boundary(estimator, X, grid_resolution=100, eps=1.0, - plot_method='contourf', response_method='auto', - ax=None, **kwargs): - """Plot Decision Boundary. - - Please see examples below for usage. - - Parameters - ---------- - estimator : estimator instance - Trained estimator. - - X : ndarray or pandas dataframe of shape (n_samples, 2) - Input values. - - grid_resolution : int, default=100 - The number of equally spaced points to evaluate the response function. - - eps : float, default=1.0 - Extends the minimum and maximum values of X for evaluating the - response function. - - plot_method : {'contourf', 'contour', 'pcolormesh'}, default='contourf' - Plotting method to call when plotting the response. Please refer - to the following matplotlib documentation for details: - :func:`contourf `, - :func:`contour `, - :func:`pcolomesh `. - - response_method : {'auto', 'predict_proba', 'decision_function', \ - 'predict'}, defaul='auto' - Specifies whether to use :term:`predict_proba`, - :term:`decision_function`, :term:`predict` as the target response. - If set to 'auto', the response method is tried in the following order: - :term:`predict_proba`, :term:`decision_function`, :term:`predict`. - - ax : Matplotlib axes, default=None - Axes object to plot on. If `None`, a new figure and axes is - created. - - **kwargs : dict - Additional keyword arguments to be pased to the `plot_method`. - - Returns - ------- - display: :class:`~sklearn.inspection.DecisionBoundaryDisplay` - """ - check_matplotlib_support('plot_decision_boundary') + Parameters + ---------- + estimator : object + Trained estimator used to plot the decision boundary. - if not grid_resolution > 1: - raise ValueError("grid_resolution must be greater than 1") + X : {array-like, sparse matrix, dataframe} of shape (n_samples, 2) + Input data that should be only 2-dimensional. - if not eps >= 0: - raise ValueError("eps must be greater than or equal to 0") + grid_resolution : int, default=100 + Number of grid points to use for plotting decision boundary. + Higher values will make the plot look nicer but be slower to + render. - if plot_method not in ('contourf', 'contour', 'pcolormesh'): - raise ValueError("plot_method must be 'contourf', 'contour', or " - "'pcolormesh'") + eps : float, default=1.0 + Extends the minimum and maximum values of X for evaluating the + response function. - pred_func = _check_boundary_response_method(estimator, response_method) + plot_method : {'contourf', 'contour', 'pcolormesh'}, default='contourf' + Plotting method to call when plotting the response. Please refer + to the following matplotlib documentation for details: + :func:`contourf `, + :func:`contour `, + :func:`pcolomesh `. - x0, x1 = _safe_indexing(X, 0, axis=1), _safe_indexing(X, 1, axis=1) + response_method : {'auto', 'predict_proba', 'decision_function', \ + 'predict'}, default='auto' + Specifies whether to use :term:`predict_proba`, + :term:`decision_function`, :term:`predict` as the target response. + If set to 'auto', the response method is tried in the following order: + :term:`predict_proba`, :term:`decision_function`, :term:`predict`. - x0_min, x0_max = x0.min() - eps, x0.max() + eps - x1_min, x1_max = x1.min() - eps, x1.max() + eps + ax : Matplotlib axes, default=None + Axes object to plot on. If `None`, a new figure and axes is + created. - xx0, xx1 = np.meshgrid(np.linspace(x0_min, x0_max, grid_resolution), - np.linspace(x1_min, x1_max, grid_resolution)) - response = pred_func(np.c_[xx0.ravel(), xx1.ravel()]) + **kwargs : dict + Additional keyword arguments to be passed to the + `plot_method`. - if response.ndim != 1: - if response.shape[1] != 2: - raise ValueError("multiclass classifiers are only supported when " - "response_method='predict'") - response = response[:, 1] + Returns + ------- + display : :class:`~sklearn.inspection.DecisionBoundaryDisplay` + Object that stores the result. - if hasattr(X, "columns"): - xlabel, ylabel = X.columns[0], X.columns[1] - else: - xlabel, ylabel = "", "" + See Also + -------- - display = DecisionBoundaryDisplay(xx0=xx0, xx1=xx1, - response=response.reshape(xx0.shape), - xlabel=xlabel, ylabel=ylabel) - return display.plot(ax=ax, plot_method=plot_method, **kwargs) + Examples + -------- + """ + check_matplotlib_support(f"{cls.__name__}.from_estimator") + + if not grid_resolution > 1: + raise ValueError("grid_resolution must be greater than 1") + + if not eps >= 0: + raise ValueError("eps must be greater than or equal to 0") + + possible_plot_method = ("contourf", "contour", "pcolormesh") + if plot_method not in possible_plot_method: + raise ValueError( + f"plot_method must be one of {', '.join(possible_plot_method)}. " + f"Got {plot_method} instead." + ) + + x0, x1 = _safe_indexing(X, 0, axis=1), _safe_indexing(X, 1, axis=1) + + x0_min, x0_max = x0.min() - eps, x0.max() + eps + x1_min, x1_max = x1.min() - eps, x1.max() + eps + + xx0, xx1 = np.meshgrid( + np.linspace(x0_min, x0_max, grid_resolution), + np.linspace(x1_min, x1_max, grid_resolution), + ) + + pred_func = _check_boundary_response_method(estimator, response_method) + response = pred_func(np.c_[xx0.ravel(), xx1.ravel()]) + + if response.ndim != 1: + if response.shape[1] != 2: + raise ValueError( + "Multiclass classifiers are only supported when " + "response_method='predict'" + ) + response = response[:, 1] + + if hasattr(X, "columns"): + xlabel, ylabel = X.columns[0], X.columns[1] + else: + xlabel, ylabel = "", "" + + display = DecisionBoundaryDisplay( + xx0=xx0, + xx1=xx1, + response=response.reshape(xx0.shape), + xlabel=xlabel, + ylabel=ylabel, + ) + return display.plot(ax=ax, plot_method=plot_method, **kwargs) diff --git a/sklearn/inspection/_plot/tests/test_plot_decision_boundary.py b/sklearn/inspection/_plot/tests/test_plot_decision_boundary.py index b118f520ca1ca..d4f7c396dd7da 100644 --- a/sklearn/inspection/_plot/tests/test_plot_decision_boundary.py +++ b/sklearn/inspection/_plot/tests/test_plot_decision_boundary.py @@ -2,24 +2,29 @@ from sklearn.base import BaseEstimator from sklearn.base import ClassifierMixin -from sklearn.inspection import plot_decision_boundary from sklearn.datasets import make_classification from sklearn.linear_model import LogisticRegression -from sklearn.inspection._plot.decision_boundary import ( - _check_boundary_response_method) + +from sklearn.inspection import DecisionBoundaryDisplay +from sklearn.inspection._plot.decision_boundary import _check_boundary_response_method # TODO: Remove when https://github.com/numpy/numpy/issues/14397 is resolved pytestmark = pytest.mark.filterwarnings( "ignore:In future, it will be an error for 'np.bool_':DeprecationWarning:" - "matplotlib.*") + "matplotlib.*" +) @pytest.fixture(scope="module") def data(): - X, y = make_classification(n_informative=1, n_redundant=1, - n_clusters_per_class=1, n_features=2, - random_state=42) + X, y = make_classification( + n_informative=1, + n_redundant=1, + n_clusters_per_class=1, + n_features=2, + random_state=42, + ) return X, y @@ -34,14 +39,15 @@ def decision_function(self): pass a_inst = A() - method = _check_boundary_response_method(a_inst, 'auto') + method = _check_boundary_response_method(a_inst, "auto") assert method == a_inst.decision_function class B: def predict_proba(self): pass + b_inst = B() - method = _check_boundary_response_method(b_inst, 'auto') + method = _check_boundary_response_method(b_inst, "auto") assert method == b_inst.predict_proba class C: @@ -50,68 +56,79 @@ def predict_proba(self): def decision_function(self): pass + c_inst = C() - method = _check_boundary_response_method(c_inst, 'auto') + method = _check_boundary_response_method(c_inst, "auto") assert method == c_inst.decision_function class D: def predict(self): pass + d_inst = D() - method = _check_boundary_response_method(d_inst, 'auto') + method = _check_boundary_response_method(d_inst, "auto") assert method == d_inst.predict -@pytest.mark.parametrize("response_method", - ['auto', 'predict_proba', 'decision_function']) +@pytest.mark.parametrize( + "response_method", ["auto", "predict_proba", "decision_function"] +) def test_multiclass_error(pyplot, response_method): X, y = make_classification(n_classes=3, n_informative=3, random_state=0) X = X[:, [0, 1]] lr = LogisticRegression().fit(X, y) - msg = ("multiclass classifiers are only supported when " - "response_method='predict'") + msg = "Multiclass classifiers are only supported when response_method='predict'" with pytest.raises(ValueError, match=msg): - plot_decision_boundary(lr, X, response_method=response_method) - - -@pytest.mark.parametrize("kwargs, error_msg", [ - ({"plot_method": "hello_world"}, - r"plot_method must be 'contourf',"), - ({"grid_resolution": 1}, - r"grid_resolution must be greater than 1"), - ({"grid_resolution": -1}, - r"grid_resolution must be greater than 1"), - ({"eps": -1.1}, - r"eps must be greater than or equal to 0") -]) + DecisionBoundaryDisplay.from_estimator(lr, X, response_method=response_method) + + +@pytest.mark.parametrize( + "kwargs, error_msg", + [ + ( + {"plot_method": "hello_world"}, + r"plot_method must be one of contourf, contour, pcolormesh. Got hello_world" + r" instead.", + ), + ({"grid_resolution": 1}, r"grid_resolution must be greater than 1"), + ({"grid_resolution": -1}, r"grid_resolution must be greater than 1"), + ({"eps": -1.1}, r"eps must be greater than or equal to 0"), + ], +) def test_input_validation_errors(pyplot, kwargs, error_msg, fitted_clf, data): X, _ = data with pytest.raises(ValueError, match=error_msg): - plot_decision_boundary(fitted_clf, X, **kwargs) + DecisionBoundaryDisplay.from_estimator(fitted_clf, X, **kwargs) def test_display_plot_input_error(pyplot, fitted_clf, data): X, y = data - disp = plot_decision_boundary(fitted_clf, X, grid_resolution=5) + disp = DecisionBoundaryDisplay.from_estimator(fitted_clf, X, grid_resolution=5) with pytest.raises(ValueError, match="plot_method must be 'contourf'"): disp.plot(plot_method="hello_world") -@pytest.mark.parametrize("response_method", ['auto', 'predict', - 'predict_proba', - 'decision_function']) -@pytest.mark.parametrize("plot_method", ['contourf', 'contour']) -def test_plot_decision_boundary(pyplot, fitted_clf, data, - response_method, plot_method): +@pytest.mark.parametrize( + "response_method", ["auto", "predict", "predict_proba", "decision_function"] +) +@pytest.mark.parametrize("plot_method", ["contourf", "contour"]) +def test_decision_boundary_display( + pyplot, fitted_clf, data, response_method, plot_method +): fig, ax = pyplot.subplots() eps = 2.0 X, y = data - disp = plot_decision_boundary(fitted_clf, X, grid_resolution=5, - response_method=response_method, - plot_method=plot_method, - eps=eps, ax=ax) + disp = DecisionBoundaryDisplay.from_estimator( + fitted_clf, + X, + grid_resolution=5, + response_method=response_method, + plot_method=plot_method, + eps=eps, + ax=ax, + ) assert isinstance(disp.surface_, pyplot.matplotlib.contour.QuadContourSet) assert disp.ax_ == ax assert disp.figure_ == fig @@ -128,7 +145,7 @@ def test_plot_decision_boundary(pyplot, fitted_clf, data, fig2, ax2 = pyplot.subplots() # change plotting method for second plot - disp.plot(plot_method='pcolormesh', ax=ax2) + disp.plot(plot_method="pcolormesh", ax=ax2) assert isinstance(disp.surface_, pyplot.matplotlib.collections.QuadMesh) assert disp.ax_ == ax2 assert disp.figure_ == fig2 @@ -136,14 +153,27 @@ def test_plot_decision_boundary(pyplot, fitted_clf, data, @pytest.mark.parametrize( "response_method, msg", - [("predict_proba", "response method predict_proba is not defined in " - "MyClassifier"), - ("decision_function", "response method decision_function is not defined " - "in MyClassifier"), - ("auto", "response method decision_function, predict_proba, or predict " - "is not defined in MyClassifier"), - ("bad_method", "response_method must be 'predict_proba', " - "'decision_function', 'predict', or 'auto'")]) + [ + ( + "predict_proba", + "response method predict_proba is not defined in MyClassifier", + ), + ( + "decision_function", + "response method decision_function is not defined in MyClassifier", + ), + ( + "auto", + "response method decision_function, predict_proba, or predict " + "is not defined in MyClassifier", + ), + ( + "bad_method", + "response_method must be one of predict_proba, decision_function, auto," + " predict", + ), + ], +) def test_error_bad_response(pyplot, response_method, msg, data): X, y = data @@ -156,16 +186,16 @@ def fit(self, X, y): clf = MyClassifier().fit(X, y) with pytest.raises(ValueError, match=msg): - plot_decision_boundary(clf, X, response_method=response_method) + DecisionBoundaryDisplay.from_estimator(clf, X, response_method=response_method) def test_dataframe_labels_used(pyplot, data, fitted_clf): pd = pytest.importorskip("pandas") - df = pd.DataFrame(data[0], columns=['col_x', 'col_y']) + df = pd.DataFrame(data[0], columns=["col_x", "col_y"]) # pandas column names are used by default _, ax = pyplot.subplots() - disp = plot_decision_boundary(fitted_clf, df, ax=ax) + disp = DecisionBoundaryDisplay.from_estimator(fitted_clf, df, ax=ax) assert ax.get_xlabel() == "col_x" assert ax.get_ylabel() == "col_y" From 7fd218106a435336108bc097168215a9e61b6285 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 6 Aug 2021 16:35:50 +0200 Subject: [PATCH 20/48] iter --- examples/svm/plot_separating_hyperplane_unbalanced.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/svm/plot_separating_hyperplane_unbalanced.py b/examples/svm/plot_separating_hyperplane_unbalanced.py index 4096410ad1de6..8336adca97380 100644 --- a/examples/svm/plot_separating_hyperplane_unbalanced.py +++ b/examples/svm/plot_separating_hyperplane_unbalanced.py @@ -60,7 +60,7 @@ ) # plot decision boundary and margins for weighted classes -wdisp = DecisionBoundaryDisplay( +wdisp = DecisionBoundaryDisplay.from_estimator( wclf, X, plot_method='contour', colors='r', levels=[0], alpha=0.5, linestyles=['-'], ax=ax ) From a69315ae8f07e25d2ebe921c934c779776b0143d Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 6 Aug 2021 16:59:34 +0200 Subject: [PATCH 21/48] rename file and avoid warning --- examples/linear_model/plot_iris_logistic.py | 1 + examples/svm/plot_custom_kernel.py | 1 + sklearn/inspection/_plot/decision_boundary.py | 20 +++++++++++++------ ...y.py => test_boundary_decision_display.py} | 2 +- 4 files changed, 17 insertions(+), 7 deletions(-) rename sklearn/inspection/_plot/tests/{test_plot_decision_boundary.py => test_boundary_decision_display.py} (98%) diff --git a/examples/linear_model/plot_iris_logistic.py b/examples/linear_model/plot_iris_logistic.py index ab7e41195b21f..bff11201aed16 100644 --- a/examples/linear_model/plot_iris_logistic.py +++ b/examples/linear_model/plot_iris_logistic.py @@ -40,6 +40,7 @@ ax=ax, response_method="predict", plot_method="pcolormesh", + shading="auto", ) # Plot also the training points diff --git a/examples/svm/plot_custom_kernel.py b/examples/svm/plot_custom_kernel.py index f66e746fa0605..092a678e99f52 100644 --- a/examples/svm/plot_custom_kernel.py +++ b/examples/svm/plot_custom_kernel.py @@ -47,6 +47,7 @@ def my_kernel(X, Y): ax=ax, response_method="predict", plot_method="pcolormesh", + shading="auto", ) # Plot also the training points diff --git a/sklearn/inspection/_plot/decision_boundary.py b/sklearn/inspection/_plot/decision_boundary.py index 7b95609fe3e33..7638c66086ebd 100644 --- a/sklearn/inspection/_plot/decision_boundary.py +++ b/sklearn/inspection/_plot/decision_boundary.py @@ -83,7 +83,7 @@ class DecisionBoundaryDisplay: Default label to place on x axis. ylabel : str, default="" - DEfault label to place on y axis. + Default label to place on y axis. Attributes ---------- @@ -100,14 +100,14 @@ class DecisionBoundaryDisplay: Figure containing the confusion matrix. """ - def __init__(self, xx0, xx1, response, xlabel=None, ylabel=None): + def __init__(self, *, xx0, xx1, response, xlabel=None, ylabel=None): self.xx0 = xx0 self.xx1 = xx1 self.response = response self.xlabel = xlabel self.ylabel = ylabel - def plot(self, plot_method="contourf", ax=None, **kwargs): + def plot(self, plot_method="contourf", ax=None, xlabel=None, ylabel=None, **kwargs): """Plot visualization. Parameters @@ -123,8 +123,14 @@ def plot(self, plot_method="contourf", ax=None, **kwargs): Axes object to plot on. If `None`, a new figure and axes is created. + xlabel : str, default=None + Overwrite the x-axis label. + + ylabel : str, default=None + Overwrite the y-axis label. + **kwargs : dict - Additional keyword arguments to be pased to the `plot_method`. + Additional keyword arguments to be passed to the `plot_method`. Returns ------- @@ -144,10 +150,12 @@ def plot(self, plot_method="contourf", ax=None, **kwargs): plot_func = getattr(ax, plot_method) self.surface_ = plot_func(self.xx0, self.xx1, self.response, **kwargs) + xlabel = self.xlabel if xlabel is None else xlabel + ylabel = self.ylabel if ylabel is None else ylabel if not ax.get_xlabel(): - ax.set_xlabel(self.xlabel) + ax.set_xlabel(xlabel) if not ax.get_ylabel(): - ax.set_ylabel(self.ylabel) + ax.set_ylabel(ylabel) self.ax_ = ax self.figure_ = ax.figure diff --git a/sklearn/inspection/_plot/tests/test_plot_decision_boundary.py b/sklearn/inspection/_plot/tests/test_boundary_decision_display.py similarity index 98% rename from sklearn/inspection/_plot/tests/test_plot_decision_boundary.py rename to sklearn/inspection/_plot/tests/test_boundary_decision_display.py index d4f7c396dd7da..32905ecf49b72 100644 --- a/sklearn/inspection/_plot/tests/test_plot_decision_boundary.py +++ b/sklearn/inspection/_plot/tests/test_boundary_decision_display.py @@ -145,7 +145,7 @@ def test_decision_boundary_display( fig2, ax2 = pyplot.subplots() # change plotting method for second plot - disp.plot(plot_method="pcolormesh", ax=ax2) + disp.plot(plot_method="pcolormesh", ax=ax2, shading="auto") assert isinstance(disp.surface_, pyplot.matplotlib.collections.QuadMesh) assert disp.ax_ == ax2 assert disp.figure_ == fig2 From 8235a74bf07f02edf1f8e019592e14995bc040a6 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 6 Aug 2021 17:10:29 +0200 Subject: [PATCH 22/48] FEA allow to set x-/y-label --- sklearn/inspection/_plot/decision_boundary.py | 31 ++++++++++++++----- .../tests/test_boundary_decision_display.py | 13 ++++++++ 2 files changed, 37 insertions(+), 7 deletions(-) diff --git a/sklearn/inspection/_plot/decision_boundary.py b/sklearn/inspection/_plot/decision_boundary.py index 7638c66086ebd..8b60d52a5dbba 100644 --- a/sklearn/inspection/_plot/decision_boundary.py +++ b/sklearn/inspection/_plot/decision_boundary.py @@ -150,11 +150,11 @@ def plot(self, plot_method="contourf", ax=None, xlabel=None, ylabel=None, **kwar plot_func = getattr(ax, plot_method) self.surface_ = plot_func(self.xx0, self.xx1, self.response, **kwargs) - xlabel = self.xlabel if xlabel is None else xlabel - ylabel = self.ylabel if ylabel is None else ylabel - if not ax.get_xlabel(): + if xlabel is not None or not ax.get_xlabel(): + xlabel = self.xlabel if xlabel is None else xlabel ax.set_xlabel(xlabel) - if not ax.get_ylabel(): + if ylabel is not None or not ax.get_ylabel(): + ylabel = self.ylabel if ylabel is None else ylabel ax.set_ylabel(ylabel) self.ax_ = ax @@ -171,6 +171,8 @@ def from_estimator( eps=1.0, plot_method="contourf", response_method="auto", + xlabel=None, + ylabel=None, ax=None, **kwargs, ): @@ -211,6 +213,16 @@ def from_estimator( If set to 'auto', the response method is tried in the following order: :term:`predict_proba`, :term:`decision_function`, :term:`predict`. + xlabel : str, default=None + The label used for the x-axis. If `None`, an attempt is made to + extract a label from `X` if it is a dataframe, otherwise an empty + string is used. + + ylabel : str, default=None + The label used for the y-axis. If `None`, an attempt is made to + extract a label from `X` if it is a dataframe, otherwise an empty + string is used. + ax : Matplotlib axes, default=None Axes object to plot on. If `None`, a new figure and axes is created. @@ -266,10 +278,15 @@ def from_estimator( ) response = response[:, 1] - if hasattr(X, "columns"): - xlabel, ylabel = X.columns[0], X.columns[1] + if xlabel is not None: + xlabel = xlabel + else: + xlabel = X.columns[0] if hasattr(X, "columns") else "" + + if ylabel is not None: + ylabel = ylabel else: - xlabel, ylabel = "", "" + ylabel = X.columns[1] if hasattr(X, "columns") else "" display = DecisionBoundaryDisplay( xx0=xx0, diff --git a/sklearn/inspection/_plot/tests/test_boundary_decision_display.py b/sklearn/inspection/_plot/tests/test_boundary_decision_display.py index 32905ecf49b72..1730439781f5a 100644 --- a/sklearn/inspection/_plot/tests/test_boundary_decision_display.py +++ b/sklearn/inspection/_plot/tests/test_boundary_decision_display.py @@ -211,3 +211,16 @@ def test_dataframe_labels_used(pyplot, data, fitted_clf): disp.plot(ax=ax) assert ax.get_xlabel() == "hello" assert ax.get_ylabel() == "world" + + # labels get overriden only if provided to the `plot` method + disp.plot(ax=ax, xlabel="overwritten_x", ylabel="overwritten_y") + assert ax.get_xlabel() == "overwritten_x" + assert ax.get_ylabel() == "overwritten_y" + + # labels do not get inferred if provided to `from_estimator` + _, ax = pyplot.subplots() + disp = DecisionBoundaryDisplay.from_estimator( + fitted_clf, df, ax=ax, xlabel="overwritten_x", ylabel="overwritten_y" + ) + assert ax.get_xlabel() == "overwritten_x" + assert ax.get_ylabel() == "overwritten_y" From 5c0273fc1427ba2a7715cf8ef856d9522ea17049 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 6 Aug 2021 17:41:24 +0200 Subject: [PATCH 23/48] DOC fix --- doc/visualizations.rst | 1 - 1 file changed, 1 deletion(-) diff --git a/doc/visualizations.rst b/doc/visualizations.rst index 434b9f82e05c4..84dcaf82a47f6 100644 --- a/doc/visualizations.rst +++ b/doc/visualizations.rst @@ -77,7 +77,6 @@ Functions .. autosummary:: inspection.plot_partial_dependence - inspection.plot_decision_boundary metrics.plot_confusion_matrix metrics.plot_det_curve metrics.plot_precision_recall_curve From f77997274d8d9b52f7dcc25f92eca17d56fc4d5a Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 6 Aug 2021 17:55:25 +0200 Subject: [PATCH 24/48] avoid to set axis outside display --- examples/ensemble/plot_adaboost_twoclass.py | 6 +++--- examples/linear_model/plot_iris_logistic.py | 5 +++-- examples/linear_model/plot_sgd_iris.py | 3 ++- examples/neighbors/plot_classification.py | 7 ++++--- examples/neighbors/plot_nca_classification.py | 1 + examples/svm/plot_iris_svc.py | 5 ++--- examples/tree/plot_iris_dtc.py | 6 ++---- 7 files changed, 17 insertions(+), 16 deletions(-) diff --git a/examples/ensemble/plot_adaboost_twoclass.py b/examples/ensemble/plot_adaboost_twoclass.py index 9f8a277bd1ef3..9807638ba8eea 100644 --- a/examples/ensemble/plot_adaboost_twoclass.py +++ b/examples/ensemble/plot_adaboost_twoclass.py @@ -57,7 +57,8 @@ # Plot the decision boundaries ax = plt.subplot(121) disp = DecisionBoundaryDisplay.from_estimator( - bdt, X, cmap=plt.cm.Paired, response_method='predict', ax=ax + bdt, X, cmap=plt.cm.Paired, response_method='predict', ax=ax, xlabel="x", + ylabel="y", ) x_min, x_max = disp.xx0.min(), disp.xx0.max() y_min, y_max = disp.xx1.min(), disp.xx1.max() @@ -73,8 +74,7 @@ plt.xlim(x_min, x_max) plt.ylim(y_min, y_max) plt.legend(loc='upper right') -plt.xlabel('x') -plt.ylabel('y') + plt.title('Decision Boundary') # Plot the two-class decision scores diff --git a/examples/linear_model/plot_iris_logistic.py b/examples/linear_model/plot_iris_logistic.py index bff11201aed16..d5ed07b384c47 100644 --- a/examples/linear_model/plot_iris_logistic.py +++ b/examples/linear_model/plot_iris_logistic.py @@ -41,12 +41,13 @@ response_method="predict", plot_method="pcolormesh", shading="auto", + xlabel="Sepal length", + ylabel="Sepal width", ) # Plot also the training points plt.scatter(X[:, 0], X[:, 1], c=Y, edgecolors="k", cmap=plt.cm.Paired) -plt.xlabel("Sepal length") -plt.ylabel("Sepal width") + plt.xticks(()) plt.yticks(()) diff --git a/examples/linear_model/plot_sgd_iris.py b/examples/linear_model/plot_sgd_iris.py index 1a4a13d86a6eb..d51ecf80000e9 100644 --- a/examples/linear_model/plot_sgd_iris.py +++ b/examples/linear_model/plot_sgd_iris.py @@ -40,7 +40,8 @@ clf = SGDClassifier(alpha=0.001, max_iter=100).fit(X, y) ax = plt.gca() DecisionBoundaryDisplay.from_estimator( - clf, X, cmap=plt.cm.Paired, ax=ax, response_method='predict' + clf, X, cmap=plt.cm.Paired, ax=ax, response_method='predict', + xlabel=iris.feature_names[0], ylabel=iris.feature_names[1], ) plt.axis('tight') diff --git a/examples/neighbors/plot_classification.py b/examples/neighbors/plot_classification.py index d7c4f7fbe3df0..a2878a127bb2b 100644 --- a/examples/neighbors/plot_classification.py +++ b/examples/neighbors/plot_classification.py @@ -40,7 +40,10 @@ cmap=cmap_light, ax=ax, response_method='predict', - plot_method='pcolormesh' + plot_method='pcolormesh', + xlabel=iris.feature_names[0], + ylabel=iris.feature_names[1], + shading="auto", ) # Plot also the training points @@ -48,7 +51,5 @@ palette=cmap_bold, alpha=1.0, edgecolor="black") plt.title("3-Class classification (k = %i, weights = '%s')" % (n_neighbors, weights)) - plt.xlabel(iris.feature_names[0]) - plt.ylabel(iris.feature_names[1]) plt.show() diff --git a/examples/neighbors/plot_nca_classification.py b/examples/neighbors/plot_nca_classification.py index 5f8ad51e7045a..e958b27e4dace 100644 --- a/examples/neighbors/plot_nca_classification.py +++ b/examples/neighbors/plot_nca_classification.py @@ -72,6 +72,7 @@ ax=ax, response_method='predict', plot_method='pcolormesh', + shading="auto", ) # Plot also the training and testing points diff --git a/examples/svm/plot_iris_svc.py b/examples/svm/plot_iris_svc.py index 08c43bab098d4..6b4a1a5d440e9 100644 --- a/examples/svm/plot_iris_svc.py +++ b/examples/svm/plot_iris_svc.py @@ -69,11 +69,10 @@ for clf, title, ax in zip(models, titles, sub.flatten()): disp = DecisionBoundaryDisplay.from_estimator( - clf, X, response_method='predict', cmap=plt.cm.coolwarm, alpha=0.8, ax=ax + clf, X, response_method='predict', cmap=plt.cm.coolwarm, alpha=0.8, ax=ax, + xlabel=iris.feature_names[0], ylabel=iris.feature_names[1] ) ax.scatter(X0, X1, c=y, cmap=plt.cm.coolwarm, s=20, edgecolors='k') - ax.set_xlabel('Sepal length') - ax.set_ylabel('Sepal width') ax.set_xticks(()) ax.set_yticks(()) ax.set_title(title) diff --git a/examples/tree/plot_iris_dtc.py b/examples/tree/plot_iris_dtc.py index d65eb959123a3..691232c4ea176 100644 --- a/examples/tree/plot_iris_dtc.py +++ b/examples/tree/plot_iris_dtc.py @@ -44,12 +44,10 @@ ax = plt.subplot(2, 3, pairidx + 1) plt.tight_layout(h_pad=0.5, w_pad=0.5, pad=2.5) DecisionBoundaryDisplay.from_estimator( - clf, X, cmap=plt.cm.RdYlBu, response_method='predict', ax=ax + clf, X, cmap=plt.cm.RdYlBu, response_method='predict', ax=ax, + xlabel=iris.feature_names[pair[0]], ylabel=iris.feature_names[pair[1]], ) - plt.xlabel(iris.feature_names[pair[0]]) - plt.ylabel(iris.feature_names[pair[1]]) - # Plot the training points for i, color in zip(range(n_classes), plot_colors): idx = np.where(y == i) From b3805d6b470bcb3898e9505b960d820e7b354ce7 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 6 Aug 2021 23:08:04 +0200 Subject: [PATCH 25/48] DOC add example and see also --- sklearn/inspection/_plot/decision_boundary.py | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/sklearn/inspection/_plot/decision_boundary.py b/sklearn/inspection/_plot/decision_boundary.py index 8b60d52a5dbba..ec862550e5e7b 100644 --- a/sklearn/inspection/_plot/decision_boundary.py +++ b/sklearn/inspection/_plot/decision_boundary.py @@ -238,9 +238,29 @@ def from_estimator( See Also -------- + DecisionBoundaryDisplay : Decision boundary visualization. + ConfusionMatrixDisplay.from_estimator : Plot the confusion matrix + given an estimator, the data, and the label. + ConfusionMatrixDisplay.from_predictions : Plot the confusion matrix + given the true and predicted labels. Examples -------- + >>> import matplotlib.pyplot as plt + >>> from sklearn.datasets import load_iris + >>> from sklearn.linear_model import LogisticRegression + >>> from sklearn.inspection import DecisionBoundaryDisplay + >>> iris = load_iris() + >>> X = iris.data[:, :2] + >>> classifier = LogisticRegression().fit(X, iris.target) + >>> disp = DecisionBoundaryDisplay.from_estimator( + ... classifier, X, response_method="predict", + ... xlabel=iris.feature_names[0], ylabel=iris.feature_names[1], + ... alpha=0.5, + ... ) + >>> disp.ax_.scatter(X[:, 0], X[:, 1], c=iris.target, edgecolor="k") + <...> + >>> plt.show() """ check_matplotlib_support(f"{cls.__name__}.from_estimator") From dd09fd3c64485a6333a45e556862baaac84188a8 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 6 Aug 2021 23:15:35 +0200 Subject: [PATCH 26/48] iter --- sklearn/inspection/_plot/decision_boundary.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sklearn/inspection/_plot/decision_boundary.py b/sklearn/inspection/_plot/decision_boundary.py index ec862550e5e7b..95b77008b4540 100644 --- a/sklearn/inspection/_plot/decision_boundary.py +++ b/sklearn/inspection/_plot/decision_boundary.py @@ -5,14 +5,14 @@ def _check_boundary_response_method(estimator, response_method): - """Return prediction method from the response_method for decision boundary. + """Return prediction method from the `response_method` for decision boundary. Parameters ---------- - estimator: object + estimator : object Estimator to check. - response_method: {'auto', 'predict_proba', 'decision_function', 'predict'} + response_method : {'auto', 'predict_proba', 'decision_function', 'predict'} Specifies whether to use :term:`predict_proba`, :term:`decision_function`, :term:`predict` as the target response. If set to 'auto', the response method is tried in the following order: From c55ac8e4a96ffb19db20a6bb04982403a6ebae35 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Sun, 29 Aug 2021 15:46:21 -0400 Subject: [PATCH 27/48] REV Revert examples for now --- .pre-commit-config.yaml | 1 - 1 file changed, 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 54c1349263bf8..db2c5084cbbb2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -9,7 +9,6 @@ repos: rev: 21.6b0 hooks: - id: black - exclude: ^examples - repo: https://gitlab.com/pycqa/flake8 rev: 3.9.2 hooks: From 88e62fe8184715c2299f3856499bc11ae30bb4ba Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Mon, 6 Sep 2021 10:46:19 -0400 Subject: [PATCH 28/48] ENH Better validation errors --- sklearn/inspection/_plot/decision_boundary.py | 16 +++++++++++----- .../tests/test_boundary_decision_display.py | 12 +++++++++--- 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/sklearn/inspection/_plot/decision_boundary.py b/sklearn/inspection/_plot/decision_boundary.py index 95b77008b4540..339273a1828dd 100644 --- a/sklearn/inspection/_plot/decision_boundary.py +++ b/sklearn/inspection/_plot/decision_boundary.py @@ -265,15 +265,21 @@ def from_estimator( check_matplotlib_support(f"{cls.__name__}.from_estimator") if not grid_resolution > 1: - raise ValueError("grid_resolution must be greater than 1") + raise ValueError( + "grid_resolution must be greater than 1. Got" + f" {grid_resolution} instead." + ) if not eps >= 0: - raise ValueError("eps must be greater than or equal to 0") + raise ValueError( + f"eps must be greater than or equal to 0. Got {eps} instead." + ) - possible_plot_method = ("contourf", "contour", "pcolormesh") - if plot_method not in possible_plot_method: + possible_plot_methods = ("contourf", "contour", "pcolormesh") + if plot_method not in possible_plot_methods: + avaliable_methods = ", ".join(possible_plot_methods) raise ValueError( - f"plot_method must be one of {', '.join(possible_plot_method)}. " + f"plot_method must be one of {avaliable_methods}. " f"Got {plot_method} instead." ) diff --git a/sklearn/inspection/_plot/tests/test_boundary_decision_display.py b/sklearn/inspection/_plot/tests/test_boundary_decision_display.py index 1730439781f5a..4c304781bc764 100644 --- a/sklearn/inspection/_plot/tests/test_boundary_decision_display.py +++ b/sklearn/inspection/_plot/tests/test_boundary_decision_display.py @@ -91,9 +91,15 @@ def test_multiclass_error(pyplot, response_method): r"plot_method must be one of contourf, contour, pcolormesh. Got hello_world" r" instead.", ), - ({"grid_resolution": 1}, r"grid_resolution must be greater than 1"), - ({"grid_resolution": -1}, r"grid_resolution must be greater than 1"), - ({"eps": -1.1}, r"eps must be greater than or equal to 0"), + ( + {"grid_resolution": 1}, + r"grid_resolution must be greater than 1. Got 1 instead", + ), + ( + {"grid_resolution": -1}, + r"grid_resolution must be greater than 1. Got -1 instead", + ), + ({"eps": -1.1}, r"eps must be greater than or equal to 0. Got -1.1 instead"), ], ) def test_input_validation_errors(pyplot, kwargs, error_msg, fitted_clf, data): From 39c5e0f15e9b7843cab24935fe2dd42a02379eb9 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Sat, 23 Oct 2021 19:45:47 -0400 Subject: [PATCH 29/48] ENH Remvoe --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 71da52002cc96..3762d2f229f76 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,6 @@ exclude = ''' | \.git # root of the project | \.mypy_cache | \.vscode - | examples | build | dist | doc/tutorial From e92e0320188c440fdcc1bd25287ab97324e4cb17 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Sat, 23 Oct 2021 21:01:49 -0400 Subject: [PATCH 30/48] ENH Update avaliable methods --- sklearn/inspection/_plot/decision_boundary.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/inspection/_plot/decision_boundary.py b/sklearn/inspection/_plot/decision_boundary.py index 339273a1828dd..5585560a57944 100644 --- a/sklearn/inspection/_plot/decision_boundary.py +++ b/sklearn/inspection/_plot/decision_boundary.py @@ -277,9 +277,9 @@ def from_estimator( possible_plot_methods = ("contourf", "contour", "pcolormesh") if plot_method not in possible_plot_methods: - avaliable_methods = ", ".join(possible_plot_methods) + available_methods = ", ".join(possible_plot_methods) raise ValueError( - f"plot_method must be one of {avaliable_methods}. " + f"plot_method must be one of {available_methods}. " f"Got {plot_method} instead." ) From fa4ded7b00a82cca431d8ad98795e7c14d1a24e0 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Sun, 28 Nov 2021 20:38:04 -0500 Subject: [PATCH 31/48] DOC Move to 1.1 --- doc/whats_new/v1.0.rst | 4 ---- doc/whats_new/v1.1.rst | 11 +++++++++-- sklearn/inspection/_plot/decision_boundary.py | 6 ++++-- 3 files changed, 13 insertions(+), 8 deletions(-) diff --git a/doc/whats_new/v1.0.rst b/doc/whats_new/v1.0.rst index 8f734c2d7e52a..7f7f6b3f509ec 100644 --- a/doc/whats_new/v1.0.rst +++ b/doc/whats_new/v1.0.rst @@ -650,10 +650,6 @@ Changelog :mod:`sklearn.inspection` ......................... -- |Feature| Add a display to plot the boundary decision of a classifier by - using the method :func:`inspection.DecisionBoundaryDisplay.from_estimator`. - :pr:`16061` by `Thomas Fan`_. - - |Enhancement| Add `max_samples` parameter in :func:`inspection.permutation_importance`. It enables to draw a subset of the samples to compute the permutation importance. This is useful to keep the diff --git a/doc/whats_new/v1.1.rst b/doc/whats_new/v1.1.rst index f755aeba20030..f843c61998f93 100644 --- a/doc/whats_new/v1.1.rst +++ b/doc/whats_new/v1.1.rst @@ -113,10 +113,10 @@ Changelog - |Fix| :class:`decomposition.FastICA` now validates input parameters in `fit` instead of `__init__`. :pr:`21432` by :user:`Hannah Bohle ` and :user:`Maren Westermann `. - + - |Fix| :class:`decomposition.FactorAnalysis` now validates input parameters in `fit` instead of `__init__`. - :pr:`21713` by :user:`Haya ` and + :pr:`21713` by :user:`Haya ` and :user:`Krum Arnaudov `. - |Fix| :class:`decomposition.KernelPCA` now validates input parameters in @@ -178,6 +178,13 @@ Changelog multilabel classification. :pr:`19689` by :user:`Guillaume Lemaitre `. +:mod:`sklearn.inspection` +......................... + +- |Feature| Add a display to plot the boundary decision of a classifier by + using the method :func:`inspection.DecisionBoundaryDisplay.from_estimator`. + :pr:`16061` by `Thomas Fan`_. + :mod:`sklearn.linear_model` ........................... diff --git a/sklearn/inspection/_plot/decision_boundary.py b/sklearn/inspection/_plot/decision_boundary.py index 5585560a57944..1e7c9a2f4c7fe 100644 --- a/sklearn/inspection/_plot/decision_boundary.py +++ b/sklearn/inspection/_plot/decision_boundary.py @@ -23,6 +23,8 @@ def _check_boundary_response_method(estimator, response_method): prediction_method: callable Prediction method of estimator. """ + # if "auto": + # response_methods = ["decision_function", "predict_proba", "predict"] possible_response_methods = ( "predict_proba", @@ -79,10 +81,10 @@ class DecisionBoundaryDisplay: response : ndarray of shape (grid_resolution, grid_resolution) Values of the response function. - xlabel : str, default="" + xlabel : str, default=None Default label to place on x axis. - ylabel : str, default="" + ylabel : str, default=None Default label to place on y axis. Attributes From d97231bf7f4f9ccf8bfbc2c1af654bf00b7de671 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Sun, 28 Nov 2021 21:32:06 -0500 Subject: [PATCH 32/48] WIP --- sklearn/inspection/_plot/decision_boundary.py | 50 ++++++------- .../tests/test_boundary_decision_display.py | 72 +++++++++++-------- 2 files changed, 61 insertions(+), 61 deletions(-) diff --git a/sklearn/inspection/_plot/decision_boundary.py b/sklearn/inspection/_plot/decision_boundary.py index 1e7c9a2f4c7fe..ff3e8d8dfb8d9 100644 --- a/sklearn/inspection/_plot/decision_boundary.py +++ b/sklearn/inspection/_plot/decision_boundary.py @@ -1,3 +1,5 @@ +from functools import reduce + import numpy as np from ...utils import check_matplotlib_support @@ -23,39 +25,20 @@ def _check_boundary_response_method(estimator, response_method): prediction_method: callable Prediction method of estimator. """ - # if "auto": - # response_methods = ["decision_function", "predict_proba", "predict"] - - possible_response_methods = ( - "predict_proba", - "decision_function", - "auto", - "predict", - ) - if response_method not in possible_response_methods: + if response_method == "auto": + list_methods = ["decision_function", "predict_proba", "predict"] + else: + list_methods = [response_method] + + prediction_method = [getattr(estimator, method, None) for method in list_methods] + prediction_method = reduce(lambda x, y: x or y, prediction_method) + if prediction_method is None: raise ValueError( - f"response_method must be one of {', '.join(possible_response_methods)}" + f"{estimator.__class__.__name__} has none of the following attributes: " + f"{', '.join(list_methods)}." ) - error_msg = "response method {} is not defined in {}" - if response_method != "auto": - if not hasattr(estimator, response_method): - raise ValueError( - error_msg.format(response_method, estimator.__class__.__name__) - ) - return getattr(estimator, response_method) - elif hasattr(estimator, "decision_function"): - return getattr(estimator, "decision_function") - elif hasattr(estimator, "predict_proba"): - return getattr(estimator, "predict_proba") - elif hasattr(estimator, "predict"): - return getattr(estimator, "predict") - - raise ValueError( - error_msg.format( - "decision_function, predict_proba, or predict", estimator.__class__.__name__ - ) - ) + return prediction_method class DecisionBoundaryDisplay: @@ -298,6 +281,13 @@ def from_estimator( pred_func = _check_boundary_response_method(estimator, response_method) response = pred_func(np.c_[xx0.ravel(), xx1.ravel()]) + # convert strings to integers + if response.dtype.kind in {"O", "U"}: + class_name_to_idx = { + name: idx for idx, name in enumerate(estimator.classes_) + } + response = np.asarray([class_name_to_idx[target] for target in response]) + if response.ndim != 1: if response.shape[1] != 2: raise ValueError( diff --git a/sklearn/inspection/_plot/tests/test_boundary_decision_display.py b/sklearn/inspection/_plot/tests/test_boundary_decision_display.py index 4c304781bc764..440dcf44ee05b 100644 --- a/sklearn/inspection/_plot/tests/test_boundary_decision_display.py +++ b/sklearn/inspection/_plot/tests/test_boundary_decision_display.py @@ -4,6 +4,7 @@ from sklearn.base import ClassifierMixin from sklearn.datasets import make_classification from sklearn.linear_model import LogisticRegression +from sklearn.datasets import load_iris from sklearn.inspection import DecisionBoundaryDisplay from sklearn.inspection._plot.decision_boundary import _check_boundary_response_method @@ -16,21 +17,18 @@ ) -@pytest.fixture(scope="module") -def data(): - X, y = make_classification( - n_informative=1, - n_redundant=1, - n_clusters_per_class=1, - n_features=2, - random_state=42, - ) - return X, y +X, y = make_classification( + n_informative=1, + n_redundant=1, + n_clusters_per_class=1, + n_features=2, + random_state=42, +) @pytest.fixture(scope="module") -def fitted_clf(data): - return LogisticRegression().fit(*data) +def fitted_clf(): + return LogisticRegression().fit(X, y) def test_check_boundary_response_method_auto(): @@ -102,14 +100,12 @@ def test_multiclass_error(pyplot, response_method): ({"eps": -1.1}, r"eps must be greater than or equal to 0. Got -1.1 instead"), ], ) -def test_input_validation_errors(pyplot, kwargs, error_msg, fitted_clf, data): - X, _ = data +def test_input_validation_errors(pyplot, kwargs, error_msg, fitted_clf): with pytest.raises(ValueError, match=error_msg): DecisionBoundaryDisplay.from_estimator(fitted_clf, X, **kwargs) -def test_display_plot_input_error(pyplot, fitted_clf, data): - X, y = data +def test_display_plot_input_error(pyplot, fitted_clf): disp = DecisionBoundaryDisplay.from_estimator(fitted_clf, X, grid_resolution=5) with pytest.raises(ValueError, match="plot_method must be 'contourf'"): @@ -120,12 +116,9 @@ def test_display_plot_input_error(pyplot, fitted_clf, data): "response_method", ["auto", "predict", "predict_proba", "decision_function"] ) @pytest.mark.parametrize("plot_method", ["contourf", "contour"]) -def test_decision_boundary_display( - pyplot, fitted_clf, data, response_method, plot_method -): +def test_decision_boundary_display(pyplot, fitted_clf, response_method, plot_method): fig, ax = pyplot.subplots() eps = 2.0 - X, y = data disp = DecisionBoundaryDisplay.from_estimator( fitted_clf, X, @@ -162,27 +155,24 @@ def test_decision_boundary_display( [ ( "predict_proba", - "response method predict_proba is not defined in MyClassifier", + "MyClassifier has none of the following attributes: predict_proba", ), ( "decision_function", - "response method decision_function is not defined in MyClassifier", + "MyClassifier has none of the following attributes: decision_function", ), ( "auto", - "response method decision_function, predict_proba, or predict " - "is not defined in MyClassifier", + "MyClassifier has none of the following attributes: decision_function, " + "predict_proba, predict", ), ( "bad_method", - "response_method must be one of predict_proba, decision_function, auto," - " predict", + "MyClassifier has none of the following attributes: bad_method", ), ], ) -def test_error_bad_response(pyplot, response_method, msg, data): - X, y = data - +def test_error_bad_response(pyplot, response_method, msg): class MyClassifier(BaseEstimator, ClassifierMixin): def fit(self, X, y): self.fitted_ = True @@ -195,9 +185,9 @@ def fit(self, X, y): DecisionBoundaryDisplay.from_estimator(clf, X, response_method=response_method) -def test_dataframe_labels_used(pyplot, data, fitted_clf): +def test_dataframe_labels_used(pyplot, fitted_clf): pd = pytest.importorskip("pandas") - df = pd.DataFrame(data[0], columns=["col_x", "col_y"]) + df = pd.DataFrame(X, columns=["col_x", "col_y"]) # pandas column names are used by default _, ax = pyplot.subplots() @@ -230,3 +220,23 @@ def test_dataframe_labels_used(pyplot, data, fitted_clf): ) assert ax.get_xlabel() == "overwritten_x" assert ax.get_ylabel() == "overwritten_y" + + +def test_string_target(pyplot): + """Check that decision boundary works with classifiers trained on string labels.""" + iris = load_iris() + X = iris.data[:, [0, 1]] + + # Use strings as target + y = iris.target_names[iris.target] + + log_reg = LogisticRegression().fit(X, y) + + disp = DecisionBoundaryDisplay.from_estimator( + log_reg, + X, + grid_resolution=5, + response_method="predict", + plot_method="pcolormesh", + ) + disp.plot() From 3c1cb4dbaad4eb3a0b3fefe880ef517982199bf8 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Sun, 28 Nov 2021 21:36:53 -0500 Subject: [PATCH 33/48] FIX Support string as targets --- examples/neighbors/plot_nearest_centroid.py | 21 +------------------ sklearn/inspection/_plot/decision_boundary.py | 8 ++++--- .../tests/test_boundary_decision_display.py | 6 ++---- 3 files changed, 8 insertions(+), 27 deletions(-) diff --git a/examples/neighbors/plot_nearest_centroid.py b/examples/neighbors/plot_nearest_centroid.py index 0dcd6b09bc222..0ea3c0c6b1209 100644 --- a/examples/neighbors/plot_nearest_centroid.py +++ b/examples/neighbors/plot_nearest_centroid.py @@ -24,11 +24,6 @@ X = iris.data[:, :2] y = iris.target -<<<<<<< HEAD -======= -h = 0.02 # step size in the mesh - ->>>>>>> upstream/main # Create color maps cmap_light = ListedColormap(["orange", "cyan", "cornflowerblue"]) cmap_bold = ListedColormap(["darkorange", "c", "darkblue"]) @@ -39,25 +34,11 @@ clf.fit(X, y) y_pred = clf.predict(X) print(shrinkage, np.mean(y == y_pred)) -<<<<<<< HEAD _, ax = plt.subplots() DecisionBoundaryDisplay.from_estimator( - clf, X, cmap=cmap_light, ax=ax, response_method='predict' + clf, X, cmap=cmap_light, ax=ax, response_method="predict" ) -======= - # Plot the decision boundary. For that, we will assign a color to each - # point in the mesh [x_min, x_max]x[y_min, y_max]. - x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1 - y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1 - xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h)) - Z = clf.predict(np.c_[xx.ravel(), yy.ravel()]) - - # Put the result into a color plot - Z = Z.reshape(xx.shape) - plt.figure() - plt.pcolormesh(xx, yy, Z, cmap=cmap_light) ->>>>>>> upstream/main # Plot also the training points plt.scatter(X[:, 0], X[:, 1], c=y, cmap=cmap_bold, edgecolor="k", s=20) diff --git a/sklearn/inspection/_plot/decision_boundary.py b/sklearn/inspection/_plot/decision_boundary.py index ff3e8d8dfb8d9..afcc11919fc14 100644 --- a/sklearn/inspection/_plot/decision_boundary.py +++ b/sklearn/inspection/_plot/decision_boundary.py @@ -281,12 +281,14 @@ def from_estimator( pred_func = _check_boundary_response_method(estimator, response_method) response = pred_func(np.c_[xx0.ravel(), xx1.ravel()]) - # convert strings to integers - if response.dtype.kind in {"O", "U"}: + # convert strings predictions to integers + if pred_func.__name__ == "predict" and response.dtype.kind in {"O", "U"}: class_name_to_idx = { name: idx for idx, name in enumerate(estimator.classes_) } - response = np.asarray([class_name_to_idx[target] for target in response]) + response = np.asarray( + [class_name_to_idx[target] for target in response], dtype=np.int32 + ) if response.ndim != 1: if response.shape[1] != 2: diff --git a/sklearn/inspection/_plot/tests/test_boundary_decision_display.py b/sklearn/inspection/_plot/tests/test_boundary_decision_display.py index 440dcf44ee05b..a6c30418320cc 100644 --- a/sklearn/inspection/_plot/tests/test_boundary_decision_display.py +++ b/sklearn/inspection/_plot/tests/test_boundary_decision_display.py @@ -229,14 +229,12 @@ def test_string_target(pyplot): # Use strings as target y = iris.target_names[iris.target] - log_reg = LogisticRegression().fit(X, y) - disp = DecisionBoundaryDisplay.from_estimator( + # Does not raise + DecisionBoundaryDisplay.from_estimator( log_reg, X, grid_resolution=5, response_method="predict", - plot_method="pcolormesh", ) - disp.plot() From fa66c813c1b0603abefeeca449cceedd8f2a9cad Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Mon, 29 Nov 2021 00:47:16 -0500 Subject: [PATCH 34/48] STY Fix black formatting --- examples/cluster/plot_inductive_clustering.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/cluster/plot_inductive_clustering.py b/examples/cluster/plot_inductive_clustering.py index 6ee18d71cf63f..e395571a1caad 100644 --- a/examples/cluster/plot_inductive_clustering.py +++ b/examples/cluster/plot_inductive_clustering.py @@ -122,7 +122,7 @@ def plot_scatter(X, color, alpha=0.5): # Plotting decision regions DecisionBoundaryDisplay.from_estimator( - inductive_learner, X, response_method='predict', alpha=0.4, ax=ax + inductive_learner, X, response_method="predict", alpha=0.4, ax=ax ) plt.title("Classify unknown instances") From 833492c4301221fd1dac84a761488d68e9befb1d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lo=C3=AFc=20Est=C3=A8ve?= Date: Wed, 23 Feb 2022 17:41:58 +0100 Subject: [PATCH 35/48] Fix whats_new --- doc/whats_new/v1.1.rst | 3 --- 1 file changed, 3 deletions(-) diff --git a/doc/whats_new/v1.1.rst b/doc/whats_new/v1.1.rst index afbfb964a45a3..949ca3c9e4d35 100644 --- a/doc/whats_new/v1.1.rst +++ b/doc/whats_new/v1.1.rst @@ -474,9 +474,6 @@ Changelog `max_iter` and `tol`. :pr:`21341` by :user:`Arturo Amor `. -:mod:`sklearn.inspection` -......................... - - |Enhancement| In :meth:`~sklearn.inspection.PartialDependenceDisplay.from_estimator` and :meth:`~sklearn.inspection.PartialDependenceDisplay.from_predictions`, allow From 9f9b1d13e0066d7c8c25798994e3a376b6b3d824 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Wed, 23 Feb 2022 15:01:36 -0500 Subject: [PATCH 36/48] DOC Fixes whats nwe --- doc/whats_new/v1.1.rst | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/doc/whats_new/v1.1.rst b/doc/whats_new/v1.1.rst index 949ca3c9e4d35..52b2ce22a10ce 100644 --- a/doc/whats_new/v1.1.rst +++ b/doc/whats_new/v1.1.rst @@ -458,12 +458,6 @@ Changelog multilabel classification. :pr:`19689` by :user:`Guillaume Lemaitre `. -:mod:`sklearn.inspection` -......................... - -- |Feature| Add a display to plot the boundary decision of a classifier by - using the method :func:`inspection.DecisionBoundaryDisplay.from_estimator`. - :pr:`16061` by `Thomas Fan`_. - |Enhancement| :class:`linear_model.RidgeCV` and :class:`linear_model.RidgeClassifierCV` now raise consistent error message when passed invalid values for `alphas`. @@ -474,6 +468,13 @@ Changelog `max_iter` and `tol`. :pr:`21341` by :user:`Arturo Amor `. +:mod:`sklearn.inspection` +......................... + +- |Feature| Add a display to plot the boundary decision of a classifier by + using the method :func:`inspection.DecisionBoundaryDisplay.from_estimator`. + :pr:`16061` by `Thomas Fan`_. + - |Enhancement| In :meth:`~sklearn.inspection.PartialDependenceDisplay.from_estimator` and :meth:`~sklearn.inspection.PartialDependenceDisplay.from_predictions`, allow From 3e8162cdba39ffe21212c10e39273177d4ed91ac Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Wed, 23 Feb 2022 16:52:11 -0500 Subject: [PATCH 37/48] ENH Improve auto behavior for multiclass problems --- examples/linear_model/plot_iris_logistic.py | 1 + sklearn/inspection/_plot/decision_boundary.py | 24 ++++++++--- .../tests/test_boundary_decision_display.py | 43 +++++++++++++++++-- 3 files changed, 57 insertions(+), 11 deletions(-) diff --git a/examples/linear_model/plot_iris_logistic.py b/examples/linear_model/plot_iris_logistic.py index aa6b379fa966b..10a1f0f15ad79 100644 --- a/examples/linear_model/plot_iris_logistic.py +++ b/examples/linear_model/plot_iris_logistic.py @@ -40,6 +40,7 @@ shading="auto", xlabel="Sepal length", ylabel="Sepal width", + eps=0.5, ) # Plot also the training points diff --git a/sklearn/inspection/_plot/decision_boundary.py b/sklearn/inspection/_plot/decision_boundary.py index afcc11919fc14..4b71f338ad937 100644 --- a/sklearn/inspection/_plot/decision_boundary.py +++ b/sklearn/inspection/_plot/decision_boundary.py @@ -12,7 +12,7 @@ def _check_boundary_response_method(estimator, response_method): Parameters ---------- estimator : object - Estimator to check. + Fitted estimator to check. response_method : {'auto', 'predict_proba', 'decision_function', 'predict'} Specifies whether to use :term:`predict_proba`, @@ -25,17 +25,25 @@ def _check_boundary_response_method(estimator, response_method): prediction_method: callable Prediction method of estimator. """ - if response_method == "auto": - list_methods = ["decision_function", "predict_proba", "predict"] + if len(estimator.classes_) > 2: + if response_method not in {"auto", "predict"}: + msg = ( + "Multiclass classifiers are only supported when response_method is" + " 'predict' or 'auto'" + ) + raise ValueError(msg) + methods_list = ["predict"] + elif response_method == "auto": + methods_list = ["decision_function", "predict_proba", "predict"] else: - list_methods = [response_method] + methods_list = [response_method] - prediction_method = [getattr(estimator, method, None) for method in list_methods] + prediction_method = [getattr(estimator, method, None) for method in methods_list] prediction_method = reduce(lambda x, y: x or y, prediction_method) if prediction_method is None: raise ValueError( f"{estimator.__class__.__name__} has none of the following attributes: " - f"{', '.join(list_methods)}." + f"{', '.join(methods_list)}." ) return prediction_method @@ -51,7 +59,7 @@ class DecisionBoundaryDisplay: Read more in the :ref:`User Guide `. - .. versionadded:: 1.0 + .. versionadded:: 1.1 Parameters ---------- @@ -197,6 +205,8 @@ def from_estimator( :term:`decision_function`, :term:`predict` as the target response. If set to 'auto', the response method is tried in the following order: :term:`predict_proba`, :term:`decision_function`, :term:`predict`. + For multiclass problems, :term:`predict` is selected when + `response_method="auto"`. xlabel : str, default=None The label used for the x-axis. If `None`, an attempt is made to diff --git a/sklearn/inspection/_plot/tests/test_boundary_decision_display.py b/sklearn/inspection/_plot/tests/test_boundary_decision_display.py index a6c30418320cc..21bb5e3215e2d 100644 --- a/sklearn/inspection/_plot/tests/test_boundary_decision_display.py +++ b/sklearn/inspection/_plot/tests/test_boundary_decision_display.py @@ -1,4 +1,6 @@ import pytest +import numpy as np +from numpy.testing import assert_allclose from sklearn.base import BaseEstimator from sklearn.base import ClassifierMixin @@ -33,6 +35,8 @@ def fitted_clf(): def test_check_boundary_response_method_auto(): class A: + classes_ = [0, 1] + def decision_function(self): pass @@ -41,6 +45,8 @@ def decision_function(self): assert method == a_inst.decision_function class B: + classes_ = [0, 1] + def predict_proba(self): pass @@ -49,6 +55,8 @@ def predict_proba(self): assert method == b_inst.predict_proba class C: + classes_ = [0, 1] + def predict_proba(self): pass @@ -60,6 +68,8 @@ def decision_function(self): assert method == c_inst.decision_function class D: + classes_ = [0, 1] + def predict(self): pass @@ -68,19 +78,44 @@ def predict(self): assert method == d_inst.predict -@pytest.mark.parametrize( - "response_method", ["auto", "predict_proba", "decision_function"] -) +@pytest.mark.parametrize("response_method", ["predict_proba", "decision_function"]) def test_multiclass_error(pyplot, response_method): X, y = make_classification(n_classes=3, n_informative=3, random_state=0) X = X[:, [0, 1]] lr = LogisticRegression().fit(X, y) - msg = "Multiclass classifiers are only supported when response_method='predict'" + msg = ( + "Multiclass classifiers are only supported when response_method is 'predict' or" + " 'auto'" + ) with pytest.raises(ValueError, match=msg): DecisionBoundaryDisplay.from_estimator(lr, X, response_method=response_method) +@pytest.mark.parametrize("response_method", ["auto", "predict"]) +def test_multiclass(pyplot, response_method): + grid_resolution = 10 + eps = 1.0 + X, y = make_classification(n_classes=3, n_informative=3, random_state=0) + X = X[:, [0, 1]] + lr = LogisticRegression(random_state=0).fit(X, y) + + disp = DecisionBoundaryDisplay.from_estimator( + lr, X, response_method=response_method, grid_resolution=grid_resolution, eps=1.0 + ) + + x0_min, x0_max = X[:, 0].min() - eps, X[:, 0].max() + eps + x1_min, x1_max = X[:, 1].min() - eps, X[:, 1].max() + eps + xx0, xx1 = np.meshgrid( + np.linspace(x0_min, x0_max, grid_resolution), + np.linspace(x1_min, x1_max, grid_resolution), + ) + response = lr.predict(np.c_[xx0.ravel(), xx1.ravel()]) + assert_allclose(disp.response, response.reshape(xx0.shape)) + assert_allclose(disp.xx0, xx0) + assert_allclose(disp.xx1, xx1) + + @pytest.mark.parametrize( "kwargs, error_msg", [ From 60d67d9d321665f985b67eca1dbad1445e6f1705 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Wed, 23 Feb 2022 16:52:50 -0500 Subject: [PATCH 38/48] FIX Removes unneeded code --- sklearn/inspection/_plot/decision_boundary.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/sklearn/inspection/_plot/decision_boundary.py b/sklearn/inspection/_plot/decision_boundary.py index 4b71f338ad937..d53504c2cc44d 100644 --- a/sklearn/inspection/_plot/decision_boundary.py +++ b/sklearn/inspection/_plot/decision_boundary.py @@ -301,11 +301,6 @@ def from_estimator( ) if response.ndim != 1: - if response.shape[1] != 2: - raise ValueError( - "Multiclass classifiers are only supported when " - "response_method='predict'" - ) response = response[:, 1] if xlabel is not None: From 741ff084e451a2f9a35ba0926dfbe38e6a966276 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Wed, 23 Feb 2022 20:35:22 -0500 Subject: [PATCH 39/48] FIX Adds classes to InductiveClusterer --- examples/cluster/plot_inductive_clustering.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/cluster/plot_inductive_clustering.py b/examples/cluster/plot_inductive_clustering.py index e395571a1caad..6c894ee76e5f8 100644 --- a/examples/cluster/plot_inductive_clustering.py +++ b/examples/cluster/plot_inductive_clustering.py @@ -60,6 +60,7 @@ def fit(self, X, y=None): self.classifier_ = clone(self.classifier) y = self.clusterer_.fit_predict(X) self.classifier_.fit(X, y) + self.classes_ = self.classifier_.classes_ return self @available_if(_classifier_has("predict")) From 86f24e9153829bb9765e2d4b7a1c3b8868695c22 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Wed, 23 Feb 2022 20:38:09 -0500 Subject: [PATCH 40/48] ENH Do not require classes --- examples/cluster/plot_inductive_clustering.py | 1 - sklearn/inspection/_plot/decision_boundary.py | 2 +- .../_plot/tests/test_boundary_decision_display.py | 8 -------- 3 files changed, 1 insertion(+), 10 deletions(-) diff --git a/examples/cluster/plot_inductive_clustering.py b/examples/cluster/plot_inductive_clustering.py index 6c894ee76e5f8..e395571a1caad 100644 --- a/examples/cluster/plot_inductive_clustering.py +++ b/examples/cluster/plot_inductive_clustering.py @@ -60,7 +60,6 @@ def fit(self, X, y=None): self.classifier_ = clone(self.classifier) y = self.clusterer_.fit_predict(X) self.classifier_.fit(X, y) - self.classes_ = self.classifier_.classes_ return self @available_if(_classifier_has("predict")) diff --git a/sklearn/inspection/_plot/decision_boundary.py b/sklearn/inspection/_plot/decision_boundary.py index d53504c2cc44d..b9ef71ed0d444 100644 --- a/sklearn/inspection/_plot/decision_boundary.py +++ b/sklearn/inspection/_plot/decision_boundary.py @@ -25,7 +25,7 @@ def _check_boundary_response_method(estimator, response_method): prediction_method: callable Prediction method of estimator. """ - if len(estimator.classes_) > 2: + if hasattr(estimator, "classes_") and len(estimator.classes_) > 2: if response_method not in {"auto", "predict"}: msg = ( "Multiclass classifiers are only supported when response_method is" diff --git a/sklearn/inspection/_plot/tests/test_boundary_decision_display.py b/sklearn/inspection/_plot/tests/test_boundary_decision_display.py index 21bb5e3215e2d..288b96924c071 100644 --- a/sklearn/inspection/_plot/tests/test_boundary_decision_display.py +++ b/sklearn/inspection/_plot/tests/test_boundary_decision_display.py @@ -35,8 +35,6 @@ def fitted_clf(): def test_check_boundary_response_method_auto(): class A: - classes_ = [0, 1] - def decision_function(self): pass @@ -45,8 +43,6 @@ def decision_function(self): assert method == a_inst.decision_function class B: - classes_ = [0, 1] - def predict_proba(self): pass @@ -55,8 +51,6 @@ def predict_proba(self): assert method == b_inst.predict_proba class C: - classes_ = [0, 1] - def predict_proba(self): pass @@ -68,8 +62,6 @@ def decision_function(self): assert method == c_inst.decision_function class D: - classes_ = [0, 1] - def predict(self): pass From 9d00d7331d005b6f1b5b9bba2687641c3b665912 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Wed, 23 Feb 2022 20:43:08 -0500 Subject: [PATCH 41/48] DOC Adds TODO --- sklearn/inspection/_plot/decision_boundary.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/sklearn/inspection/_plot/decision_boundary.py b/sklearn/inspection/_plot/decision_boundary.py index b9ef71ed0d444..ce25cfc49ec03 100644 --- a/sklearn/inspection/_plot/decision_boundary.py +++ b/sklearn/inspection/_plot/decision_boundary.py @@ -18,7 +18,7 @@ def _check_boundary_response_method(estimator, response_method): Specifies whether to use :term:`predict_proba`, :term:`decision_function`, :term:`predict` as the target response. If set to 'auto', the response method is tried in the following order: - :term:`predict_proba`, :term:`decision_function`, :term:`predict`. + :term:`decision_function`, :term:`predict_proba`, :term:`predict`. Returns ------- @@ -204,7 +204,7 @@ def from_estimator( Specifies whether to use :term:`predict_proba`, :term:`decision_function`, :term:`predict` as the target response. If set to 'auto', the response method is tried in the following order: - :term:`predict_proba`, :term:`decision_function`, :term:`predict`. + :term:`decision_function`, :term:`predict_proba`, :term:`predict`. For multiclass problems, :term:`predict` is selected when `response_method="auto"`. @@ -300,7 +300,10 @@ def from_estimator( [class_name_to_idx[target] for target in response], dtype=np.int32 ) + # TODO: Check for multi-label classifier, a multi-output multiclass + # classifier or a multioutput regressor and error. if response.ndim != 1: + # TODO: Support pos_label response = response[:, 1] if xlabel is not None: From 1f2d6e68124ebebe2a22b090e57f150161a9073e Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Wed, 23 Feb 2022 23:41:02 -0500 Subject: [PATCH 42/48] ENH Improve error message for unsupported estimators --- sklearn/inspection/_plot/decision_boundary.py | 13 ++++- .../tests/test_boundary_decision_display.py | 54 +++++++++++++++++++ 2 files changed, 66 insertions(+), 1 deletion(-) diff --git a/sklearn/inspection/_plot/decision_boundary.py b/sklearn/inspection/_plot/decision_boundary.py index ce25cfc49ec03..4aa0f0040be4f 100644 --- a/sklearn/inspection/_plot/decision_boundary.py +++ b/sklearn/inspection/_plot/decision_boundary.py @@ -4,6 +4,8 @@ from ...utils import check_matplotlib_support from ...utils import _safe_indexing +from ...base import is_regressor +from ...utils.validation import check_is_fitted, _is_arraylike_not_scalar def _check_boundary_response_method(estimator, response_method): @@ -25,7 +27,12 @@ def _check_boundary_response_method(estimator, response_method): prediction_method: callable Prediction method of estimator. """ - if hasattr(estimator, "classes_") and len(estimator.classes_) > 2: + has_classes = hasattr(estimator, "classes_") + if has_classes and _is_arraylike_not_scalar(estimator.classes_[0]): + msg = "Multi-label and multi-output multi-class classifiers are not supported" + raise ValueError(msg) + + if has_classes and len(estimator.classes_) > 2: if response_method not in {"auto", "predict"}: msg = ( "Multiclass classifiers are only supported when response_method is" @@ -258,6 +265,7 @@ def from_estimator( >>> plt.show() """ check_matplotlib_support(f"{cls.__name__}.from_estimator") + check_is_fitted(estimator) if not grid_resolution > 1: raise ValueError( @@ -303,6 +311,9 @@ def from_estimator( # TODO: Check for multi-label classifier, a multi-output multiclass # classifier or a multioutput regressor and error. if response.ndim != 1: + if is_regressor(estimator): + raise ValueError("Multi-output regressors are not supported") + # TODO: Support pos_label response = response[:, 1] diff --git a/sklearn/inspection/_plot/tests/test_boundary_decision_display.py b/sklearn/inspection/_plot/tests/test_boundary_decision_display.py index 288b96924c071..e1f25ff0cc607 100644 --- a/sklearn/inspection/_plot/tests/test_boundary_decision_display.py +++ b/sklearn/inspection/_plot/tests/test_boundary_decision_display.py @@ -7,6 +7,9 @@ from sklearn.datasets import make_classification from sklearn.linear_model import LogisticRegression from sklearn.datasets import load_iris +from sklearn.datasets import make_multilabel_classification +from sklearn.tree import DecisionTreeRegressor +from sklearn.tree import DecisionTreeClassifier from sklearn.inspection import DecisionBoundaryDisplay from sklearn.inspection._plot.decision_boundary import _check_boundary_response_method @@ -34,6 +37,8 @@ def fitted_clf(): def test_check_boundary_response_method_auto(): + """Check _check_boundary_response_method behavior with 'auto'.""" + class A: def decision_function(self): pass @@ -72,6 +77,7 @@ def predict(self): @pytest.mark.parametrize("response_method", ["predict_proba", "decision_function"]) def test_multiclass_error(pyplot, response_method): + """Check multiclass errors.""" X, y = make_classification(n_classes=3, n_informative=3, random_state=0) X = X[:, [0, 1]] lr = LogisticRegression().fit(X, y) @@ -86,6 +92,7 @@ def test_multiclass_error(pyplot, response_method): @pytest.mark.parametrize("response_method", ["auto", "predict"]) def test_multiclass(pyplot, response_method): + """Check multiclass gives expected results.""" grid_resolution = 10 eps = 1.0 X, y = make_classification(n_classes=3, n_informative=3, random_state=0) @@ -128,11 +135,13 @@ def test_multiclass(pyplot, response_method): ], ) def test_input_validation_errors(pyplot, kwargs, error_msg, fitted_clf): + """Check input validation from_estimator.""" with pytest.raises(ValueError, match=error_msg): DecisionBoundaryDisplay.from_estimator(fitted_clf, X, **kwargs) def test_display_plot_input_error(pyplot, fitted_clf): + """Check input validation for `plot`.""" disp = DecisionBoundaryDisplay.from_estimator(fitted_clf, X, grid_resolution=5) with pytest.raises(ValueError, match="plot_method must be 'contourf'"): @@ -144,6 +153,7 @@ def test_display_plot_input_error(pyplot, fitted_clf): ) @pytest.mark.parametrize("plot_method", ["contourf", "contour"]) def test_decision_boundary_display(pyplot, fitted_clf, response_method, plot_method): + """Check that decision boundary is correct.""" fig, ax = pyplot.subplots() eps = 2.0 disp = DecisionBoundaryDisplay.from_estimator( @@ -200,6 +210,8 @@ def test_decision_boundary_display(pyplot, fitted_clf, response_method, plot_met ], ) def test_error_bad_response(pyplot, response_method, msg): + """Check errors for bad response.""" + class MyClassifier(BaseEstimator, ClassifierMixin): def fit(self, X, y): self.fitted_ = True @@ -212,7 +224,49 @@ def fit(self, X, y): DecisionBoundaryDisplay.from_estimator(clf, X, response_method=response_method) +@pytest.mark.parametrize("response_method", ["auto", "predict", "predict_proba"]) +def test_multilabel_classifier_error(response_method): + """Check that multilabel classifier raises correct error.""" + X, y = make_multilabel_classification(random_state=0) + X = X[:, :2] + tree = DecisionTreeClassifier().fit(X, y) + + msg = "Multi-label and multi-output multi-class classifiers are not supported" + with pytest.raises(ValueError, match=msg): + DecisionBoundaryDisplay.from_estimator( + tree, + X, + response_method=response_method, + ) + + +@pytest.mark.parametrize("response_method", ["auto", "predict", "predict_proba"]) +def test_multi_output_multi_class_classifier_error(response_method): + """Check that multi-output multi-class classifier raises correct error.""" + X = np.asarray([[0, 1], [1, 2]]) + y = np.asarray([["tree", "cat"], ["cat", "tree"]]) + tree = DecisionTreeClassifier().fit(X, y) + + msg = "Multi-label and multi-output multi-class classifiers are not supported" + with pytest.raises(ValueError, match=msg): + DecisionBoundaryDisplay.from_estimator( + tree, + X, + response_method=response_method, + ) + + +def test_multioutput_regressor_error(pyplot): + """Check that multioutput regressor raises correct error.""" + X = np.asarray([[0, 1], [1, 2]]) + y = np.asarray([[0, 1], [4, 1]]) + tree = DecisionTreeRegressor().fit(X, y) + with pytest.raises(ValueError, match="Multi-output regressors are not supported"): + DecisionBoundaryDisplay.from_estimator(tree, X) + + def test_dataframe_labels_used(pyplot, fitted_clf): + """Check that column names are used for pandas.""" pd = pytest.importorskip("pandas") df = pd.DataFrame(X, columns=["col_x", "col_y"]) From 702e44625d23dc76264391e3bb61805a9f53c017 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Wed, 23 Feb 2022 23:41:56 -0500 Subject: [PATCH 43/48] CLN Remove comment --- sklearn/inspection/_plot/decision_boundary.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/sklearn/inspection/_plot/decision_boundary.py b/sklearn/inspection/_plot/decision_boundary.py index 4aa0f0040be4f..2dfe57c82c7e9 100644 --- a/sklearn/inspection/_plot/decision_boundary.py +++ b/sklearn/inspection/_plot/decision_boundary.py @@ -308,8 +308,6 @@ def from_estimator( [class_name_to_idx[target] for target in response], dtype=np.int32 ) - # TODO: Check for multi-label classifier, a multi-output multiclass - # classifier or a multioutput regressor and error. if response.ndim != 1: if is_regressor(estimator): raise ValueError("Multi-output regressors are not supported") From 0f923140f236bd89791b2d062ae59584227294e7 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Mon, 7 Mar 2022 14:49:02 -0500 Subject: [PATCH 44/48] TST Fixes test error --- .../inspection/_plot/tests/test_boundary_decision_display.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/inspection/_plot/tests/test_boundary_decision_display.py b/sklearn/inspection/_plot/tests/test_boundary_decision_display.py index e1f25ff0cc607..955deb33331d6 100644 --- a/sklearn/inspection/_plot/tests/test_boundary_decision_display.py +++ b/sklearn/inspection/_plot/tests/test_boundary_decision_display.py @@ -225,7 +225,7 @@ def fit(self, X, y): @pytest.mark.parametrize("response_method", ["auto", "predict", "predict_proba"]) -def test_multilabel_classifier_error(response_method): +def test_multilabel_classifier_error(pyplot, response_method): """Check that multilabel classifier raises correct error.""" X, y = make_multilabel_classification(random_state=0) X = X[:, :2] @@ -241,7 +241,7 @@ def test_multilabel_classifier_error(response_method): @pytest.mark.parametrize("response_method", ["auto", "predict", "predict_proba"]) -def test_multi_output_multi_class_classifier_error(response_method): +def test_multi_output_multi_class_classifier_error(pyplot, response_method): """Check that multi-output multi-class classifier raises correct error.""" X = np.asarray([[0, 1], [1, 2]]) y = np.asarray([["tree", "cat"], ["cat", "tree"]]) From b864194b551173689aa6ad07341da01a4f6b54f7 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Sun, 13 Mar 2022 15:59:50 -0400 Subject: [PATCH 45/48] FIX Uses labelencoder --- sklearn/inspection/_plot/decision_boundary.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/sklearn/inspection/_plot/decision_boundary.py b/sklearn/inspection/_plot/decision_boundary.py index 2dfe57c82c7e9..19d7bd51a3b8d 100644 --- a/sklearn/inspection/_plot/decision_boundary.py +++ b/sklearn/inspection/_plot/decision_boundary.py @@ -2,6 +2,7 @@ import numpy as np +from ...preprocessing import LabelEncoder from ...utils import check_matplotlib_support from ...utils import _safe_indexing from ...base import is_regressor @@ -299,14 +300,11 @@ def from_estimator( pred_func = _check_boundary_response_method(estimator, response_method) response = pred_func(np.c_[xx0.ravel(), xx1.ravel()]) - # convert strings predictions to integers - if pred_func.__name__ == "predict" and response.dtype.kind in {"O", "U"}: - class_name_to_idx = { - name: idx for idx, name in enumerate(estimator.classes_) - } - response = np.asarray( - [class_name_to_idx[target] for target in response], dtype=np.int32 - ) + # convert classes predictions into integers + if pred_func.__name__ == "predict" and hasattr(estimator, "classes_"): + encoder = LabelEncoder() + encoder.classes_ = estimator.classes_ + response = encoder.transform(response) if response.ndim != 1: if is_regressor(estimator): From aebbb6b7f61cad94d150b47f87567188645c3825 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Tue, 15 Mar 2022 12:41:08 -0400 Subject: [PATCH 46/48] REV Remove unneeded whats new item --- doc/whats_new/v1.1.rst | 7 ------- 1 file changed, 7 deletions(-) diff --git a/doc/whats_new/v1.1.rst b/doc/whats_new/v1.1.rst index 6c761654609ac..cb5477deb0ae4 100644 --- a/doc/whats_new/v1.1.rst +++ b/doc/whats_new/v1.1.rst @@ -239,13 +239,6 @@ Changelog get accurate results when the number of features is large. :pr:`21109` by :user:`Smile `. -- |Fix| :class:`decomposition.FastICA` now validates input parameters in `fit` instead of `__init__`. - :pr:`21432` by :user:`Hannah Bohle ` and :user:`Maren Westermann `. - -- |Fix| :class:`decomposition.FactorAnalysis` now validates input parameters - in `fit` instead of `__init__`. - :pr:`21713` by :user:`Haya ` and - :user:`Krum Arnaudov `. - |Enhancement| :func:`decomposition.dict_learning`, :func:`decomposition.dict_learning_online` and :func:`decomposition.sparse_encode` preserve dtype for `numpy.float32`. :class:`decomposition.DictionaryLearning`, :class:`decompsition.MiniBatchDictionaryLearning` From 666b1be618e290915408aa9f34dd05fab8d1fb61 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=A9mie=20du=20Boisberranger?= <34657725+jeremiedbb@users.noreply.github.com> Date: Fri, 25 Mar 2022 16:01:22 +0100 Subject: [PATCH 47/48] Update sklearn/inspection/_plot/decision_boundary.py --- sklearn/inspection/_plot/decision_boundary.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/inspection/_plot/decision_boundary.py b/sklearn/inspection/_plot/decision_boundary.py index 19d7bd51a3b8d..adab0e50538ed 100644 --- a/sklearn/inspection/_plot/decision_boundary.py +++ b/sklearn/inspection/_plot/decision_boundary.py @@ -60,7 +60,7 @@ def _check_boundary_response_method(estimator, response_method): class DecisionBoundaryDisplay: """Decisions boundary visualization. - It is recommend to use + It is recommended to use :func:`~sklearn.inspection.DecisionBoundaryDisplay.from_estimator` to create a :class:`DecisionBoundaryDisplay`. All parameters are stored as attributes. From 67c8fc8a2d149e8e2c0a4471c29a39fb8eea4697 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=A9mie=20du=20Boisberranger?= <34657725+jeremiedbb@users.noreply.github.com> Date: Fri, 25 Mar 2022 16:02:48 +0100 Subject: [PATCH 48/48] Update sklearn/inspection/_plot/decision_boundary.py --- sklearn/inspection/_plot/decision_boundary.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/sklearn/inspection/_plot/decision_boundary.py b/sklearn/inspection/_plot/decision_boundary.py index adab0e50538ed..78a8b16bd577a 100644 --- a/sklearn/inspection/_plot/decision_boundary.py +++ b/sklearn/inspection/_plot/decision_boundary.py @@ -181,8 +181,6 @@ def from_estimator( Read more in the :ref:`User Guide `. - .. versionadded:: 1.0 - Parameters ---------- estimator : object