Skip to content

ENH Uses _openmp_effective_n_threads to get the number of threads in HistGradientBoosting* #20477

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jul 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions doc/whats_new/v1.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,13 @@ Changelog
:mod:`sklearn.ensemble`
.......................

- |Enhancement| :class:`~sklearn.ensemble.HistGradientBoostingClassifier` and
:class:`~sklearn.ensemble.HistGradientBoostingRegressor` take cgroups quotas
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be worth adding a link to https://en.wikipedia.org/wiki/Cgroups for further context. I'm not sure we can expect readers to know what cgroups is

into account when deciding the number of threads used by OpenMP. This
avoids performance problems caused by over-subscription when using those
classes in a docker container for instance. :pr:`20477`
Comment on lines +268 to +269
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when using those classes in a docker container for instance

This is unrelated to docker, isn't it?
Some docker images will set CPU quotas, but some don't. I think that the over-generalization to all docker containers is confusing.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, this only affects linux machines right?

Copy link
Member

@jjerphan jjerphan Jul 12, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, this only affects linux machines right?

This will affect most machines because: the docker deamon needs a Linux kernel (to use cgroups and other features of it); this kernel generally is the host's OS's, or a virtual machine's running Linux.

I think Docker started developing support for Windows-native images, but this is rather niche.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My comment about linux was dissociated from the one about docker.

Basically, we should probably clarify that this change does not affect Windows or OSX users.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that it still affect Windows or OSX users because a Linux VM has to be used for Docker in those cases.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not talking about docker.

I am saying that as far as I understand, this entry will not affect users using scikit-learn on Windows or OSX.
(as long as they don't use docker, which is probably the vast majority of them).

by `Thomas Fan`_.

- |Fix| Do not allow to compute out-of-bag (OOB) score in
:class:`ensemble.RandomForestClassifier` and
:class:`ensemble.ExtraTreesClassifier` with multiclass-multioutput target
Expand Down
9 changes: 7 additions & 2 deletions sklearn/ensemble/_hist_gradient_boosting/_binning.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ np.import_array()
def _map_to_bins(const X_DTYPE_C [:, :] data,
list binning_thresholds,
const unsigned char missing_values_bin_idx,
int n_threads,
X_BINNED_DTYPE_C [::1, :] binned):
"""Bin continuous and categorical values to discrete integer-coded levels.

Expand All @@ -35,6 +36,8 @@ def _map_to_bins(const X_DTYPE_C [:, :] data,
binning_thresholds : list of arrays
For each feature, stores the increasing numeric values that are
used to separate the bins.
n_threads : int
Number of OpenMP threads to use.
binned : ndarray, shape (n_samples, n_features)
Output array, must be fortran aligned.
"""
Expand All @@ -45,12 +48,14 @@ def _map_to_bins(const X_DTYPE_C [:, :] data,
_map_col_to_bins(data[:, feature_idx],
binning_thresholds[feature_idx],
missing_values_bin_idx,
n_threads,
binned[:, feature_idx])


cdef void _map_col_to_bins(const X_DTYPE_C [:] data,
const X_DTYPE_C [:] binning_thresholds,
const unsigned char missing_values_bin_idx,
int n_threads,
X_BINNED_DTYPE_C [:] binned):
"""Binary search to find the bin index for each value in the data."""
cdef:
Expand All @@ -59,8 +64,8 @@ cdef void _map_col_to_bins(const X_DTYPE_C [:] data,
int right
int middle

for i in prange(data.shape[0], schedule='static', nogil=True):

for i in prange(data.shape[0], schedule='static', nogil=True,
num_threads=n_threads):
if isnan(data[i]):
binned[i] = missing_values_bin_idx
else:
Expand Down
13 changes: 9 additions & 4 deletions sklearn/ensemble/_hist_gradient_boosting/_gradient_boosting.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ np.import_array()

def _update_raw_predictions(
Y_DTYPE_C [::1] raw_predictions, # OUT
grower):
grower,
n_threads,
):
"""Update raw_predictions with the predictions of the newest tree.

This is equivalent to (and much faster than):
Expand All @@ -42,21 +44,24 @@ def _update_raw_predictions(
values = np.array([leaf.value for leaf in leaves], dtype=Y_DTYPE)

_update_raw_predictions_helper(raw_predictions, starts, stops, partition,
values)
values, n_threads)


cdef inline void _update_raw_predictions_helper(
Y_DTYPE_C [::1] raw_predictions, # OUT
const unsigned int [::1] starts,
const unsigned int [::1] stops,
const unsigned int [::1] partition,
const Y_DTYPE_C [::1] values):
const Y_DTYPE_C [::1] values,
int n_threads,
):

cdef:
unsigned int position
int leaf_idx
int n_leaves = starts.shape[0]

for leaf_idx in prange(n_leaves, schedule='static', nogil=True):
for leaf_idx in prange(n_leaves, schedule='static', nogil=True,
num_threads=n_threads):
for position in range(starts[leaf_idx], stops[leaf_idx]):
raw_predictions[partition[position]] += values[leaf_idx]
51 changes: 31 additions & 20 deletions sklearn/ensemble/_hist_gradient_boosting/_loss.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,16 @@ np.import_array()
def _update_gradients_least_squares(
G_H_DTYPE_C [::1] gradients, # OUT
const Y_DTYPE_C [::1] y_true, # IN
const Y_DTYPE_C [::1] raw_predictions): # IN
const Y_DTYPE_C [::1] raw_predictions, # IN
int n_threads, # IN
):

cdef:
int n_samples
int i

n_samples = raw_predictions.shape[0]
for i in prange(n_samples, schedule='static', nogil=True):
for i in prange(n_samples, schedule='static', nogil=True, num_threads=n_threads):
# Note: a more correct expression is 2 * (raw_predictions - y_true)
# but since we use 1 for the constant hessian value (and not 2) this
# is strictly equivalent for the leaves values.
Expand All @@ -40,14 +42,16 @@ def _update_gradients_hessians_least_squares(
G_H_DTYPE_C [::1] hessians, # OUT
const Y_DTYPE_C [::1] y_true, # IN
const Y_DTYPE_C [::1] raw_predictions, # IN
const Y_DTYPE_C [::1] sample_weight): # IN
const Y_DTYPE_C [::1] sample_weight, # IN
int n_threads, # IN
):

cdef:
int n_samples
int i

n_samples = raw_predictions.shape[0]
for i in prange(n_samples, schedule='static', nogil=True):
for i in prange(n_samples, schedule='static', nogil=True, num_threads=n_threads):
# Note: a more correct exp is 2 * (raw_predictions - y_true) * sample_weight
# but since we use 1 for the constant hessian value (and not 2) this
# is strictly equivalent for the leaves values.
Expand All @@ -60,14 +64,15 @@ def _update_gradients_hessians_least_absolute_deviation(
G_H_DTYPE_C [::1] hessians, # OUT
const Y_DTYPE_C [::1] y_true, # IN
const Y_DTYPE_C [::1] raw_predictions, # IN
const Y_DTYPE_C [::1] sample_weight): # IN

const Y_DTYPE_C [::1] sample_weight, # IN
int n_threads, # IN
):
cdef:
int n_samples
int i

n_samples = raw_predictions.shape[0]
for i in prange(n_samples, schedule='static', nogil=True):
for i in prange(n_samples, schedule='static', nogil=True, num_threads=n_threads):
# gradient = sign(raw_predicition - y_pred) * sample_weight
gradients[i] = sample_weight[i] * (2 *
(y_true[i] - raw_predictions[i] < 0) - 1)
Expand All @@ -77,14 +82,15 @@ def _update_gradients_hessians_least_absolute_deviation(
def _update_gradients_least_absolute_deviation(
G_H_DTYPE_C [::1] gradients, # OUT
const Y_DTYPE_C [::1] y_true, # IN
const Y_DTYPE_C [::1] raw_predictions): # IN

const Y_DTYPE_C [::1] raw_predictions, # IN
int n_threads, # IN
):
cdef:
int n_samples
int i

n_samples = raw_predictions.shape[0]
for i in prange(n_samples, schedule='static', nogil=True):
for i in prange(n_samples, schedule='static', nogil=True, num_threads=n_threads):
# gradient = sign(raw_predicition - y_pred)
gradients[i] = 2 * (y_true[i] - raw_predictions[i] < 0) - 1

Expand All @@ -94,23 +100,24 @@ def _update_gradients_hessians_poisson(
G_H_DTYPE_C [::1] hessians, # OUT
const Y_DTYPE_C [::1] y_true, # IN
const Y_DTYPE_C [::1] raw_predictions, # IN
const Y_DTYPE_C [::1] sample_weight): # IN

const Y_DTYPE_C [::1] sample_weight, # IN
int n_threads, # IN
):
cdef:
int n_samples
int i
Y_DTYPE_C y_pred

n_samples = raw_predictions.shape[0]
if sample_weight is None:
for i in prange(n_samples, schedule='static', nogil=True):
for i in prange(n_samples, schedule='static', nogil=True, num_threads=n_threads):
# Note: We use only half of the deviance loss. Therefore, there is
# no factor of 2.
y_pred = exp(raw_predictions[i])
gradients[i] = (y_pred - y_true[i])
hessians[i] = y_pred
else:
for i in prange(n_samples, schedule='static', nogil=True):
for i in prange(n_samples, schedule='static', nogil=True, num_threads=n_threads):
# Note: We use only half of the deviance loss. Therefore, there is
# no factor of 2.
y_pred = exp(raw_predictions[i])
Expand All @@ -123,20 +130,22 @@ def _update_gradients_hessians_binary_crossentropy(
G_H_DTYPE_C [::1] hessians, # OUT
const Y_DTYPE_C [::1] y_true, # IN
const Y_DTYPE_C [::1] raw_predictions, # IN
const Y_DTYPE_C [::1] sample_weight): # IN
const Y_DTYPE_C [::1] sample_weight, # IN
int n_threads, # IN
):
cdef:
int n_samples
Y_DTYPE_C p_i # proba that ith sample belongs to positive class
int i

n_samples = raw_predictions.shape[0]
if sample_weight is None:
for i in prange(n_samples, schedule='static', nogil=True):
for i in prange(n_samples, schedule='static', nogil=True, num_threads=n_threads):
p_i = _cexpit(raw_predictions[i])
gradients[i] = p_i - y_true[i]
hessians[i] = p_i * (1. - p_i)
else:
for i in prange(n_samples, schedule='static', nogil=True):
for i in prange(n_samples, schedule='static', nogil=True, num_threads=n_threads):
p_i = _cexpit(raw_predictions[i])
gradients[i] = (p_i - y_true[i]) * sample_weight[i]
hessians[i] = p_i * (1. - p_i) * sample_weight[i]
Expand All @@ -147,7 +156,9 @@ def _update_gradients_hessians_categorical_crossentropy(
G_H_DTYPE_C [:, ::1] hessians, # OUT
const Y_DTYPE_C [::1] y_true, # IN
const Y_DTYPE_C [:, ::1] raw_predictions, # IN
const Y_DTYPE_C [::1] sample_weight): # IN
const Y_DTYPE_C [::1] sample_weight, # IN
int n_threads, # IN
):
cdef:
int prediction_dim = raw_predictions.shape[0]
int n_samples = raw_predictions.shape[1]
Expand All @@ -160,7 +171,7 @@ def _update_gradients_hessians_categorical_crossentropy(
Y_DTYPE_C p_i_k

if sample_weight is None:
for i in prange(n_samples, schedule='static', nogil=True):
for i in prange(n_samples, schedule='static', nogil=True, num_threads=n_threads):
# first compute softmaxes of sample i for each class
for k in range(prediction_dim):
p[i, k] = raw_predictions[k, i] # prepare softmax
Expand All @@ -171,7 +182,7 @@ def _update_gradients_hessians_categorical_crossentropy(
gradients[k, i] = p_i_k - (y_true[i] == k)
hessians[k, i] = p_i_k * (1. - p_i_k)
else:
for i in prange(n_samples, schedule='static', nogil=True):
for i in prange(n_samples, schedule='static', nogil=True, num_threads=n_threads):
# first compute softmaxes of sample i for each class
for k in range(prediction_dim):
p[i, k] = raw_predictions[k, i] # prepare softmax
Expand Down
8 changes: 6 additions & 2 deletions sklearn/ensemble/_hist_gradient_boosting/_predictor.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,14 @@ def _predict_from_raw_data( # raw data = non-binned data
const BITSET_INNER_DTYPE_C [:, ::1] raw_left_cat_bitsets,
const BITSET_INNER_DTYPE_C [:, ::1] known_cat_bitsets,
const unsigned int [::1] f_idx_map,
int n_threads,
Y_DTYPE_C [:] out):

cdef:
int i

for i in prange(numeric_data.shape[0], schedule='static', nogil=True):
for i in prange(numeric_data.shape[0], schedule='static', nogil=True,
num_threads=n_threads):
out[i] = _predict_one_from_raw_data(
nodes, numeric_data, raw_left_cat_bitsets,
known_cat_bitsets,
Expand Down Expand Up @@ -95,12 +97,14 @@ def _predict_from_binned_data(
const X_BINNED_DTYPE_C [:, :] binned_data,
BITSET_INNER_DTYPE_C [:, :] binned_left_cat_bitsets,
const unsigned char missing_values_bin_idx,
int n_threads,
Y_DTYPE_C [:] out):

cdef:
int i

for i in prange(binned_data.shape[0], schedule='static', nogil=True):
for i in prange(binned_data.shape[0], schedule='static', nogil=True,
num_threads=n_threads):
out[i] = _predict_one_from_binned_data(nodes,
binned_data,
binned_left_cat_bitsets, i,
Expand Down
14 changes: 13 additions & 1 deletion sklearn/ensemble/_hist_gradient_boosting/binning.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from ...utils import check_random_state, check_array
from ...base import BaseEstimator, TransformerMixin
from ...utils.validation import check_is_fitted
from ...utils._openmp_helpers import _openmp_effective_n_threads
from ._binning import _map_to_bins
from .common import X_DTYPE, X_BINNED_DTYPE, ALMOST_INF, X_BITSET_INNER_DTYPE
from ._bitset import set_bitset_memoryview
Expand Down Expand Up @@ -115,6 +116,11 @@ class _BinMapper(TransformerMixin, BaseEstimator):
Pass an int for reproducible output across multiple
function calls.
See :term: `Glossary <random_state>`.
n_threads : int, default=None
Number of OpenMP threads to use. `_openmp_effective_n_threads` is called
to determine the effective number of threads use, which takes cgroups CPU
quotes into account. See the docstring of `_openmp_effective_n_threads`
for details.

Attributes
----------
Expand Down Expand Up @@ -151,12 +157,14 @@ def __init__(
is_categorical=None,
known_categories=None,
random_state=None,
n_threads=None,
):
self.n_bins = n_bins
self.subsample = subsample
self.is_categorical = is_categorical
self.known_categories = known_categories
self.random_state = random_state
self.n_threads = n_threads

def fit(self, X, y=None):
"""Fit data X by computing the binning thresholds.
Expand Down Expand Up @@ -264,8 +272,12 @@ def transform(self, X):
"This estimator was fitted with {} features but {} got passed "
"to transform()".format(self.n_bins_non_missing_.shape[0], X.shape[1])
)

n_threads = _openmp_effective_n_threads(self.n_threads)
binned = np.zeros_like(X, dtype=X_BINNED_DTYPE, order="F")
_map_to_bins(X, self.bin_thresholds_, self.missing_values_bin_idx_, binned)
_map_to_bins(
X, self.bin_thresholds_, self.missing_values_bin_idx_, n_threads, binned
)
return binned

def make_known_categories_bitsets(self):
Expand Down
Loading