Skip to content

Commit ea9894a

Browse files
authored
ENH Support sample weights in partial_dependence (#25209)
1 parent fbe7e5e commit ea9894a

File tree

3 files changed

+285
-18
lines changed

3 files changed

+285
-18
lines changed

doc/whats_new/v1.3.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,12 @@ Changelog
391391
:mod:`sklearn.inspection`
392392
.........................
393393

394+
- |Enhancement| Added support for `sample_weight` in
395+
:func:`inspection.partial_dependence`. This allows for weighted averaging when
396+
aggregating for each value of the grid we are making the inspection on. The
397+
option is only available when `method` is set to `brute`. :pr:`25209`
398+
by :user:`Carlo Lemos <vitaliset>`.
399+
394400
- |API| :func:`inspection.partial_dependence` returns a :class:`utils.Bunch` with
395401
new key: `grid_values`. The `values` key is deprecated in favor of `grid_values`
396402
and the `values` key will be removed in 1.5.

sklearn/inspection/_partial_dependence.py

Lines changed: 160 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from ..utils import _safe_assign
2121
from ..utils import _determine_key_type
2222
from ..utils import _get_column_indices
23+
from ..utils.validation import _check_sample_weight
2324
from ..utils.validation import check_is_fitted
2425
from ..utils import Bunch
2526
from ..utils._param_validation import (
@@ -136,6 +137,54 @@ def _grid_from_X(X, percentiles, is_categorical, grid_resolution):
136137

137138

138139
def _partial_dependence_recursion(est, grid, features):
140+
"""Calculate partial dependence via the recursion method.
141+
142+
The recursion method is in particular enabled for tree-based estimators.
143+
144+
For each `grid` value, a weighted tree traversal is performed: if a split node
145+
involves an input feature of interest, the corresponding left or right branch
146+
is followed; otherwise both branches are followed, each branch being weighted
147+
by the fraction of training samples that entered that branch. Finally, the
148+
partial dependence is given by a weighted average of all the visited leaves
149+
values.
150+
151+
This method is more efficient in terms of speed than the `'brute'` method
152+
(:func:`~sklearn.inspection._partial_dependence._partial_dependence_brute`).
153+
However, here, the partial dependence computation is done explicitly with the
154+
`X` used during training of `est`.
155+
156+
Parameters
157+
----------
158+
est : BaseEstimator
159+
A fitted estimator object implementing :term:`predict` or
160+
:term:`decision_function`. Multioutput-multiclass classifiers are not
161+
supported. Note that `'recursion'` is only supported for some tree-based
162+
estimators (namely
163+
:class:`~sklearn.ensemble.GradientBoostingClassifier`,
164+
:class:`~sklearn.ensemble.GradientBoostingRegressor`,
165+
:class:`~sklearn.ensemble.HistGradientBoostingClassifier`,
166+
:class:`~sklearn.ensemble.HistGradientBoostingRegressor`,
167+
:class:`~sklearn.tree.DecisionTreeRegressor`,
168+
:class:`~sklearn.ensemble.RandomForestRegressor`,
169+
).
170+
171+
grid : array-like of shape (n_points, n_target_features)
172+
The grid of feature values for which the partial dependence is calculated.
173+
Note that `n_points` is the number of points in the grid and `n_target_features`
174+
is the number of features you are doing partial dependence at.
175+
176+
features : array-like of {int, str}
177+
The feature (e.g. `[0]`) or pair of interacting features
178+
(e.g. `[(0, 1)]`) for which the partial dependency should be computed.
179+
180+
Returns
181+
-------
182+
averaged_predictions : array-like of shape (n_targets, n_points)
183+
The averaged predictions for the given `grid` of features values.
184+
Note that `n_targets` is the number of targets (e.g. 1 for binary
185+
classification, `n_tasks` for multi-output regression, and `n_classes` for
186+
multiclass classification) and `n_points` is the number of points in the `grid`.
187+
"""
139188
averaged_predictions = est._compute_partial_dependence_recursion(grid, features)
140189
if averaged_predictions.ndim == 1:
141190
# reshape to (1, n_points) for consistency with
@@ -145,7 +194,78 @@ def _partial_dependence_recursion(est, grid, features):
145194
return averaged_predictions
146195

147196

148-
def _partial_dependence_brute(est, grid, features, X, response_method):
197+
def _partial_dependence_brute(
198+
est, grid, features, X, response_method, sample_weight=None
199+
):
200+
"""Calculate partial dependence via the brute force method.
201+
202+
The brute method explicitly averages the predictions of an estimator over a
203+
grid of feature values.
204+
205+
For each `grid` value, all the samples from `X` have their variables of
206+
interest replaced by that specific `grid` value. The predictions are then made
207+
and averaged across the samples.
208+
209+
This method is slower than the `'recursion'`
210+
(:func:`~sklearn.inspection._partial_dependence._partial_dependence_recursion`)
211+
version for estimators with this second option. However, with the `'brute'`
212+
force method, the average will be done with the given `X` and not the `X`
213+
used during training, as it is done in the `'recursion'` version. Therefore
214+
the average can always accept `sample_weight` (even when the estimator was
215+
fitted without).
216+
217+
Parameters
218+
----------
219+
est : BaseEstimator
220+
A fitted estimator object implementing :term:`predict`,
221+
:term:`predict_proba`, or :term:`decision_function`.
222+
Multioutput-multiclass classifiers are not supported.
223+
224+
grid : array-like of shape (n_points, n_target_features)
225+
The grid of feature values for which the partial dependence is calculated.
226+
Note that `n_points` is the number of points in the grid and `n_target_features`
227+
is the number of features you are doing partial dependence at.
228+
229+
features : array-like of {int, str}
230+
The feature (e.g. `[0]`) or pair of interacting features
231+
(e.g. `[(0, 1)]`) for which the partial dependency should be computed.
232+
233+
X : array-like of shape (n_samples, n_features)
234+
`X` is used to generate values for the complement features. That is, for
235+
each value in `grid`, the method will average the prediction of each
236+
sample from `X` having that grid value for `features`.
237+
238+
response_method : {'auto', 'predict_proba', 'decision_function'}, \
239+
default='auto'
240+
Specifies whether to use :term:`predict_proba` or
241+
:term:`decision_function` as the target response. For regressors
242+
this parameter is ignored and the response is always the output of
243+
:term:`predict`. By default, :term:`predict_proba` is tried first
244+
and we revert to :term:`decision_function` if it doesn't exist.
245+
246+
sample_weight : array-like of shape (n_samples,), default=None
247+
Sample weights are used to calculate weighted means when averaging the
248+
model output. If `None`, then samples are equally weighted. Note that
249+
`sample_weight` does not change the individual predictions.
250+
251+
Returns
252+
-------
253+
averaged_predictions : array-like of shape (n_targets, n_points)
254+
The averaged predictions for the given `grid` of features values.
255+
Note that `n_targets` is the number of targets (e.g. 1 for binary
256+
classification, `n_tasks` for multi-output regression, and `n_classes` for
257+
multiclass classification) and `n_points` is the number of points in the `grid`.
258+
259+
predictions : array-like
260+
The predictions for the given `grid` of features values over the samples
261+
from `X`. For non-multioutput regression and binary classification the
262+
shape is `(n_instances, n_points)` and for multi-output regression and
263+
multiclass classification the shape is `(n_targets, n_instances, n_points)`,
264+
where `n_targets` is the number of targets (`n_tasks` for multi-output
265+
regression, and `n_classes` for multiclass classification), `n_instances`
266+
is the number of instances in `X`, and `n_points` is the number of points
267+
in the `grid`.
268+
"""
149269
predictions = []
150270
averaged_predictions = []
151271

@@ -191,7 +311,7 @@ def _partial_dependence_brute(est, grid, features, X, response_method):
191311

192312
predictions.append(pred)
193313
# average over samples
194-
averaged_predictions.append(np.mean(pred, axis=0))
314+
averaged_predictions.append(np.average(pred, axis=0, weights=sample_weight))
195315
except NotFittedError as e:
196316
raise ValueError("'estimator' parameter must be a fitted estimator") from e
197317

@@ -239,6 +359,7 @@ def _partial_dependence_brute(est, grid, features, X, response_method):
239359
],
240360
"X": ["array-like", "sparse matrix"],
241361
"features": ["array-like", Integral, str],
362+
"sample_weight": ["array-like", None],
242363
"categorical_features": ["array-like", None],
243364
"feature_names": ["array-like", None],
244365
"response_method": [StrOptions({"auto", "predict_proba", "decision_function"})],
@@ -253,6 +374,7 @@ def partial_dependence(
253374
X,
254375
features,
255376
*,
377+
sample_weight=None,
256378
categorical_features=None,
257379
feature_names=None,
258380
response_method="auto",
@@ -303,6 +425,14 @@ def partial_dependence(
303425
The feature (e.g. `[0]`) or pair of interacting features
304426
(e.g. `[(0, 1)]`) for which the partial dependency should be computed.
305427
428+
sample_weight : array-like of shape (n_samples,), default=None
429+
Sample weights are used to calculate weighted means when averaging the
430+
model output. If `None`, then samples are equally weighted. If
431+
`sample_weight` is not `None`, then `method` will be set to `'brute'`.
432+
Note that `sample_weight` is ignored for `kind='individual'`.
433+
434+
.. versionadded:: 1.3
435+
306436
categorical_features : array-like of shape (n_features,) or shape \
307437
(n_categorical_features,), dtype={bool, int, str}, default=None
308438
Indicates the categorical features.
@@ -366,7 +496,8 @@ def partial_dependence(
366496
computationally intensive.
367497
368498
- `'auto'`: the `'recursion'` is used for estimators that support it,
369-
and `'brute'` is used otherwise.
499+
and `'brute'` is used otherwise. If `sample_weight` is not `None`,
500+
then `'brute'` is used regardless of the estimator.
370501
371502
Please see :ref:`this note <pdp_method_differences>` for
372503
differences between the `'brute'` and `'recursion'` method.
@@ -377,8 +508,9 @@ def partial_dependence(
377508
See Returns below.
378509
379510
Note that the fast `method='recursion'` option is only available for
380-
`kind='average'`. Computing individual dependencies requires using the
381-
slower `method='brute'` option.
511+
`kind='average'` and `sample_weights=None`. Computing individual
512+
dependencies and doing weighted averages requires using the slower
513+
`method='brute'`.
382514
383515
.. versionadded:: 0.24
384516
@@ -391,14 +523,15 @@ def partial_dependence(
391523
len(values[0]), len(values[1]), ...)
392524
The predictions for all the points in the grid for all
393525
samples in X. This is also known as Individual
394-
Conditional Expectation (ICE)
526+
Conditional Expectation (ICE).
527+
Only available when `kind='individual'` or `kind='both'`.
395528
396529
average : ndarray of shape (n_outputs, len(values[0]), \
397530
len(values[1]), ...)
398531
The predictions for all the points in the grid, averaged
399532
over all samples in X (or over the training data if
400-
``method`` is 'recursion').
401-
Only available when ``kind='both'``.
533+
`method` is 'recursion').
534+
Only available when `kind='average'` or `kind='both'`.
402535
403536
values : seq of 1d ndarrays
404537
The values with which the grid has been created.
@@ -410,17 +543,17 @@ def partial_dependence(
410543
411544
grid_values : seq of 1d ndarrays
412545
The values with which the grid has been created. The generated
413-
grid is a cartesian product of the arrays in ``grid_values`` where
414-
``len(grid_values) == len(features)``. The size of each array
415-
``grid_values[j]`` is either ``grid_resolution``, or the number of
416-
unique values in ``X[:, j]``, whichever is smaller.
546+
grid is a cartesian product of the arrays in `grid_values` where
547+
`len(grid_values) == len(features)`. The size of each array
548+
`grid_values[j]` is either `grid_resolution`, or the number of
549+
unique values in `X[:, j]`, whichever is smaller.
417550
418551
.. versionadded:: 1.3
419552
420-
``n_outputs`` corresponds to the number of classes in a multi-class
553+
`n_outputs` corresponds to the number of classes in a multi-class
421554
setting, or to the number of tasks for multi-output regression.
422-
For classical regression and binary classification ``n_outputs==1``.
423-
``n_values_feature_j`` corresponds to the size ``grid_values[j]``.
555+
For classical regression and binary classification `n_outputs==1`.
556+
`n_values_feature_j` corresponds to the size `grid_values[j]`.
424557
425558
See Also
426559
--------
@@ -463,8 +596,15 @@ def partial_dependence(
463596
)
464597
method = "brute"
465598

599+
if method == "recursion" and sample_weight is not None:
600+
raise ValueError(
601+
"The 'recursion' method can only be applied when sample_weight is None."
602+
)
603+
466604
if method == "auto":
467-
if isinstance(estimator, BaseGradientBoosting) and estimator.init is None:
605+
if sample_weight is not None:
606+
method = "brute"
607+
elif isinstance(estimator, BaseGradientBoosting) and estimator.init is None:
468608
method = "recursion"
469609
elif isinstance(
470610
estimator,
@@ -508,6 +648,9 @@ def partial_dependence(
508648
"'decision_function'. Got {}.".format(response_method)
509649
)
510650

651+
if sample_weight is not None:
652+
sample_weight = _check_sample_weight(sample_weight, X)
653+
511654
if _determine_key_type(features, accept_slice=False) == "int":
512655
# _get_column_indices() supports negative indexing. Here, we limit
513656
# the indexing to be positive. The upper bound will be checked
@@ -560,7 +703,7 @@ def partial_dependence(
560703

561704
if method == "brute":
562705
averaged_predictions, predictions = _partial_dependence_brute(
563-
estimator, grid, features_indices, X, response_method
706+
estimator, grid, features_indices, X, response_method, sample_weight
564707
)
565708

566709
# reshape predictions to

0 commit comments

Comments
 (0)