diff --git a/benchmarks/bench_hist_gradient_boosting_adult.py b/benchmarks/bench_hist_gradient_boosting_adult.py new file mode 100644 index 0000000000000..80447edec34bc --- /dev/null +++ b/benchmarks/bench_hist_gradient_boosting_adult.py @@ -0,0 +1,88 @@ +import argparse +from time import time + +import numpy as np +from sklearn.model_selection import train_test_split +from sklearn.datasets import fetch_openml +from sklearn.metrics import accuracy_score, roc_auc_score +from sklearn.experimental import enable_hist_gradient_boosting # noqa +from sklearn.ensemble import HistGradientBoostingClassifier +from sklearn.ensemble._hist_gradient_boosting.utils import ( + get_equivalent_estimator) + + +parser = argparse.ArgumentParser() +parser.add_argument('--n-leaf-nodes', type=int, default=31) +parser.add_argument('--n-trees', type=int, default=40) +parser.add_argument('--lightgbm', action="store_true", default=False) +parser.add_argument('--learning-rate', type=float, default=1.) +parser.add_argument('--max-bins', type=int, default=255) +parser.add_argument('--no-predict', action="store_true", default=False) +args = parser.parse_args() + +n_leaf_nodes = args.n_leaf_nodes +n_trees = args.n_trees +lr = args.learning_rate +max_bins = args.max_bins + + +def fit(est, data_train, target_train, libname, **fit_params): + print(f"Fitting a {libname} model...") + tic = time() + est.fit(data_train, target_train, **fit_params) + toc = time() + print(f"fitted in {toc - tic:.3f}s") + + +def predict(est, data_test, target_test): + if args.no_predict: + return + tic = time() + predicted_test = est.predict(data_test) + predicted_proba_test = est.predict_proba(data_test) + toc = time() + roc_auc = roc_auc_score(target_test, predicted_proba_test[:, 1]) + acc = accuracy_score(target_test, predicted_test) + print(f"predicted in {toc - tic:.3f}s, " + f"ROC AUC: {roc_auc:.4f}, ACC: {acc :.4f}") + + +data, target = fetch_openml(data_id=179, as_frame=True, return_X_y=True) + +# does not support categories in encoding y yet +target = target.cat.codes + +n_features = data.shape[1] +is_categorical = data.dtypes == 'category' +n_categorical_features = is_categorical.sum() +n_numerical_features = (data.dtypes == 'float').sum() +print(f"Number of features: {data.shape[1]}") +print(f"Number of categorical features: {n_categorical_features}") +print(f"Number of numerical features: {n_numerical_features}") + +categorical_features = np.flatnonzero(is_categorical) +for i in categorical_features: + data.iloc[:, i] = data.iloc[:, i].cat.codes + +data_train, data_test, target_train, target_test = train_test_split( + data, target, test_size=.2, random_state=0) + +est = HistGradientBoostingClassifier(loss='binary_crossentropy', + learning_rate=lr, + max_iter=n_trees, + max_bins=max_bins, + categorical_features=categorical_features, + max_leaf_nodes=n_leaf_nodes, + early_stopping=False, + random_state=0, + verbose=1) + +fit(est, data_train, target_train, 'sklearn') +predict(est, data_test, target_test) + +# lightgbm infers the categories from the dtype +if args.lightgbm: + est = get_equivalent_estimator(est, lib='lightgbm') + fit(est, data_train, target_train, 'lightgbm', + categorical_feature=is_categorical[is_categorical].index.tolist()) + predict(est, data_test, target_test) diff --git a/doc/modules/ensemble.rst b/doc/modules/ensemble.rst index 5c2e9f94c1b4d..d6411b48b6c37 100644 --- a/doc/modules/ensemble.rst +++ b/doc/modules/ensemble.rst @@ -1051,6 +1051,43 @@ multiplying the gradients (and the hessians) by the sample weights. Note that the binning stage (specifically the quantiles computation) does not take the weights into account. +.. _categorical_support_gbdt: + +Categorical Features Support +---------------------------- + +For datasets with categorical data, :class:`HistGradientBoostingClassifier` +and :class:`HistGradientBoostingRegressor` have native support for splitting +on categorical features. This is often better than one hot encoding because +it leads to faster training times and trees with less depth. The canonical way +of considering categorical splits is to consider all of the :math:`2^{K - 1} - +1` partitions where `K` is the number of categories. This can quickly become +prohibitive when `K` is large. Fortunately, since gradient boosting trees are +always regression trees (even for classification problems), there exist a +faster strategy that can yield equivalent splits. First, the categories of a +feature are sorted according to the ratio `sum_gradient_k / sum_hessians_k` of +each category `k`. Once the categories are sorted, one can consider *continuous +partitions*, i.e. treat the categories as if they were ordered continuous +values (see Fisher [Fisher1958]_ for a formal proof). As a result, only `K - 1` +splits need to be considered instead of :math:`2^{K - 1} - 1`. + +If there are missing values during training, the missing values will be +considered as a single category. When predicting, categories that were unknown +during fit time, will be consider missing. If the cardinality of a categorical +feature is greater than `max_bins`, then the top `max_bins` categories based on +cardinality will be kept, and the less frequent categories will be considered +as missing. + +To enable categorical support, a boolean mask can be passed to the +`categorical_features` parameter. In the following, the first feature will be +treated as categorical and the second feature as nummerical:: + + >>> gbdt = HistGradientBoostingClassifier(categorical_features=[True, False]) + +.. topic:: Examples: + + * :ref:`sphx_glr_auto_examples_ensemble_plot_gradient_boosting_categorical.py` + .. _monotonic_cst_gbdt: Monotonic Constraints @@ -1158,6 +1195,8 @@ Finally, many parts of the implementation of .. [LightGBM] Ke et. al. `"LightGBM: A Highly Efficient Gradient BoostingDecision Tree" `_ + .. [Fisher1958] Walter D. Fisher. `"On Grouping for Maximum Homogeneity" + `_ .. _voting_classifier: diff --git a/examples/ensemble/plot_gradient_boosting_categorical.py b/examples/ensemble/plot_gradient_boosting_categorical.py new file mode 100644 index 0000000000000..20500953c5285 --- /dev/null +++ b/examples/ensemble/plot_gradient_boosting_categorical.py @@ -0,0 +1,91 @@ +""" +======================================== +Categorical Support in Gradient Boosting +======================================== + +.. currentmodule:: sklearn + +In this example, we will compare the performance of +:class:`~ensemble.HistGradientBoostingRegressor` using one hot encoding +and with native categorical support. + +We will work with the Ames Lowa Housing dataset which consists of numerical +and categorical features, where the houses' sales prices is the target. +""" +############################################################################## +# Load Ames Housing dataset +# ------------------------- +# First, we load the ames housing data as a pandas dataframe. The features +# are either categorical or numerical: +print(__doc__) + +from sklearn.datasets import fetch_openml + +X, y = fetch_openml(data_id=41211, as_frame=True, return_X_y=True) + +n_features = X.shape[1] +n_categorical_features = (X.dtypes == 'category').sum() +n_numerical_features = (X.dtypes == 'float').sum() +print(f"Number of features: {X.shape[1]}") +print(f"Number of categorical featuers: {n_categorical_features}") +print(f"Number of numerical featuers: {n_numerical_features}") + +############################################################################## +# Create gradient boosting estimator with one hot encoding +# -------------------------------------------------------- +# Next, we create a pipeline that will one hot encode the categorical features +# and let rest of the numerical data to passthrough: + +from sklearn.experimental import enable_hist_gradient_boosting # noqa +from sklearn.ensemble import HistGradientBoostingRegressor +from sklearn.pipeline import make_pipeline +from sklearn.compose import make_column_transformer +from sklearn.compose import make_column_selector +from sklearn.preprocessing import OneHotEncoder + +preprocessor = make_column_transformer( + (OneHotEncoder(sparse=False, handle_unknown='ignore'), + make_column_selector(dtype_include='category')), + remainder='passthrough') + +hist_one_hot = make_pipeline(preprocessor, + HistGradientBoostingRegressor(random_state=42)) + +############################################################################## +# Create gradient boosting estimator with native categorical support +# ------------------------------------------------------------------ +# The :class:`~ensemble.HistGradientBoostingRegressor` has native support +# for categorical features using the `categorical_features` parameter: + +hist_native = HistGradientBoostingRegressor(categorical_features='pandas', + random_state=42) + +############################################################################## +# Train the models with cross-validation +# -------------------------------- +# Finally, we train the models using cross validation. Here we compare the +# models performance in terms of :func:`~metrics.r2_score` and fit times. We +# show that fit times are faster with native categorical support and that the +# test scores and scores times are comparable: + +from sklearn.model_selection import cross_validate +import matplotlib.pyplot as plt +import numpy as np + +one_hot_result = cross_validate(hist_one_hot, X, y) +native_result = cross_validate(hist_native, X, y) + +fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 8)) + +plot_info = [('fit_time', 'Fit times (s)', ax1), + ('test_score', 'Test Scores (r2 score)', ax2)] + +x, width = np.arange(2), 0.9 +for key, title, ax in plot_info: + items = [native_result[key], one_hot_result[key]] + ax.bar(x, [np.mean(item) for item in items], + width, yerr=[np.std(item) for item in items], + color=['b', 'r']) + ax.set(xlabel='Split number', title=title, xticks=[0, 1], + xticklabels=['Native', "One Hot"]) +plt.show() diff --git a/sklearn/ensemble/_hist_gradient_boosting/_binning.pyx b/sklearn/ensemble/_hist_gradient_boosting/_binning.pyx index 7c36a2366806f..67b9c3724af3b 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/_binning.pyx +++ b/sklearn/ensemble/_hist_gradient_boosting/_binning.pyx @@ -15,6 +15,7 @@ from cython.parallel import prange from libc.math cimport isnan from .common cimport X_DTYPE_C, X_BINNED_DTYPE_C +from ._cat_mapper cimport CategoryMapper np.import_array() @@ -22,8 +23,11 @@ np.import_array() def _map_to_bins(const X_DTYPE_C [:, :] data, list binning_thresholds, const unsigned char missing_values_bin_idx, + CategoryMapper category_mapper, + const unsigned char[::1] is_categorical, X_BINNED_DTYPE_C [::1, :] binned): """Bin numerical values to discrete integer-coded levels. + TODO docstring needs update Parameters ---------- @@ -32,6 +36,8 @@ def _map_to_bins(const X_DTYPE_C [:, :] data, binning_thresholds : list of arrays For each feature, stores the increasing numeric values that are used to separate the bins. + is_categorical : ndarray, shape (n_features,) + Indicates categorical features. binned : ndarray, shape (n_samples, n_features) Output array, must be fortran aligned. """ @@ -39,10 +45,15 @@ def _map_to_bins(const X_DTYPE_C [:, :] data, int feature_idx for feature_idx in range(data.shape[1]): - _map_num_col_to_bins(data[:, feature_idx], - binning_thresholds[feature_idx], - missing_values_bin_idx, - binned[:, feature_idx]) + if is_categorical[feature_idx]: + _map_cat_col_to_bins(data[:, feature_idx], feature_idx, + category_mapper, missing_values_bin_idx, + binned[:, feature_idx]) + else: + _map_num_col_to_bins(data[:, feature_idx], + binning_thresholds[feature_idx], + missing_values_bin_idx, + binned[:, feature_idx]) cdef void _map_num_col_to_bins(const X_DTYPE_C [:] data, @@ -71,3 +82,14 @@ cdef void _map_num_col_to_bins(const X_DTYPE_C [:] data, else: left = middle + 1 binned[i] = left + + +cdef void _map_cat_col_to_bins(const X_DTYPE_C [:] data, + int feature_idx, + CategoryMapper category_mapper, + const unsigned char missing_values_bin_idx, + X_BINNED_DTYPE_C [:] binned): + """Map form raw categories to bin""" + cdef int i + for i in prange(data.shape[0], schedule='static', nogil=True): + binned[i] = category_mapper.map_to_bin(feature_idx, data[i]) diff --git a/sklearn/ensemble/_hist_gradient_boosting/_bitset.pxd b/sklearn/ensemble/_hist_gradient_boosting/_bitset.pxd new file mode 100644 index 0000000000000..1b6d4433cc063 --- /dev/null +++ b/sklearn/ensemble/_hist_gradient_boosting/_bitset.pxd @@ -0,0 +1,9 @@ +# cython: language_level=3 +from .common cimport X_BINNED_DTYPE_C +from .common cimport BITSET_DTYPE_C + +cdef void init_bitset(BITSET_DTYPE_C bitset) nogil + +cdef void set_bitset(BITSET_DTYPE_C bitset, X_BINNED_DTYPE_C val) nogil + +cdef unsigned char in_bitset(BITSET_DTYPE_C bitset, X_BINNED_DTYPE_C val) nogil diff --git a/sklearn/ensemble/_hist_gradient_boosting/_bitset.pyx b/sklearn/ensemble/_hist_gradient_boosting/_bitset.pyx new file mode 100644 index 0000000000000..b7de68c0d0849 --- /dev/null +++ b/sklearn/ensemble/_hist_gradient_boosting/_bitset.pyx @@ -0,0 +1,38 @@ +# cython: cdivision=True +# cython: boundscheck=False +# cython: wraparound=False +# cython: language_level=3 +from .common cimport BITSET_INNER_DTYPE_C + +cdef inline void init_bitset(BITSET_DTYPE_C bitset) nogil: # OUT + cdef: + unsigned int i + + for i in range(8): + bitset[i] = 0 + +cdef inline void set_bitset(BITSET_DTYPE_C bitset, # OUT + X_BINNED_DTYPE_C val) nogil: + cdef: + unsigned int i1 = val // 32 + unsigned int i2 = val % 32 + + # It is assumed that val < 256 so that i1 < 8 + bitset[i1] |= (1 << i2) + +cdef inline unsigned char in_bitset(BITSET_DTYPE_C bitset, + X_BINNED_DTYPE_C val) nogil: + cdef: + unsigned int i1 = val // 32 + unsigned int i2 = val % 32 + + return (bitset[i1] >> i2) & 1 + + +def set_bitset_py(BITSET_INNER_DTYPE_C[:] bitset, X_BINNED_DTYPE_C val): + cdef: + unsigned int i1 = val // 32 + unsigned int i2 = val % 32 + + # It is assumed that val < 256 so that i1 < 8 + bitset[i1] |= (1 << i2) diff --git a/sklearn/ensemble/_hist_gradient_boosting/_cat_mapper.pxd b/sklearn/ensemble/_hist_gradient_boosting/_cat_mapper.pxd new file mode 100644 index 0000000000000..55834c749bd5a --- /dev/null +++ b/sklearn/ensemble/_hist_gradient_boosting/_cat_mapper.pxd @@ -0,0 +1,19 @@ +# cython: cdivision=True +# cython: boundscheck=False +# cython: wraparound=False +# cython: language_level=3 +# cython: nonecheck=False +# distutils: language=c++ + +from libcpp.map cimport map +from libcpp.vector cimport vector +from .common cimport X_DTYPE_C +from .common cimport X_BINNED_DTYPE_C + +cdef class CategoryMapper: + cdef: + map[int, map[int, X_BINNED_DTYPE_C]] raw_category_to_bin + X_BINNED_DTYPE_C missing_values_bin_idx + + cdef X_BINNED_DTYPE_C map_to_bin(self, int feature_idx, + X_DTYPE_C raw_category) nogil diff --git a/sklearn/ensemble/_hist_gradient_boosting/_cat_mapper.pyx b/sklearn/ensemble/_hist_gradient_boosting/_cat_mapper.pyx new file mode 100644 index 0000000000000..2cd448fb1b008 --- /dev/null +++ b/sklearn/ensemble/_hist_gradient_boosting/_cat_mapper.pyx @@ -0,0 +1,40 @@ +# cython: cdivision=True +# cython: boundscheck=False +# cython: wraparound=False +# cython: language_level=3 +# cython: nonecheck=False + +from .common cimport X_DTYPE_C +from libc.math cimport isnan + +cdef class CategoryMapper: + + def __init__(self, X_BINNED_DTYPE_C missing_values_bin_idx): + self.missing_values_bin_idx = missing_values_bin_idx + + def insert(self, int feature_idx, X_DTYPE_C[:] category_bins): + cdef int i + + for i in range(category_bins.shape[0]): + self.raw_category_to_bin[feature_idx][(category_bins[i])] = i + + cdef X_BINNED_DTYPE_C map_to_bin(self, int feature_idx, + X_DTYPE_C raw_category) nogil: + # negative values and nans are mapped to missing value + if isnan(raw_category) or raw_category < 0: + return self.missing_values_bin_idx + + # This should never happen, but to be safe we check for feature_idx in + # raw_category_to_bin + if self.raw_category_to_bin.count(feature_idx) == 0: + return self.missing_values_bin_idx + + cdef: + int int_value = (raw_category) + map[int, X_BINNED_DTYPE_C] category_to_bin = \ + self.raw_category_to_bin[feature_idx] + + if category_to_bin.count(int_value) == 0: + return self.missing_values_bin_idx + + return category_to_bin[int_value] diff --git a/sklearn/ensemble/_hist_gradient_boosting/_predictor.pyx b/sklearn/ensemble/_hist_gradient_boosting/_predictor.pyx index 8f7241b6d5589..bb2a3afc7a981 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/_predictor.pyx +++ b/sklearn/ensemble/_hist_gradient_boosting/_predictor.pyx @@ -17,12 +17,16 @@ from .common cimport Y_DTYPE_C from .common import Y_DTYPE from .common cimport X_BINNED_DTYPE_C from .common cimport node_struct +from ._bitset cimport in_bitset +from ._predictor_bitset cimport PredictorBitSet +from ._cat_mapper cimport CategoryMapper np.import_array() -def _predict_from_numeric_data( +def _predict_from_data( node_struct [:] nodes, + PredictorBitSet predictor_bitset, const X_DTYPE_C [:, :] numeric_data, Y_DTYPE_C [:] out): @@ -30,11 +34,13 @@ def _predict_from_numeric_data( int i for i in prange(numeric_data.shape[0], schedule='static', nogil=True): - out[i] = _predict_one_from_numeric_data(nodes, numeric_data, i) + out[i] = _predict_one_from_numeric_data( + nodes, predictor_bitset, numeric_data, i) cdef inline Y_DTYPE_C _predict_one_from_numeric_data( node_struct [:] nodes, + PredictorBitSet predictor_bitset, const X_DTYPE_C [:, :] numeric_data, const int row) nogil: # Need to pass the whole array and the row index, else prange won't work. @@ -42,6 +48,7 @@ cdef inline Y_DTYPE_C _predict_one_from_numeric_data( cdef: node_struct node = nodes[0] + unsigned int node_idx = 0 while True: if node.is_leaf: @@ -49,18 +56,30 @@ cdef inline Y_DTYPE_C _predict_one_from_numeric_data( if isnan(numeric_data[row, node.feature_idx]): if node.missing_go_to_left: - node = nodes[node.left] + node_idx = node.left else: - node = nodes[node.right] + node_idx = node.right + elif node.is_categorical: + if not predictor_bitset.is_known_category( + node.feature_idx, numeric_data[row, node.feature_idx]): + # treat unknown categories as missing. + node_idx = node.left if node.missing_go_to_left else node.right + elif predictor_bitset.raw_category_in_bitset( + node_idx, numeric_data[row, node.feature_idx]): + node_idx = node.left + else: + node_idx = node.right else: if numeric_data[row, node.feature_idx] <= node.num_threshold: - node = nodes[node.left] + node_idx = node.left else: - node = nodes[node.right] + node_idx = node.right + node = nodes[node_idx] def _predict_from_binned_data( node_struct [:] nodes, + PredictorBitSet predictor_bitset, const X_BINNED_DTYPE_C [:, :] binned_data, const unsigned char missing_values_bin_idx, Y_DTYPE_C [:] out): @@ -69,12 +88,14 @@ def _predict_from_binned_data( int i for i in prange(binned_data.shape[0], schedule='static', nogil=True): - out[i] = _predict_one_from_binned_data(nodes, binned_data, i, + out[i] = _predict_one_from_binned_data(nodes, predictor_bitset, + binned_data, i, missing_values_bin_idx) cdef inline Y_DTYPE_C _predict_one_from_binned_data( node_struct [:] nodes, + PredictorBitSet predictor_bitset, const X_BINNED_DTYPE_C [:, :] binned_data, const int row, const unsigned char missing_values_bin_idx) nogil: @@ -83,20 +104,30 @@ cdef inline Y_DTYPE_C _predict_one_from_binned_data( cdef: node_struct node = nodes[0] + unsigned int node_idx = 0 while True: if node.is_leaf: return node.value - if binned_data[row, node.feature_idx] == missing_values_bin_idx: + + if node.is_categorical: + if predictor_bitset.binned_category_in_bitset( + node_idx, binned_data[row, node.feature_idx]): + node_idx = node.left + else: + node_idx = node.right + elif binned_data[row, node.feature_idx] == missing_values_bin_idx: if node.missing_go_to_left: - node = nodes[node.left] + node_idx = node.left else: - node = nodes[node.right] + node_idx = node.right else: if binned_data[row, node.feature_idx] <= node.bin_threshold: - node = nodes[node.left] + node_idx = node.left else: - node = nodes[node.right] + node_idx = node.right + node = nodes[node_idx] + def _compute_partial_dependence( node_struct [:] nodes, diff --git a/sklearn/ensemble/_hist_gradient_boosting/_predictor_bitset.pxd b/sklearn/ensemble/_hist_gradient_boosting/_predictor_bitset.pxd new file mode 100644 index 0000000000000..24959663e4993 --- /dev/null +++ b/sklearn/ensemble/_hist_gradient_boosting/_predictor_bitset.pxd @@ -0,0 +1,27 @@ +# cython: cdivision=True +# cython: boundscheck=False +# cython: wraparound=False +# cython: language_level=3 +# cython: nonecheck=False +# distutils: language=c++ + +from libcpp.set cimport set +from libcpp.map cimport map +from libcpp.vector cimport vector +from .common cimport BITSET_INNER_DTYPE_C +from .common cimport X_BINNED_DTYPE_C +from .common cimport X_DTYPE_C + +cdef class PredictorBitSet: + cdef map[int, vector[BITSET_INNER_DTYPE_C]] node_to_raw_bitset + cdef map[int, vector[BITSET_INNER_DTYPE_C]] node_to_binned_bitset + cdef map[int, vector[BITSET_INNER_DTYPE_C]] feature_idx_raw_cats + + cdef unsigned char raw_category_in_bitset(self, unsigned int node_idx, + X_DTYPE_C category) nogil + + cdef unsigned char binned_category_in_bitset(self, unsigned int node_idx, + X_BINNED_DTYPE_C category) nogil + + cdef unsigned char is_known_category(self, unsigned int feature_idx, + X_DTYPE_C category) nogil diff --git a/sklearn/ensemble/_hist_gradient_boosting/_predictor_bitset.pyx b/sklearn/ensemble/_hist_gradient_boosting/_predictor_bitset.pyx new file mode 100644 index 0000000000000..e936e947a9b9b --- /dev/null +++ b/sklearn/ensemble/_hist_gradient_boosting/_predictor_bitset.pyx @@ -0,0 +1,103 @@ +# cython: cdivision=True +# cython: boundscheck=False +# cython: wraparound=False +# cython: language_level=3 +# cython: nonecheck=False +# distutils: language=c++ + +from ._bitset cimport in_bitset +from .common cimport BITSET_INNER_DTYPE_C +from libc.limits cimport CHAR_BIT +from libcpp.vector cimport vector + + +cdef inline unsigned char in_vec_bitset(vector[BITSET_INNER_DTYPE_C]& bitset, + int value) nogil: + cdef: + unsigned int i1 = value // 32 + unsigned int i2 = value % 32 + + if bitset.size() < i1 + 1: + return 0 + return (bitset[i1] >> i2) & 1 + + +cdef inline void insert_vec_bitset(vector[BITSET_INNER_DTYPE_C]& bitset, # OUT + int value) nogil: + cdef: + unsigned int i1 = value // 32 + unsigned int i2 = value % 32 + + if bitset.size() < i1 + 1: + bitset.resize(i1 + 1, 0) + bitset[i1] |= (1 << i2) + + +cdef class PredictorBitSet: + def __init__(self, list bin_thresholds, + const unsigned char [:] is_categorical): + """Creates bitset for all known categories""" + if is_categorical is None or bin_thresholds is None: + return + + cdef: + int i + X_DTYPE_C raw_cat + unsigned int i1 + unsigned int i2 + + for f_idx in range(is_categorical.shape[0]): + if not is_categorical[f_idx]: + continue + for raw_cat in bin_thresholds[f_idx]: + insert_vec_bitset(self.feature_idx_raw_cats[f_idx], + (raw_cat)) + + def insert_categories_bitset(self, unsigned int node_idx, + X_DTYPE_C[:] category_bins, + BITSET_INNER_DTYPE_C[:] cat_bitset): + """Insert category into bitset for raw categories and binned cateogires + for node_idx. + """ + cdef: + BITSET_INNER_DTYPE_C val + int k, offset + int cardinality = category_bins.shape[0] + int BITSET_SIZE = sizeof(BITSET_INNER_DTYPE_C) * CHAR_BIT + unsigned int i1, i2 + + self.node_to_binned_bitset[node_idx].resize(cat_bitset.shape[0]) + for k, val in enumerate(cat_bitset): + offset = BITSET_SIZE * k + self.node_to_binned_bitset[node_idx][k] = val + while val and offset < cardinality: + if val & 1: + insert_vec_bitset(self.node_to_raw_bitset[node_idx], + (category_bins[offset])) + val >>= 1 + offset += 1 + + cdef unsigned char is_known_category(self, unsigned int feature_idx, + X_DTYPE_C category) nogil: + """Check if category is known""" + return in_vec_bitset(self.feature_idx_raw_cats[feature_idx], + category) + + cdef unsigned char raw_category_in_bitset(self, unsigned int node_idx, + X_DTYPE_C category) nogil: + """Check if raw category is in bitset for node_idx""" + return in_vec_bitset(self.node_to_raw_bitset[node_idx], category) + + cdef unsigned char binned_category_in_bitset(self, unsigned int node_idx, + X_BINNED_DTYPE_C category) nogil: + """Check if binned cateogry is in bitset for node_idx""" + return in_vec_bitset(self.node_to_binned_bitset[node_idx], + category) + + def get_binned_categories(self, unsigned int node_idx): + """Used for testing""" + return self.node_to_binned_bitset[node_idx] + + def get_raw_categories(self, unsigned int node_idx): + """Used for testing""" + return self.node_to_raw_bitset[node_idx] diff --git a/sklearn/ensemble/_hist_gradient_boosting/binning.py b/sklearn/ensemble/_hist_gradient_boosting/binning.py index 37a19aad6cb97..6a300de976320 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/binning.py +++ b/sklearn/ensemble/_hist_gradient_boosting/binning.py @@ -14,78 +14,95 @@ from ...utils.validation import check_is_fitted from ._binning import _map_to_bins from .common import X_DTYPE, X_BINNED_DTYPE, ALMOST_INF +from ._cat_mapper import CategoryMapper -def _find_binning_thresholds(data, max_bins, subsample, random_state): +def _find_binning_threshold(col_data, max_bins): """Extract feature-wise quantiles from numerical data. Missing values are ignored for finding the thresholds. Parameters ---------- - data : array-like of shape (n_samples, n_features) - The data to bin. - + col_data : array-like, shape (n_features,) + The numerical feature to bin. max_bins: int The maximum number of bins to use for non-missing values. If for a given feature the number of unique values is less than ``max_bins``, then those unique values will be used to compute the bin thresholds, - instead of the quantiles. - - subsample : int or None - If ``n_samples > subsample``, then ``sub_samples`` samples will be - randomly chosen to compute the quantiles. If ``None``, the whole data - is used. - - random_state: int, RandomState instance or None - Pseudo-random number generator to control the random sub-sampling. - Pass an int for reproducible output across multiple - function calls. - See :term: `Glossary `. + instead of the quantiles Return ------ - binning_thresholds: list of ndarray + binning_thresholds: ndarray For each feature, stores the increasing numeric values that can be used to separate the bins. Thus ``len(binning_thresholds) == n_features``. """ - rng = check_random_state(random_state) - if subsample is not None and data.shape[0] > subsample: - subset = rng.choice(data.shape[0], subsample, replace=False) - data = data.take(subset, axis=0) - - binning_thresholds = [] - for f_idx in range(data.shape[1]): - col_data = data[:, f_idx] - # ignore missing values when computing bin thresholds - missing_mask = np.isnan(col_data) - if missing_mask.any(): - col_data = col_data[~missing_mask] - col_data = np.ascontiguousarray(col_data, dtype=X_DTYPE) - distinct_values = np.unique(col_data) - if len(distinct_values) <= max_bins: - midpoints = distinct_values[:-1] + distinct_values[1:] - midpoints *= .5 - else: - # We sort again the data in this case. We could compute - # approximate midpoint percentiles using the output of - # np.unique(col_data, return_counts) instead but this is more - # work and the performance benefit will be limited because we - # work on a fixed-size subsample of the full data. - percentiles = np.linspace(0, 100, num=max_bins + 1) - percentiles = percentiles[1:-1] - midpoints = np.percentile(col_data, percentiles, - interpolation='midpoint').astype(X_DTYPE) - assert midpoints.shape[0] == max_bins - 1 + # ignore missing values when computing bin thresholds + missing_mask = np.isnan(col_data) + if missing_mask.any(): + col_data = col_data[~missing_mask] + col_data = np.ascontiguousarray(col_data, dtype=X_DTYPE) + distinct_values = np.unique(col_data) + if len(distinct_values) <= max_bins: + midpoints = distinct_values[:-1] + distinct_values[1:] + midpoints *= .5 + else: + # We sort again the data in this case. We could compute + # approximate midpoint percentiles using the output of + # np.unique(col_data, return_counts) instead but this is more + # work and the performance benefit will be limited because we + # work on a fixed-size subsample of the full data. + percentiles = np.linspace(0, 100, num=max_bins + 1) + percentiles = percentiles[1:-1] + midpoints = np.percentile(col_data, percentiles, + interpolation='midpoint').astype(X_DTYPE) + assert midpoints.shape[0] == max_bins - 1 + + # We avoid having +inf thresholds: +inf thresholds are only allowed in + # a "split on nan" situation. + np.clip(midpoints, a_min=None, a_max=ALMOST_INF, out=midpoints) + return midpoints + + +def _find_bin_categories(col_data, max_bins): + """Extract feature-wise categories from categorical data + + Missing values and negative values (considered missing) are ignored. They + will be handled at transform time in the CategoryMapper object. - # We avoid having +inf thresholds: +inf thresholds are only allowed in - # a "split on nan" situation. - np.clip(midpoints, a_min=None, a_max=ALMOST_INF, out=midpoints) + Parameters + ---------- + col_data : array-like, shape (n_features,) + The categorical feature to bin. + max_bins: int + The maximum number of bins to be used for categories. - binning_thresholds.append(midpoints) + Return + ------ + bin: ndarray + Map from bin index to categorical value. The size of each array is + equal to minimum of `max_bins` and the categories' cardinality, + ignoring missing and negative values. + """ + categories, counts = np.unique(col_data, return_counts=True) + + # sort by highest count + sorted_idx = np.argsort(counts)[::-1] + categories = categories[sorted_idx] - return binning_thresholds + # nans and negative values will be considered missing + missing = np.isnan(categories) + negative = categories < 0 + both = missing | negative + if both.any(): + categories = categories[~both] + + # keep at most max_bins categories + # needs to be sorted, because `_map_cat_col_to_bins` will assume + # that the categories are sorted + return np.sort(categories[:max_bins]) class _BinMapper(TransformerMixin, BaseEstimator): @@ -109,13 +126,19 @@ class _BinMapper(TransformerMixin, BaseEstimator): bins. The last bin is always reserved for missing values. If for a given feature the number of unique values is less than ``max_bins``, then those unique values will be used to compute the bin thresholds, - instead of the quantiles. - - subsample : int or None, default=2e5 + instead of the quantiles. For categorical features indicated by + ``is_categorical``, the docstring for ``is_categorical`` details on + this procedure. + subsample : int or None, optional (default=2e5) If ``n_samples > subsample``, then ``sub_samples`` samples will be randomly chosen to compute the quantiles. If ``None``, the whole data is used. - + is_categorical : ndarray of bool of shape (n_features,), default=None + Indicates categorical features. If the cardinality of a categorical + feature is greater than ``n_bins``, then the ``n_bins`` most frequent + categories are kept. The infrequent categories will be consider + missing. During ``transform`` time, unknown categories will also be + considered missing. random_state: int, RandomState instance or None, default=None Pseudo-random number generator to control the random sub-sampling. Pass an int for reproducible output across multiple @@ -125,15 +148,27 @@ class _BinMapper(TransformerMixin, BaseEstimator): Attributes ---------- bin_thresholds_ : list of ndarray - For each feature, gives the real-valued bin thresholds. There are - ``max_bins - 1`` thresholds, where ``max_bins = n_bins - 1`` is the - number of bins used for non-missing values. - + TODO: do we really need this for categorical features? + For each feature, each array indicates how to map a feature into a + binned feature. The semantic and size depends on the nature of the + feature: + + - for real-valued features, the array corresponds to the real-valued + bin threhsolds (the upper bound of each bin). There are ``max_bins + - 1`` thresholds, where ``max_bins = n_bins - 1`` is the number of + bins used for non-missing values. + - for categorical features, the array is a map from a binned category + value to the raw category value. The size of the array is equal to + ``min(max_bins, category_cardinality)`` where we ignore negative + categories and missing values in the cardinality. n_bins_non_missing_ : ndarray, dtype=np.uint32 For each feature, gives the number of bins actually used for non-missing values. For features with a lot of unique values, this is equal to ``n_bins - 1``. - + is_categorical_ : ndarray of shape (n_features,), dtype=np.uint8 + Indicator for categorical features. + category_mapper_ : CategoryMapper + Object used to map raw categories into bins. missing_values_bin_idx_ : np.uint8 The index of the bin where missing values are mapped. This is a constant across all features. This corresponds to the last bin, and @@ -141,9 +176,11 @@ class _BinMapper(TransformerMixin, BaseEstimator): is less than ``n_bins - 1`` for a given feature, then there are empty (and unused) bins. """ - def __init__(self, n_bins=256, subsample=int(2e5), random_state=None): + def __init__(self, n_bins=256, subsample=int(2e5), is_categorical=None, + random_state=None): self.n_bins = n_bins self.subsample = subsample + self.is_categorical = is_categorical self.random_state = random_state def fit(self, X, y=None): @@ -170,16 +207,39 @@ def fit(self, X, y=None): X = check_array(X, dtype=[X_DTYPE], force_all_finite=False) max_bins = self.n_bins - 1 - self.bin_thresholds_ = _find_binning_thresholds( - X, max_bins, subsample=self.subsample, - random_state=self.random_state) - self.n_bins_non_missing_ = np.array( - [thresholds.shape[0] + 1 for thresholds in self.bin_thresholds_], - dtype=np.uint32) + rng = check_random_state(self.random_state) + if self.subsample is not None and X.shape[0] > self.subsample: + subset = rng.choice(X.shape[0], self.subsample, replace=False) + X = X.take(subset, axis=0) - self.missing_values_bin_idx_ = self.n_bins - 1 + if self.is_categorical is None: + self.is_categorical_ = np.zeros(X.shape[1], dtype=np.uint8) + else: + self.is_categorical_ = np.asarray(self.is_categorical, + dtype=np.uint8) + self.missing_values_bin_idx_ = self.n_bins - 1 + self.category_mapper_ = CategoryMapper(self.missing_values_bin_idx_) + + bin_thresholds = [] + n_bins_non_missing = [] + + for f_idx in range(X.shape[1]): + col_data = X[:, f_idx] + + if self.is_categorical_[f_idx] == 0: + bins = _find_binning_threshold(col_data, max_bins) + n_bins_non_missing.append(bins.shape[0] + 1) + else: + bins = _find_bin_categories(col_data, max_bins) + n_bins_non_missing.append(bins.shape[0]) + self.category_mapper_.insert(f_idx, bins) + bin_thresholds.append(bins) + + self.bin_thresholds_ = bin_thresholds + self.n_bins_non_missing_ = np.array(n_bins_non_missing, + dtype=np.uint32) return self def transform(self, X): @@ -205,7 +265,10 @@ def transform(self, X): 'to transform()'.format(self.n_bins_non_missing_.shape[0], X.shape[1]) ) + binned = np.zeros_like(X, dtype=X_BINNED_DTYPE, order='F') - _map_to_bins(X, self.bin_thresholds_, self.missing_values_bin_idx_, - binned) + _map_to_bins(X, self.bin_thresholds_, + self.missing_values_bin_idx_, + self.category_mapper_, + self.is_categorical_, binned) return binned diff --git a/sklearn/ensemble/_hist_gradient_boosting/common.pxd b/sklearn/ensemble/_hist_gradient_boosting/common.pxd index 7d09bc9365c26..0eaa16268f973 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/common.pxd +++ b/sklearn/ensemble/_hist_gradient_boosting/common.pxd @@ -4,11 +4,12 @@ cimport numpy as np np.import_array() - ctypedef np.npy_float64 X_DTYPE_C ctypedef np.npy_uint8 X_BINNED_DTYPE_C ctypedef np.npy_float64 Y_DTYPE_C ctypedef np.npy_float32 G_H_DTYPE_C +ctypedef np.npy_uint32 BITSET_INNER_DTYPE_C +ctypedef BITSET_INNER_DTYPE_C[8] BITSET_DTYPE_C cdef packed struct hist_struct: # Same as histogram dtype but we need a struct to declare views. It needs @@ -32,7 +33,8 @@ cdef packed struct node_struct: unsigned int depth unsigned char is_leaf X_BINNED_DTYPE_C bin_threshold - + unsigned char is_categorical + # int category_bitset_idx cpdef enum MonotonicConstraint: NO_CST = 0 diff --git a/sklearn/ensemble/_hist_gradient_boosting/common.pyx b/sklearn/ensemble/_hist_gradient_boosting/common.pyx index ffd1a05e7aca0..93e60e1f19660 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/common.pyx +++ b/sklearn/ensemble/_hist_gradient_boosting/common.pyx @@ -8,6 +8,7 @@ X_DTYPE = np.float64 X_BINNED_DTYPE = np.uint8 # hence max_bins == 256 # dtype for gradients and hessians arrays G_H_DTYPE = np.float32 +X_BITSET_INNER_DTYPE = np.uint32 HISTOGRAM_DTYPE = np.dtype([ ('sum_gradients', Y_DTYPE), # sum of sample gradients in bin @@ -27,6 +28,7 @@ PREDICTOR_RECORD_DTYPE = np.dtype([ ('depth', np.uint32), ('is_leaf', np.uint8), ('bin_threshold', X_BINNED_DTYPE), + ('is_categorical', np.uint8), ]) ALMOST_INF = 1e300 # see LightGBM AvoidInf() diff --git a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py index ef646df5e8738..789aeee264d4f 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py +++ b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py @@ -32,7 +32,8 @@ class BaseHistGradientBoosting(BaseEstimator, ABC): @abstractmethod def __init__(self, loss, *, learning_rate, max_iter, max_leaf_nodes, max_depth, min_samples_leaf, l2_regularization, max_bins, - monotonic_cst, warm_start, early_stopping, scoring, + categorical_features, monotonic_cst, + warm_start, early_stopping, scoring, validation_fraction, n_iter_no_change, tol, verbose, random_state): self.loss = loss @@ -44,6 +45,7 @@ def __init__(self, loss, *, learning_rate, max_iter, max_leaf_nodes, self.l2_regularization = l2_regularization self.max_bins = max_bins self.monotonic_cst = monotonic_cst + self.categorical_features = categorical_features self.warm_start = warm_start self.early_stopping = early_stopping self.scoring = scoring @@ -94,6 +96,39 @@ def _validate_parameters(self): 'multiclass classification.' ) + def _check_categories(self, n_features, X_orig): + """Check and validate categories params in X""" + if self.categorical_features is None: + self.is_categorical_ = None + return + + cat_features_input = np.asarray(self.categorical_features) + + if cat_features_input.dtype.kind not in ('i', 'b'): + raise ValueError("categorical_features must be an array-like of " + "bools or array-like of ints.") + + # check for categorical features as indices + if cat_features_input.dtype.kind == 'i': + if (np.max(cat_features_input) >= n_features + or np.min(cat_features_input) < 0): + raise ValueError("categorical_features set as integer " + "indices must be in [0, n_features - 1]") + cat_feats = np.zeros(n_features, dtype=bool) + cat_feats[cat_features_input] = True + + if cat_features_input.dtype.kind == 'b': + if cat_features_input.shape[0] != n_features: + raise ValueError("categorical_features set as a boolean mask " + "must have shape (n_features,)") + cat_feats = cat_features_input + + if np.any(cat_feats): + self.is_categorical_ = cat_feats + else: + # no categories + self.is_categorical_ = None + def fit(self, X, y, sample_weight=None): """Fit the gradient boosting model. @@ -118,8 +153,10 @@ def fit(self, X, y, sample_weight=None): acc_compute_hist_time = 0. # time spent computing histograms # time spent predicting X for gradient and hessians update acc_prediction_time = 0. - X, y = self._validate_data(X, y, dtype=[X_DTYPE], - force_all_finite=False) + X_orig = X + X, y = self._validate_data( + X, y, dtype=[X_DTYPE], force_all_finite=False + ) y = self._encode_y(y) check_consistent_length(X, y) # Do not create unit sample weights by default to later skip some @@ -140,9 +177,8 @@ def fit(self, X, y, sample_weight=None): dtype='u8') self._validate_parameters() - - # used for validation in predict - n_samples, self._n_features = X.shape + n_samples, self._n_features = X.shape # used for validation in predict + self._check_categories(self._n_features, X_orig) # we need this stateful variable to tell raw_predict() that it was # called from fit() (this current method), and that the data it has @@ -201,8 +237,10 @@ def fit(self, X, y, sample_weight=None): # actual total number of bins. Everywhere in the code, the # convention is that n_bins == max_bins + 1 n_bins = self.max_bins + 1 # + 1 for missing values - self._bin_mapper = _BinMapper(n_bins=n_bins, - random_state=self._random_seed) + self._bin_mapper = _BinMapper( + n_bins=n_bins, + is_categorical=self.is_categorical_, + random_state=self._random_seed) X_binned_train = self._bin_data(X_train, is_training_data=True) if X_val is not None: X_binned_val = self._bin_data(X_val, is_training_data=False) @@ -217,6 +255,11 @@ def fit(self, X, y, sample_weight=None): if self.verbose: print("Fitting gradient boosted rounds:") + # Uses binned data to check for missing values + has_missing_values = ( + X_binned_train == self._bin_mapper.missing_values_bin_idx_).any( + axis=0).astype(np.uint8) + n_samples = X_binned_train.shape[0] # First time calling fit, or no warm start @@ -363,6 +406,7 @@ def fit(self, X, y, sample_weight=None): n_bins=n_bins, n_bins_non_missing=self._bin_mapper.n_bins_non_missing_, has_missing_values=has_missing_values, + is_categorical=self.is_categorical_, monotonic_cst=self.monotonic_cst, max_leaf_nodes=self.max_leaf_nodes, max_depth=self.max_depth, @@ -381,7 +425,8 @@ def fit(self, X, y, sample_weight=None): sample_weight_train) predictor = grower.make_predictor( - num_thresholds=self._bin_mapper.bin_thresholds_ + num_thresholds=self._bin_mapper.bin_thresholds_, + is_categorical=self._bin_mapper.is_categorical_, ) predictors[-1].append(predictor) @@ -566,7 +611,7 @@ def _should_stop(self, scores): for score in recent_scores] return not any(recent_improvements) - def _bin_data(self, X, is_training_data): + def _bin_data(self, X, is_training_data, categorical_only=False): """Bin data X. If is_training_data, then set the _bin_mapper attribute. @@ -581,7 +626,8 @@ def _bin_data(self, X, is_training_data): if is_training_data: X_binned = self._bin_mapper.fit_transform(X) # F-aligned array else: - X_binned = self._bin_mapper.transform(X) # F-aligned array + # F-aligned array + X_binned = self._bin_mapper.transform(X) # We convert the array to C-contiguous since predicting is faster # with this layout (training is faster on F-arrays though) X_binned = np.ascontiguousarray(X_binned) @@ -645,8 +691,9 @@ def _raw_predict(self, X): raw_predictions : array, shape (n_trees_per_iteration, n_samples) The raw predicted values. """ - X = check_array(X, dtype=[X_DTYPE, X_BINNED_DTYPE], - force_all_finite=False) + X = check_array( + X, dtype=[X_DTYPE, X_BINNED_DTYPE], force_all_finite=False + ) check_is_fitted(self) if X.shape[1] != self._n_features: raise ValueError( @@ -667,6 +714,7 @@ def _raw_predict(self, X): def _predict_iterations(self, X, predictors, raw_predictions, is_binned): """Add the predictions of the predictors to raw_predictions.""" + # bin categorical features when predicting outside of training loop for predictors_of_ith_iteration in predictors: for k, predictor in enumerate(predictors_of_ith_iteration): if is_binned: @@ -848,6 +896,22 @@ class HistGradientBoostingRegressor(RegressorMixin, BaseHistGradientBoosting): and 0 respectively correspond to a positive constraint, negative constraint and no constraint. Read more in the :ref:`User Guide `. + categorical_features : array-like of {bool, int} of shape (n_features), \ + default=None. + Indicates the categorical features. + + - None : no feature will be considered categorical. + - boolean array-like : boolean mask indicating categorical features. + The categories must have been already be numerical i.e. encoded by + an :class:`~sklearn.preprocessing.OrdinalEncoder`. + - integer array-like : integer indices indicating categorical + features. The categories must have been already be numerical i.e. + encoded by an :class:`~sklearn.preprocessing.OrdinalEncoder`. + + If the number of categories is greater than ``max_bins``, then the top + ``max_bins`` categories based on cardinality are kept. Categories + encoded as negative number will be considered missing. Read more in + the :ref:`User Guide `. warm_start : bool, default=False When set to ``True``, reuse the solution of the previous call to fit and add more estimators to the ensemble. For results to be valid, the @@ -909,6 +973,9 @@ class HistGradientBoostingRegressor(RegressorMixin, BaseHistGradientBoosting): first entry is the score of the ensemble before the first iteration. Scores are computed according to the ``scoring`` parameter. Empty if no early stopping or if ``validation_fraction`` is None. + is_categorical_ : ndarray, shape (n_features, ) or None + Boolean mask for the categorical features. ``None`` if there are no + categorical features. Examples -------- @@ -929,7 +996,8 @@ class HistGradientBoostingRegressor(RegressorMixin, BaseHistGradientBoosting): def __init__(self, loss='least_squares', *, learning_rate=0.1, max_iter=100, max_leaf_nodes=31, max_depth=None, min_samples_leaf=20, l2_regularization=0., max_bins=255, - monotonic_cst=None, warm_start=False, early_stopping='auto', + categorical_features=None, monotonic_cst=None, + warm_start=False, early_stopping='auto', scoring='loss', validation_fraction=0.1, n_iter_no_change=10, tol=1e-7, verbose=0, random_state=None): @@ -938,7 +1006,9 @@ def __init__(self, loss='least_squares', *, learning_rate=0.1, max_leaf_nodes=max_leaf_nodes, max_depth=max_depth, min_samples_leaf=min_samples_leaf, l2_regularization=l2_regularization, max_bins=max_bins, - monotonic_cst=monotonic_cst, early_stopping=early_stopping, + monotonic_cst=monotonic_cst, + categorical_features=categorical_features, + early_stopping=early_stopping, warm_start=warm_start, scoring=scoring, validation_fraction=validation_fraction, n_iter_no_change=n_iter_no_change, tol=tol, verbose=verbose, @@ -1074,6 +1144,22 @@ class HistGradientBoostingClassifier(ClassifierMixin, and 0 respectively correspond to a positive constraint, negative constraint and no constraint. Read more in the :ref:`User Guide `. + categorical_features : array-like of {bool, int} of shape (n_features), \ + default=None. + Indicates the categorical features. + + - None : no feature will be considered categorical. + - boolean array-like : boolean mask indicating categorical features. + The categories must have been already be numerical i.e. encoded by + an :class:`~sklearn.preprocessing.OrdinalEncoder`. + - integer array-like : integer indices indicating categorical + features. The categories must have been already be numerical i.e. + encoded by an :class:`~sklearn.preprocessing.OrdinalEncoder`. + + If the number of categories is greater than ``max_bins``, then the top + ``max_bins`` categories based on cardinality are kept. Categories + encoded as negative number will be considered missing. Read more in + the :ref:`User Guide `. warm_start : bool, default=False When set to ``True``, reuse the solution of the previous call to fit and add more estimators to the ensemble. For results to be valid, the @@ -1138,6 +1224,9 @@ class HistGradientBoostingClassifier(ClassifierMixin, first entry is the score of the ensemble before the first iteration. Scores are computed according to the ``scoring`` parameter. Empty if no early stopping or if ``validation_fraction`` is None. + is_categorical_ : ndarray, shape (n_features, ) or None + Boolean mask for the categorical features. ``None`` if there are no + categorical features. Examples -------- @@ -1157,7 +1246,8 @@ class HistGradientBoostingClassifier(ClassifierMixin, @_deprecate_positional_args def __init__(self, loss='auto', *, learning_rate=0.1, max_iter=100, max_leaf_nodes=31, max_depth=None, min_samples_leaf=20, - l2_regularization=0., max_bins=255, monotonic_cst=None, + l2_regularization=0., max_bins=255, + categorical_features=None, monotonic_cst=None, warm_start=False, early_stopping='auto', scoring='loss', validation_fraction=0.1, n_iter_no_change=10, tol=1e-7, verbose=0, random_state=None): @@ -1166,7 +1256,9 @@ def __init__(self, loss='auto', *, learning_rate=0.1, max_iter=100, max_leaf_nodes=max_leaf_nodes, max_depth=max_depth, min_samples_leaf=min_samples_leaf, l2_regularization=l2_regularization, max_bins=max_bins, - monotonic_cst=monotonic_cst, warm_start=warm_start, + categorical_features=categorical_features, + monotonic_cst=monotonic_cst, + warm_start=warm_start, early_stopping=early_stopping, scoring=scoring, validation_fraction=validation_fraction, n_iter_no_change=n_iter_no_change, tol=tol, verbose=verbose, diff --git a/sklearn/ensemble/_hist_gradient_boosting/grower.py b/sklearn/ensemble/_hist_gradient_boosting/grower.py index 473dc37674684..fa69caebffac8 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/grower.py +++ b/sklearn/ensemble/_hist_gradient_boosting/grower.py @@ -18,7 +18,8 @@ from .common import PREDICTOR_RECORD_DTYPE from .common import Y_DTYPE from .common import MonotonicConstraint - +from ._bitset import set_bitset_py +from ._predictor_bitset import PredictorBitSet EPS = np.finfo(Y_DTYPE).eps # to avoid zero division errors @@ -159,6 +160,8 @@ class TreeGrower: has_missing_values : bool or ndarray, dtype=bool, default=False Whether each feature contains missing values (in the training data). If it's a bool, the same value is used for all features. + is_categorical : ndarray of bool of shape (n_features,), default=None + Indicates categorical features. monotonic_cst : array-like of shape (n_features,), dtype=int, default=None Indicates the monotonic constraint to enforce on each feature. -1, 1 and 0 respectively correspond to a positive constraint, negative @@ -178,7 +181,8 @@ class TreeGrower: def __init__(self, X_binned, gradients, hessians, max_leaf_nodes=None, max_depth=None, min_samples_leaf=20, min_gain_to_split=0., n_bins=256, n_bins_non_missing=None, has_missing_values=False, - monotonic_cst=None, l2_regularization=0., + is_categorical=None, monotonic_cst=None, + l2_regularization=0., min_hessian_to_split=1e-3, shrinkage=1.): self._validate_parameters(X_binned, max_leaf_nodes, max_depth, @@ -222,19 +226,30 @@ def __init__(self, X_binned, gradients, hessians, max_leaf_nodes=None, "-1, 0 or 1." ) + if is_categorical is None: + is_categorical = np.zeros(shape=X_binned.shape[1], dtype=np.uint8) + else: + is_categorical = np.asarray(is_categorical, dtype=np.uint8) + + if np.any(np.logical_and(is_categorical == 1, monotonic_cst != 0)): + raise ValueError("categorical features can not have monotonic " + "constraints") + hessians_are_constant = hessians.shape[0] == 1 self.histogram_builder = HistogramBuilder( X_binned, n_bins, gradients, hessians, hessians_are_constant) missing_values_bin_idx = n_bins - 1 self.splitter = Splitter( X_binned, n_bins_non_missing, missing_values_bin_idx, - has_missing_values, monotonic_cst, + has_missing_values, is_categorical, monotonic_cst, l2_regularization, min_hessian_to_split, min_samples_leaf, min_gain_to_split, hessians_are_constant) self.n_bins_non_missing = n_bins_non_missing + self.missing_values_bin_idx = missing_values_bin_idx self.max_leaf_nodes = max_leaf_nodes self.has_missing_values = has_missing_values self.monotonic_cst = monotonic_cst + self.is_categorical = is_categorical self.l2_regularization = l2_regularization self.n_features = X_binned.shape[1] self.max_depth = max_depth @@ -406,6 +421,12 @@ def split_next(self): node.split_info.missing_go_to_left = ( left_child_node.n_samples > right_child_node.n_samples) + # For binned predictions with categorical splits. + if (node.split_info.is_categorical and + node.split_info.missing_go_to_left): + set_bitset_py(node.split_info.cat_bitset, + self.missing_values_bin_idx) + self.n_nodes += 2 if (self.max_leaf_nodes is not None @@ -508,7 +529,7 @@ def _finalize_splittable_nodes(self): node = self.splittable_nodes.pop() self._finalize_leaf(node) - def make_predictor(self, num_thresholds): + def make_predictor(self, num_thresholds, is_categorical=None): """Make a TreePredictor object out of the current tree. Parameters @@ -521,12 +542,14 @@ def make_predictor(self, num_thresholds): A TreePredictor object. """ predictor_nodes = np.zeros(self.n_nodes, dtype=PREDICTOR_RECORD_DTYPE) - _fill_predictor_node_array(predictor_nodes, self.root, - num_thresholds, self.n_bins_non_missing) - return TreePredictor(predictor_nodes) + predictor_bitset = PredictorBitSet(num_thresholds, is_categorical) + _fill_predictor_node_array(predictor_nodes, predictor_bitset, + self.root, num_thresholds, + self.n_bins_non_missing) + return TreePredictor(predictor_nodes, predictor_bitset) -def _fill_predictor_node_array(predictor_nodes, grower_node, +def _fill_predictor_node_array(predictor_nodes, predictor_bitset, grower_node, num_thresholds, n_bins_non_missing, next_free_idx=0): """Helper used in make_predictor to set the TreePredictor fields.""" @@ -551,26 +574,31 @@ def _fill_predictor_node_array(predictor_nodes, grower_node, node['feature_idx'] = feature_idx node['bin_threshold'] = bin_idx node['missing_go_to_left'] = split_info.missing_go_to_left + node['is_categorical'] = split_info.is_categorical if split_info.bin_idx == n_bins_non_missing[feature_idx] - 1: - # Split is on the last non-missing bin: it's a "split on nans". All - # nans go to the right, the rest go to the left. + # Split is on the last non-missing bin: it's a "split on nans". + # All nans go to the right, the rest go to the left. node['num_threshold'] = np.inf else: - node['num_threshold'] = num_thresholds[feature_idx][bin_idx] + bins = num_thresholds[feature_idx] + node['num_threshold'] = bins[bin_idx] + if split_info.is_categorical: + predictor_bitset.insert_categories_bitset( + next_free_idx, bins, split_info.cat_bitset) next_free_idx += 1 node['left'] = next_free_idx next_free_idx = _fill_predictor_node_array( - predictor_nodes, grower_node.left_child, + predictor_nodes, predictor_bitset, grower_node.left_child, num_thresholds=num_thresholds, n_bins_non_missing=n_bins_non_missing, next_free_idx=next_free_idx) node['right'] = next_free_idx return _fill_predictor_node_array( - predictor_nodes, grower_node.right_child, + predictor_nodes, predictor_bitset, grower_node.right_child, num_thresholds=num_thresholds, n_bins_non_missing=n_bins_non_missing, next_free_idx=next_free_idx) diff --git a/sklearn/ensemble/_hist_gradient_boosting/predictor.py b/sklearn/ensemble/_hist_gradient_boosting/predictor.py index 0b359c8f98224..4864bef9a7c76 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/predictor.py +++ b/sklearn/ensemble/_hist_gradient_boosting/predictor.py @@ -6,7 +6,7 @@ import numpy as np from .common import Y_DTYPE -from ._predictor import _predict_from_numeric_data +from ._predictor import _predict_from_data from ._predictor import _predict_from_binned_data from ._predictor import _compute_partial_dependence @@ -18,9 +18,12 @@ class TreePredictor: ---------- nodes : ndarray of PREDICTOR_RECORD_DTYPE The nodes of the tree. + predictor_bitset : PredictorBitSet + Bitset used to check for predictions. """ - def __init__(self, nodes): + def __init__(self, nodes, predictor_bitset): self.nodes = nodes + self.predictor_bitset = predictor_bitset def get_n_leaf_nodes(self): """Return number of leaves.""" @@ -31,7 +34,8 @@ def get_max_depth(self): return int(self.nodes['depth'].max()) def predict(self, X): - """Predict raw values for non-binned data. + """Predict raw values for non-binned numerical data and binned + categorical data. Parameters ---------- @@ -44,7 +48,7 @@ def predict(self, X): The raw predicted values. """ out = np.empty(X.shape[0], dtype=Y_DTYPE) - _predict_from_numeric_data(self.nodes, X, out) + _predict_from_data(self.nodes, self.predictor_bitset, X, out) return out def predict_binned(self, X, missing_values_bin_idx): @@ -65,7 +69,8 @@ def predict_binned(self, X, missing_values_bin_idx): The raw predicted values. """ out = np.empty(X.shape[0], dtype=Y_DTYPE) - _predict_from_binned_data(self.nodes, X, missing_values_bin_idx, out) + _predict_from_binned_data(self.nodes, self.predictor_bitset, + X, missing_values_bin_idx, out) return out def compute_partial_dependence(self, grid, target_features, out): diff --git a/sklearn/ensemble/_hist_gradient_boosting/splitting.pyx b/sklearn/ensemble/_hist_gradient_boosting/splitting.pyx index 49440adf098d9..0b47931c71324 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/splitting.pyx +++ b/sklearn/ensemble/_hist_gradient_boosting/splitting.pyx @@ -2,6 +2,7 @@ # cython: boundscheck=False # cython: wraparound=False # cython: language_level=3 + """This module contains routines and data structures to: - Find the best possible split of a node. For a given node, a split is @@ -17,7 +18,7 @@ import numpy as np cimport numpy as np IF SKLEARN_OPENMP_PARALLELISM_ENABLED: from openmp cimport omp_get_max_threads -from libc.stdlib cimport malloc, free +from libc.stdlib cimport malloc, free, qsort from libc.string cimport memcpy from numpy.math cimport INFINITY @@ -25,7 +26,12 @@ from .common cimport X_BINNED_DTYPE_C from .common cimport Y_DTYPE_C from .common cimport hist_struct from .common import HISTOGRAM_DTYPE +from .common cimport BITSET_INNER_DTYPE_C +from .common cimport BITSET_DTYPE_C from .common cimport MonotonicConstraint +from ._bitset cimport init_bitset +from ._bitset cimport set_bitset +from ._bitset cimport in_bitset np.import_array() @@ -45,6 +51,15 @@ cdef struct split_info_struct: unsigned int n_samples_right Y_DTYPE_C value_left Y_DTYPE_C value_right + unsigned char is_categorical + BITSET_DTYPE_C cat_bitset + + +# used in categorical splits for sorting categories by increasing values of +# sum_gradients / sum_hessians +cdef struct categorical_info: + X_BINNED_DTYPE_C bin_idx + Y_DTYPE_C value class SplitInfo: @@ -57,9 +72,13 @@ class SplitInfo: feature_idx : int The index of the feature to be split. bin_idx : int - The index of the bin on which the split is made. + The index of the bin on which the split is made. Ignored if + ``is_categorical`` is True. ``cat_bitset`` will be used to determine + the split. missing_go_to_left : bool - Whether missing values should go to the left child. + Whether missing values should go to the left child. Ignored if + ``is_categorical`` is True. ``cat_bitset`` will be used to determine + the split. sum_gradient_left : float The sum of the gradients of all the samples in the left child. sum_hessian_left : float @@ -72,11 +91,17 @@ class SplitInfo: The number of samples in the left child. n_samples_right : int The number of samples in the right child. + is_categorical : bool + Whether split is categorical. + cat_bitset : ndarray of shape=(8,), dtype=uint32 or None + Bitset representing the categories that go to the left. This is used + only when ``is_categorical`` is True. """ def __init__(self, gain, feature_idx, bin_idx, missing_go_to_left, sum_gradient_left, sum_hessian_left, sum_gradient_right, sum_hessian_right, n_samples_left, - n_samples_right, value_left, value_right): + n_samples_right, value_left, value_right, + is_categorical, cat_bitset): self.gain = gain self.feature_idx = feature_idx self.bin_idx = bin_idx @@ -89,6 +114,8 @@ class SplitInfo: self.n_samples_right = n_samples_right self.value_left = value_left self.value_right = value_right + self.is_categorical = is_categorical + self.cat_bitset = cat_bitset @cython.final @@ -114,6 +141,8 @@ cdef class Splitter: has_missing_values : ndarray, shape (n_features,) Whether missing values were observed in the training data, for each feature. + is_categorical : ndarray of bool of shape (n_features,) + Indicates categorical features. l2_regularization : float The L2 regularization parameter. min_hessian_to_split : float, default=1e-3 @@ -134,6 +163,7 @@ cdef class Splitter: const unsigned int [::1] n_bins_non_missing unsigned char missing_values_bin_idx const unsigned char [::1] has_missing_values + const unsigned char [::1] is_categorical const signed char [::1] monotonic_cst unsigned char hessians_are_constant Y_DTYPE_C l2_regularization @@ -150,6 +180,7 @@ cdef class Splitter: const unsigned int [::1] n_bins_non_missing, const unsigned char missing_values_bin_idx, const unsigned char [::1] has_missing_values, + const unsigned char [::1] is_categorical, const signed char [::1] monotonic_cst, Y_DTYPE_C l2_regularization, Y_DTYPE_C min_hessian_to_split=1e-3, @@ -163,6 +194,7 @@ cdef class Splitter: self.missing_values_bin_idx = missing_values_bin_idx self.has_missing_values = has_missing_values self.monotonic_cst = monotonic_cst + self.is_categorical = is_categorical self.l2_regularization = l2_regularization self.min_hessian_to_split = min_hessian_to_split self.min_samples_leaf = min_samples_leaf @@ -266,7 +298,10 @@ cdef class Splitter: self.X_binned[:, feature_idx] unsigned int [::1] left_indices_buffer = self.left_indices_buffer unsigned int [::1] right_indices_buffer = self.right_indices_buffer - + unsigned char is_categorical = split_info.is_categorical + BITSET_INNER_DTYPE_C [:] cat_bitset_tmp = \ + split_info.cat_bitset + BITSET_DTYPE_C cat_bitset IF SKLEARN_OPENMP_PARALLELISM_ENABLED: int n_threads = omp_get_max_threads() ELSE: @@ -289,6 +324,10 @@ cdef class Splitter: int [:] left_offset = np.zeros(n_threads, dtype=np.int32) int [:] right_offset = np.zeros(n_threads, dtype=np.int32) + # only set cat_bitset when is_categorical is True + if is_categorical: + cat_bitset = &cat_bitset_tmp[0] + with nogil: for thread_idx in range(n_samples % n_threads): sizes[thread_idx] += 1 @@ -310,7 +349,8 @@ cdef class Splitter: turn_left = sample_goes_left( missing_go_to_left, missing_values_bin_idx, bin_idx, - X_binned[sample_idx]) + X_binned[sample_idx], is_categorical, + cat_bitset) if turn_left: left_indices_buffer[start + left_count] = sample_idx @@ -410,6 +450,7 @@ cdef class Splitter: split_info_struct split_info split_info_struct * split_infos const unsigned char [::1] has_missing_values = self.has_missing_values + const unsigned char [::1] is_categorical = self.is_categorical const signed char [::1] monotonic_cst = self.monotonic_cst with nogil: @@ -426,31 +467,44 @@ cdef class Splitter: # (min_samples_leaf, etc) and the grower will later turn the # node into a leaf. split_infos[feature_idx].gain = -1 - - # We will scan bins from left to right (in all cases), and if - # there are any missing values, we will also scan bins from - # right to left. This way, we can consider whichever case - # yields the best gain: either missing values go to the right - # (left to right scan) or to the left (right to left case). - # See algo 3 from the XGBoost paper - # https://arxiv.org/abs/1603.02754 - - self._find_best_bin_to_split_left_to_right( - feature_idx, has_missing_values[feature_idx], - histograms, n_samples, sum_gradients, sum_hessians, - value, monotonic_cst[feature_idx], - lower_bound, upper_bound, &split_infos[feature_idx]) - - if has_missing_values[feature_idx]: - # We need to explore both directions to check whether - # sending the nans to the left child would lead to a higher - # gain - self._find_best_bin_to_split_right_to_left( - feature_idx, histograms, n_samples, - sum_gradients, sum_hessians, + split_infos[feature_idx].is_categorical = \ + is_categorical[feature_idx] + + if is_categorical[feature_idx]: + # While numerical features need to scan in both directions + # when there are missing values, categorical features does + # not need to do this because it treats missing values + # as a native category + self._find_best_bin_to_split_category( + feature_idx, has_missing_values[feature_idx], + histograms, n_samples, sum_gradients, sum_hessians, + value, monotonic_cst[feature_idx], lower_bound, + upper_bound, &split_infos[feature_idx]) + else: + # We will scan bins from left to right (in all cases), and + # if there are any missing values, we will also scan bins + # from right to left. This way, we can consider whichever + # case yields the best gain: either missing values go to + # the right (left to right scan) or to the left (right to + # left case). See algo 3 from the XGBoost paper + # https://arxiv.org/abs/1603.02754 + self._find_best_bin_to_split_left_to_right( + feature_idx, has_missing_values[feature_idx], + histograms, n_samples, sum_gradients, sum_hessians, value, monotonic_cst[feature_idx], lower_bound, upper_bound, &split_infos[feature_idx]) + if has_missing_values[feature_idx]: + # We need to explore both directions to check whether + # sending the nans to the left child would lead to a + # higher gain + self._find_best_bin_to_split_right_to_left( + feature_idx, histograms, n_samples, + sum_gradients, sum_hessians, + value, monotonic_cst[feature_idx], + lower_bound, upper_bound, + &split_infos[feature_idx]) + # then compute best possible split among all features best_feature_idx = self._find_best_feature_to_split_helper( split_infos) @@ -469,7 +523,13 @@ cdef class Splitter: split_info.n_samples_right, split_info.value_left, split_info.value_right, + split_info.is_categorical, + None, # will only be set if the splt is categorical ) + # Only set bitset if the split is categorical + if split_info.is_categorical: + out.cat_bitset = np.asarray(split_info.cat_bitset, dtype=np.uint32) + free(split_infos) return out @@ -716,6 +776,196 @@ cdef class Splitter: split_info.sum_gradient_right, split_info.sum_hessian_right, lower_bound, upper_bound, self.l2_regularization) + @cython.initializedcheck(False) + cdef void _find_best_bin_to_split_category( + self, + unsigned int feature_idx, + unsigned char has_missing_values, + const hist_struct [:, ::1] histograms, # IN + unsigned int n_samples, + Y_DTYPE_C sum_gradients, + Y_DTYPE_C sum_hessians, + Y_DTYPE_C value, + char monotonic_cst, + Y_DTYPE_C lower_bound, + Y_DTYPE_C upper_bound, + split_info_struct * split_info) nogil: # OUT + """Find best split for categorical features. """ + + cdef: + unsigned int bin_idx + unsigned int end = self.n_bins_non_missing[feature_idx] + unsigned int missing_values_bin_idx = self.missing_values_bin_idx + categorical_info * cat_sorted_infos + unsigned int sorted_idx + unsigned int n_used_bin = 0 + const hist_struct[::1] feature_hist = histograms[feature_idx, :] + Y_DTYPE_C sum_gradients_bin + Y_DTYPE_C sum_hessians_bin + Y_DTYPE_C loss_current_node + # Reduces the effect of noises in categorical features, + # especially for categoires with few data + # TODO: Make this user adjustable? + unsigned int MIN_CAT_SUPPORT = 10 + # Used for find best split + unsigned int MAX_CAT_THRESHOLD = 32 + unsigned int max_n_cat + # holds directional information + Y_DTYPE_C sum_gradient_left, sum_hessian_left + Y_DTYPE_C sum_gradient_right, sum_hessian_right + unsigned int n_samples_left, n_samples_right + unsigned int i + Y_DTYPE_C gain + Y_DTYPE_C best_gain = -1.0 + unsigned char found_better_split = False + Y_DTYPE_C best_sum_hessian_left + Y_DTYPE_C best_sum_gradient_left + unsigned int best_n_samples_left + unsigned int best_sorted_thres + int direction, current_position + int best_direction + int[2] directions + int[2] start_positions + + cat_sorted_infos = malloc( + (end + has_missing_values) * sizeof(categorical_info)) + + # fill cat_sorted_infos while filtering out categories based on + # MIN_CAT_SUPPORT + for bin_idx in range(end): + if feature_hist[bin_idx].count >= MIN_CAT_SUPPORT: + cat_sorted_infos[n_used_bin].bin_idx = bin_idx + sum_gradients_bin = feature_hist[bin_idx].sum_gradients + if self.hessians_are_constant: + sum_hessians_bin = feature_hist[bin_idx].count + else: + sum_hessians_bin = feature_hist[bin_idx].sum_hessians + + cat_sorted_infos[n_used_bin].value = \ + sum_gradients_bin / (sum_hessians_bin + MIN_CAT_SUPPORT) + n_used_bin += 1 + + # check missing bin + if has_missing_values: + if feature_hist[missing_values_bin_idx].count >= MIN_CAT_SUPPORT: + cat_sorted_infos[n_used_bin].bin_idx = missing_values_bin_idx + sum_gradients_bin = \ + feature_hist[missing_values_bin_idx].sum_gradients + if self.hessians_are_constant: + sum_hessians_bin = \ + feature_hist[missing_values_bin_idx].count + else: + sum_hessians_bin = \ + feature_hist[missing_values_bin_idx].sum_hessians + + cat_sorted_infos[n_used_bin].value = \ + sum_gradients_bin / (sum_hessians_bin + MIN_CAT_SUPPORT) + n_used_bin += 1 + + # not enough categories to form a split + if n_used_bin <= 1: + free(cat_sorted_infos) + return + + qsort(cat_sorted_infos, n_used_bin, sizeof(categorical_info), + compare_cat_infos) + + max_n_cat = min(MAX_CAT_THRESHOLD, (n_used_bin + 1) / 2) + + directions[0], directions[1] = 1, -1 + start_positions[0], start_positions[1] = 0, n_used_bin - 1 + loss_current_node = _loss_from_value(value, sum_gradients) + + for i in range(2): + direction, current_position = directions[i], start_positions[i] + sum_gradient_left, sum_hessian_left = 0., 0. + n_samples_left = 0 + + for sorted_idx in range(max_n_cat): + bin_idx = cat_sorted_infos[current_position].bin_idx; + current_position += direction + + n_samples_left += feature_hist[bin_idx].count + n_samples_right = n_samples - n_samples_left + + if self.hessians_are_constant: + sum_hessian_left += feature_hist[bin_idx].count + else: + sum_hessian_left += feature_hist[bin_idx].sum_hessians + sum_hessian_right = sum_hessians - sum_hessian_left + + sum_gradient_left += feature_hist[bin_idx].sum_gradients + sum_gradient_right = sum_gradients - sum_gradient_left + + if (n_samples_left < self.min_samples_leaf or + sum_hessian_left < self.min_hessian_to_split): + continue + if (n_samples_right < self.min_samples_leaf or + sum_hessian_right < self.min_hessian_to_split): + break + + gain = _split_gain(sum_gradient_left, sum_hessian_left, + sum_gradient_right, sum_hessian_right, + loss_current_node, monotonic_cst, + lower_bound, upper_bound, + self.l2_regularization) + if gain > best_gain and gain > self.min_gain_to_split: + found_better_split = True + best_gain = gain + best_direction = direction + best_sorted_thres = sorted_idx + best_sum_gradient_left = sum_gradient_left + best_sum_hessian_left = sum_hessian_left + best_n_samples_left = n_samples_left + + + if found_better_split: + split_info.gain = best_gain + + # bin_idx is unused for categorical splits: cat_bitset is used + # instead and set below + split_info.bin_idx = 0 + + split_info.sum_gradient_left = best_sum_gradient_left + split_info.sum_gradient_right = sum_gradients - best_sum_gradient_left + split_info.sum_hessian_left = best_sum_hessian_left + split_info.sum_hessian_right = sum_hessians - best_sum_hessian_left + split_info.n_samples_left = best_n_samples_left + split_info.n_samples_right = n_samples - best_n_samples_left + + # We recompute best values here but it's cheap + split_info.value_left = compute_node_value( + split_info.sum_gradient_left, split_info.sum_hessian_left, + lower_bound, upper_bound, self.l2_regularization) + + split_info.value_right = compute_node_value( + split_info.sum_gradient_right, split_info.sum_hessian_right, + lower_bound, upper_bound, self.l2_regularization) + + # create bitset with values from best_sorted_thres + init_bitset(split_info.cat_bitset) + + if best_direction == 1: # left + for i in range(best_sorted_thres + 1): + bin_idx = cat_sorted_infos[i].bin_idx + set_bitset(split_info.cat_bitset, bin_idx) + else: + for i in range(best_sorted_thres + 1): + bin_idx = cat_sorted_infos[n_used_bin - 1 - i].bin_idx + set_bitset(split_info.cat_bitset, bin_idx) + + if has_missing_values: + split_info.missing_go_to_left = in_bitset( + split_info.cat_bitset, missing_values_bin_idx) + + free(cat_sorted_infos) + + +cdef int compare_cat_infos(const void * a, const void * b) nogil: + cdef: + categorical_info a_info = (a)[0] + categorical_info b_info = (b)[0] + return 1 if a_info.value <= b_info.value else -1 cdef inline Y_DTYPE_C _split_gain( Y_DTYPE_C sum_gradient_left, @@ -781,17 +1031,23 @@ cdef inline unsigned char sample_goes_left( unsigned char missing_go_to_left, unsigned char missing_values_bin_idx, X_BINNED_DTYPE_C split_bin_idx, - X_BINNED_DTYPE_C bin_value) nogil: + X_BINNED_DTYPE_C bin_value, + unsigned char is_categorical, + BITSET_DTYPE_C cat_bitset) nogil: """Helper to decide whether sample should go to left or right child.""" - return ( - ( - missing_go_to_left and - bin_value == missing_values_bin_idx - ) - or ( - bin_value <= split_bin_idx - )) + if is_categorical: + # missing value is encoded in cat_bitset + return in_bitset(cat_bitset, bin_value) + else: # numerical + return ( + ( + missing_go_to_left and + bin_value == missing_values_bin_idx + ) + or ( + bin_value <= split_bin_idx + )) cpdef inline Y_DTYPE_C compute_node_value( diff --git a/sklearn/ensemble/_hist_gradient_boosting/tests/test_binning.py b/sklearn/ensemble/_hist_gradient_boosting/tests/test_binning.py index 06e38d62f7638..c080a13d4043c 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/tests/test_binning.py +++ b/sklearn/ensemble/_hist_gradient_boosting/tests/test_binning.py @@ -4,12 +4,13 @@ from sklearn.ensemble._hist_gradient_boosting.binning import ( _BinMapper, - _find_binning_thresholds as _find_binning_thresholds_orig, + _find_binning_threshold, _map_to_bins ) from sklearn.ensemble._hist_gradient_boosting.common import X_DTYPE from sklearn.ensemble._hist_gradient_boosting.common import X_BINNED_DTYPE from sklearn.ensemble._hist_gradient_boosting.common import ALMOST_INF +from sklearn.ensemble._hist_gradient_boosting._cat_mapper import CategoryMapper DATA = np.random.RandomState(42).normal( @@ -17,45 +18,34 @@ ).astype(X_DTYPE) -def _find_binning_thresholds(data, max_bins=255, subsample=int(2e5), - random_state=None): - # Just a redef to avoid having to pass arguments all the time (as the - # function is private we don't use default values for parameters) - return _find_binning_thresholds_orig(data, max_bins, subsample, - random_state) - - def test_find_binning_thresholds_regular_data(): - data = np.linspace(0, 10, 1001).reshape(-1, 1) - bin_thresholds = _find_binning_thresholds(data, max_bins=10) - assert_allclose(bin_thresholds[0], [1, 2, 3, 4, 5, 6, 7, 8, 9]) - assert len(bin_thresholds) == 1 + data = np.linspace(0, 10, 1001) + bin_thresholds = _find_binning_threshold(data, max_bins=10) + assert_allclose(bin_thresholds, [1, 2, 3, 4, 5, 6, 7, 8, 9]) - bin_thresholds = _find_binning_thresholds(data, max_bins=5) - assert_allclose(bin_thresholds[0], [2, 4, 6, 8]) - assert len(bin_thresholds) == 1 + bin_thresholds = _find_binning_threshold(data, max_bins=5) + assert_allclose(bin_thresholds, [2, 4, 6, 8]) def test_find_binning_thresholds_small_regular_data(): - data = np.linspace(0, 10, 11).reshape(-1, 1) + data = np.linspace(0, 10, 11) - bin_thresholds = _find_binning_thresholds(data, max_bins=5) - assert_allclose(bin_thresholds[0], [2, 4, 6, 8]) + bin_thresholds = _find_binning_threshold(data, max_bins=5) + assert_allclose(bin_thresholds, [2, 4, 6, 8]) - bin_thresholds = _find_binning_thresholds(data, max_bins=10) - assert_allclose(bin_thresholds[0], [1, 2, 3, 4, 5, 6, 7, 8, 9]) + bin_thresholds = _find_binning_threshold(data, max_bins=10) + assert_allclose(bin_thresholds, [1, 2, 3, 4, 5, 6, 7, 8, 9]) - bin_thresholds = _find_binning_thresholds(data, max_bins=11) - assert_allclose(bin_thresholds[0], np.arange(10) + .5) + bin_thresholds = _find_binning_threshold(data, max_bins=11) + assert_allclose(bin_thresholds, np.arange(10) + .5) - bin_thresholds = _find_binning_thresholds(data, max_bins=255) - assert_allclose(bin_thresholds[0], np.arange(10) + .5) + bin_thresholds = _find_binning_threshold(data, max_bins=255) + assert_allclose(bin_thresholds, np.arange(10) + .5) def test_find_binning_thresholds_random_data(): - bin_thresholds = _find_binning_thresholds(DATA, max_bins=255, - random_state=0) - assert len(bin_thresholds) == 2 + bin_thresholds = [_find_binning_threshold(DATA[:, i], max_bins=255) + for i in range(2)] for i in range(len(bin_thresholds)): assert bin_thresholds[i].shape == (254,) # 255 - 1 assert bin_thresholds[i].dtype == DATA.dtype @@ -68,8 +58,8 @@ def test_find_binning_thresholds_random_data(): def test_find_binning_thresholds_low_n_bins(): - bin_thresholds = _find_binning_thresholds(DATA, max_bins=128, - random_state=0) + bin_thresholds = [_find_binning_threshold(DATA[:, i], max_bins=128) + for i in range(2)] assert len(bin_thresholds) == 2 for i in range(len(bin_thresholds)): assert bin_thresholds[i].shape == (127,) # 128 - 1 @@ -94,11 +84,14 @@ def test_bin_mapper_n_features_transform(): @pytest.mark.parametrize('max_bins', [16, 128, 255]) def test_map_to_bins(max_bins): - bin_thresholds = _find_binning_thresholds(DATA, max_bins=max_bins, - random_state=0) + bin_thresholds = [_find_binning_threshold(DATA[:, i], max_bins=max_bins) + for i in range(2)] binned = np.zeros_like(DATA, dtype=X_BINNED_DTYPE, order='F') + is_categorical = np.zeros(2, dtype=np.uint8) last_bin_idx = max_bins - _map_to_bins(DATA, bin_thresholds, last_bin_idx, binned) + category_mapper = CategoryMapper(255) + _map_to_bins(DATA, bin_thresholds, last_bin_idx, category_mapper, + is_categorical, binned) assert binned.shape == DATA.shape assert binned.dtype == np.uint8 assert binned.flags.f_contiguous @@ -312,3 +305,121 @@ def test_infinite_values(): expected_binned_X = np.array([0, 1, 2, 3]).reshape(-1, 1) assert_array_equal(bin_mapper.transform(X), expected_binned_X) + + +@pytest.mark.parametrize("n_bins", [15, 256]) +def test_categorical_n_bins_greater_than_cardinality(n_bins): + # test when n_bins is large enough to hold all categories (+ missing + # values bin which is always allocated) + X = np.array([[4] * 2 + [1] * 3 + [10] * 4 + + [0] * 4 + [9] + [7] * 5], dtype=X_DTYPE).T + + bin_mapper = _BinMapper(n_bins=n_bins, + is_categorical=np.array([True])).fit(X) + assert bin_mapper.n_bins_non_missing_ == [6] + assert_allclose(bin_mapper.bin_thresholds_[0], [0, 1, 4, 7, 9, 10]) + + X_trans = bin_mapper.transform( + np.array([[10, 1, -1, 9, np.nan, 7, 4, 100, 0]], dtype=X_DTYPE).T) + + # missing, negative, and unknown values are mapped to the missing bin + missing_val_bin = n_bins - 1 + expected_trans = np.array([[5, 1, missing_val_bin, 4, missing_val_bin, 3, + 2, missing_val_bin, 0]]).T + assert_array_equal(X_trans, expected_trans) + + +@pytest.mark.parametrize( + "n_bins, expected_trans, expected_bin_categories", + [ + (3, [1, 2, 2, 2, 2, 0, 2, 2, 2], [7, 10]), + # 0 is chosen instead of 10 because it comes after in lexicon + # order + (4, [2, 3, 3, 3, 3, 1, 3, 3, 0], [0, 7, 10]), + (5, [3, 1, 4, 4, 4, 2, 4, 4, 0], [0, 1, 7, 10]), + (6, [4, 1, 5, 5, 5, 3, 2, 5, 0], [0, 1, 4, 7, 10]), + ]) +def test_categorical_n_bins_less_than_cardinality( + n_bins, expected_trans, expected_bin_categories): + # test when there is not enough bins: only the top n_bins - 1 distinct + # values (based on cardinality) will have their own bin + X = np.array([[4] * 2 + [1] * 3 + [10] * 4 + + [0] * 4 + [9] + [7] * 5], dtype=X_DTYPE).T + X_test = np.array([[10, 1, -1, 9, np.nan, 7, 4, 100, 0]], dtype=X_DTYPE).T + + bin_mapper = _BinMapper(n_bins=n_bins, + is_categorical=np.array([True])).fit(X) + assert bin_mapper.n_bins_non_missing_ == [n_bins - 1] + assert_allclose(bin_mapper.bin_thresholds_[0], expected_bin_categories) + X_trans = bin_mapper.transform(X_test) + + # missing, negative, unknown values are mapped to the missing bin + expected_trans = np.array([expected_trans]).T + assert_array_equal(X_trans, expected_trans) + + +def test_categorical_n_bins_less_than_cardinality_ties(): + # when there are not enough bins and the cardiality is the same + # the elements with a higher lexicon ordering will be choosen as + # categories + X = np.array([[1] * 4 + [2] * 4 + [3] * 4 + + [4] * 2 + [5] * 2 + [6] * 1], dtype=X_DTYPE).T + X_test = np.array([[1, 2, 3, 4, 5, 6, 7]], dtype=X_DTYPE).T + + # With 4 bins used for non missing values. categories 1, 2, 3, 5 + # will have their own bin, the rest will be placed in the missing bin. + bin_mapper = _BinMapper(n_bins=5, is_categorical=np.array([True])).fit(X) + assert bin_mapper.n_bins_non_missing_ == [4] + assert_allclose(bin_mapper.bin_thresholds_[0], [1, 2, 3, 5]) + + X_trans = bin_mapper.transform(X_test) + expected_trans = np.array([[0, 1, 2, 4, 3, 4, 4]]).T + assert_array_equal(X_trans, expected_trans) + + # With 3 bins usd for non missing values. categoires 2, 3 + # will have their own bin, the rest will be placed in the missing bin. + bin_mapper = _BinMapper(n_bins=3, is_categorical=np.array([True])).fit(X) + assert bin_mapper.n_bins_non_missing_ == [2] + assert_allclose(bin_mapper.bin_thresholds_[0], [2, 3]) + + X_trans = bin_mapper.transform(X_test) + expected_trans = np.array([[2, 0, 1, 2, 2, 2, 2]]).T + assert_array_equal(X_trans, expected_trans) + + +def test_categorical_default_categories_are_missing(): + # check that negative values are considered missing + X = np.array([[0] * 10 + [1] * 12 + + [-1] * 20 + [np.nan] * 20], dtype=X_DTYPE).T + + bin_mapper = _BinMapper(n_bins=3, is_categorical=np.array([True])).fit(X) + assert bin_mapper.n_bins_non_missing_ == [2] + assert_array_equal(bin_mapper.bin_thresholds_[0], [0, 1]) + + X_test = np.array([[0, 1, -1, 2, np.nan]], dtype=X_DTYPE).T + X_trans = bin_mapper.transform(X_test) + expected_trans = np.array([[0, 1, 2, 2, 2]]).T + assert_array_equal(X_trans, expected_trans) + + +@pytest.mark.parametrize( + "n_bins, expected_n_bins_non_missing", + [(128, [10, 5]), + (256, [10, 5])] +) +def test_categorical_with_numerical_features(n_bins, + expected_n_bins_non_missing): + # check binmapper with mixed data + X1 = np.arange(10, 20).reshape(-1, 1) + X2 = np.arange(10, 15).reshape(-1, 1) + X2 = np.r_[X2, X2] + X = np.c_[X1, X2] + + bin_mapper = _BinMapper(n_bins=n_bins, + is_categorical=np.array([False, True])).fit(X) + assert_array_equal(bin_mapper.n_bins_non_missing_, + expected_n_bins_non_missing) + bin_thresholds = bin_mapper.bin_thresholds_ + + assert len(bin_thresholds) == 2 + assert_array_equal(bin_thresholds[1], np.arange(10, 15)) diff --git a/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py b/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py index 3cbcd4702fe30..8b604eae364a1 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py +++ b/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py @@ -8,6 +8,7 @@ from sklearn.base import clone, BaseEstimator, TransformerMixin from sklearn.base import is_regressor from sklearn.pipeline import make_pipeline +from sklearn.metrics import r2_score from sklearn.metrics import mean_poisson_deviance from sklearn.dummy import DummyRegressor from sklearn.exceptions import NotFittedError @@ -794,3 +795,72 @@ def test_staged_predict(HistGradientBoosting, X, y): assert_allclose(staged_predictions, pred_aux) assert staged_predictions.shape == pred_aux.shape + + +@pytest.mark.parametrize("insert_missing", [False, True]) +@pytest.mark.parametrize("make_dataset, Est", [ + (make_regression, HistGradientBoostingRegressor), + (make_classification, HistGradientBoostingClassifier) +]) +@pytest.mark.parametrize("bool_categorical_parameter", [True, False]) +def test_categorical_sanity(insert_missing, make_dataset, Est, + bool_categorical_parameter): + # Test support categories with or without missing data + X, y = make_dataset(n_samples=1000, n_features=8, random_state=0) + + # even indices are categorical + categorical = np.zeros(X.shape[1], dtype=bool) + categorical[::2] = True + + X[:, categorical] = KBinsDiscretizer( + encode='ordinal', n_bins=20).fit_transform(X[:, categorical]) + + if insert_missing: + rng = np.random.RandomState(42) + mask = rng.binomial(1, 0.01, size=X.shape).astype(np.bool) + X[mask] = np.nan + + if bool_categorical_parameter: + categorical_features = categorical + else: + categorical_features = np.flatnonzero(categorical) + + est = Est(max_iter=20, categorical_features=categorical_features, + random_state=0).fit(X, y) + assert_array_equal(est.is_categorical_, categorical) + + y_pred = est.predict(X) + assert r2_score(y, y_pred) >= 0.8 + + X_test = np.zeros((1, X.shape[1]), dtype=float) + X_test[:, ::2] = 30 # unknown category + X_test[:, 5:] = np.nan # sets remaining + + # Does not error on unknown or missing categories + est.predict(X_test) + + +@pytest.mark.parametrize('Est', (HistGradientBoostingClassifier, + HistGradientBoostingRegressor)) +@pytest.mark.parametrize("categorical_features, monotonic_cst, expected_msg", [ + (["hello", "world"], None, + ("categorical_features must be an array-like of bools or array-like of " + "ints.")), + ([0, -1], None, + (r"categorical_features set as integer indices must be in " + r"\[0, n_features - 1\]")), + ([True, True, False, False, True], None, + r"categorical_features set as a boolean mask must have shape " + r"\(n_features,\)"), + ([True, True, False, False], [0, -1, 0, 1], + "categorical features can not have monotonic constraints"), +]) +def test_categorical_spec_errors(Est, categorical_features, monotonic_cst, + expected_msg): + # Test errors when categories are specified incorrectly + X, y = make_classification(random_state=0, n_features=4) + est = Est(categorical_features=categorical_features, + monotonic_cst=monotonic_cst) + + with pytest.raises(ValueError, match=expected_msg): + est.fit(X, y) diff --git a/sklearn/ensemble/_hist_gradient_boosting/tests/test_grower.py b/sklearn/ensemble/_hist_gradient_boosting/tests/test_grower.py index 3a91c49b6583b..8679d88ea86f2 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/tests/test_grower.py +++ b/sklearn/ensemble/_hist_gradient_boosting/tests/test_grower.py @@ -2,9 +2,11 @@ import pytest from pytest import approx +from sklearn.preprocessing import OneHotEncoder from sklearn.ensemble._hist_gradient_boosting.grower import TreeGrower from sklearn.ensemble._hist_gradient_boosting.binning import _BinMapper from sklearn.ensemble._hist_gradient_boosting.common import X_BINNED_DTYPE +from sklearn.ensemble._hist_gradient_boosting.common import X_DTYPE from sklearn.ensemble._hist_gradient_boosting.common import Y_DTYPE from sklearn.ensemble._hist_gradient_boosting.common import G_H_DTYPE @@ -158,7 +160,7 @@ def test_predictor_from_grower(): # object to perform predictions at scale # We pass undefined num_thresholds because we won't use predict() anyway predictor = grower.make_predictor( - num_thresholds=np.zeros((X_binned.shape[1], n_bins)) + num_thresholds=list(np.zeros((X_binned.shape[1], n_bins))) ) assert predictor.nodes.shape[0] == 5 assert predictor.nodes['is_leaf'].sum() == 3 @@ -344,7 +346,7 @@ def test_missing_value_predict_only(): # We pass undefined num_thresholds because we won't use predict() anyway predictor = grower.make_predictor( - num_thresholds=np.zeros((X_binned.shape[1], X_binned.max() + 1)) + num_thresholds=list(np.zeros((X_binned.shape[1], X_binned.max() + 1))) ) # go from root to a leaf, always following node with the most samples. @@ -403,3 +405,90 @@ def test_split_on_nan_with_infinite_values(): X_binned, missing_values_bin_idx=bin_mapper.missing_values_bin_idx_) np.testing.assert_allclose(predictions, -gradients) np.testing.assert_allclose(predictions_binned, -gradients) + + +def test_grow_tree_categories(): + # Checks growing the tree with categorical features + X_binned = np.array([[0, 1] * 11 + [1]], dtype=X_BINNED_DTYPE).T + X_binned = np.asfortranarray(X_binned) + + all_gradients = np.array([1, 10] * 11 + [1], dtype=G_H_DTYPE) + all_hessians = np.ones(1, dtype=G_H_DTYPE) + is_categorical = np.ones(1, dtype=np.uint8) + + grower = TreeGrower(X_binned, all_gradients, all_hessians, + n_bins=4, shrinkage=1.0, min_samples_leaf=1, + is_categorical=is_categorical) + grower.grow() + assert grower.n_nodes == 3 + + num_thresholds = [np.array([4.0, 10.0], dtype=X_DTYPE)] + predictor = grower.make_predictor(num_thresholds=num_thresholds) + root = predictor.nodes[0] + assert root['count'] == 23 + assert root['depth'] == 0 + assert root['is_categorical'] + + raw_categories = predictor.predictor_bitset.get_raw_categories(0) + np.testing.assert_array_equal(raw_categories, [2**10]) + + # missing values with n_bins = 4 goes left because it has more samples + # and category 1 goes left -> bitset 0101000 -> 2 + 8 = 10 + expected_cat_bitset = [10] + [0] * 7 + cat_bitset = predictor.predictor_bitset.get_binned_categories(0) + np.testing.assert_array_equal(cat_bitset, expected_cat_bitset) + + +@pytest.mark.parametrize('min_samples_leaf', (1, 20)) +@pytest.mark.parametrize('n_unique_categories', (2, 10, 100)) +@pytest.mark.parametrize('target', ('binary', 'random', 'equal')) +def test_ohe_equivalence(min_samples_leaf, n_unique_categories, target): + # Make sure that native categorical splits are equivalent to using a OHE, + # when given enough depth + + rng = np.random.RandomState(0) + n_samples = 10_000 + X_binned = rng.randint(0, n_unique_categories, + size=(n_samples, 1), dtype=np.uint8) + + X_ohe = OneHotEncoder(sparse=False).fit_transform(X_binned) + X_ohe = np.asfortranarray(X_ohe).astype(np.uint8) + + if target == 'equal': + gradients = X_binned.reshape(-1) + elif target == 'binary': + gradients = (X_binned % 2).reshape(-1) + else: + gradients = rng.randn(n_samples) + gradients = gradients.astype(G_H_DTYPE) + + hessians = np.ones(shape=1, dtype=G_H_DTYPE) + + grower_params = { + 'min_samples_leaf': min_samples_leaf, + 'max_depth': None, + 'max_leaf_nodes': None, + } + + grower = TreeGrower(X_binned, gradients, hessians, is_categorical=[True], + **grower_params) + grower.grow() + # we pass undefined bin_thresholds because we won't use predict() + predictor = grower.make_predictor( + num_thresholds=list(np.zeros((1, n_unique_categories))) + ) + preds = predictor.predict_binned(X_binned, missing_values_bin_idx=255) + + grower_ohe = TreeGrower(X_ohe, gradients, hessians, **grower_params) + grower_ohe.grow() + predictor_ohe = grower_ohe.make_predictor( + num_thresholds=list(np.zeros((X_ohe.shape[1], n_unique_categories))) + ) + preds_ohe = predictor_ohe.predict_binned(X_ohe, missing_values_bin_idx=255) + + assert predictor.get_max_depth() <= predictor_ohe.get_max_depth() + if target == 'binary' and n_unique_categories > 2: + # OHE needs more splits to achieve the same predictions + assert predictor.get_max_depth() < predictor_ohe.get_max_depth() + + np.testing.assert_allclose(preds, preds_ohe) diff --git a/sklearn/ensemble/_hist_gradient_boosting/tests/test_monotonic_contraints.py b/sklearn/ensemble/_hist_gradient_boosting/tests/test_monotonic_contraints.py index c642bdd59335a..ca9ac521b11a1 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/tests/test_monotonic_contraints.py +++ b/sklearn/ensemble/_hist_gradient_boosting/tests/test_monotonic_contraints.py @@ -178,7 +178,7 @@ def test_nodes_values(monotonic_cst, seed): # We pass undefined num_thresholds because we won't use predict() anyway predictor = grower.make_predictor( - num_thresholds=np.zeros((X_binned.shape[1], X_binned.max() + 1)) + num_thresholds=list(np.zeros((X_binned.shape[1], X_binned.max() + 1))) ) # The consistency of the bounds can only be checked on the tree grower @@ -302,14 +302,16 @@ def test_bounded_value_min_gain_to_split(): monotonic_cst = np.array( [MonotonicConstraint.NO_CST] * X_binned.shape[1], dtype=np.int8) + is_categorical = np.zeros_like(monotonic_cst, dtype=np.uint8) missing_values_bin_idx = n_bins - 1 children_lower_bound, children_upper_bound = -np.inf, np.inf min_gain_to_split = 2000 splitter = Splitter(X_binned, n_bins_non_missing, missing_values_bin_idx, - has_missing_values, monotonic_cst, l2_regularization, - min_hessian_to_split, min_samples_leaf, - min_gain_to_split, hessians_are_constant) + has_missing_values, is_categorical, monotonic_cst, + l2_regularization, min_hessian_to_split, + min_samples_leaf, min_gain_to_split, + hessians_are_constant) histograms = builder.compute_histograms_brute(sample_indices) diff --git a/sklearn/ensemble/_hist_gradient_boosting/tests/test_predictor.py b/sklearn/ensemble/_hist_gradient_boosting/tests/test_predictor.py index 3be53c2c6c638..641d891fd582f 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/tests/test_predictor.py +++ b/sklearn/ensemble/_hist_gradient_boosting/tests/test_predictor.py @@ -1,4 +1,5 @@ import numpy as np +from numpy.testing import assert_allclose from sklearn.datasets import make_regression from sklearn.model_selection import train_test_split from sklearn.metrics import r2_score @@ -8,7 +9,10 @@ from sklearn.ensemble._hist_gradient_boosting.grower import TreeGrower from sklearn.ensemble._hist_gradient_boosting.predictor import TreePredictor from sklearn.ensemble._hist_gradient_boosting.common import ( - G_H_DTYPE, PREDICTOR_RECORD_DTYPE, ALMOST_INF) + G_H_DTYPE, PREDICTOR_RECORD_DTYPE, ALMOST_INF, X_BINNED_DTYPE, + X_BITSET_INNER_DTYPE, X_DTYPE) +from sklearn.ensemble._hist_gradient_boosting._predictor_bitset import \ + PredictorBitSet @pytest.mark.parametrize('n_bins', [200, 256]) @@ -33,7 +37,8 @@ def test_regression_dataset(n_bins): n_bins_non_missing=mapper.n_bins_non_missing_) grower.grow() - predictor = grower.make_predictor(num_thresholds=mapper.bin_thresholds_) + predictor = grower.make_predictor( + num_thresholds=mapper.bin_thresholds_) assert r2_score(y_train, predictor.predict(X_train)) > 0.82 assert r2_score(y_test, predictor.predict(X_test)) > 0.67 @@ -70,7 +75,75 @@ def test_infinite_values_and_thresholds(num_threshold, expected_predictions): nodes[2]['is_leaf'] = True nodes[2]['value'] = 1 - predictor = TreePredictor(nodes) + predictor = TreePredictor( + nodes, PredictorBitSet([], np.array([False], dtype=np.uint8))) predictions = predictor.predict(X) assert np.all(predictions == expected_predictions) + + +def _construct_bitset(bins_go_left): + output = np.zeros(8, dtype=X_BITSET_INNER_DTYPE) + + for threshold in bins_go_left: + i1 = threshold // 32 + i2 = threshold % 32 + output[i1] |= X_BITSET_INNER_DTYPE(1) << X_BITSET_INNER_DTYPE(i2) + + return output + + +@pytest.mark.parametrize( + 'bins_go_left, expected_predictions', [ + ([0, 3, 4, 6], [1, 0, 0, 1, 1, 0]), + ([0, 1, 2, 6], [1, 1, 1, 0, 0, 0]), + ([3, 5, 6], [0, 0, 0, 1, 0, 1]) + ]) +def test_categorical_predictor(bins_go_left, expected_predictions): + # Test predictor outputs are correct with categorical features + + X_binned = np.array([[0, 1, 2, 3, 4, 5]], dtype=X_BINNED_DTYPE).T + category_bins = np.array([2, 5, 6, 8, 10, 15], dtype=X_DTYPE) + nodes = np.zeros(3, dtype=PREDICTOR_RECORD_DTYPE) + + # We just construct a simple tree with 1 root and 2 children + # parent node + nodes[0]['left'] = 1 + nodes[0]['right'] = 2 + nodes[0]['feature_idx'] = 0 + nodes[0]['is_categorical'] = True + nodes[0]['missing_go_to_left'] = True + + # left child + nodes[1]['is_leaf'] = True + nodes[1]['value'] = 1 + + # right child + nodes[2]['is_leaf'] = True + nodes[2]['value'] = 0 + + cat_bitset = _construct_bitset(bins_go_left) + predictor_bitset = PredictorBitSet([category_bins], + np.array([True], dtype=np.uint8)) + predictor_bitset.insert_categories_bitset(0, category_bins, cat_bitset) + predictor = TreePredictor(nodes, predictor_bitset) + + # Check binned data gives correct predictions + prediction_binned = predictor.predict_binned(X_binned, + missing_values_bin_idx=6) + assert_allclose(prediction_binned, expected_predictions) + + # Check with un-binned data + predictions = predictor.predict(category_bins.reshape(-1, 1)) + assert_allclose(predictions, expected_predictions) + + # Check missing goes left because missing_values_bin_idx=6 + X_binned_missing = np.array([[6]], dtype=X_BINNED_DTYPE).T + predictions = predictor.predict_binned(X_binned_missing, + missing_values_bin_idx=6) + assert_allclose(predictions, [1]) + + # missing and unknown go left + predictions = predictor.predict(np.array([[np.nan, 17.0]], + dtype=X_DTYPE).T) + assert_allclose(predictions, [1, 1]) diff --git a/sklearn/ensemble/_hist_gradient_boosting/tests/test_splitting.py b/sklearn/ensemble/_hist_gradient_boosting/tests/test_splitting.py index bcc19d750d9d3..9f2907eda0fbf 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/tests/test_splitting.py +++ b/sklearn/ensemble/_hist_gradient_boosting/tests/test_splitting.py @@ -1,5 +1,6 @@ import numpy as np import pytest +from numpy.testing import assert_array_equal from sklearn.ensemble._hist_gradient_boosting.common import HISTOGRAM_DTYPE from sklearn.ensemble._hist_gradient_boosting.common import G_H_DTYPE @@ -50,11 +51,13 @@ def test_histogram_split(n_bins): monotonic_cst = np.array( [MonotonicConstraint.NO_CST] * X_binned.shape[1], dtype=np.int8) + is_categorical = np.zeros_like(monotonic_cst, dtype=np.uint8) missing_values_bin_idx = n_bins - 1 splitter = Splitter(X_binned, n_bins_non_missing, missing_values_bin_idx, has_missing_values, + is_categorical, monotonic_cst, l2_regularization, min_hessian_to_split, @@ -119,11 +122,12 @@ def test_gradient_and_hessian_sanity(constant_hessian): monotonic_cst = np.array( [MonotonicConstraint.NO_CST] * X_binned.shape[1], dtype=np.int8) + is_categorical = np.zeros_like(monotonic_cst, dtype=np.uint8) missing_values_bin_idx = n_bins - 1 splitter = Splitter(X_binned, n_bins_non_missing, missing_values_bin_idx, - has_missing_values, monotonic_cst, l2_regularization, - min_hessian_to_split, min_samples_leaf, - min_gain_to_split, constant_hessian) + has_missing_values, is_categorical, monotonic_cst, + l2_regularization, min_hessian_to_split, + min_samples_leaf, min_gain_to_split, constant_hessian) hists_parent = builder.compute_histograms_brute(sample_indices) value_parent = compute_node_value(sum_gradients, sum_hessians, @@ -233,11 +237,13 @@ def test_split_indices(): monotonic_cst = np.array( [MonotonicConstraint.NO_CST] * X_binned.shape[1], dtype=np.int8) + is_categorical = np.zeros_like(monotonic_cst, dtype=np.uint8) missing_values_bin_idx = n_bins - 1 splitter = Splitter(X_binned, n_bins_non_missing, missing_values_bin_idx, - has_missing_values, monotonic_cst, l2_regularization, - min_hessian_to_split, min_samples_leaf, - min_gain_to_split, hessians_are_constant) + has_missing_values, is_categorical, monotonic_cst, + l2_regularization, min_hessian_to_split, + min_samples_leaf, min_gain_to_split, + hessians_are_constant) assert np.all(sample_indices == splitter.partition) @@ -295,9 +301,11 @@ def test_min_gain_to_split(): monotonic_cst = np.array( [MonotonicConstraint.NO_CST] * X_binned.shape[1], dtype=np.int8) + is_categorical = np.zeros_like(monotonic_cst, dtype=np.uint8) missing_values_bin_idx = n_bins - 1 splitter = Splitter(X_binned, n_bins_non_missing, missing_values_bin_idx, - has_missing_values, monotonic_cst, l2_regularization, + has_missing_values, is_categorical, monotonic_cst, + l2_regularization, min_hessian_to_split, min_samples_leaf, min_gain_to_split, hessians_are_constant) @@ -437,10 +445,11 @@ def test_splitting_missing_values(X_binned, all_gradients, monotonic_cst = np.array( [MonotonicConstraint.NO_CST] * X_binned.shape[1], dtype=np.int8) + is_categorical = np.zeros_like(monotonic_cst, dtype=np.uint8) missing_values_bin_idx = n_bins - 1 splitter = Splitter(X_binned, n_bins_non_missing, missing_values_bin_idx, has_missing_values, - monotonic_cst, + is_categorical, monotonic_cst, l2_regularization, min_hessian_to_split, min_samples_leaf, min_gain_to_split, hessians_are_constant) @@ -478,3 +487,200 @@ def test_splitting_missing_values(X_binned, all_gradients, assert set(samples_right) == set(missing_samples_indices) assert set(samples_left) == set(non_missing_samples_indices) + + +@pytest.mark.parametrize( + 'X_binned, has_missing_values, n_bins_non_missing, ', [ + # one category + ([0] * 20, False, 1), + + # all categories appear less than CAT_SMOOTH + ([0] * 9 + [1] * 8, False, 2), + + # only one category appear more than CAT_SMOOTH + ([0] * 12 + [1] * 8, False, 2), + + # missing values + category appear less than CAT_SMOOTH + # 9 is missing + ([0] * 9 + [1] * 8 + [9] * 4, True, 2), + + # 9 is missing + ([9] * 11, True, 0), + ]) +def test_splitting_categorical_no_splits(X_binned, has_missing_values, + n_bins_non_missing): + # Checks categorical splits are correct when there are no spliits + + n_bins = max(X_binned) + 1 + n_samples = len(X_binned) + X_binned = np.array([X_binned], dtype=X_BINNED_DTYPE).T + X_binned = np.asfortranarray(X_binned) + + l2_regularization = 0.0 + min_hessian_to_split = 1e-3 + min_samples_leaf = 1 + min_gain_to_split = 0.0 + + sample_indices = np.arange(n_samples, dtype=np.uint32) + all_gradients = np.ones(n_samples, dtype=G_H_DTYPE) + has_missing_values = np.array([has_missing_values], dtype=np.uint8) + all_hessians = np.ones(1, dtype=G_H_DTYPE) + sum_gradients = all_gradients.sum() + sum_hessians = n_samples + hessians_are_constant = True + + builder = HistogramBuilder(X_binned, n_bins, all_gradients, + all_hessians, hessians_are_constant) + + n_bins_non_missing = np.array([n_bins_non_missing], dtype=np.uint32) + monotonic_cst = np.array([MonotonicConstraint.NO_CST] * X_binned.shape[1], + dtype=np.int8) + is_categorical = np.ones_like(monotonic_cst, dtype=np.uint8) + missing_values_bin_idx = n_bins - 1 + + splitter = Splitter(X_binned, n_bins_non_missing, + missing_values_bin_idx, has_missing_values, + is_categorical, monotonic_cst, + l2_regularization, min_hessian_to_split, + min_samples_leaf, min_gain_to_split, + hessians_are_constant) + + histograms = builder.compute_histograms_brute(sample_indices) + value = compute_node_value(sum_gradients, sum_hessians, + -np.inf, np.inf, l2_regularization) + split_info = splitter.find_node_split(n_samples, histograms, + sum_gradients, sum_hessians, value) + + # no split found + assert split_info.gain == -1 + + +def _assert_threshold_equals_bitset(expected_thresholds, bitset): + # bitset is assumed to be an array 8 of uint32 + + # form bitset from threshold + expected_threshold_bitset = np.zeros(8, dtype=np.uint32) + for threshold in expected_thresholds: + i1 = threshold // 32 + i2 = threshold % 32 + expected_threshold_bitset[i1] |= 1 << i2 + + # check for equality + assert_array_equal(expected_threshold_bitset, bitset) + + +@pytest.mark.parametrize( + "X_binned, all_gradients, expected_categories_left, n_bins_non_missing," + "missing_values_bin_idx, has_missing_values", + [ + # 4 categories + # since there is no missing value during training, the + # missing values should go to the left bin with 22 samples but + # this is done in the grower + ([0, 1, 2, 3] * 11, # X_binned + [1, 20, 1, 1] * 11, # all_gradients + [1], # expected_categories_left + 4, # n_bins_non_missing + 4, # missing_values_bin_idx + False), # has_missing_values + + # 5 categories where the left node has more samples + # the grower would add the missing value bin to go to the left + ([0, 1, 2, 3, 4] * 11 + [1] * 50, # X_binned + [1, 10, 1, 1, 1] * 11 + [10] * 50, # all_gradients + [1], # expected_categories_left + 5, # n_bins_non_missing + 5, # missing_values_bin_idx + False), # has_missing_values + + # 4 categories (including missing value) + ([0, 1, 2] * 11 + [9] * 11, # X_binned + [1, 5, 1] * 11 + [1] * 11, # all_gradients + [1], # expected_categories_left + 3, # n_bins_non_missing + 9, # missing_values_bin_idx + True), # has_missing_values + + # split is on the missing value + ([0, 1, 2, 3, 4] * 11 + [255] * 12, # X_binned + [1, 1, 1, 1, 1] * 11 + [20] * 12, # all_gradients + [255], # expected_categories_left + 5, # n_bins_non_missing + 255, # missing_values_bin_idx + True), # has_missing_values + + # split on even categories + (list(range(60)) * 12, # X_binned + [1, 10] * 360, # all_gradients + list(range(1, 60, 2)), # expected_categories_left + 59, # n_bins_non_missing + 59, # missing_values_bin_idx + True), # has_missing_values + + # split on every 8 categories + (list(range(256)) * 12, # X_binned + [1, 1, 1, 1, 1, 1, 1, 10] * 384, # all_gradients + list(range(7, 256, 8)), # expected_categories_left + 255, # n_bins_non_missing + 255, # missing_values_bin_idx + True), # has_missing_values + ]) +def test_splitting_categorical_sanity(X_binned, all_gradients, + expected_categories_left, + n_bins_non_missing, + missing_values_bin_idx, + has_missing_values): + # Tests various combinations of categorical splits + + n_samples = len(X_binned) + n_bins = max(X_binned) + 1 + + X_binned = np.array(X_binned, dtype=X_BINNED_DTYPE).reshape(-1, 1) + X_binned = np.asfortranarray(X_binned) + + l2_regularization = 0.0 + min_hessian_to_split = 1e-3 + min_samples_leaf = 1 + min_gain_to_split = 0. + + sample_indices = np.arange(n_samples, dtype=np.uint32) + all_gradients = np.array(all_gradients, dtype=G_H_DTYPE) + all_hessians = np.ones(1, dtype=G_H_DTYPE) + has_missing_values = np.array([has_missing_values], dtype=np.uint8) + sum_gradients = all_gradients.sum() + sum_hessians = n_samples + hessians_are_constant = True + + builder = HistogramBuilder(X_binned, n_bins, all_gradients, + all_hessians, hessians_are_constant) + + n_bins_non_missing = np.array([n_bins_non_missing], dtype=np.uint32) + monotonic_cst = np.array([MonotonicConstraint.NO_CST] * X_binned.shape[1], + dtype=np.int8) + is_categorical = np.ones_like(monotonic_cst, dtype=np.uint8) + + splitter = Splitter(X_binned, n_bins_non_missing, + missing_values_bin_idx, has_missing_values, + is_categorical, monotonic_cst, + l2_regularization, min_hessian_to_split, + min_samples_leaf, min_gain_to_split, + hessians_are_constant) + + histograms = builder.compute_histograms_brute(sample_indices) + + value = compute_node_value(sum_gradients, sum_hessians, + -np.inf, np.inf, l2_regularization) + split_info = splitter.find_node_split(n_samples, histograms, + sum_gradients, sum_hessians, value) + + assert split_info.is_categorical + _assert_threshold_equals_bitset(expected_categories_left, + split_info.cat_bitset) + + # make sure samples are split correctly + samples_left, samples_right, _ = splitter.split_indices( + split_info, splitter.partition) + + left_mask = np.isin(X_binned.ravel(), expected_categories_left) + assert_array_equal(sample_indices[left_mask], samples_left) + assert_array_equal(sample_indices[~left_mask], samples_right) diff --git a/sklearn/ensemble/setup.py b/sklearn/ensemble/setup.py index 4430cb129efcf..7d03a7482c83a 100644 --- a/sklearn/ensemble/setup.py +++ b/sklearn/ensemble/setup.py @@ -37,6 +37,20 @@ def configuration(parent_package="", top_path=None): sources=["_hist_gradient_boosting/_loss.pyx"], include_dirs=[numpy.get_include()]) + config.add_extension("_hist_gradient_boosting._bitset", + sources=["_hist_gradient_boosting/_bitset.pyx"], + include_dirs=[numpy.get_include()]) + + config.add_extension( + "_hist_gradient_boosting._predictor_bitset", + sources=["_hist_gradient_boosting/_predictor_bitset.pyx"], + include_dirs=[numpy.get_include()]) + + config.add_extension( + "_hist_gradient_boosting._cat_mapper", + sources=["_hist_gradient_boosting/_cat_mapper.pyx"], + include_dirs=[numpy.get_include()]) + config.add_extension("_hist_gradient_boosting.common", sources=["_hist_gradient_boosting/common.pyx"], include_dirs=[numpy.get_include()])