diff --git a/sklearn/ensemble/_hist_gradient_boosting/grower.py b/sklearn/ensemble/_hist_gradient_boosting/grower.py index a71e564056f8f..c3dbbe7d82948 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/grower.py +++ b/sklearn/ensemble/_hist_gradient_boosting/grower.py @@ -16,7 +16,6 @@ from sklearn.utils._openmp_helpers import _openmp_effective_n_threads -from ...utils.arrayfuncs import sum_parallel from ._bitset import set_raw_bitset_from_binned_bitset from .common import ( PREDICTOR_RECORD_DTYPE, @@ -353,7 +352,7 @@ def __init__( self.total_compute_hist_time = 0.0 # time spent computing histograms self.total_apply_split_time = 0.0 # time spent splitting nodes self.n_categorical_splits = 0 - self._initialize_root(gradients, hessians) + self._initialize_root() self.n_nodes = 1 def _validate_parameters( @@ -401,15 +400,38 @@ def _apply_shrinkage(self): for leaf in self.finalized_leaves: leaf.value *= self.shrinkage - def _initialize_root(self, gradients, hessians): + def _initialize_root(self): """Initialize root node and finalize it if needed.""" + tic = time() + if self.interaction_cst is not None: + allowed_features = set().union(*self.interaction_cst) + allowed_features = np.fromiter( + allowed_features, dtype=np.uint32, count=len(allowed_features) + ) + arbitrary_feature = allowed_features[0] + else: + allowed_features = None + arbitrary_feature = 0 + + # TreeNode init needs the total sum of gradients and hessians. Therefore, we + # first compute the histograms and then compute the total grad/hess on an + # arbitrary feature histogram. This way we replace a loop over n_samples by a + # loop over n_bins. + histograms = self.histogram_builder.compute_histograms_brute( + self.splitter.partition, # =self.root.sample_indices + allowed_features, + ) + self.total_compute_hist_time += time() - tic + + tic = time() n_samples = self.X_binned.shape[0] depth = 0 - sum_gradients = sum_parallel(gradients, self.n_threads) + histogram_array = np.asarray(histograms[arbitrary_feature]) + sum_gradients = histogram_array["sum_gradients"].sum() if self.histogram_builder.hessians_are_constant: - sum_hessians = hessians[0] * n_samples + sum_hessians = self.histogram_builder.hessians[0] * n_samples else: - sum_hessians = sum_parallel(hessians, self.n_threads) + sum_hessians = histogram_array["sum_hessians"].sum() self.root = TreeNode( depth=depth, sample_indices=self.splitter.partition, @@ -430,18 +452,10 @@ def _initialize_root(self, gradients, hessians): if self.interaction_cst is not None: self.root.interaction_cst_indices = range(len(self.interaction_cst)) - allowed_features = set().union(*self.interaction_cst) - self.root.allowed_features = np.fromiter( - allowed_features, dtype=np.uint32, count=len(allowed_features) - ) + self.root.allowed_features = allowed_features - tic = time() - self.root.histograms = self.histogram_builder.compute_histograms_brute( - self.root.sample_indices, self.root.allowed_features - ) - self.total_compute_hist_time += time() - tic + self.root.histograms = histograms - tic = time() self._compute_best_split_and_push(self.root) self.total_find_split_time += time() - tic diff --git a/sklearn/utils/arrayfuncs.pyx b/sklearn/utils/arrayfuncs.pyx index 2cf98e0f5cc3e..951751fd08fed 100644 --- a/sklearn/utils/arrayfuncs.pyx +++ b/sklearn/utils/arrayfuncs.pyx @@ -1,12 +1,10 @@ """A small collection of auxiliary functions that operate on arrays.""" from cython cimport floating -from cython.parallel cimport prange from libc.math cimport fabs from libc.float cimport DBL_MAX, FLT_MAX from ._cython_blas cimport _copy, _rotg, _rot -from ._typedefs cimport float64_t ctypedef fused real_numeric: @@ -118,17 +116,3 @@ def cholesky_delete(floating[:, :] L, int go_out): L1 += m _rot(n - i - 2, L1 + i, m, L1 + i + 1, m, c, s) - - -def sum_parallel(const floating [:] array, int n_threads): - """Parallel sum, always using float64 internally.""" - cdef: - float64_t out = 0. - int i = 0 - - for i in prange( - array.shape[0], schedule='static', nogil=True, num_threads=n_threads - ): - out += array[i] - - return out diff --git a/sklearn/utils/meson.build b/sklearn/utils/meson.build index c7a6102b956e8..9bbfc01b7b6bf 100644 --- a/sklearn/utils/meson.build +++ b/sklearn/utils/meson.build @@ -18,7 +18,7 @@ utils_extension_metadata = { 'sparsefuncs_fast': {'sources': ['sparsefuncs_fast.pyx']}, '_cython_blas': {'sources': ['_cython_blas.pyx']}, - 'arrayfuncs': {'sources': ['arrayfuncs.pyx'], 'dependencies': [openmp_dep]}, + 'arrayfuncs': {'sources': ['arrayfuncs.pyx']}, 'murmurhash': { 'sources': ['murmurhash.pyx', 'src' / 'MurmurHash3.cpp'], },