Skip to content

Commit fc11dea

Browse files
ENH add feature subsampling per split for HGBT (#27139)
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
1 parent ff0274d commit fc11dea

File tree

6 files changed

+227
-34
lines changed

6 files changed

+227
-34
lines changed

benchmarks/bench_hist_gradient_boosting_higgsboson.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
parser.add_argument("--no-predict", action="store_true", default=False)
2626
parser.add_argument("--cache-loc", type=str, default="/tmp")
2727
parser.add_argument("--no-interactions", type=bool, default=False)
28+
parser.add_argument("--max-features", type=float, default=1)
2829
args = parser.parse_args()
2930

3031
HERE = os.path.dirname(__file__)
@@ -36,6 +37,7 @@
3637
subsample = args.subsample
3738
lr = args.learning_rate
3839
max_bins = args.max_bins
40+
max_features = args.max_features
3941

4042

4143
@m.cache
@@ -104,6 +106,7 @@ def predict(est, data_test, target_test):
104106
random_state=0,
105107
verbose=1,
106108
interaction_cst=interaction_cst,
109+
max_features=max_features,
107110
)
108111
fit(est, data_train, target_train, "sklearn")
109112
predict(est, data_test, target_test)

doc/whats_new/v1.4.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,12 @@ Changelog
303303
:pr:`13649` by :user:`Samuel Ronsin <samronsin>`,
304304
initiated by :user:`Patrick O'Reilly <pat-oreilly>`.
305305

306+
- |Feature| :class:`ensemble.HistGradientBoostingClassifier` and
307+
:class:`ensemble.HistGradientBoostingRegressor` got the new parameter
308+
`max_features` to specify the proportion of randomly chosen features considered
309+
in each split.
310+
:pr:`27139` by :user:`Christian Lorentzen <lorentzenchr>`.
311+
306312
- |Efficiency| :class:`ensemble.GradientBoostingClassifier` is faster,
307313
for binary and in particular for multiclass problems thanks to the private loss
308314
function module.

sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ class BaseHistGradientBoosting(BaseEstimator, ABC):
105105
"max_depth": [Interval(Integral, 1, None, closed="left"), None],
106106
"min_samples_leaf": [Interval(Integral, 1, None, closed="left")],
107107
"l2_regularization": [Interval(Real, 0, None, closed="left")],
108+
"max_features": [Interval(RealNotInt, 0, 1, closed="right")],
108109
"monotonic_cst": ["array-like", dict, None],
109110
"interaction_cst": [
110111
list,
@@ -139,6 +140,7 @@ def __init__(
139140
max_depth,
140141
min_samples_leaf,
141142
l2_regularization,
143+
max_features,
142144
max_bins,
143145
categorical_features,
144146
monotonic_cst,
@@ -159,6 +161,7 @@ def __init__(
159161
self.max_depth = max_depth
160162
self.min_samples_leaf = min_samples_leaf
161163
self.l2_regularization = l2_regularization
164+
self.max_features = max_features
162165
self.max_bins = max_bins
163166
self.monotonic_cst = monotonic_cst
164167
self.interaction_cst = interaction_cst
@@ -393,10 +396,12 @@ def fit(self, X, y, sample_weight=None):
393396
rng = check_random_state(self.random_state)
394397

395398
# When warm starting, we want to reuse the same seed that was used
396-
# the first time fit was called (e.g. for subsampling or for the
397-
# train/val split).
398-
if not (self.warm_start and self._is_fitted()):
399+
# the first time fit was called (e.g. train/val split).
400+
# For feature subsampling, we want to continue with the rng we started with.
401+
if not self.warm_start or not self._is_fitted():
399402
self._random_seed = rng.randint(np.iinfo(np.uint32).max, dtype="u8")
403+
feature_subsample_seed = rng.randint(np.iinfo(np.uint32).max, dtype="u8")
404+
self._feature_subsample_rng = np.random.default_rng(feature_subsample_seed)
400405

401406
self._validate_parameters()
402407
monotonic_cst = _check_monotonic_cst(self, self.monotonic_cst)
@@ -700,6 +705,8 @@ def fit(self, X, y, sample_weight=None):
700705
max_depth=self.max_depth,
701706
min_samples_leaf=self.min_samples_leaf,
702707
l2_regularization=self.l2_regularization,
708+
feature_fraction_per_split=self.max_features,
709+
rng=self._feature_subsample_rng,
703710
shrinkage=self.learning_rate,
704711
n_threads=n_threads,
705712
)
@@ -1261,8 +1268,16 @@ class HistGradientBoostingRegressor(RegressorMixin, BaseHistGradientBoosting):
12611268
than a few hundred samples, it is recommended to lower this value
12621269
since only very shallow trees would be built.
12631270
l2_regularization : float, default=0
1264-
The L2 regularization parameter. Use ``0`` for no regularization
1265-
(default).
1271+
The L2 regularization parameter. Use ``0`` for no regularization (default).
1272+
max_features : float, default=1.0
1273+
Proportion of randomly chosen features in each and every node split.
1274+
This is a form of regularization, smaller values make the trees weaker
1275+
learners and might prevent overfitting.
1276+
If interaction constraints from `interaction_cst` are present, only allowed
1277+
features are taken into account for the subsampling.
1278+
1279+
.. versionadded:: 1.4
1280+
12661281
max_bins : int, default=255
12671282
The maximum number of bins to use for non-missing values. Before
12681283
training, each feature of the input array `X` is binned into
@@ -1463,6 +1478,7 @@ def __init__(
14631478
max_depth=None,
14641479
min_samples_leaf=20,
14651480
l2_regularization=0.0,
1481+
max_features=1.0,
14661482
max_bins=255,
14671483
categorical_features=None,
14681484
monotonic_cst=None,
@@ -1484,6 +1500,7 @@ def __init__(
14841500
max_depth=max_depth,
14851501
min_samples_leaf=min_samples_leaf,
14861502
l2_regularization=l2_regularization,
1503+
max_features=max_features,
14871504
max_bins=max_bins,
14881505
monotonic_cst=monotonic_cst,
14891506
interaction_cst=interaction_cst,
@@ -1620,7 +1637,16 @@ class HistGradientBoostingClassifier(ClassifierMixin, BaseHistGradientBoosting):
16201637
than a few hundred samples, it is recommended to lower this value
16211638
since only very shallow trees would be built.
16221639
l2_regularization : float, default=0
1623-
The L2 regularization parameter. Use 0 for no regularization.
1640+
The L2 regularization parameter. Use ``0`` for no regularization (default).
1641+
max_features : float, default=1.0
1642+
Proportion of randomly chosen features in each and every node split.
1643+
This is a form of regularization, smaller values make the trees weaker
1644+
learners and might prevent overfitting.
1645+
If interaction constraints from `interaction_cst` are present, only allowed
1646+
features are taken into account for the subsampling.
1647+
1648+
.. versionadded:: 1.4
1649+
16241650
max_bins : int, default=255
16251651
The maximum number of bins to use for non-missing values. Before
16261652
training, each feature of the input array `X` is binned into
@@ -1823,6 +1849,7 @@ def __init__(
18231849
max_depth=None,
18241850
min_samples_leaf=20,
18251851
l2_regularization=0.0,
1852+
max_features=1.0,
18261853
max_bins=255,
18271854
categorical_features=None,
18281855
monotonic_cst=None,
@@ -1845,6 +1872,7 @@ def __init__(
18451872
max_depth=max_depth,
18461873
min_samples_leaf=min_samples_leaf,
18471874
l2_regularization=l2_regularization,
1875+
max_features=max_features,
18481876
max_bins=max_bins,
18491877
categorical_features=categorical_features,
18501878
monotonic_cst=monotonic_cst,

sklearn/ensemble/_hist_gradient_boosting/grower.py

Lines changed: 34 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,10 @@ class TreeGrower:
164164
min_gain_to_split : float, default=0.
165165
The minimum gain needed to split a node. Splits with lower gain will
166166
be ignored.
167+
min_hessian_to_split : float, default=1e-3
168+
The minimum sum of hessians needed in each node. Splits that result in
169+
at least one child having a sum of hessians less than
170+
``min_hessian_to_split`` are discarded.
167171
n_bins : int, default=256
168172
The total number of bins, including the bin for missing values. Used
169173
to define the shape of the histograms.
@@ -189,10 +193,12 @@ class TreeGrower:
189193
List of interaction constraints.
190194
l2_regularization : float, default=0.
191195
The L2 regularization parameter.
192-
min_hessian_to_split : float, default=1e-3
193-
The minimum sum of hessians needed in each node. Splits that result in
194-
at least one child having a sum of hessians less than
195-
``min_hessian_to_split`` are discarded.
196+
feature_fraction_per_split : float, default=1
197+
Proportion of randomly chosen features in each and every node split.
198+
This is a form of regularization, smaller values make the trees weaker
199+
learners and might prevent overfitting.
200+
rng : Generator
201+
Numpy random Generator used for feature subsampling.
196202
shrinkage : float, default=1.
197203
The shrinkage parameter to apply to the leaves values, also known as
198204
learning rate.
@@ -234,14 +240,16 @@ def __init__(
234240
max_depth=None,
235241
min_samples_leaf=20,
236242
min_gain_to_split=0.0,
243+
min_hessian_to_split=1e-3,
237244
n_bins=256,
238245
n_bins_non_missing=None,
239246
has_missing_values=False,
240247
is_categorical=None,
241248
monotonic_cst=None,
242249
interaction_cst=None,
243250
l2_regularization=0.0,
244-
min_hessian_to_split=1e-3,
251+
feature_fraction_per_split=1.0,
252+
rng=np.random.default_rng(),
245253
shrinkage=1.0,
246254
n_threads=None,
247255
):
@@ -297,33 +305,35 @@ def __init__(
297305
)
298306
missing_values_bin_idx = n_bins - 1
299307
self.splitter = Splitter(
300-
X_binned,
301-
n_bins_non_missing,
302-
missing_values_bin_idx,
303-
has_missing_values,
304-
is_categorical,
305-
monotonic_cst,
306-
l2_regularization,
307-
min_hessian_to_split,
308-
min_samples_leaf,
309-
min_gain_to_split,
310-
hessians_are_constant,
311-
n_threads,
308+
X_binned=X_binned,
309+
n_bins_non_missing=n_bins_non_missing,
310+
missing_values_bin_idx=missing_values_bin_idx,
311+
has_missing_values=has_missing_values,
312+
is_categorical=is_categorical,
313+
monotonic_cst=monotonic_cst,
314+
l2_regularization=l2_regularization,
315+
min_hessian_to_split=min_hessian_to_split,
316+
min_samples_leaf=min_samples_leaf,
317+
min_gain_to_split=min_gain_to_split,
318+
hessians_are_constant=hessians_are_constant,
319+
feature_fraction_per_split=feature_fraction_per_split,
320+
rng=rng,
321+
n_threads=n_threads,
312322
)
323+
self.X_binned = X_binned
324+
self.max_leaf_nodes = max_leaf_nodes
325+
self.max_depth = max_depth
326+
self.min_samples_leaf = min_samples_leaf
327+
self.min_gain_to_split = min_gain_to_split
313328
self.n_bins_non_missing = n_bins_non_missing
314329
self.missing_values_bin_idx = missing_values_bin_idx
315-
self.max_leaf_nodes = max_leaf_nodes
316330
self.has_missing_values = has_missing_values
331+
self.is_categorical = is_categorical
317332
self.monotonic_cst = monotonic_cst
318333
self.interaction_cst = interaction_cst
319-
self.is_categorical = is_categorical
320334
self.l2_regularization = l2_regularization
321-
self.n_features = X_binned.shape[1]
322-
self.max_depth = max_depth
323-
self.min_samples_leaf = min_samples_leaf
324-
self.X_binned = X_binned
325-
self.min_gain_to_split = min_gain_to_split
326335
self.shrinkage = shrinkage
336+
self.n_features = X_binned.shape[1]
327337
self.n_threads = n_threads
328338
self.splittable_nodes = []
329339
self.finalized_leaves = []

sklearn/ensemble/_hist_gradient_boosting/splitting.pyx

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,9 @@
99

1010
cimport cython
1111
from cython.parallel import prange
12+
cimport numpy as cnp
1213
import numpy as np
13-
from libc.math cimport INFINITY
14+
from libc.math cimport INFINITY, ceil
1415
from libc.stdlib cimport malloc, free, qsort
1516
from libc.string cimport memcpy
1617

@@ -24,6 +25,8 @@ from ._bitset cimport init_bitset
2425
from ._bitset cimport set_bitset
2526
from ._bitset cimport in_bitset
2627

28+
cnp.import_array()
29+
2730

2831
cdef struct split_info_struct:
2932
# Same as the SplitInfo class, but we need a C struct to use it in the
@@ -155,6 +158,11 @@ cdef class Splitter:
155158
be ignored.
156159
hessians_are_constant: bool, default is False
157160
Whether hessians are constant.
161+
feature_fraction_per_split : float, default=1
162+
Proportion of randomly chosen features in each and every node split.
163+
This is a form of regularization, smaller values make the trees weaker
164+
learners and might prevent overfitting.
165+
rng : Generator
158166
n_threads : int, default=1
159167
Number of OpenMP threads to use.
160168
"""
@@ -171,6 +179,8 @@ cdef class Splitter:
171179
Y_DTYPE_C min_hessian_to_split
172180
unsigned int min_samples_leaf
173181
Y_DTYPE_C min_gain_to_split
182+
Y_DTYPE_C feature_fraction_per_split
183+
rng
174184

175185
unsigned int [::1] partition
176186
unsigned int [::1] left_indices_buffer
@@ -189,20 +199,24 @@ cdef class Splitter:
189199
unsigned int min_samples_leaf=20,
190200
Y_DTYPE_C min_gain_to_split=0.,
191201
unsigned char hessians_are_constant=False,
202+
Y_DTYPE_C feature_fraction_per_split=1.0,
203+
rng=np.random.RandomState(),
192204
unsigned int n_threads=1):
193205

194206
self.X_binned = X_binned
195207
self.n_features = X_binned.shape[1]
196208
self.n_bins_non_missing = n_bins_non_missing
197209
self.missing_values_bin_idx = missing_values_bin_idx
198210
self.has_missing_values = has_missing_values
199-
self.monotonic_cst = monotonic_cst
200211
self.is_categorical = is_categorical
212+
self.monotonic_cst = monotonic_cst
201213
self.l2_regularization = l2_regularization
202214
self.min_hessian_to_split = min_hessian_to_split
203215
self.min_samples_leaf = min_samples_leaf
204216
self.min_gain_to_split = min_gain_to_split
205217
self.hessians_are_constant = hessians_are_constant
218+
self.feature_fraction_per_split = feature_fraction_per_split
219+
self.rng = rng
206220
self.n_threads = n_threads
207221

208222
# The partition array maps each sample index into the leaves of the
@@ -475,20 +489,36 @@ cdef class Splitter:
475489
const signed char [::1] monotonic_cst = self.monotonic_cst
476490
int n_threads = self.n_threads
477491
bint has_interaction_cst = False
492+
Y_DTYPE_C feature_fraction_per_split = self.feature_fraction_per_split
493+
cnp.npy_bool [:] subsample_mask
494+
int n_subsampled_features
478495

479496
has_interaction_cst = allowed_features is not None
480497
if has_interaction_cst:
481498
n_allowed_features = allowed_features.shape[0]
482499
else:
483500
n_allowed_features = self.n_features
484501

502+
if feature_fraction_per_split < 1.0:
503+
# We do all random sampling before the nogil and make sure that we sample
504+
# exactly n_subsampled_features >= 1 features.
505+
n_subsampled_features = max(
506+
1,
507+
int(ceil(feature_fraction_per_split * n_allowed_features)),
508+
)
509+
subsample_mask_arr = np.full(n_allowed_features, False)
510+
subsample_mask_arr[:n_subsampled_features] = True
511+
self.rng.shuffle(subsample_mask_arr)
512+
# https://github.com/numpy/numpy/issues/18273
513+
subsample_mask = subsample_mask_arr
514+
485515
with nogil:
486516

487517
split_infos = <split_info_struct *> malloc(
488518
n_allowed_features * sizeof(split_info_struct))
489519

490-
# split_info_idx is index of split_infos of size n_features_allowed
491-
# features_idx is the index of the feature column in X
520+
# split_info_idx is index of split_infos of size n_allowed_features.
521+
# features_idx is the index of the feature column in X.
492522
for split_info_idx in prange(n_allowed_features, schedule='static',
493523
num_threads=n_threads):
494524
if has_interaction_cst:
@@ -506,6 +536,13 @@ cdef class Splitter:
506536
split_infos[split_info_idx].gain = -1
507537
split_infos[split_info_idx].is_categorical = is_categorical[feature_idx]
508538

539+
# Note that subsample_mask is indexed by split_info_idx and not by
540+
# feature_idx because we only need to exclude the same features again
541+
# and again. We do NOT need to access the features directly by using
542+
# allowed_features.
543+
if feature_fraction_per_split < 1.0 and not subsample_mask[split_info_idx]:
544+
continue
545+
509546
if is_categorical[feature_idx]:
510547
self._find_best_bin_to_split_category(
511548
feature_idx, has_missing_values[feature_idx],

0 commit comments

Comments
 (0)