diff --git a/doc/whats_new/v0.22.rst b/doc/whats_new/v0.22.rst index 923ec57bf40f4..049a22ad73b78 100644 --- a/doc/whats_new/v0.22.rst +++ b/doc/whats_new/v0.22.rst @@ -370,6 +370,10 @@ Changelog :class:`ensemble.HistGradientBoostingRegressor`. :pr:`13769` by `Nicolas Hug`_. +- |Enhancement| :func:`inspection.partial_dependence` accepts pandas DataFrame + and :class:`pipeline.Pipeline` containing :class:`compose.ColumnTransformer`. + :pr:`14028` by :user:`Guillaume Lemaitre `. + :mod:`sklearn.kernel_approximation` ................................... diff --git a/examples/inspection/plot_partial_dependence.py b/examples/inspection/plot_partial_dependence.py index 0d79401e3f662..d7564d5ec95c7 100644 --- a/examples/inspection/plot_partial_dependence.py +++ b/examples/inspection/plot_partial_dependence.py @@ -30,6 +30,7 @@ from time import time import numpy as np +import pandas as pd import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D @@ -54,8 +55,8 @@ # (here the average target, by default) cal_housing = fetch_california_housing() -names = cal_housing.feature_names -X, y = cal_housing.data, cal_housing.target +X = pd.DataFrame(cal_housing.data, columns=cal_housing.feature_names) +y = cal_housing.target y -= y.mean() @@ -104,8 +105,9 @@ tic = time() # We don't compute the 2-way PDP (5, 1) here, because it is a lot slower # with the brute method. -features = [0, 5, 1, 2] -plot_partial_dependence(est, X_train, features, feature_names=names, +features = ['MedInc', 'AveOccup', 'HouseAge', 'AveRooms'] +plot_partial_dependence(est, X_train, features, + feature_names=X_train.columns.tolist(), n_jobs=3, grid_resolution=20) print("done in {:.3f}s".format(time() - tic)) fig = plt.gcf() @@ -143,8 +145,10 @@ print('Computing partial dependence plots...') tic = time() -features = [0, 5, 1, 2, (5, 1)] -plot_partial_dependence(est, X_train, features, feature_names=names, +features = ['MedInc', 'AveOccup', 'HouseAge', 'AveRooms', + ('AveOccup', 'HouseAge')] +plot_partial_dependence(est, X_train, features, + feature_names=X_train.columns.tolist(), n_jobs=3, grid_resolution=20) print("done in {:.3f}s".format(time() - tic)) fig = plt.gcf() @@ -192,16 +196,16 @@ fig = plt.figure() -target_feature = (1, 5) -pdp, axes = partial_dependence(est, X_train, target_feature, +features = ('AveOccup', 'HouseAge') +pdp, axes = partial_dependence(est, X_train, features=features, grid_resolution=20) XX, YY = np.meshgrid(axes[0], axes[1]) Z = pdp[0].T ax = Axes3D(fig) surf = ax.plot_surface(XX, YY, Z, rstride=1, cstride=1, cmap=plt.cm.BuPu, edgecolor='k') -ax.set_xlabel(names[target_feature[0]]) -ax.set_ylabel(names[target_feature[1]]) +ax.set_xlabel(features[0]) +ax.set_ylabel(features[1]) ax.set_zlabel('Partial dependence') # pretty init view ax.view_init(elev=22, azim=122) diff --git a/examples/plot_partial_dependence_visualization_api.py b/examples/plot_partial_dependence_visualization_api.py index 8884d52f80d25..911a2409efe0b 100644 --- a/examples/plot_partial_dependence_visualization_api.py +++ b/examples/plot_partial_dependence_visualization_api.py @@ -15,6 +15,7 @@ """ print(__doc__) +import pandas as pd import matplotlib.pyplot as plt from sklearn.datasets import load_boston from sklearn.neural_network import MLPRegressor @@ -32,8 +33,8 @@ # housing price dataset. boston = load_boston() -X, y = boston.data, boston.target -feature_names = boston.feature_names +X = pd.DataFrame(boston.data, columns=boston.feature_names) +y = boston.target tree = DecisionTreeRegressor() mlp = make_pipeline(StandardScaler(), @@ -55,7 +56,7 @@ fig, ax = plt.subplots(figsize=(12, 6)) ax.set_title("Decision Tree") tree_disp = plot_partial_dependence(tree, X, ["LSTAT", "RM"], - feature_names=feature_names, ax=ax) + feature_names=X.columns.tolist(), ax=ax) ############################################################################## # The partial depdendence curves can be plotted for the multi-layer perceptron. @@ -65,7 +66,7 @@ fig, ax = plt.subplots(figsize=(12, 6)) ax.set_title("Multi-layer Perceptron") mlp_disp = plot_partial_dependence(mlp, X, ["LSTAT", "RM"], - feature_names=feature_names, ax=ax, + feature_names=X.columns.tolist(), ax=ax, line_kw={"c": "red"}) ############################################################################## @@ -134,7 +135,7 @@ # the same axes. In this case, `tree_disp.axes_` is passed into the second # plot function. tree_disp = plot_partial_dependence(tree, X, ["LSTAT"], - feature_names=feature_names) + feature_names=X.columns.tolist()) mlp_disp = plot_partial_dependence(mlp, X, ["LSTAT"], - feature_names=feature_names, + feature_names=X.columns.tolist(), ax=tree_disp.axes_, line_kw={"c": "red"}) diff --git a/sklearn/inspection/_partial_dependence.py b/sklearn/inspection/_partial_dependence.py index 885c07a23883b..75cfdb10a621d 100644 --- a/sklearn/inspection/_partial_dependence.py +++ b/sklearn/inspection/_partial_dependence.py @@ -12,13 +12,18 @@ import warnings import numpy as np +from scipy import sparse from scipy.stats.mstats import mquantiles from joblib import Parallel, delayed from ..base import is_classifier, is_regressor +from ..pipeline import Pipeline from ..utils.extmath import cartesian from ..utils import check_array from ..utils import check_matplotlib_support # noqa +from ..utils import _safe_indexing +from ..utils import _determine_key_type +from ..utils import _get_column_indices from ..utils.validation import check_is_fitted from ..tree._tree import DTYPE from ..exceptions import NotFittedError @@ -44,9 +49,11 @@ def _grid_from_X(X, percentiles, grid_resolution): ---------- X : ndarray, shape (n_samples, n_target_features) The data + percentiles : tuple of floats The percentiles which are used to construct the extreme values of the grid. Must be in [0, 1]. + grid_resolution : int The number of equally spaced points to be placed on the grid for each feature. @@ -56,6 +63,7 @@ def _grid_from_X(X, percentiles, grid_resolution): grid : ndarray, shape (n_points, n_target_features) A value for each feature at each point in the grid. ``n_points`` is always ``<= grid_resolution ** X.shape[1]``. + values : list of 1d ndarrays The values with which the grid has been created. The size of each array ``values[j]`` is either ``grid_resolution``, or the number of @@ -74,16 +82,16 @@ def _grid_from_X(X, percentiles, grid_resolution): values = [] for feature in range(X.shape[1]): - uniques = np.unique(X[:, feature]) + uniques = np.unique(_safe_indexing(X, feature, axis=1)) if uniques.shape[0] < grid_resolution: # feature has low resolution use unique vals axis = uniques else: # create axis based on percentiles and grid resolution - emp_percentiles = mquantiles(X[:, feature], prob=percentiles, - axis=0) - if np.allclose(emp_percentiles[0], - emp_percentiles[1]): + emp_percentiles = mquantiles( + _safe_indexing(X, feature, axis=1), prob=percentiles, axis=0 + ) + if np.allclose(emp_percentiles[0], emp_percentiles[1]): raise ValueError( 'percentiles are too close to each other, ' 'unable to build the grid. Please choose percentiles ' @@ -130,7 +138,10 @@ def _partial_dependence_brute(est, grid, features, X, response_method): for new_values in grid: X_eval = X.copy() for i, variable in enumerate(features): - X_eval[:, variable] = new_values[i] + if hasattr(X_eval, 'iloc'): + X_eval.iloc[:, variable] = new_values[i] + else: + X_eval[:, variable] = new_values[i] try: predictions = prediction_method(X_eval) @@ -142,7 +153,7 @@ def _partial_dependence_brute(est, grid, features, X, response_method): # (n_points,) for non-multioutput regressors # (n_points, n_tasks) for multioutput regressors # (n_points, 1) for the regressors in cross_decomposition (I think) - # (n_points, 2) for binary classifaction + # (n_points, 2) for binary classification # (n_points, n_classes) for multiclass classification # average over samples @@ -183,13 +194,16 @@ def partial_dependence(estimator, X, features, response_method='auto', A fitted estimator object implementing :term:`predict`, :term:`predict_proba`, or :term:`decision_function`. Multioutput-multiclass classifiers are not supported. - X : array-like, shape (n_samples, n_features) + + X : {array-like or dataframe} of shape (n_samples, n_features) ``X`` is used both to generate a grid of values for the ``features``, and to compute the averaged predictions when method is 'brute'. - features : list or array-like of int - The target features for which the partial dependency should be - computed. + + features : array-like of {int, str} + The feature (e.g. `[0]`) or pair of interacting features + (e.g. `[(0, 1)]`) for which the partial dependency should be computed. + response_method : 'auto', 'predict_proba' or 'decision_function', \ optional (default='auto') Specifies whether to use :term:`predict_proba` or @@ -199,12 +213,15 @@ def partial_dependence(estimator, X, features, response_method='auto', and we revert to :term:`decision_function` if it doesn't exist. If ``method`` is 'recursion', the response is always the output of :term:`decision_function`. + percentiles : tuple of float, optional (default=(0.05, 0.95)) The lower and upper percentile used to create the extreme values for the grid. Must be in [0, 1]. + grid_resolution : int, optional (default=100) The number of equally spaced points on the grid, for each target feature. + method : str, optional (default='auto') The method used to calculate the averaged predictions: @@ -216,7 +233,7 @@ def partial_dependence(estimator, X, features, response_method='auto', but is more efficient in terms of speed. With this method, ``X`` is only used to build the grid and the partial dependences are computed using the training - data. This method does not account for the ``init`` predicor of + data. This method does not account for the ``init`` predictor of the boosting process, which may lead to incorrect values (see warning below). With this method, the target response of a classifier is always the decision function, not the predicted @@ -248,6 +265,7 @@ def partial_dependence(estimator, X, features, response_method='auto', regression. For classical regression and binary classification ``n_outputs==1``. ``n_values_feature_j`` corresponds to the size ``values[j]``. + values : seq of 1d ndarrays The values with which the grid has been created. The generated grid is a cartesian product of the arrays in ``values``. ``len(values) == @@ -284,22 +302,32 @@ def partial_dependence(estimator, X, features, response_method='auto', `, which do not have an ``init`` parameter. """ - if not (is_classifier(estimator) or is_regressor(estimator)): raise ValueError( - "'estimator' must be a fitted regressor or classifier.") + "'estimator' must be a fitted regressor or classifier." + ) - if is_classifier(estimator): - if not hasattr(estimator, 'classes_'): - raise ValueError( - "'estimator' parameter must be a fitted estimator" - ) - if isinstance(estimator.classes_[0], np.ndarray): - raise ValueError( - 'Multiclass-multioutput estimators are not supported' - ) + if isinstance(estimator, Pipeline): + # TODO: to be removed if/when pipeline get a `steps_` attributes + # assuming Pipeline is the only estimator that does not store a new + # attribute + for est in estimator: + # FIXME: remove the None option when it will be deprecated + if est not in (None, 'drop'): + check_is_fitted(est) + else: + check_is_fitted(estimator) - X = check_array(X) + if (is_classifier(estimator) and + isinstance(estimator.classes_[0], np.ndarray)): + raise ValueError( + 'Multiclass-multioutput estimators are not supported' + ) + + # Use check_array only on lists and other non-array-likes / sparse. Do not + # convert DataFrame into a NumPy array. + if not(hasattr(X, '__array__') or sparse.issparse(X)): + X = check_array(X, force_all_finite='allow-nan', dtype=np.object) accepted_responses = ('auto', 'predict_proba', 'decision_function') if response_method not in accepted_responses: @@ -312,6 +340,7 @@ def partial_dependence(estimator, X, features, response_method='auto', "The response_method parameter is ignored for regressors and " "must be 'auto'." ) + accepted_methods = ('brute', 'recursion', 'auto') if method not in accepted_methods: raise ValueError( @@ -349,21 +378,32 @@ def partial_dependence(estimator, X, features, response_method='auto', "'decision_function'. Got {}.".format(response_method) ) - n_features = X.shape[1] - features = np.asarray(features, dtype=np.int32, order='C').ravel() - if any(not (0 <= f < n_features) for f in features): - raise ValueError('all features must be in [0, %d]' - % (n_features - 1)) + if _determine_key_type(features, accept_slice=False) == 'int': + # _get_column_indices() supports negative indexing. Here, we limit + # the indexing to be positive. The upper bound will be checked + # by _get_column_indices() + if np.any(np.less(features, 0)): + raise ValueError( + 'all features must be in [0, {}]'.format(X.shape[1] - 1) + ) + + features_indices = np.asarray( + _get_column_indices(X, features), dtype=np.int32, order='C' + ).ravel() + + grid, values = _grid_from_X( + _safe_indexing(X, features_indices, axis=1), percentiles, + grid_resolution + ) - grid, values = _grid_from_X(X[:, features], percentiles, - grid_resolution) if method == 'brute': - averaged_predictions = _partial_dependence_brute(estimator, grid, - features, X, - response_method) + averaged_predictions = _partial_dependence_brute( + estimator, grid, features_indices, X, response_method + ) else: - averaged_predictions = _partial_dependence_recursion(estimator, grid, - features) + averaged_predictions = _partial_dependence_recursion( + estimator, grid, features_indices + ) # reshape averaged_predictions to # (n_outputs, n_values_feature_0, n_values_feature_1, ...) @@ -394,7 +434,7 @@ def plot_partial_dependence(estimator, X, features, feature_names=None, :term:`predict_proba`, or :term:`decision_function`. Multioutput-multiclass classifiers are not supported. - X : array-like, shape (n_samples, n_features) + X : {array-like or dataframe} of shape (n_samples, n_features) The data to use to build the grid of values on which the dependence will be evaluated. This is usually the training data. @@ -452,7 +492,7 @@ def plot_partial_dependence(estimator, X, features, feature_names=None, but is more efficient in terms of speed. With this method, ``X`` is optional and is only used to build the grid and the partial dependences are computed using the training - data. This method does not account for the ``init`` predicor of + data. This method does not account for the ``init`` predictor of the boosting process, which may lead to incorrect values (see warning below. With this method, the target response of a classifier is always the decision function, not the predicted @@ -491,7 +531,7 @@ def plot_partial_dependence(estimator, X, features, feature_names=None, ax : Matplotlib axes or array-like of Matplotlib axes, default=None - If a single axis is passed in, it is treated as a bounding axes - and a grid of partial depedendence plots will be drawn within + and a grid of partial dependence plots will be drawn within these bounds. The `n_cols` parameter controls the number of columns in the grid. - If an array-like of axes are passed in, the partial dependence @@ -582,7 +622,7 @@ def convert_feature(fx): except TypeError: raise ValueError('Each entry in features must be either an int, ' 'a string, or an iterable of size at most 2.') - if not (1 <= np.size(fxs) <= 2): + if not 1 <= np.size(fxs) <= 2: raise ValueError('Each entry in features must be either an int, ' 'a string, or an iterable of size at most 2.') @@ -680,7 +720,7 @@ class PartialDependenceDisplay: plot a two-way partial dependence curve as a contour plot. feature_names : list of str - Feature names corrsponding to the indicies in ``features``. + Feature names corresponding to the indices in ``features``. target_idx : int @@ -748,7 +788,7 @@ def plot(self, ax=None, n_cols=3, line_kw=None, contour_kw=None): ---------- ax : Matplotlib axes or array-like of Matplotlib axes, default=None - If a single axis is passed in, it is treated as a bounding axes - and a grid of partial depedendence plots will be drawn within + and a grid of partial dependence plots will be drawn within these bounds. The `n_cols` parameter controls the number of columns in the grid. - If an array-like of axes are passed in, the partial dependence diff --git a/sklearn/inspection/tests/test_partial_dependence.py b/sklearn/inspection/tests/test_partial_dependence.py index 89d411eafa616..8d3194f34249f 100644 --- a/sklearn/inspection/tests/test_partial_dependence.py +++ b/sklearn/inspection/tests/test_partial_dependence.py @@ -24,12 +24,15 @@ from sklearn.datasets import load_iris from sklearn.datasets import make_classification, make_regression from sklearn.cluster import KMeans +from sklearn.compose import make_column_transformer from sklearn.metrics import r2_score -from sklearn.pipeline import make_pipeline from sklearn.preprocessing import PolynomialFeatures from sklearn.preprocessing import StandardScaler +from sklearn.preprocessing import RobustScaler +from sklearn.pipeline import make_pipeline from sklearn.dummy import DummyClassifier -from sklearn.base import BaseEstimator, ClassifierMixin +from sklearn.base import BaseEstimator, ClassifierMixin, clone +from sklearn.exceptions import NotFittedError from sklearn.utils._testing import assert_allclose from sklearn.utils._testing import assert_array_equal @@ -50,6 +53,9 @@ multioutput_regression_data = (make_regression(n_samples=50, n_targets=2, random_state=0), 2) +# iris +iris = load_iris() + @pytest.mark.parametrize('Estimator, method, data', [ (GradientBoostingClassifier, 'recursion', binary_classification_data), @@ -335,12 +341,28 @@ def test_partial_dependence_error(estimator, params, err_msg): partial_dependence(estimator, X, **params) +@pytest.mark.parametrize( + "with_dataframe, err_msg", + [(True, "Only array-like or scalar are supported"), + (False, "Only array-like or scalar are supported")] +) +def test_partial_dependence_slice_error(with_dataframe, err_msg): + X, y = make_classification(random_state=0) + if with_dataframe: + pd = pytest.importorskip('pandas') + X = pd.DataFrame(X) + estimator = LogisticRegression().fit(X, y) + + with pytest.raises(TypeError, match=err_msg): + partial_dependence(estimator, X, features=slice(0, 2, 1)) + + @pytest.mark.parametrize( 'estimator', [LinearRegression(), GradientBoostingClassifier(random_state=0)] ) -@pytest.mark.parametrize('features', [-1, 1000000]) -def test_partial_dependence_unknown_feature(estimator, features): +@pytest.mark.parametrize('features', [-1, 10000]) +def test_partial_dependence_unknown_feature_indices(estimator, features): X, y = make_classification(random_state=0) estimator.fit(X, y) @@ -353,10 +375,16 @@ def test_partial_dependence_unknown_feature(estimator, features): 'estimator', [LinearRegression(), GradientBoostingClassifier(random_state=0)] ) -def test_partial_dependence_unfitted_estimator(estimator): - err_msg = "'estimator' parameter must be a fitted estimator" +def test_partial_dependence_unknown_feature_string(estimator): + pd = pytest.importorskip("pandas") + X, y = make_classification(random_state=0) + df = pd.DataFrame(X) + estimator.fit(df, y) + + features = ['random'] + err_msg = 'A given column is not a column of the dataframe' with pytest.raises(ValueError, match=err_msg): - partial_dependence(estimator, X, [0]) + partial_dependence(estimator, df, features) @pytest.mark.parametrize( @@ -427,13 +455,119 @@ def test_partial_dependence_pipeline(): features = 0 pdp_pipe, values_pipe = partial_dependence( - pipe, iris.data, features=[features] + pipe, iris.data, features=[features], grid_resolution=10 ) pdp_clf, values_clf = partial_dependence( - clf, scaler.transform(iris.data), features=[features] + clf, scaler.transform(iris.data), features=[features], + grid_resolution=10 ) assert_allclose(pdp_pipe, pdp_clf) assert_allclose( values_pipe[0], values_clf[0] * scaler.scale_[features] + scaler.mean_[features] ) + + +@pytest.mark.parametrize( + "estimator", + [LogisticRegression(max_iter=1000, random_state=0), + GradientBoostingClassifier(random_state=0, n_estimators=5)], + ids=['estimator-brute', 'estimator-recursion'] +) +@pytest.mark.parametrize( + "preprocessor", + [None, + make_column_transformer( + (StandardScaler(), [iris.feature_names[i] for i in (0, 2)]), + (RobustScaler(), [iris.feature_names[i] for i in (1, 3)])), + make_column_transformer( + (StandardScaler(), [iris.feature_names[i] for i in (0, 2)]), + remainder='passthrough')], + ids=['None', 'column-transformer', 'column-transformer-passthrough'] +) +@pytest.mark.parametrize( + "features", + [[0, 2], [iris.feature_names[i] for i in (0, 2)]], + ids=['features-integer', 'features-string'] +) +def test_partial_dependence_dataframe(estimator, preprocessor, features): + # check that the partial dependence support dataframe and pipeline + # including a column transformer + pd = pytest.importorskip("pandas") + df = pd.DataFrame(iris.data, columns=iris.feature_names) + + pipe = make_pipeline(preprocessor, estimator) + pipe.fit(df, iris.target) + pdp_pipe, values_pipe = partial_dependence( + pipe, df, features=features, grid_resolution=10 + ) + + # the column transformer will reorder the column when transforming + # we mixed the index to be sure that we are computing the partial + # dependence of the right columns + if preprocessor is not None: + X_proc = clone(preprocessor).fit_transform(df) + features_clf = [0, 1] + else: + X_proc = df + features_clf = [0, 2] + + clf = clone(estimator).fit(X_proc, iris.target) + pdp_clf, values_clf = partial_dependence( + clf, X_proc, features=features_clf, method='brute', grid_resolution=10 + ) + + assert_allclose(pdp_pipe, pdp_clf) + if preprocessor is not None: + scaler = preprocessor.named_transformers_['standardscaler'] + assert_allclose( + values_pipe[1], + values_clf[1] * scaler.scale_[1] + scaler.mean_[1] + ) + else: + assert_allclose(values_pipe[1], values_clf[1]) + + +@pytest.mark.parametrize( + "features, expected_pd_shape", + [(0, (3, 10)), + (iris.feature_names[0], (3, 10)), + ([0, 2], (3, 10, 10)), + ([iris.feature_names[i] for i in (0, 2)], (3, 10, 10)), + ([True, False, True, False], (3, 10, 10))], + ids=['scalar-int', 'scalar-str', 'list-int', 'list-str', 'mask'] +) +def test_partial_dependence_feature_type(features, expected_pd_shape): + # check all possible features type supported in PDP + pd = pytest.importorskip("pandas") + df = pd.DataFrame(iris.data, columns=iris.feature_names) + + preprocessor = make_column_transformer( + (StandardScaler(), [iris.feature_names[i] for i in (0, 2)]), + (RobustScaler(), [iris.feature_names[i] for i in (1, 3)]) + ) + pipe = make_pipeline( + preprocessor, LogisticRegression(max_iter=1000, random_state=0) + ) + pipe.fit(df, iris.target) + pdp_pipe, values_pipe = partial_dependence( + pipe, df, features=features, grid_resolution=10 + ) + assert pdp_pipe.shape == expected_pd_shape + assert len(values_pipe) == len(pdp_pipe.shape) - 1 + + +@pytest.mark.parametrize( + "estimator", [LinearRegression(), LogisticRegression(), + GradientBoostingRegressor(), GradientBoostingClassifier()] +) +def test_partial_dependence_unfitted(estimator): + X = iris.data + preprocessor = make_column_transformer( + (StandardScaler(), [0, 2]), (RobustScaler(), [1, 3]) + ) + pipe = make_pipeline(preprocessor, estimator) + with pytest.raises(NotFittedError, match="is not fitted yet"): + partial_dependence(pipe, X, features=[0, 2], grid_resolution=10) + with pytest.raises(NotFittedError, match="is not fitted yet"): + partial_dependence(estimator, X, features=[0, 2], grid_resolution=10) diff --git a/sklearn/inspection/tests/test_permutation_importance.py b/sklearn/inspection/tests/test_permutation_importance.py index b444310695dee..671a1e11b1fec 100644 --- a/sklearn/inspection/tests/test_permutation_importance.py +++ b/sklearn/inspection/tests/test_permutation_importance.py @@ -18,7 +18,6 @@ from sklearn.preprocessing import StandardScaler from sklearn.preprocessing import scale - @pytest.mark.parametrize("n_jobs", [1, 2]) def test_permutation_importance_correlated_feature_regression(n_jobs): # Make sure that feature highly correlated to the target have a higher diff --git a/sklearn/inspection/tests/test_plot_partial_dependence.py b/sklearn/inspection/tests/test_plot_partial_dependence.py index 5349985de2f10..9d6b1ecb6f94d 100644 --- a/sklearn/inspection/tests/test_plot_partial_dependence.py +++ b/sklearn/inspection/tests/test_plot_partial_dependence.py @@ -307,6 +307,18 @@ def test_plot_partial_dependence_multioutput(pyplot, target): assert ax.get_xlabel() == "{}".format(i) +def test_plot_partial_dependence_dataframe(pyplot, clf_boston, boston): + pd = pytest.importorskip('pandas') + df = pd.DataFrame(boston.data, columns=boston.feature_names) + + grid_resolution = 25 + + plot_partial_dependence( + clf_boston, df, ['TAX', 'AGE'], grid_resolution=grid_resolution, + feature_names=df.columns.tolist() + ) + + dummy_classification_data = make_classification(random_state=0) diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index 4efe89aa844f7..4d4ef606341ca 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -192,6 +192,8 @@ def _array_indexing(array, key, key_dtype, axis): # check if we have an boolean array-likes to make the proper indexing if key_dtype == 'bool': key = np.asarray(key) + if isinstance(key, tuple): + key = list(key) return array[key] if axis == 0 else array[:, key] @@ -202,6 +204,8 @@ def _pandas_indexing(X, key, key_dtype, axis): # FIXME: solved in pandas 0.25 key = np.asarray(key) key = key if key.flags.writeable else key.copy() + elif isinstance(key, tuple): + key = list(key) # check whether we should index with loc or iloc indexer = X.iloc if key_dtype == 'int' else X.loc return indexer[:, key] if axis else indexer[key] @@ -219,7 +223,7 @@ def _list_indexing(X, key, key_dtype): return [X[idx] for idx in key] -def _determine_key_type(key): +def _determine_key_type(key, accept_slice=True): """Determine the data type of key. Parameters @@ -227,6 +231,9 @@ def _determine_key_type(key): key : scalar, slice or array-like The key from which we want to infer the data type. + accept_slice : bool, default=True + Whether or not to raise an error if the key is a slice. + Returns ------- dtype : {'int', 'str', 'bool', None} @@ -248,6 +255,11 @@ def _determine_key_type(key): except KeyError: raise ValueError(err_msg) if isinstance(key, slice): + if not accept_slice: + raise TypeError( + 'Only array-like or scalar are supported. ' + 'A Python slice was given.' + ) if key.start is None and key.stop is None: return None key_start_type = _determine_key_type(key.start) @@ -258,7 +270,7 @@ def _determine_key_type(key): if key_start_type is not None: return key_start_type return key_stop_type - if isinstance(key, list): + if isinstance(key, (list, tuple)): unique_key = set(key) key_type = {_determine_key_type(elt) for elt in unique_key} if not key_type: @@ -411,7 +423,7 @@ def _get_column_indices(X, key): key_dtype = _determine_key_type(key) - if isinstance(key, list) and not key: + if isinstance(key, (list, tuple)) and not key: # we get an empty list return [] elif key_dtype in ('bool', 'int'): diff --git a/sklearn/utils/tests/test_utils.py b/sklearn/utils/tests/test_utils.py index eeafa14b4020e..2cf1e59a73f29 100644 --- a/sklearn/utils/tests/test_utils.py +++ b/sklearn/utils/tests/test_utils.py @@ -205,15 +205,19 @@ def test_column_or_1d(): (np.bool_(True), 'bool'), ([0, 1, 2], 'int'), (['0', '1', '2'], 'str'), + ((0, 1, 2), 'int'), + (('0', '1', '2'), 'str'), (slice(None, None), None), (slice(0, 2), 'int'), (np.array([0, 1, 2], dtype=np.int32), 'int'), (np.array([0, 1, 2], dtype=np.int64), 'int'), (np.array([0, 1, 2], dtype=np.uint8), 'int'), ([True, False], 'bool'), + ((True, False), 'bool'), (np.array([True, False]), 'bool'), ('col_0', 'str'), (['col_0', 'col_1', 'col_2'], 'str'), + (('col_0', 'col_1', 'col_2'), 'str'), (slice('begin', 'end'), 'str'), (np.array(['col_0', 'col_1', 'col_2']), 'str'), (np.array(['col_0', 'col_1', 'col_2'], dtype=object), 'str')] @@ -227,9 +231,16 @@ def test_determine_key_type_error(): _determine_key_type(1.0) +def test_determine_key_type_slice_error(): + with pytest.raises(TypeError, match="Only array-like or scalar are"): + _determine_key_type(slice(0, 2, 1), accept_slice=False) + + def _convert_container(container, constructor_name, columns_name=None): if constructor_name == 'list': return list(container) + elif constructor_name == 'tuple': + return tuple(container) elif constructor_name == 'array': return np.asarray(container) elif constructor_name == 'sparse': @@ -247,7 +258,9 @@ def _convert_container(container, constructor_name, columns_name=None): @pytest.mark.parametrize( "array_type", ["list", "array", "sparse", "dataframe"] ) -@pytest.mark.parametrize("indices_type", ["list", "array", "series", "slice"]) +@pytest.mark.parametrize( + "indices_type", ["list", "tuple", "array", "series", "slice"] +) def test_safe_indexing_2d_container_axis_0(array_type, indices_type): indices = [1, 2] if indices_type == 'slice' and isinstance(indices[1], int): @@ -261,7 +274,9 @@ def test_safe_indexing_2d_container_axis_0(array_type, indices_type): @pytest.mark.parametrize("array_type", ["list", "array", "series"]) -@pytest.mark.parametrize("indices_type", ["list", "array", "series", "slice"]) +@pytest.mark.parametrize( + "indices_type", ["list", "tuple", "array", "series", "slice"] +) def test_safe_indexing_1d_container(array_type, indices_type): indices = [1, 2] if indices_type == 'slice' and isinstance(indices[1], int): @@ -275,7 +290,9 @@ def test_safe_indexing_1d_container(array_type, indices_type): @pytest.mark.parametrize("array_type", ["array", "sparse", "dataframe"]) -@pytest.mark.parametrize("indices_type", ["list", "array", "series", "slice"]) +@pytest.mark.parametrize( + "indices_type", ["list", "tuple", "array", "series", "slice"] +) @pytest.mark.parametrize("indices", [[1, 2], ["col_1", "col_2"]]) def test_safe_indexing_2d_container_axis_1(array_type, indices_type, indices): # validation of the indices @@ -328,7 +345,7 @@ def test_safe_indexing_2d_read_only_axis_1(array_read_only, indices_read_only, @pytest.mark.parametrize("array_type", ["list", "array", "series"]) -@pytest.mark.parametrize("indices_type", ["list", "array", "series"]) +@pytest.mark.parametrize("indices_type", ["list", "tuple", "array", "series"]) def test_safe_indexing_1d_container_mask(array_type, indices_type): indices = [False] + [True] * 2 + [False] * 6 array = _convert_container([1, 2, 3, 4, 5, 6, 7, 8, 9], array_type) @@ -340,7 +357,7 @@ def test_safe_indexing_1d_container_mask(array_type, indices_type): @pytest.mark.parametrize("array_type", ["array", "sparse", "dataframe"]) -@pytest.mark.parametrize("indices_type", ["list", "array", "series"]) +@pytest.mark.parametrize("indices_type", ["list", "tuple", "array", "series"]) @pytest.mark.parametrize( "axis, expected_subset", [(0, [[4, 5, 6], [7, 8, 9]]),