Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 27 additions & 13 deletions sklearn/inspection/_partial_dependence.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from ..utils import _get_column_indices
from ..utils.validation import check_is_fitted
from ..tree._tree import DTYPE
from ..tree import DecisionTreeRegressor
from ..exceptions import NotFittedError
from ..ensemble._gb import BaseGradientBoosting
from sklearn.ensemble._hist_gradient_boosting.gradient_boosting import (
Expand Down Expand Up @@ -105,7 +106,14 @@ def _grid_from_X(X, percentiles, grid_resolution):


def _partial_dependence_recursion(est, grid, features):
return est._compute_partial_dependence_recursion(grid, features)
averaged_predictions = est._compute_partial_dependence_recursion(grid,
features)
if averaged_predictions.ndim == 1:
# reshape to (1, n_points) for consistency with
# _partial_dependence_brute
averaged_predictions = averaged_predictions.reshape(1, -1)

return averaged_predictions


def _partial_dependence_brute(est, grid, features, X, response_method):
Expand Down Expand Up @@ -225,11 +233,12 @@ def partial_dependence(estimator, X, features, response_method='auto',
method : str, optional (default='auto')
The method used to calculate the averaged predictions:

- 'recursion' is only supported for gradient boosting estimator (namely
:class:`GradientBoostingClassifier<sklearn.ensemble.GradientBoostingClassifier>`,
:class:`GradientBoostingRegressor<sklearn.ensemble.GradientBoostingRegressor>`,
:class:`HistGradientBoostingClassifier<sklearn.ensemble.HistGradientBoostingClassifier>`,
:class:`HistGradientBoostingRegressor<sklearn.ensemble.HistGradientBoostingRegressor>`)
- 'recursion' is only supported for some tree-based estimators, (namely
:class:`~sklearn.ensemble.GradientBoostingClassifier`,
:class:`~sklearn.ensemble.GradientBoostingRegressor`,
:class:`~sklearn.ensemble.HistGradientBoostingClassifier`,
:class:`~sklearn.ensemble.HistGradientBoostingRegressor`,
:class:`~sklearn.tree.DecisionTreeRegressor`)
but is more efficient in terms of speed.
With this method, ``X`` is only used to build the
grid and the partial dependences are computed using the training
Expand Down Expand Up @@ -351,19 +360,23 @@ def partial_dependence(estimator, X, features, response_method='auto',
if (isinstance(estimator, BaseGradientBoosting) and
estimator.init is None):
method = 'recursion'
elif isinstance(estimator, BaseHistGradientBoosting):
elif isinstance(estimator, (BaseHistGradientBoosting,
DecisionTreeRegressor)):
method = 'recursion'
else:
method = 'brute'

if method == 'recursion':
if not isinstance(estimator,
(BaseGradientBoosting, BaseHistGradientBoosting)):
(BaseGradientBoosting, BaseHistGradientBoosting,
DecisionTreeRegressor)):
supported_classes_recursion = (
'GradientBoostingClassifier',
'GradientBoostingRegressor',
'HistGradientBoostingClassifier',
'HistGradientBoostingRegressor',
'HistGradientBoostingRegressor',
'DecisionTreeRegressor',
)
raise ValueError(
"Only the following estimators support the 'recursion' "
Expand Down Expand Up @@ -501,11 +514,12 @@ def plot_partial_dependence(estimator, X, features, feature_names=None,
method : str, optional (default='auto')
The method to use to calculate the partial dependence predictions:

- 'recursion' is only supported for gradient boosting estimator (namely
:class:`GradientBoostingClassifier<sklearn.ensemble.GradientBoostingClassifier>`,
:class:`GradientBoostingRegressor<sklearn.ensemble.GradientBoostingRegressor>`,
:class:`HistGradientBoostingClassifier<sklearn.ensemble.HistGradientBoostingClassifier>`,
:class:`HistGradientBoostingRegressor<sklearn.ensemble.HistGradientBoostingRegressor>`)
- 'recursion' is only supported for some tree-based estimators, (namely
:class:`~sklearn.ensemble.GradientBoostingClassifier`,
:class:`~sklearn.ensemble.GradientBoostingRegressor`,
:class:`~sklearn.ensemble.HistGradientBoostingClassifier`,
:class:`~sklearn.ensemble.HistGradientBoostingRegressor`,
:class:`~sklearn.tree.DecisionTreeRegressor`)
but is more efficient in terms of speed.
With this method, ``X`` is optional and is only used to build the
grid and the partial dependences are computed using the training
Expand Down
59 changes: 57 additions & 2 deletions sklearn/inspection/tests/test_partial_dependence.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
from sklearn.utils._testing import assert_allclose
from sklearn.utils._testing import assert_array_equal
from sklearn.utils._testing import ignore_warnings
from sklearn.utils import _IS_32BIT
from sklearn.tree.tests.test_tree import assert_is_subtree


# toy sample
Expand Down Expand Up @@ -174,6 +176,11 @@ def test_partial_dependence_helpers(est, method, target_feature):
# samples.
# This also checks that the brute and recursion methods give the same
# output.
# Note that even on the trainset, the brute and the recursion methods
# aren't always strictly equivalent (despite what we say in the docs), in
# particular when the slow method generates unrealistic samples that have
# low mass in the joint distribution of the input features, and when some
# of the features are dependent. Hence the high tolerance on the checks.

X, y = make_regression(random_state=0, n_features=5, n_informative=5)
# The 'init' estimator for GBDT (here the average prediction) isn't taken
Expand Down Expand Up @@ -206,6 +213,53 @@ def test_partial_dependence_helpers(est, method, target_feature):
assert np.allclose(pdp, mean_predictions, rtol=rtol)


def test_decision_tree_vs_gradient_boosting():
# Make sure that the recursion method gives the same results on a
# DecisionTreeRegressor and a GradientBoostingRegressor with 1 tree and
# same parameters. The DecisionTreeRegressor doesn't pass the
# test_partial_dependence_helpers() test.

# Purely random dataset to avoid correlated features
n_samples = 100
n_features = 5
X = np.random.RandomState(0).randn(n_samples, n_features)
y = np.random.RandomState(0).randn(n_samples)

# The 'init' estimator for GBDT (here the average prediction) isn't taken
# into account with the recursion method, for technical reasons. We set
# the mean to 0 to that this 'bug' doesn't have any effect.
y = y - y.mean()

# set max_depth not too high to avoid splits with same gain but different
# features
max_depth = 5
gbdt = GradientBoostingRegressor(n_estimators=1, learning_rate=1,
criterion='mse', max_depth=max_depth,
random_state=0)
gbdt.fit(X, y)

tree = DecisionTreeRegressor(random_state=0, max_depth=max_depth)
tree.fit(X, y)

# sanity check
try:
assert_is_subtree(tree.tree_, gbdt[0, 0].tree_)
except AssertionError:
# For some reason the trees aren't exactly equal on 32bits, so the PDs
# cannot be equal either.
assert _IS_32BIT
return

grid = np.random.RandomState(0).randn(50).reshape(-1, 1)
for f in range(n_features):
features = np.array([f], dtype=np.int32)

pdp_gbdt = _partial_dependence_recursion(gbdt, grid, features)
pdp_tree = _partial_dependence_recursion(tree, grid, features)

np.testing.assert_allclose(pdp_gbdt, pdp_tree)


@pytest.mark.parametrize('est', (
GradientBoostingClassifier(random_state=0),
HistGradientBoostingClassifier(random_state=0),
Expand Down Expand Up @@ -236,8 +290,9 @@ def test_recursion_decision_function(est, target_feature):
LinearRegression(),
GradientBoostingRegressor(random_state=0),
HistGradientBoostingRegressor(random_state=0, min_samples_leaf=1,
max_leaf_nodes=None, max_iter=1))
)
max_leaf_nodes=None, max_iter=1),
DecisionTreeRegressor(random_state=0),
))
@pytest.mark.parametrize('power', (1, 2))
def test_partial_dependence_easy_target(est, power):
# If the target y only depends on one feature in an obvious way (linear or
Expand Down
28 changes: 28 additions & 0 deletions sklearn/tree/_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1242,6 +1242,34 @@ def n_classes_(self):
warnings.warn(msg, FutureWarning)
return np.array([1] * self.n_outputs_, dtype=np.intp)

def _compute_partial_dependence_recursion(self, grid, target_features):
"""Fast partial dependence computation.

Parameters
----------
grid : ndarray, shape (n_samples, n_target_features)
The grid points on which the partial dependence should be
evaluated.
target_features : ndarray, shape (n_target_features)
The set of target features for which the partial dependence
should be evaluated.

Returns
-------
averaged_predictions : ndarray, shape \
(n_trees_per_iteration, n_samples)
The value of the partial dependence function on each grid point.
"""
check_is_fitted(self,
msg="'estimator' parameter must be a fitted estimator")
grid = np.asarray(grid, dtype=DTYPE, order='C')
averaged_predictions = np.zeros(shape=grid.shape[0],
dtype=np.float64, order='C')

self.tree_.compute_partial_dependence(
grid, target_features, averaged_predictions)
return averaged_predictions


class ExtraTreeClassifier(DecisionTreeClassifier):
"""An extremely randomized tree classifier.
Expand Down