Skip to content

Commit 8c788e4

Browse files
committed
ENH Uses _openmp_effective_n_threads to get the number of threads
1 parent f2fb075 commit 8c788e4

20 files changed

+259
-101
lines changed

doc/whats_new/v1.0.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,11 @@ Changelog
262262
:mod:`sklearn.ensemble`
263263
.......................
264264

265+
- |Enhancement| :class:`~sklearn.ensemble.HistGradientBoostingClassifier` and
266+
:class:`~sklearn.ensemble.HistGradientBoostingRegressor` takes cgroups quotas
267+
into account when deciding the number of threads used for OpenMP. :pr:`xxxx`
268+
by `Thomas Fan`_.
269+
265270
- |Fix| Do not allow to compute out-of-bag (OOB) score in
266271
:class:`ensemble.RandomForestClassifier` and
267272
:class:`ensemble.ExtraTreesClassifier` with multiclass-multioutput target

sklearn/ensemble/_hist_gradient_boosting/_binning.pyx

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ np.import_array()
2222
def _map_to_bins(const X_DTYPE_C [:, :] data,
2323
list binning_thresholds,
2424
const unsigned char missing_values_bin_idx,
25+
int n_threads,
2526
X_BINNED_DTYPE_C [::1, :] binned):
2627
"""Bin continuous and categorical values to discrete integer-coded levels.
2728
@@ -37,6 +38,8 @@ def _map_to_bins(const X_DTYPE_C [:, :] data,
3738
used to separate the bins.
3839
binned : ndarray, shape (n_samples, n_features)
3940
Output array, must be fortran aligned.
41+
n_threads : int
42+
Number of threads for openmp
4043
"""
4144
cdef:
4245
int feature_idx
@@ -45,12 +48,14 @@ def _map_to_bins(const X_DTYPE_C [:, :] data,
4548
_map_col_to_bins(data[:, feature_idx],
4649
binning_thresholds[feature_idx],
4750
missing_values_bin_idx,
51+
n_threads,
4852
binned[:, feature_idx])
4953

5054

5155
cdef void _map_col_to_bins(const X_DTYPE_C [:] data,
5256
const X_DTYPE_C [:] binning_thresholds,
5357
const unsigned char missing_values_bin_idx,
58+
int n_threads,
5459
X_BINNED_DTYPE_C [:] binned):
5560
"""Binary search to find the bin index for each value in the data."""
5661
cdef:
@@ -59,8 +64,8 @@ cdef void _map_col_to_bins(const X_DTYPE_C [:] data,
5964
int right
6065
int middle
6166

62-
for i in prange(data.shape[0], schedule='static', nogil=True):
63-
67+
for i in prange(data.shape[0], schedule='static', nogil=True,
68+
num_threads=n_threads):
6469
if isnan(data[i]):
6570
binned[i] = missing_values_bin_idx
6671
else:

sklearn/ensemble/_hist_gradient_boosting/_gradient_boosting.pyx

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@ np.import_array()
1818

1919
def _update_raw_predictions(
2020
Y_DTYPE_C [::1] raw_predictions, # OUT
21-
grower):
21+
grower,
22+
n_threads,
23+
):
2224
"""Update raw_predictions with the predictions of the newest tree.
2325
2426
This is equivalent to (and much faster than):
@@ -42,21 +44,24 @@ def _update_raw_predictions(
4244
values = np.array([leaf.value for leaf in leaves], dtype=Y_DTYPE)
4345

4446
_update_raw_predictions_helper(raw_predictions, starts, stops, partition,
45-
values)
47+
values, n_threads)
4648

4749

4850
cdef inline void _update_raw_predictions_helper(
4951
Y_DTYPE_C [::1] raw_predictions, # OUT
5052
const unsigned int [::1] starts,
5153
const unsigned int [::1] stops,
5254
const unsigned int [::1] partition,
53-
const Y_DTYPE_C [::1] values):
55+
const Y_DTYPE_C [::1] values,
56+
int n_threads,
57+
):
5458

5559
cdef:
5660
unsigned int position
5761
int leaf_idx
5862
int n_leaves = starts.shape[0]
5963

60-
for leaf_idx in prange(n_leaves, schedule='static', nogil=True):
64+
for leaf_idx in prange(n_leaves, schedule='static', nogil=True,
65+
num_threads=n_threads):
6166
for position in range(starts[leaf_idx], stops[leaf_idx]):
6267
raw_predictions[partition[position]] += values[leaf_idx]

sklearn/ensemble/_hist_gradient_boosting/_loss.pyx

Lines changed: 31 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,16 @@ np.import_array()
2121
def _update_gradients_least_squares(
2222
G_H_DTYPE_C [::1] gradients, # OUT
2323
const Y_DTYPE_C [::1] y_true, # IN
24-
const Y_DTYPE_C [::1] raw_predictions): # IN
24+
const Y_DTYPE_C [::1] raw_predictions, # IN
25+
int n_threads, # IN
26+
):
2527

2628
cdef:
2729
int n_samples
2830
int i
2931

3032
n_samples = raw_predictions.shape[0]
31-
for i in prange(n_samples, schedule='static', nogil=True):
33+
for i in prange(n_samples, schedule='static', nogil=True, num_threads=n_threads):
3234
# Note: a more correct expression is 2 * (raw_predictions - y_true)
3335
# but since we use 1 for the constant hessian value (and not 2) this
3436
# is strictly equivalent for the leaves values.
@@ -40,14 +42,16 @@ def _update_gradients_hessians_least_squares(
4042
G_H_DTYPE_C [::1] hessians, # OUT
4143
const Y_DTYPE_C [::1] y_true, # IN
4244
const Y_DTYPE_C [::1] raw_predictions, # IN
43-
const Y_DTYPE_C [::1] sample_weight): # IN
45+
const Y_DTYPE_C [::1] sample_weight, # IN
46+
int n_threads, # IN
47+
):
4448

4549
cdef:
4650
int n_samples
4751
int i
4852

4953
n_samples = raw_predictions.shape[0]
50-
for i in prange(n_samples, schedule='static', nogil=True):
54+
for i in prange(n_samples, schedule='static', nogil=True, num_threads=n_threads):
5155
# Note: a more correct exp is 2 * (raw_predictions - y_true) * sample_weight
5256
# but since we use 1 for the constant hessian value (and not 2) this
5357
# is strictly equivalent for the leaves values.
@@ -60,14 +64,15 @@ def _update_gradients_hessians_least_absolute_deviation(
6064
G_H_DTYPE_C [::1] hessians, # OUT
6165
const Y_DTYPE_C [::1] y_true, # IN
6266
const Y_DTYPE_C [::1] raw_predictions, # IN
63-
const Y_DTYPE_C [::1] sample_weight): # IN
64-
67+
const Y_DTYPE_C [::1] sample_weight, # IN
68+
int n_threads, # IN
69+
):
6570
cdef:
6671
int n_samples
6772
int i
6873

6974
n_samples = raw_predictions.shape[0]
70-
for i in prange(n_samples, schedule='static', nogil=True):
75+
for i in prange(n_samples, schedule='static', nogil=True, num_threads=n_threads):
7176
# gradient = sign(raw_predicition - y_pred) * sample_weight
7277
gradients[i] = sample_weight[i] * (2 *
7378
(y_true[i] - raw_predictions[i] < 0) - 1)
@@ -77,14 +82,15 @@ def _update_gradients_hessians_least_absolute_deviation(
7782
def _update_gradients_least_absolute_deviation(
7883
G_H_DTYPE_C [::1] gradients, # OUT
7984
const Y_DTYPE_C [::1] y_true, # IN
80-
const Y_DTYPE_C [::1] raw_predictions): # IN
81-
85+
const Y_DTYPE_C [::1] raw_predictions, # IN
86+
int n_threads, # IN
87+
):
8288
cdef:
8389
int n_samples
8490
int i
8591

8692
n_samples = raw_predictions.shape[0]
87-
for i in prange(n_samples, schedule='static', nogil=True):
93+
for i in prange(n_samples, schedule='static', nogil=True, num_threads=n_threads):
8894
# gradient = sign(raw_predicition - y_pred)
8995
gradients[i] = 2 * (y_true[i] - raw_predictions[i] < 0) - 1
9096

@@ -94,23 +100,24 @@ def _update_gradients_hessians_poisson(
94100
G_H_DTYPE_C [::1] hessians, # OUT
95101
const Y_DTYPE_C [::1] y_true, # IN
96102
const Y_DTYPE_C [::1] raw_predictions, # IN
97-
const Y_DTYPE_C [::1] sample_weight): # IN
98-
103+
const Y_DTYPE_C [::1] sample_weight, # IN
104+
int n_threads, # IN
105+
):
99106
cdef:
100107
int n_samples
101108
int i
102109
Y_DTYPE_C y_pred
103110

104111
n_samples = raw_predictions.shape[0]
105112
if sample_weight is None:
106-
for i in prange(n_samples, schedule='static', nogil=True):
113+
for i in prange(n_samples, schedule='static', nogil=True, num_threads=n_threads):
107114
# Note: We use only half of the deviance loss. Therefore, there is
108115
# no factor of 2.
109116
y_pred = exp(raw_predictions[i])
110117
gradients[i] = (y_pred - y_true[i])
111118
hessians[i] = y_pred
112119
else:
113-
for i in prange(n_samples, schedule='static', nogil=True):
120+
for i in prange(n_samples, schedule='static', nogil=True, num_threads=n_threads):
114121
# Note: We use only half of the deviance loss. Therefore, there is
115122
# no factor of 2.
116123
y_pred = exp(raw_predictions[i])
@@ -123,20 +130,22 @@ def _update_gradients_hessians_binary_crossentropy(
123130
G_H_DTYPE_C [::1] hessians, # OUT
124131
const Y_DTYPE_C [::1] y_true, # IN
125132
const Y_DTYPE_C [::1] raw_predictions, # IN
126-
const Y_DTYPE_C [::1] sample_weight): # IN
133+
const Y_DTYPE_C [::1] sample_weight, # IN
134+
int n_threads, # IN
135+
):
127136
cdef:
128137
int n_samples
129138
Y_DTYPE_C p_i # proba that ith sample belongs to positive class
130139
int i
131140

132141
n_samples = raw_predictions.shape[0]
133142
if sample_weight is None:
134-
for i in prange(n_samples, schedule='static', nogil=True):
143+
for i in prange(n_samples, schedule='static', nogil=True, num_threads=n_threads):
135144
p_i = _cexpit(raw_predictions[i])
136145
gradients[i] = p_i - y_true[i]
137146
hessians[i] = p_i * (1. - p_i)
138147
else:
139-
for i in prange(n_samples, schedule='static', nogil=True):
148+
for i in prange(n_samples, schedule='static', nogil=True, num_threads=n_threads):
140149
p_i = _cexpit(raw_predictions[i])
141150
gradients[i] = (p_i - y_true[i]) * sample_weight[i]
142151
hessians[i] = p_i * (1. - p_i) * sample_weight[i]
@@ -147,7 +156,9 @@ def _update_gradients_hessians_categorical_crossentropy(
147156
G_H_DTYPE_C [:, ::1] hessians, # OUT
148157
const Y_DTYPE_C [::1] y_true, # IN
149158
const Y_DTYPE_C [:, ::1] raw_predictions, # IN
150-
const Y_DTYPE_C [::1] sample_weight): # IN
159+
const Y_DTYPE_C [::1] sample_weight, # IN
160+
int n_threads, # IN
161+
):
151162
cdef:
152163
int prediction_dim = raw_predictions.shape[0]
153164
int n_samples = raw_predictions.shape[1]
@@ -160,7 +171,7 @@ def _update_gradients_hessians_categorical_crossentropy(
160171
Y_DTYPE_C p_i_k
161172

162173
if sample_weight is None:
163-
for i in prange(n_samples, schedule='static', nogil=True):
174+
for i in prange(n_samples, schedule='static', nogil=True, num_threads=n_threads):
164175
# first compute softmaxes of sample i for each class
165176
for k in range(prediction_dim):
166177
p[i, k] = raw_predictions[k, i] # prepare softmax
@@ -171,7 +182,7 @@ def _update_gradients_hessians_categorical_crossentropy(
171182
gradients[k, i] = p_i_k - (y_true[i] == k)
172183
hessians[k, i] = p_i_k * (1. - p_i_k)
173184
else:
174-
for i in prange(n_samples, schedule='static', nogil=True):
185+
for i in prange(n_samples, schedule='static', nogil=True, num_threads=n_threads):
175186
# first compute softmaxes of sample i for each class
176187
for k in range(prediction_dim):
177188
p[i, k] = raw_predictions[k, i] # prepare softmax

sklearn/ensemble/_hist_gradient_boosting/_predictor.pyx

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,14 @@ def _predict_from_raw_data( # raw data = non-binned data
3030
const BITSET_INNER_DTYPE_C [:, ::1] raw_left_cat_bitsets,
3131
const BITSET_INNER_DTYPE_C [:, ::1] known_cat_bitsets,
3232
const unsigned int [::1] f_idx_map,
33+
int n_threads,
3334
Y_DTYPE_C [:] out):
3435

3536
cdef:
3637
int i
3738

38-
for i in prange(numeric_data.shape[0], schedule='static', nogil=True):
39+
for i in prange(numeric_data.shape[0], schedule='static', nogil=True,
40+
num_threads=n_threads):
3941
out[i] = _predict_one_from_raw_data(
4042
nodes, numeric_data, raw_left_cat_bitsets,
4143
known_cat_bitsets,
@@ -95,12 +97,14 @@ def _predict_from_binned_data(
9597
const X_BINNED_DTYPE_C [:, :] binned_data,
9698
BITSET_INNER_DTYPE_C [:, :] binned_left_cat_bitsets,
9799
const unsigned char missing_values_bin_idx,
100+
int n_threads,
98101
Y_DTYPE_C [:] out):
99102

100103
cdef:
101104
int i
102105

103-
for i in prange(binned_data.shape[0], schedule='static', nogil=True):
106+
for i in prange(binned_data.shape[0], schedule='static', nogil=True,
107+
num_threads=n_threads):
104108
out[i] = _predict_one_from_binned_data(nodes,
105109
binned_data,
106110
binned_left_cat_bitsets, i,

sklearn/ensemble/_hist_gradient_boosting/binning.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from ...utils import check_random_state, check_array
1313
from ...base import BaseEstimator, TransformerMixin
1414
from ...utils.validation import check_is_fitted
15+
from ...utils._openmp_helpers import _openmp_effective_n_threads
1516
from ._binning import _map_to_bins
1617
from .common import X_DTYPE, X_BINNED_DTYPE, ALMOST_INF, X_BITSET_INNER_DTYPE
1718
from ._bitset import set_bitset_memoryview
@@ -264,8 +265,12 @@ def transform(self, X):
264265
"This estimator was fitted with {} features but {} got passed "
265266
"to transform()".format(self.n_bins_non_missing_.shape[0], X.shape[1])
266267
)
268+
269+
n_threads = _openmp_effective_n_threads()
267270
binned = np.zeros_like(X, dtype=X_BINNED_DTYPE, order="F")
268-
_map_to_bins(X, self.bin_thresholds_, self.missing_values_bin_idx_, binned)
271+
_map_to_bins(
272+
X, self.bin_thresholds_, self.missing_values_bin_idx_, n_threads, binned
273+
)
269274
return binned
270275

271276
def make_known_categories_bitsets(self):

sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
check_consistent_length,
1515
_check_sample_weight,
1616
)
17+
from ...utils._openmp_helpers import _openmp_effective_n_threads
1718
from ...utils.multiclass import check_classification_targets
1819
from ...metrics import check_scoring
1920
from ...model_selection import train_test_split
@@ -478,6 +479,7 @@ def fit(self, X, y, sample_weight=None):
478479
sample_weight=sample_weight_train,
479480
)
480481

482+
n_threads = _openmp_effective_n_threads()
481483
for iteration in range(begin_at_stage, self.max_iter):
482484

483485
if self.verbose:
@@ -488,7 +490,12 @@ def fit(self, X, y, sample_weight=None):
488490

489491
# Update gradients and hessians, inplace
490492
self._loss.update_gradients_and_hessians(
491-
gradients, hessians, y_train, raw_predictions, sample_weight_train
493+
gradients,
494+
hessians,
495+
y_train,
496+
raw_predictions,
497+
sample_weight_train,
498+
n_threads,
492499
)
493500

494501
# Append a list since there may be more than 1 predictor per iter
@@ -510,6 +517,7 @@ def fit(self, X, y, sample_weight=None):
510517
min_samples_leaf=self.min_samples_leaf,
511518
l2_regularization=self.l2_regularization,
512519
shrinkage=self.learning_rate,
520+
n_threads=n_threads,
513521
)
514522
grower.grow()
515523

@@ -530,7 +538,7 @@ def fit(self, X, y, sample_weight=None):
530538
# Update raw_predictions with the predictions of the newly
531539
# created tree.
532540
tic_pred = time()
533-
_update_raw_predictions(raw_predictions[k, :], grower)
541+
_update_raw_predictions(raw_predictions[k, :], grower, n_threads)
534542
toc_pred = time()
535543
acc_prediction_time += toc_pred - tic_pred
536544

@@ -541,7 +549,9 @@ def fit(self, X, y, sample_weight=None):
541549
if self._use_validation_data:
542550
for k, pred in enumerate(self._predictors[-1]):
543551
raw_predictions_val[k, :] += pred.predict_binned(
544-
X_binned_val, self._bin_mapper.missing_values_bin_idx_
552+
X_binned_val,
553+
self._bin_mapper.missing_values_bin_idx_,
554+
n_threads,
545555
)
546556

547557
should_early_stop = self._check_early_stopping_loss(
@@ -848,18 +858,21 @@ def _predict_iterations(self, X, predictors, raw_predictions, is_binned):
848858
f_idx_map,
849859
) = self._bin_mapper.make_known_categories_bitsets()
850860

861+
n_threads = _openmp_effective_n_threads()
851862
for predictors_of_ith_iteration in predictors:
852863
for k, predictor in enumerate(predictors_of_ith_iteration):
853864
if is_binned:
854865
predict = partial(
855866
predictor.predict_binned,
856-
missing_values_bin_idx=self._bin_mapper.missing_values_bin_idx_, # noqa
867+
missing_values_bin_idx=self._bin_mapper.missing_values_bin_idx_,
868+
n_threads=n_threads,
857869
)
858870
else:
859871
predict = partial(
860872
predictor.predict,
861873
known_cat_bitsets=known_cat_bitsets,
862874
f_idx_map=f_idx_map,
875+
n_threads=n_threads,
863876
)
864877
raw_predictions[k, :] += predict(X)
865878

0 commit comments

Comments
 (0)