Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 30 additions & 16 deletions sklearn/ensemble/_hist_gradient_boosting/grower.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

from sklearn.utils._openmp_helpers import _openmp_effective_n_threads

from ...utils.arrayfuncs import sum_parallel
from ._bitset import set_raw_bitset_from_binned_bitset
from .common import (
PREDICTOR_RECORD_DTYPE,
Expand Down Expand Up @@ -353,7 +352,7 @@ def __init__(
self.total_compute_hist_time = 0.0 # time spent computing histograms
self.total_apply_split_time = 0.0 # time spent splitting nodes
self.n_categorical_splits = 0
self._initialize_root(gradients, hessians)
self._initialize_root()
self.n_nodes = 1

def _validate_parameters(
Expand Down Expand Up @@ -401,15 +400,38 @@ def _apply_shrinkage(self):
for leaf in self.finalized_leaves:
leaf.value *= self.shrinkage

def _initialize_root(self, gradients, hessians):
def _initialize_root(self):
"""Initialize root node and finalize it if needed."""
tic = time()
if self.interaction_cst is not None:
allowed_features = set().union(*self.interaction_cst)
allowed_features = np.fromiter(
allowed_features, dtype=np.uint32, count=len(allowed_features)
)
arbitrary_feature = allowed_features[0]
else:
allowed_features = None
arbitrary_feature = 0

# TreeNode init needs the total sum of gradients and hessians. Therefore, we
# first compute the histograms and then compute the total grad/hess on an
# arbitrary feature histogram. This way we replace a loop over n_samples by a
# loop over n_bins.
histograms = self.histogram_builder.compute_histograms_brute(
self.splitter.partition, # =self.root.sample_indices
allowed_features,
)
self.total_compute_hist_time += time() - tic

tic = time()
n_samples = self.X_binned.shape[0]
depth = 0
sum_gradients = sum_parallel(gradients, self.n_threads)
histogram_array = np.asarray(histograms[arbitrary_feature])
sum_gradients = histogram_array["sum_gradients"].sum()
if self.histogram_builder.hessians_are_constant:
sum_hessians = hessians[0] * n_samples
sum_hessians = self.histogram_builder.hessians[0] * n_samples
else:
sum_hessians = sum_parallel(hessians, self.n_threads)
sum_hessians = histogram_array["sum_hessians"].sum()
self.root = TreeNode(
depth=depth,
sample_indices=self.splitter.partition,
Expand All @@ -430,18 +452,10 @@ def _initialize_root(self, gradients, hessians):

if self.interaction_cst is not None:
self.root.interaction_cst_indices = range(len(self.interaction_cst))
allowed_features = set().union(*self.interaction_cst)
self.root.allowed_features = np.fromiter(
allowed_features, dtype=np.uint32, count=len(allowed_features)
)
self.root.allowed_features = allowed_features

tic = time()
self.root.histograms = self.histogram_builder.compute_histograms_brute(
self.root.sample_indices, self.root.allowed_features
)
self.total_compute_hist_time += time() - tic
self.root.histograms = histograms

tic = time()
self._compute_best_split_and_push(self.root)
self.total_find_split_time += time() - tic

Expand Down
16 changes: 0 additions & 16 deletions sklearn/utils/arrayfuncs.pyx
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
"""A small collection of auxiliary functions that operate on arrays."""

from cython cimport floating
from cython.parallel cimport prange
from libc.math cimport fabs
from libc.float cimport DBL_MAX, FLT_MAX

from ._cython_blas cimport _copy, _rotg, _rot
from ._typedefs cimport float64_t


ctypedef fused real_numeric:
Expand Down Expand Up @@ -118,17 +116,3 @@ def cholesky_delete(floating[:, :] L, int go_out):
L1 += m

_rot(n - i - 2, L1 + i, m, L1 + i + 1, m, c, s)


def sum_parallel(const floating [:] array, int n_threads):
"""Parallel sum, always using float64 internally."""
cdef:
float64_t out = 0.
int i = 0

for i in prange(
array.shape[0], schedule='static', nogil=True, num_threads=n_threads
):
out += array[i]

return out
2 changes: 1 addition & 1 deletion sklearn/utils/meson.build
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ utils_extension_metadata = {
'sparsefuncs_fast':
{'sources': ['sparsefuncs_fast.pyx']},
'_cython_blas': {'sources': ['_cython_blas.pyx']},
'arrayfuncs': {'sources': ['arrayfuncs.pyx'], 'dependencies': [openmp_dep]},
'arrayfuncs': {'sources': ['arrayfuncs.pyx']},
'murmurhash': {
'sources': ['murmurhash.pyx', 'src' / 'MurmurHash3.cpp'],
},
Expand Down