Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
74d369f
WIP
NicolasHug Dec 6, 2019
17bc2a3
test and doc
NicolasHug Dec 9, 2019
9fe5234
added comment
NicolasHug Dec 9, 2019
d8f5ee3
pep8
NicolasHug Dec 9, 2019
e52ae50
maybe fix 32 bits issue?
NicolasHug Dec 9, 2019
597006b
skip test if 32 bits
NicolasHug Dec 9, 2019
7d2761e
test recursion instead of brute
NicolasHug Dec 10, 2019
11b6489
use np.testing
NicolasHug Dec 10, 2019
7263640
put back skipif32bits
NicolasHug Dec 10, 2019
3ee45b6
try converting grid to float32
NicolasHug Dec 10, 2019
cf0df03
Update sklearn/inspection/tests/test_partial_dependence.py
NicolasHug Dec 10, 2019
fb8257b
pep8
NicolasHug Dec 10, 2019
db8dc09
still nope
NicolasHug Dec 10, 2019
f141732
Merge branch 'master' of github.com:scikit-learn/scikit-learn into pa…
NicolasHug Dec 11, 2019
397ce88
assert tree from DecisionTree and GBDT is exactly the same
NicolasHug Dec 11, 2019
be260a0
pep
NicolasHug Dec 11, 2019
c09565a
skip if 32 bits but better
NicolasHug Dec 11, 2019
4ec827d
add fast partial dep for regression forest
NicolasHug Dec 11, 2019
de09aa4
test name
NicolasHug Dec 11, 2019
fe0f6f3
whats new
NicolasHug Dec 11, 2019
07785f9
doc
NicolasHug Dec 11, 2019
2a9470c
Merge branch 'master' of github.com:scikit-learn/scikit-learn into pa…
NicolasHug Jan 12, 2020
d7bbce5
Merge branch 'master' of github.com:scikit-learn/scikit-learn into pa…
NicolasHug Jan 13, 2020
5d3badb
Better docs and UG for PDPs
NicolasHug Jan 13, 2020
a6a661d
typo
NicolasHug Jan 13, 2020
f89ef33
Merge branch 'partial_dep_doc_update' into partial_dep_forest
NicolasHug Jan 13, 2020
2ecd25e
fix import issue
NicolasHug Jan 13, 2020
bd94eda
Merge branch 'master' of github.com:scikit-learn/scikit-learn into pa…
NicolasHug Jan 14, 2020
fea11a0
Addressed comments
NicolasHug Jan 14, 2020
53962d3
remove constant
NicolasHug Jan 15, 2020
b58c3da
Merge branch 'master' of github.com:scikit-learn/scikit-learn into pa…
NicolasHug Jan 15, 2020
ec65848
minor docstring correction
NicolasHug Jan 16, 2020
a352720
Merge branch 'master' of github.com:scikit-learn/scikit-learn into pa…
NicolasHug Feb 23, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions doc/whats_new/v0.23.rst
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,16 @@ Changelog
``max_value`` and ``min_value``. Array-like inputs allow a different max and min to be specified
for each feature. :pr:`16403` by :user:`Narendra Mukherjee <narendramukherjee>`.

:mod:`sklearn.inspection`
.........................

- |Feature| :func:`inspection.partial_dependence` and
:func:`inspection.plot_partial_dependence` now support the fast 'recursion'
method for :class:`ensemble.RandomForestRegressor` and
:class:`tree.DecisionTreeRegressor`. :pr:`15864` by
`Nicolas Hug`_.


:mod:`sklearn.linear_model`
...........................

Expand Down
30 changes: 30 additions & 0 deletions sklearn/ensemble/_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,6 +846,36 @@ def _set_oob_score(self, X, y):

self.oob_score_ /= self.n_outputs_

def _compute_partial_dependence_recursion(self, grid, target_features):
"""Fast partial dependence computation.

Parameters
----------
grid : ndarray of shape (n_samples, n_target_features)
The grid points on which the partial dependence should be
evaluated.
target_features : ndarray of shape (n_target_features)
The set of target features for which the partial dependence
should be evaluated.

Returns
-------
averaged_predictions : ndarray of shape (n_samples,)
The value of the partial dependence function on each grid point.
"""
grid = np.asarray(grid, dtype=DTYPE, order='C')
averaged_predictions = np.zeros(shape=grid.shape[0],
dtype=np.float64, order='C')

for tree in self.estimators_:
# Note: we don't sum in parallel because the GIL isn't released in
# the fast method.
tree.tree_.compute_partial_dependence(
grid, target_features, averaged_predictions)
# Average over the forest
averaged_predictions /= len(self.estimators_)

return averaged_predictions

class RandomForestClassifier(ForestClassifier):
"""
Expand Down
2 changes: 0 additions & 2 deletions sklearn/ensemble/_gb.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,8 +708,6 @@ def _compute_partial_dependence_recursion(self, grid, target_features):
(n_trees_per_iteration, n_samples)
The value of the partial dependence function on each grid point.
"""
check_is_fitted(self,
msg="'estimator' parameter must be a fitted estimator")
if self.init is not None:
warnings.warn(
'Using recursion method with a non-constant init predictor '
Expand Down
28 changes: 22 additions & 6 deletions sklearn/inspection/_partial_dependence.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from ..utils import _determine_key_type
from ..utils import _get_column_indices
from ..utils.validation import check_is_fitted
from ..tree import DecisionTreeRegressor
from ..ensemble import RandomForestRegressor
from ..exceptions import NotFittedError
from ..ensemble._gb import BaseGradientBoosting
from sklearn.ensemble._hist_gradient_boosting.gradient_boosting import (
Expand Down Expand Up @@ -100,7 +102,14 @@ def _grid_from_X(X, percentiles, grid_resolution):


def _partial_dependence_recursion(est, grid, features):
return est._compute_partial_dependence_recursion(grid, features)
averaged_predictions = est._compute_partial_dependence_recursion(grid,
features)
if averaged_predictions.ndim == 1:
# reshape to (1, n_points) for consistency with
# _partial_dependence_brute
averaged_predictions = averaged_predictions.reshape(1, -1)

return averaged_predictions


def _partial_dependence_brute(est, grid, features, X, response_method):
Expand Down Expand Up @@ -242,7 +251,10 @@ def partial_dependence(estimator, X, features, response_method='auto',
:class:`~sklearn.ensemble.GradientBoostingClassifier`,
:class:`~sklearn.ensemble.GradientBoostingRegressor`,
:class:`~sklearn.ensemble.HistGradientBoostingClassifier`,
:class:`~sklearn.ensemble.HistGradientBoostingRegressor`)
:class:`~sklearn.ensemble.HistGradientBoostingRegressor`,
:class:`~sklearn.tree.DecisionTreeRegressor`,
:class:`~sklearn.ensemble.RandomForestRegressor`,
)
but is more efficient in terms of speed.
With this method, the target response of a
classifier is always the decision function, not the predicted
Expand Down Expand Up @@ -339,19 +351,25 @@ def partial_dependence(estimator, X, features, response_method='auto',
if (isinstance(estimator, BaseGradientBoosting) and
estimator.init is None):
method = 'recursion'
elif isinstance(estimator, BaseHistGradientBoosting):
elif isinstance(estimator, (BaseHistGradientBoosting,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you think that it would make sense to avoid this isinstance and check for hasattr(estimator, 'compute_partial_dependence') or a method that could be shared across all these estimators?

I could think that library as xgboost or lightgbm or tree base could expose the same method without us checking for type.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's a nice idea but that means we need to support a new public API, i.e. the interface of _compute_partial_dependence_recursion would now be fixed

We would also still need to hardcoded list of supported estimators for the error message.

Maybe we can keep that in mind for later?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep we could check how other libraries are exposing this and put all these info in an issue

DecisionTreeRegressor,
RandomForestRegressor)):
method = 'recursion'
else:
method = 'brute'

if method == 'recursion':
if not isinstance(estimator,
(BaseGradientBoosting, BaseHistGradientBoosting)):
(BaseGradientBoosting, BaseHistGradientBoosting,
DecisionTreeRegressor, RandomForestRegressor)):
supported_classes_recursion = (
'GradientBoostingClassifier',
'GradientBoostingRegressor',
'HistGradientBoostingClassifier',
'HistGradientBoostingRegressor',
'HistGradientBoostingRegressor',
'DecisionTreeRegressor',
'RandomForestRegressor',
)
raise ValueError(
"Only the following estimators support the 'recursion' "
Expand Down Expand Up @@ -399,5 +417,3 @@ def partial_dependence(estimator, X, features, response_method='auto',
-1, *[val.shape[0] for val in values])

return averaged_predictions, values


8 changes: 4 additions & 4 deletions sklearn/inspection/_plot/partial_dependence.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,9 @@ def plot_partial_dependence(estimator, X, features, feature_names=None,
:class:`~sklearn.ensemble.GradientBoostingClassifier`,
:class:`~sklearn.ensemble.GradientBoostingRegressor`,
:class:`~sklearn.ensemble.HistGradientBoostingClassifier`,
:class:`~sklearn.ensemble.HistGradientBoostingRegressor`)
:class:`~sklearn.ensemble.HistGradientBoostingRegressor`,
:class:`~sklearn.tree.DecisionTreeRegressor`,
:class:`~sklearn.ensemble.RandomForestRegressor`
but is more efficient in terms of speed.
With this method, the target response of a
classifier is always the decision function, not the predicted
Expand Down Expand Up @@ -201,9 +203,7 @@ def plot_partial_dependence(estimator, X, features, feature_names=None,
from matplotlib.ticker import ScalarFormatter # noqa

# set target_idx for multi-class estimators
if (is_classifier(estimator) and
hasattr(estimator, 'classes_') and
np.size(estimator.classes_) > 2):
if hasattr(estimator, 'classes_') and np.size(estimator.classes_) > 2:
if target is None:
raise ValueError('target must be specified for multi-class')
target_idx = np.searchsorted(estimator.classes_, target)
Expand Down
79 changes: 77 additions & 2 deletions sklearn/inspection/tests/test_partial_dependence.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
)
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.ensemble import GradientBoostingRegressor
from sklearn.ensemble import RandomForestRegressor
from sklearn.experimental import enable_hist_gradient_boosting # noqa
from sklearn.ensemble import HistGradientBoostingClassifier
from sklearn.ensemble import HistGradientBoostingRegressor
Expand All @@ -36,6 +37,9 @@
from sklearn.utils._testing import assert_allclose
from sklearn.utils._testing import assert_array_equal
from sklearn.utils._testing import ignore_warnings
from sklearn.utils import _IS_32BIT
from sklearn.utils.validation import check_random_state
from sklearn.tree.tests.test_tree import assert_is_subtree


# toy sample
Expand Down Expand Up @@ -174,6 +178,11 @@ def test_partial_dependence_helpers(est, method, target_feature):
# samples.
# This also checks that the brute and recursion methods give the same
# output.
# Note that even on the trainset, the brute and the recursion methods
# aren't always strictly equivalent, in particular when the slow method
# generates unrealistic samples that have low mass in the joint
# distribution of the input features, and when some of the features are
# dependent. Hence the high tolerance on the checks.

X, y = make_regression(random_state=0, n_features=5, n_informative=5)
# The 'init' estimator for GBDT (here the average prediction) isn't taken
Expand Down Expand Up @@ -206,6 +215,71 @@ def test_partial_dependence_helpers(est, method, target_feature):
assert np.allclose(pdp, mean_predictions, rtol=rtol)


@pytest.mark.parametrize('seed', range(1))
def test_recursion_decision_tree_vs_forest_and_gbdt(seed):
# Make sure that the recursion method gives the same results on a
# DecisionTreeRegressor and a GradientBoostingRegressor or a
# RandomForestRegressor with 1 tree and equivalent parameters.

rng = np.random.RandomState(seed)

# Purely random dataset to avoid correlated features
n_samples = 1000
n_features = 5
X = rng.randn(n_samples, n_features)
y = rng.randn(n_samples) * 10

# The 'init' estimator for GBDT (here the average prediction) isn't taken
# into account with the recursion method, for technical reasons. We set
# the mean to 0 to that this 'bug' doesn't have any effect.
y = y - y.mean()

# set max_depth not too high to avoid splits with same gain but different
# features
max_depth = 5

tree_seed = 0
forest = RandomForestRegressor(n_estimators=1, max_features=None,
bootstrap=False, max_depth=max_depth,
random_state=tree_seed)
# The forest will use ensemble.base._set_random_states to set the
# random_state of the tree sub-estimator. We simulate this here to have
# equivalent estimators.
equiv_random_state = check_random_state(tree_seed).randint(
np.iinfo(np.int32).max)
gbdt = GradientBoostingRegressor(n_estimators=1, learning_rate=1,
criterion='mse', max_depth=max_depth,
random_state=equiv_random_state)
tree = DecisionTreeRegressor(max_depth=max_depth,
random_state=equiv_random_state)

forest.fit(X, y)
gbdt.fit(X, y)
tree.fit(X, y)

# sanity check: if the trees aren't the same, the PD values won't be equal
try:
assert_is_subtree(tree.tree_, gbdt[0, 0].tree_)
assert_is_subtree(tree.tree_, forest[0].tree_)
except AssertionError:
# For some reason the trees aren't exactly equal on 32bits, so the PDs
# cannot be equal either. See
# https://github.com/scikit-learn/scikit-learn/issues/8853
assert _IS_32BIT, "this should only fail on 32 bit platforms"
return

grid = rng.randn(50).reshape(-1, 1)
for f in range(n_features):
features = np.array([f], dtype=np.int32)

pdp_forest = _partial_dependence_recursion(forest, grid, features)
pdp_gbdt = _partial_dependence_recursion(gbdt, grid, features)
pdp_tree = _partial_dependence_recursion(tree, grid, features)

np.testing.assert_allclose(pdp_gbdt, pdp_tree)
np.testing.assert_allclose(pdp_forest, pdp_tree)


@pytest.mark.parametrize('est', (
GradientBoostingClassifier(random_state=0),
HistGradientBoostingClassifier(random_state=0),
Expand Down Expand Up @@ -236,8 +310,9 @@ def test_recursion_decision_function(est, target_feature):
LinearRegression(),
GradientBoostingRegressor(random_state=0),
HistGradientBoostingRegressor(random_state=0, min_samples_leaf=1,
max_leaf_nodes=None, max_iter=1))
)
max_leaf_nodes=None, max_iter=1),
DecisionTreeRegressor(random_state=0),
))
@pytest.mark.parametrize('power', (1, 2))
def test_partial_dependence_easy_target(est, power):
# If the target y only depends on one feature in an obvious way (linear or
Expand Down
25 changes: 25 additions & 0 deletions sklearn/tree/_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1252,6 +1252,31 @@ def n_classes_(self):
warnings.warn(msg, FutureWarning)
return np.array([1] * self.n_outputs_, dtype=np.intp)

def _compute_partial_dependence_recursion(self, grid, target_features):
"""Fast partial dependence computation.

Parameters
----------
grid : ndarray of shape (n_samples, n_target_features)
The grid points on which the partial dependence should be
evaluated.
target_features : ndarray of shape (n_target_features)
The set of target features for which the partial dependence
should be evaluated.

Returns
-------
averaged_predictions : ndarray of shape (n_samples,)
The value of the partial dependence function on each grid point.
"""
grid = np.asarray(grid, dtype=DTYPE, order='C')
averaged_predictions = np.zeros(shape=grid.shape[0],
dtype=np.float64, order='C')

self.tree_.compute_partial_dependence(
grid, target_features, averaged_predictions)
return averaged_predictions


class ExtraTreeClassifier(DecisionTreeClassifier):
"""An extremely randomized tree classifier.
Expand Down