Skip to content

ENH Support sample weights when fitting HistGradientBoosting estimator #25431

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
a5168dd
Support user-supplied sample weights when fitting HistGradientBoostin…
Andrew-Wang-IB45 Jan 19, 2023
964edd0
Fix ordering of min_weight_leaf calculation
Andrew-Wang-IB45 Jan 20, 2023
01f5586
Update all references of 'count' to 'weighted_n_node_samples' in tests
Andrew-Wang-IB45 Jan 26, 2023
6d27ef8
Merge branch 'main' into hist_gradient_boosting_fit_sample_weights
Andrew-Wang-IB45 Jan 27, 2023
dfc41fb
Merge branch 'main' into hist_gradient_boosting_fit_sample_weights
Andrew-Wang-IB45 Jan 30, 2023
bae1162
Merge branch 'main' into hist_gradient_boosting_fit_sample_weights
Andrew-Wang-IB45 Feb 6, 2023
89bc47a
Merge remote-tracking branch 'origin/hist_gradient_boosting_fit_sampl…
Andrew-Wang-IB45 Feb 8, 2023
22d80e3
Add tests
Andrew-Wang-IB45 Feb 10, 2023
0f244d1
Merge branch 'main' into hist_gradient_boosting_fit_sample_weights
Andrew-Wang-IB45 Feb 10, 2023
edadc0d
Update changelog
Andrew-Wang-IB45 Feb 10, 2023
3056dd5
Black fixes
Andrew-Wang-IB45 Feb 10, 2023
8928b87
Merge branch 'main' into hist_gradient_boosting_fit_sample_weights
Andrew-Wang-IB45 Feb 19, 2023
62f3520
Merge branch 'main' into hist_gradient_boosting_fit_sample_weights
Andrew-Wang-IB45 Feb 25, 2023
18ffab3
Merge branch 'main' into hist_gradient_boosting_fit_sample_weights
Andrew-Wang-IB45 Mar 2, 2023
ac2bb62
Merge branch 'main' into hist_gradient_boosting_fit_sample_weights
Andrew-Wang-IB45 Mar 9, 2023
59720f9
Merge branch 'main' into hist_gradient_boosting_fit_sample_weights
Andrew-Wang-IB45 Mar 16, 2023
0c7e91e
Fix changelog
Andrew-Wang-IB45 Mar 16, 2023
c123aca
Merge branch 'main' into hist_gradient_boosting_fit_sample_weights
Andrew-Wang-IB45 Mar 19, 2023
5e325bd
Remove stopping criteria that use sample_weight
Andrew-Wang-IB45 Mar 21, 2023
a8e2174
Merge branch 'main' into hist_gradient_boosting_fit_sample_weights
Andrew-Wang-IB45 Mar 21, 2023
ecdcc80
Move computation of weighted_n_node_samples to split_indices function
Andrew-Wang-IB45 Mar 23, 2023
228f4b2
Update test cases
Andrew-Wang-IB45 Mar 23, 2023
6cac71d
Merge branch 'main' into hist_gradient_boosting_fit_sample_weights
Andrew-Wang-IB45 Mar 23, 2023
fb234b7
Fix formatting
Andrew-Wang-IB45 Mar 23, 2023
fa61578
Merge branch 'main' into hist_gradient_boosting_fit_sample_weights
Andrew-Wang-IB45 Mar 25, 2023
321332c
Merge branch 'main' into hist_gradient_boosting_fit_sample_weights
Andrew-Wang-IB45 Mar 26, 2023
354adb8
Merge branch 'main' into hist_gradient_boosting_fit_sample_weights
Andrew-Wang-IB45 Mar 28, 2023
88e5d97
Merge branch 'main' into hist_gradient_boosting_fit_sample_weights
Andrew-Wang-IB45 Mar 29, 2023
5c8b689
Merge branch 'main' into hist_gradient_boosting_fit_sample_weights
Andrew-Wang-IB45 Mar 31, 2023
3135ce9
Merge branch 'main' into hist_gradient_boosting_fit_sample_weights
Andrew-Wang-IB45 Apr 8, 2023
754d3ea
Merge branch 'main' into hist_gradient_boosting_fit_sample_weights
Andrew-Wang-IB45 Apr 17, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions doc/whats_new/v1.3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,12 @@ Changelog
pandas' conventions.
:pr:`25629` by `Thomas Fan`_.

- |Enhancement| :class:`ensemble.HistGradientBoostingClassifier` and
:class:`ensemble.HistGradientBoostingRegressor` now take the user-supplied
`sample_weight` into account at `fit` time, so each node in the estimator's
predictors stores and uses the weighted sample count.
:pr:`25431` by :user:`Andrew Wang <Andrew-Wang-IB45>.`

:mod:`sklearn.exception`
........................
- |Feature| Added :class:`exception.InconsistentVersionWarning` which is raised
Expand Down
4 changes: 2 additions & 2 deletions sklearn/ensemble/_hist_gradient_boosting/_predictor.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -236,8 +236,8 @@ def _compute_partial_dependence(
# push left child
node_idx_stack[stack_size] = current_node.left
left_sample_frac = (
<Y_DTYPE_C> nodes[current_node.left].count /
current_node.count)
<Y_DTYPE_C> nodes[current_node.left].weighted_n_node_samples /
current_node.weighted_n_node_samples)
current_weight = weight_stack[stack_size]
weight_stack[stack_size] = current_weight * left_sample_frac
stack_size += 1
Expand Down
2 changes: 1 addition & 1 deletion sklearn/ensemble/_hist_gradient_boosting/common.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ cdef packed struct node_struct:
# Equivalent struct to PREDICTOR_RECORD_DTYPE to use in memory views. It
# needs to be packed since by default numpy dtypes aren't aligned
Y_DTYPE_C value
unsigned int count
Y_DTYPE_C weighted_n_node_samples
unsigned int feature_idx
X_DTYPE_C num_threshold
unsigned char missing_go_to_left
Expand Down
2 changes: 1 addition & 1 deletion sklearn/ensemble/_hist_gradient_boosting/common.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ HISTOGRAM_DTYPE = np.dtype([

PREDICTOR_RECORD_DTYPE = np.dtype([
('value', Y_DTYPE),
('count', np.uint32),
('weighted_n_node_samples', Y_DTYPE),
('feature_idx', np.uint32),
('num_threshold', X_DTYPE),
('missing_go_to_left', np.uint8),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -677,6 +677,7 @@ def fit(self, X, y, sample_weight=None):
X_binned=X_binned_train,
gradients=g_view[:, k],
hessians=h_view[:, k],
sample_weight=sample_weight_train,
n_bins=n_bins,
n_bins_non_missing=self._bin_mapper.n_bins_non_missing_,
has_missing_values=has_missing_values,
Expand Down
28 changes: 25 additions & 3 deletions sklearn/ensemble/_hist_gradient_boosting/grower.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .histogram import HistogramBuilder
from .predictor import TreePredictor
from .utils import sum_parallel
from ...utils.validation import _check_sample_weight
from .common import PREDICTOR_RECORD_DTYPE
from .common import X_BITSET_INNER_DTYPE
from .common import Y_DTYPE
Expand All @@ -38,6 +39,8 @@ class TreeNode:
The depth of the node, i.e. its distance from the root.
sample_indices : ndarray of shape (n_samples_at_node,), dtype=np.uint
The indices of the samples at the node.
weighted_n_node_samples : float
The weighted number of training samples at the node.
sum_gradients : float
The sum of the gradients of the samples at the node.
sum_hessians : float
Expand Down Expand Up @@ -94,10 +97,19 @@ class TreeNode:
partition_start = 0
partition_stop = 0

def __init__(self, depth, sample_indices, sum_gradients, sum_hessians, value=None):
def __init__(
self,
depth,
sample_indices,
weighted_n_node_samples,
sum_gradients,
sum_hessians,
value=None,
):
self.depth = depth
self.sample_indices = sample_indices
self.n_samples = sample_indices.shape[0]
self.weighted_n_node_samples = weighted_n_node_samples
self.sum_gradients = sum_gradients
self.sum_hessians = sum_hessians
self.value = value
Expand Down Expand Up @@ -149,6 +161,8 @@ class TreeGrower:
hessians : ndarray of shape (n_samples,)
The hessians of each training sample. Those are the hessians of the
loss w.r.t the predictions, evaluated at iteration ``i - 1``.
sample_weight : array-like of shape (n_samples,), default=None
Weights of training data.
max_leaf_nodes : int, default=None
The maximum number of leaves for each tree. If None, there is no
maximum limit.
Expand Down Expand Up @@ -227,6 +241,7 @@ def __init__(
X_binned,
gradients,
hessians,
sample_weight=None,
max_leaf_nodes=None,
max_depth=None,
min_samples_leaf=20,
Expand Down Expand Up @@ -263,6 +278,8 @@ def __init__(
has_missing_values = [has_missing_values] * X_binned.shape[1]
has_missing_values = np.asarray(has_missing_values, dtype=np.uint8)

sample_weight = _check_sample_weight(sample_weight, X_binned, dtype=np.float64)

# `monotonic_cst` validation is done in _validate_monotonic_cst
# at the estimator level and therefore the following should not be
# needed when using the public API.
Expand Down Expand Up @@ -305,6 +322,7 @@ def __init__(
min_samples_leaf,
min_gain_to_split,
hessians_are_constant,
sample_weight,
n_threads,
)
self.n_bins_non_missing = n_bins_non_missing
Expand Down Expand Up @@ -388,6 +406,7 @@ def _intilialize_root(self, gradients, hessians, hessians_are_constant):
self.root = TreeNode(
depth=depth,
sample_indices=self.splitter.partition,
weighted_n_node_samples=np.sum(self.splitter.sample_weight),
sum_gradients=sum_gradients,
sum_hessians=sum_hessians,
value=0,
Expand All @@ -396,7 +415,7 @@ def _intilialize_root(self, gradients, hessians, hessians_are_constant):
self.root.partition_start = 0
self.root.partition_stop = n_samples

if self.root.n_samples < 2 * self.min_samples_leaf:
if self.root.n_samples < self.min_samples_leaf * 2:
# Do not even bother computing any splitting statistics.
self._finalize_leaf(self.root)
return
Expand Down Expand Up @@ -463,6 +482,7 @@ def split_next(self):
(
sample_indices_left,
sample_indices_right,
right_weighted_n_node_samples,
right_child_pos,
) = self.splitter.split_indices(node.split_info, node.sample_indices)
self.total_apply_split_time += time() - tic
Expand All @@ -474,13 +494,15 @@ def split_next(self):
left_child_node = TreeNode(
depth,
sample_indices_left,
node.weighted_n_node_samples - right_weighted_n_node_samples,
node.split_info.sum_gradient_left,
node.split_info.sum_hessian_left,
value=node.split_info.value_left,
)
right_child_node = TreeNode(
depth,
sample_indices_right,
right_weighted_n_node_samples,
node.split_info.sum_gradient_right,
node.split_info.sum_hessian_right,
value=node.split_info.value_right,
Expand Down Expand Up @@ -716,7 +738,7 @@ def _fill_predictor_arrays(
):
"""Helper used in make_predictor to set the TreePredictor fields."""
node = predictor_nodes[next_free_node_idx]
node["count"] = grower_node.n_samples
node["weighted_n_node_samples"] = grower_node.weighted_n_node_samples
node["depth"] = grower_node.depth
if grower_node.split_info is not None:
node["gain"] = grower_node.split_info.gain
Expand Down
13 changes: 13 additions & 0 deletions sklearn/ensemble/_hist_gradient_boosting/splitting.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,9 @@ cdef class Splitter:
be ignored.
hessians_are_constant: bool, default is False
Whether hessians are constant.
sample_weight: ndarray of float, shape (n_samples,), default=None
Weights of training data. If not provided, all samples are assumed
to have uniform weight.
n_threads : int, default=1
Number of OpenMP threads to use.
"""
Expand All @@ -166,6 +169,7 @@ cdef class Splitter:
const unsigned char [::1] has_missing_values
const unsigned char [::1] is_categorical
const signed char [::1] monotonic_cst
const Y_DTYPE_C [::1] sample_weight
unsigned char hessians_are_constant
Y_DTYPE_C l2_regularization
Y_DTYPE_C min_hessian_to_split
Expand All @@ -189,6 +193,7 @@ cdef class Splitter:
unsigned int min_samples_leaf=20,
Y_DTYPE_C min_gain_to_split=0.,
unsigned char hessians_are_constant=False,
const Y_DTYPE_C [::1] sample_weight=None,
unsigned int n_threads=1):

self.X_binned = X_binned
Expand All @@ -203,6 +208,7 @@ cdef class Splitter:
self.min_samples_leaf = min_samples_leaf
self.min_gain_to_split = min_gain_to_split
self.hessians_are_constant = hessians_are_constant
self.sample_weight = sample_weight if sample_weight is not None else np.ones(X_binned.shape[0], dtype=np.float64)
self.n_threads = n_threads

# The partition array maps each sample index into the leaves of the
Expand Down Expand Up @@ -247,6 +253,8 @@ cdef class Splitter:
right_indices : ndarray of int, shape (n_right_samples,)
The indices of the samples in the right child. This is a view on
self.partition.
right_weighted_n_node_samples : float
The weighted number of training samples in the right child.
right_child_position : int
The position of the right child in ``sample_indices``.
"""
Expand Down Expand Up @@ -302,6 +310,7 @@ cdef class Splitter:
self.X_binned[:, feature_idx]
unsigned int [::1] left_indices_buffer = self.left_indices_buffer
unsigned int [::1] right_indices_buffer = self.right_indices_buffer
const Y_DTYPE_C [::1] sample_weight = self.sample_weight
unsigned char is_categorical = split_info.is_categorical
# Cython is unhappy if we set left_cat_bitset to
# split_info.left_cat_bitset directly, so we need a tmp var
Expand All @@ -321,6 +330,7 @@ cdef class Splitter:
int i
int thread_idx
int sample_idx
double right_weighted_n_node_samples
int right_child_position
unsigned char turn_left
int [:] left_offset = np.zeros(n_threads, dtype=np.int32)
Expand All @@ -339,6 +349,7 @@ cdef class Splitter:
offset_in_buffers[thread_idx - 1] + sizes[thread_idx - 1]

# map indices from sample_indices to left/right_indices_buffer
right_weighted_n_node_samples = 0
for thread_idx in prange(n_threads, schedule='static',
chunksize=1, num_threads=n_threads):
left_count = 0
Expand All @@ -360,6 +371,7 @@ cdef class Splitter:
else:
right_indices_buffer[start + right_count] = sample_idx
right_count = right_count + 1
right_weighted_n_node_samples += sample_weight[sample_idx]

left_counts[thread_idx] = left_count
right_counts[thread_idx] = right_count
Expand Down Expand Up @@ -410,6 +422,7 @@ cdef class Splitter:

return (sample_indices[:right_child_position],
sample_indices[right_child_position:],
right_weighted_n_node_samples,
right_child_position)

def find_node_split(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -759,6 +759,66 @@ def test_sample_weight_effect(problem, duplication):
assert np.allclose(est_sw._raw_predict(X_dup), est_dup._raw_predict(X_dup))


@pytest.mark.parametrize(
"sample_weight_distribution", ("poisson", "exponential", "uniform")
)
def test_sample_weight_leaf_weighted_nodes_classification_random(
sample_weight_distribution,
):
# Ensures that the `weighted_n_node_samples` for each node in the predictor
# tree is the sum of `sample_weights` whose samples belong in that node

n_samples = 1000
X, y = make_classification(n_samples=n_samples, random_state=0)

if sample_weight_distribution == "poisson":
sample_weight = np.random.RandomState(0).poisson(lam=1 + 4 * y)
elif sample_weight_distribution == "exponential":
sample_weight = np.random.RandomState(0).exponential(scale=1 + 4 * y)
else:
sample_weight = np.random.RandomState(0).uniform(high=1 + 4 * y)

hgbc = HistGradientBoostingClassifier(
random_state=0, min_samples_leaf=1, max_depth=1
).fit(X, y, sample_weight)

for predictor in hgbc._predictors:
nodes = predictor[0].nodes
feat_idx = int(nodes[0][2])
num_tresh = nodes[0][3]

assert_allclose(nodes[0][1], sample_weight.sum())
assert_allclose(nodes[1][1], sample_weight[X[:, feat_idx] < num_tresh].sum())
assert_allclose(nodes[2][1], sample_weight[X[:, feat_idx] >= num_tresh].sum())


@pytest.mark.parametrize(
"left_sample_weight, right_sample_weight", [(2.5, 7.5), (1, 1), (0.5, 0.5)]
)
def test_sample_weight_leaf_weighted_nodes_classification_two_values(
left_sample_weight, right_sample_weight
):
# Ensures that the `weighted_n_node_samples` for each node in the predictor
# tree is the sum of `sample_weights` whose samples belong in that node

n_samples = 1000
X = np.array(n_samples * [0] + n_samples * [1]).reshape(-1, 1)
y = np.array(n_samples * [0] + n_samples * [1])

sample_weight = n_samples * [left_sample_weight] + n_samples * [right_sample_weight]

hgbc = HistGradientBoostingClassifier(min_samples_leaf=1, max_depth=1).fit(
X, y, sample_weight
)

for predictor in hgbc._predictors:
nodes = predictor[0].nodes

assert nodes[0][1] == (left_sample_weight + right_sample_weight) * n_samples
assert nodes[1][1] == left_sample_weight * n_samples
assert nodes[2][1] == right_sample_weight * n_samples


@pytest.mark.parametrize("Loss", (HalfSquaredError, AbsoluteError))
def test_sum_hessians_are_sample_weight(Loss):
# For losses with constant hessians, the sum_hessians field of the
Expand Down
14 changes: 9 additions & 5 deletions sklearn/ensemble/_hist_gradient_boosting/tests/test_grower.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,11 +250,11 @@ def test_min_samples_leaf(n_samples, min_samples_leaf, n_bins, constant_hessian,
if n_samples >= min_samples_leaf:
for node in predictor.nodes:
if node["is_leaf"]:
assert node["count"] >= min_samples_leaf
assert node["weighted_n_node_samples"] >= min_samples_leaf
else:
assert predictor.nodes.shape[0] == 1
assert predictor.nodes[0]["is_leaf"]
assert predictor.nodes[0]["count"] == n_samples
assert predictor.nodes[0]["weighted_n_node_samples"] == n_samples


@pytest.mark.parametrize("n_samples, min_samples_leaf", [(99, 50), (100, 50)])
Expand Down Expand Up @@ -374,7 +374,11 @@ def test_missing_value_predict_only():
while not node["is_leaf"]:
left = predictor.nodes[node["left"]]
right = predictor.nodes[node["right"]]
node = left if left["count"] > right["count"] else right
node = (
left
if left["weighted_n_node_samples"] > right["weighted_n_node_samples"]
else right
)

prediction_main_path = node["value"]

Expand Down Expand Up @@ -464,14 +468,14 @@ def test_grow_tree_categories():
categories = [np.array([4, 9], dtype=X_DTYPE)]
predictor = grower.make_predictor(binning_thresholds=categories)
root = predictor.nodes[0]
assert root["count"] == 23
assert root["weighted_n_node_samples"] == 23
assert root["depth"] == 0
assert root["is_categorical"]

left, right = predictor.nodes[root["left"]], predictor.nodes[root["right"]]

# arbitrary validation, but this means ones go to the left.
assert left["count"] >= right["count"]
assert left["weighted_n_node_samples"] >= right["weighted_n_node_samples"]

# check binned category value (1)
expected_binned_cat_bitset = [2**1] + [0] * 7
Expand Down
Loading