Skip to content

[WIP] Generalized partial dependence plots #5653

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
57f9a6f
general partial dependence plots
trevorstephens Oct 19, 2015
9a09888
add init
trevorstephens Oct 21, 2015
e714e16
implement exact and estimated methods
trevorstephens Nov 1, 2015
19ed28e
support for Pipeline and GridSearchCV type estimators
trevorstephens Nov 6, 2015
6fafc5e
add multioutput support
trevorstephens Nov 16, 2015
152a190
rebase and catch up to #6762, #7673, #7846
trevorstephens Jul 22, 2017
2cdc5ea
catch up on #9434
trevorstephens Jul 26, 2017
ba1f8da
initial update of plot_partial_dependence
trevorstephens Jul 26, 2017
1b1d8f0
deprecate ensemble.partial_dependence
trevorstephens Aug 12, 2017
9095305
refactor estimated and exact functions to _predict
trevorstephens Aug 15, 2017
3fc1727
make "auto" the default rather than None for method
trevorstephens Aug 15, 2017
259ec99
some more refactoring
trevorstephens Aug 16, 2017
cbc20af
avoid namespace collision
trevorstephens Aug 21, 2017
63da115
fix output shapes of all estimators
trevorstephens Aug 22, 2017
8f7d2b0
add tests to ensure all estimators output same shape
trevorstephens Aug 22, 2017
6fc3a49
quick fixes
trevorstephens Aug 30, 2017
b1f8bfc
fix docstring, test fails
trevorstephens Sep 1, 2017
dc93b69
refactor tests for easier debugging
trevorstephens Sep 1, 2017
cd8f8de
speed up tests, add two-way plot test
trevorstephens Sep 2, 2017
4eb1a80
move input validation on X
trevorstephens Sep 2, 2017
21544ce
fix output shape for multi-label classification
trevorstephens Oct 28, 2017
610b5c5
update plot helper to support multi-output
trevorstephens Oct 28, 2017
dcbd0c6
update plot helper to pass-through output
trevorstephens Oct 28, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions sklearn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,10 @@
'isotonic', 'kernel_approximation', 'kernel_ridge',
'learning_curve', 'linear_model', 'manifold', 'metrics',
'mixture', 'model_selection', 'multiclass', 'multioutput',
'naive_bayes', 'neighbors', 'neural_network', 'pipeline',
'preprocessing', 'random_projection', 'semi_supervised',
'svm', 'tree', 'discriminant_analysis', 'impute', 'compose',
'naive_bayes', 'neighbors', 'neural_network',
'partial_dependence', 'pipeline', 'preprocessing',
'random_projection', 'semi_supervised', 'svm', 'tree',
'discriminant_analysis', 'impute', 'compose',
# Non-modules:
'clone', 'get_config', 'set_config', 'config_context']

Expand Down
294 changes: 33 additions & 261 deletions sklearn/ensemble/partial_dependence.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,70 +3,9 @@
# Authors: Peter Prettenhofer
# License: BSD 3 clause

from itertools import count
import numbers

import numpy as np
from scipy.stats.mstats import mquantiles

from ..utils.extmath import cartesian
from ..externals.joblib import Parallel, delayed
from ..externals import six
from ..externals.six.moves import map, range, zip
from ..utils import check_array
from ..utils.validation import check_is_fitted
from ..tree._tree import DTYPE

from ._gradient_boosting import _partial_dependence_tree
from .gradient_boosting import BaseGradientBoosting


def _grid_from_X(X, percentiles=(0.05, 0.95), grid_resolution=100):
"""Generate a grid of points based on the ``percentiles of ``X``.

The grid is generated by placing ``grid_resolution`` equally
spaced points between the ``percentiles`` of each column
of ``X``.

Parameters
----------
X : ndarray
The data
percentiles : tuple of floats
The percentiles which are used to construct the extreme
values of the grid axes.
grid_resolution : int
The number of equally spaced points that are placed
on the grid.

Returns
-------
grid : ndarray
All data points on the grid; ``grid.shape[1] == X.shape[1]``
and ``grid.shape[0] == grid_resolution * X.shape[1]``.
axes : seq of ndarray
The axes with which the grid has been created.
"""
if len(percentiles) != 2:
raise ValueError('percentile must be tuple of len 2')
if not all(0. <= x <= 1. for x in percentiles):
raise ValueError('percentile values must be in [0, 1]')

axes = []
emp_percentiles = mquantiles(X, prob=percentiles, axis=0)
for col in range(X.shape[1]):
uniques = np.unique(X[:, col])
if uniques.shape[0] < grid_resolution:
# feature has low resolution use unique vals
axis = uniques
else:
# create axis based on percentiles and grid resolution
axis = np.linspace(emp_percentiles[0, col],
emp_percentiles[1, col],
num=grid_resolution, endpoint=True)
axes.append(axis)

return cartesian(axes), axes
import warnings
from ..partial_dependence import partial_dependence as new_pd
from ..partial_dependence import plot_partial_dependence as new_ppd


def partial_dependence(gbrt, target_variables, grid=None, X=None,
Expand Down Expand Up @@ -120,47 +59,17 @@ def partial_dependence(gbrt, target_variables, grid=None, X=None,
>>> partial_dependence(gb, [0], **kwargs) # doctest: +SKIP
(array([[-4.52..., 4.52...]]), [array([ 0., 1.])])
"""
if not isinstance(gbrt, BaseGradientBoosting):
raise ValueError('gbrt has to be an instance of BaseGradientBoosting')
check_is_fitted(gbrt, 'estimators_')
if (grid is None and X is None) or (grid is not None and X is not None):
raise ValueError('Either grid or X must be specified')

target_variables = np.asarray(target_variables, dtype=np.int32,
order='C').ravel()

if any([not (0 <= fx < gbrt.n_features_) for fx in target_variables]):
raise ValueError('target_variables must be in [0, %d]'
% (gbrt.n_features_ - 1))

if X is not None:
X = check_array(X, dtype=DTYPE, order='C')
grid, axes = _grid_from_X(X[:, target_variables], percentiles,
grid_resolution)
else:
assert grid is not None
# dont return axes if grid is given
axes = None
# grid must be 2d
if grid.ndim == 1:
grid = grid[:, np.newaxis]
if grid.ndim != 2:
raise ValueError('grid must be 2d but is %dd' % grid.ndim)

grid = np.asarray(grid, dtype=DTYPE, order='C')
assert grid.shape[1] == target_variables.shape[0]

n_trees_per_stage = gbrt.estimators_.shape[1]
n_estimators = gbrt.estimators_.shape[0]
pdp = np.zeros((n_trees_per_stage, grid.shape[0],), dtype=np.float64,
order='C')
for stage in range(n_estimators):
for k in range(n_trees_per_stage):
tree = gbrt.estimators_[stage, k].tree_
_partial_dependence_tree(tree, grid, target_variables,
gbrt.learning_rate, pdp[k])

return pdp, axes
warnings.warn("The function ensemble.partial_dependence has been moved to "
"partial_dependence in 0.20 and will be removed in 0.22.",
DeprecationWarning)
return new_pd(est=gbrt,
target_variables=target_variables,
grid=grid,
X=X,
output=None,
percentiles=percentiles,
grid_resolution=grid_resolution,
method='recursion')


def plot_partial_dependence(gbrt, X, features, feature_names=None,
Expand Down Expand Up @@ -237,159 +146,22 @@ def plot_partial_dependence(gbrt, X, features, feature_names=None,
>>> fig, axs = plot_partial_dependence(clf, X, [0, (0, 1)]) #doctest: +SKIP
...
"""
import matplotlib.pyplot as plt
from matplotlib import transforms
from matplotlib.ticker import MaxNLocator
from matplotlib.ticker import ScalarFormatter

if not isinstance(gbrt, BaseGradientBoosting):
raise ValueError('gbrt has to be an instance of BaseGradientBoosting')
check_is_fitted(gbrt, 'estimators_')

# set label_idx for multi-class GBRT
if hasattr(gbrt, 'classes_') and np.size(gbrt.classes_) > 2:
if label is None:
raise ValueError('label is not given for multi-class PDP')
label_idx = np.searchsorted(gbrt.classes_, label)
if gbrt.classes_[label_idx] != label:
raise ValueError('label %s not in ``gbrt.classes_``' % str(label))
else:
# regression and binary classification
label_idx = 0

X = check_array(X, dtype=DTYPE, order='C')
if gbrt.n_features_ != X.shape[1]:
raise ValueError('X.shape[1] does not match gbrt.n_features_')

if line_kw is None:
line_kw = {'color': 'green'}
if contour_kw is None:
contour_kw = {}

# convert feature_names to list
if feature_names is None:
# if not feature_names use fx indices as name
feature_names = [str(i) for i in range(gbrt.n_features_)]
elif isinstance(feature_names, np.ndarray):
feature_names = feature_names.tolist()

def convert_feature(fx):
if isinstance(fx, six.string_types):
try:
fx = feature_names.index(fx)
except ValueError:
raise ValueError('Feature %s not in feature_names' % fx)
return fx

# convert features into a seq of int tuples
tmp_features = []
for fxs in features:
if isinstance(fxs, (numbers.Integral,) + six.string_types):
fxs = (fxs,)
try:
fxs = np.array([convert_feature(fx) for fx in fxs], dtype=np.int32)
except TypeError:
raise ValueError('features must be either int, str, or tuple '
'of int/str')
if not (1 <= np.size(fxs) <= 2):
raise ValueError('target features must be either one or two')

tmp_features.append(fxs)

features = tmp_features

names = []
try:
for fxs in features:
l = []
# explicit loop so "i" is bound for exception below
for i in fxs:
l.append(feature_names[i])
names.append(l)
except IndexError:
raise ValueError('All entries of features must be less than '
'len(feature_names) = {0}, got {1}.'
.format(len(feature_names), i))

# compute PD functions
pd_result = Parallel(n_jobs=n_jobs, verbose=verbose)(
delayed(partial_dependence)(gbrt, fxs, X=X,
grid_resolution=grid_resolution,
percentiles=percentiles)
for fxs in features)

# get global min and max values of PD grouped by plot type
pdp_lim = {}
for pdp, axes in pd_result:
min_pd, max_pd = pdp[label_idx].min(), pdp[label_idx].max()
n_fx = len(axes)
old_min_pd, old_max_pd = pdp_lim.get(n_fx, (min_pd, max_pd))
min_pd = min(min_pd, old_min_pd)
max_pd = max(max_pd, old_max_pd)
pdp_lim[n_fx] = (min_pd, max_pd)

# create contour levels for two-way plots
if 2 in pdp_lim:
Z_level = np.linspace(*pdp_lim[2], num=8)

if ax is None:
fig = plt.figure(**fig_kw)
else:
fig = ax.get_figure()
fig.clear()

n_cols = min(n_cols, len(features))
n_rows = int(np.ceil(len(features) / float(n_cols)))
axs = []
for i, fx, name, (pdp, axes) in zip(count(), features, names,
pd_result):
ax = fig.add_subplot(n_rows, n_cols, i + 1)

if len(axes) == 1:
ax.plot(axes[0], pdp[label_idx].ravel(), **line_kw)
else:
# make contour plot
assert len(axes) == 2
XX, YY = np.meshgrid(axes[0], axes[1])
Z = pdp[label_idx].reshape(list(map(np.size, axes))).T
CS = ax.contour(XX, YY, Z, levels=Z_level, linewidths=0.5,
colors='k')
ax.contourf(XX, YY, Z, levels=Z_level, vmax=Z_level[-1],
vmin=Z_level[0], alpha=0.75, **contour_kw)
ax.clabel(CS, fmt='%2.2f', colors='k', fontsize=10, inline=True)

# plot data deciles + axes labels
deciles = mquantiles(X[:, fx[0]], prob=np.arange(0.1, 1.0, 0.1))
trans = transforms.blended_transform_factory(ax.transData,
ax.transAxes)
ylim = ax.get_ylim()
ax.vlines(deciles, [0], 0.05, transform=trans, color='k')
ax.set_xlabel(name[0])
ax.set_ylim(ylim)

# prevent x-axis ticks from overlapping
ax.xaxis.set_major_locator(MaxNLocator(nbins=6, prune='lower'))
tick_formatter = ScalarFormatter()
tick_formatter.set_powerlimits((-3, 4))
ax.xaxis.set_major_formatter(tick_formatter)

if len(axes) > 1:
# two-way PDP - y-axis deciles + labels
deciles = mquantiles(X[:, fx[1]], prob=np.arange(0.1, 1.0, 0.1))
trans = transforms.blended_transform_factory(ax.transAxes,
ax.transData)
xlim = ax.get_xlim()
ax.hlines(deciles, [0], 0.05, transform=trans, color='k')
ax.set_ylabel(name[1])
# hline erases xlim
ax.set_xlim(xlim)
else:
ax.set_ylabel('Partial dependence')

if len(axes) == 1:
ax.set_ylim(pdp_lim[1])
axs.append(ax)

fig.subplots_adjust(bottom=0.15, top=0.7, left=0.1, right=0.95, wspace=0.4,
hspace=0.3)
return fig, axs
warnings.warn("The function ensemble.plot_partial_dependence has been "
"moved to partial_dependence in 0.20 and will be removed "
"in 0.22.",
DeprecationWarning)
return new_ppd(est=gbrt,
X=X,
features=features,
feature_names=feature_names,
label=label,
n_cols=n_cols,
grid_resolution=grid_resolution,
method='recursion',
percentiles=percentiles,
n_jobs=n_jobs,
verbose=verbose,
ax=ax,
line_kw=line_kw,
contour_kw=contour_kw,
**fig_kw)
Loading