diff --git a/.coveragerc b/.coveragerc index 6d76a5bca8235..1133065a5b248 100644 --- a/.coveragerc +++ b/.coveragerc @@ -1,7 +1,7 @@ [run] branch = True source = sklearn -include = */sklearn/* +parallel = True omit = */sklearn/externals/* */benchmarks/* diff --git a/benchmarks/bench_hist_gradient_boosting.py b/benchmarks/bench_hist_gradient_boosting.py new file mode 100644 index 0000000000000..8d055b22c2252 --- /dev/null +++ b/benchmarks/bench_hist_gradient_boosting.py @@ -0,0 +1,241 @@ +from time import time +import argparse + +import matplotlib.pyplot as plt +from sklearn.model_selection import train_test_split +# To use this experimental feature, we need to explicitly ask for it: +from sklearn.experimental import enable_hist_gradient_boosting # noqa +from sklearn.ensemble import HistGradientBoostingRegressor +from sklearn.ensemble import HistGradientBoostingClassifier +from sklearn.datasets import make_classification +from sklearn.datasets import make_regression +from sklearn.ensemble._hist_gradient_boosting.utils import ( + get_equivalent_estimator) + + +parser = argparse.ArgumentParser() +parser.add_argument('--n-leaf-nodes', type=int, default=31) +parser.add_argument('--n-trees', type=int, default=10) +parser.add_argument('--lightgbm', action="store_true", default=False, + help='also plot lightgbm') +parser.add_argument('--xgboost', action="store_true", default=False, + help='also plot xgboost') +parser.add_argument('--catboost', action="store_true", default=False, + help='also plot catboost') +parser.add_argument('--learning-rate', type=float, default=.1) +parser.add_argument('--problem', type=str, default='classification', + choices=['classification', 'regression']) +parser.add_argument('--n-classes', type=int, default=2) +parser.add_argument('--n-samples-max', type=int, default=int(1e6)) +parser.add_argument('--n-features', type=int, default=20) +parser.add_argument('--max-bins', type=int, default=255) +args = parser.parse_args() + +n_leaf_nodes = args.n_leaf_nodes +n_trees = args.n_trees +lr = args.learning_rate +max_bins = args.max_bins + + +def get_estimator_and_data(): + if args.problem == 'classification': + X, y = make_classification(args.n_samples_max * 2, + n_features=args.n_features, + n_classes=args.n_classes, + n_clusters_per_class=1, + random_state=0) + return X, y, HistGradientBoostingClassifier + elif args.problem == 'regression': + X, y = make_regression(args.n_samples_max * 2, + n_features=args.n_features, random_state=0) + return X, y, HistGradientBoostingRegressor + + +X, y, Estimator = get_estimator_and_data() +X_train_, X_test_, y_train_, y_test_ = train_test_split( + X, y, test_size=0.5, random_state=0) + + +def one_run(n_samples): + X_train = X_train_[:n_samples] + X_test = X_test_[:n_samples] + y_train = y_train_[:n_samples] + y_test = y_test_[:n_samples] + assert X_train.shape[0] == n_samples + assert X_test.shape[0] == n_samples + print("Data size: %d samples train, %d samples test." + % (n_samples, n_samples)) + print("Fitting a sklearn model...") + tic = time() + est = Estimator(learning_rate=lr, + max_iter=n_trees, + max_bins=max_bins, + max_leaf_nodes=n_leaf_nodes, + n_iter_no_change=None, + random_state=0, + verbose=0) + est.fit(X_train, y_train) + sklearn_fit_duration = time() - tic + tic = time() + sklearn_score = est.score(X_test, y_test) + sklearn_score_duration = time() - tic + print("score: {:.4f}".format(sklearn_score)) + print("fit duration: {:.3f}s,".format(sklearn_fit_duration)) + print("score duration: {:.3f}s,".format(sklearn_score_duration)) + + lightgbm_score = None + lightgbm_fit_duration = None + lightgbm_score_duration = None + if args.lightgbm: + print("Fitting a LightGBM model...") + # get_lightgbm does not accept loss='auto' + if args.problem == 'classification': + loss = 'binary_crossentropy' if args.n_classes == 2 else \ + 'categorical_crossentropy' + est.set_params(loss=loss) + lightgbm_est = get_equivalent_estimator(est, lib='lightgbm') + + tic = time() + lightgbm_est.fit(X_train, y_train) + lightgbm_fit_duration = time() - tic + tic = time() + lightgbm_score = lightgbm_est.score(X_test, y_test) + lightgbm_score_duration = time() - tic + print("score: {:.4f}".format(lightgbm_score)) + print("fit duration: {:.3f}s,".format(lightgbm_fit_duration)) + print("score duration: {:.3f}s,".format(lightgbm_score_duration)) + + xgb_score = None + xgb_fit_duration = None + xgb_score_duration = None + if args.xgboost: + print("Fitting an XGBoost model...") + # get_xgb does not accept loss='auto' + if args.problem == 'classification': + loss = 'binary_crossentropy' if args.n_classes == 2 else \ + 'categorical_crossentropy' + est.set_params(loss=loss) + xgb_est = get_equivalent_estimator(est, lib='xgboost') + + tic = time() + xgb_est.fit(X_train, y_train) + xgb_fit_duration = time() - tic + tic = time() + xgb_score = xgb_est.score(X_test, y_test) + xgb_score_duration = time() - tic + print("score: {:.4f}".format(xgb_score)) + print("fit duration: {:.3f}s,".format(xgb_fit_duration)) + print("score duration: {:.3f}s,".format(xgb_score_duration)) + + cat_score = None + cat_fit_duration = None + cat_score_duration = None + if args.catboost: + print("Fitting a CatBoost model...") + # get_cat does not accept loss='auto' + if args.problem == 'classification': + loss = 'binary_crossentropy' if args.n_classes == 2 else \ + 'categorical_crossentropy' + est.set_params(loss=loss) + cat_est = get_equivalent_estimator(est, lib='catboost') + + tic = time() + cat_est.fit(X_train, y_train) + cat_fit_duration = time() - tic + tic = time() + cat_score = cat_est.score(X_test, y_test) + cat_score_duration = time() - tic + print("score: {:.4f}".format(cat_score)) + print("fit duration: {:.3f}s,".format(cat_fit_duration)) + print("score duration: {:.3f}s,".format(cat_score_duration)) + + return (sklearn_score, sklearn_fit_duration, sklearn_score_duration, + lightgbm_score, lightgbm_fit_duration, lightgbm_score_duration, + xgb_score, xgb_fit_duration, xgb_score_duration, + cat_score, cat_fit_duration, cat_score_duration) + + +n_samples_list = [1000, 10000, 100000, 500000, 1000000, 5000000, 10000000] +n_samples_list = [n_samples for n_samples in n_samples_list + if n_samples <= args.n_samples_max] + +sklearn_scores = [] +sklearn_fit_durations = [] +sklearn_score_durations = [] +lightgbm_scores = [] +lightgbm_fit_durations = [] +lightgbm_score_durations = [] +xgb_scores = [] +xgb_fit_durations = [] +xgb_score_durations = [] +cat_scores = [] +cat_fit_durations = [] +cat_score_durations = [] + +for n_samples in n_samples_list: + (sklearn_score, + sklearn_fit_duration, + sklearn_score_duration, + lightgbm_score, + lightgbm_fit_duration, + lightgbm_score_duration, + xgb_score, + xgb_fit_duration, + xgb_score_duration, + cat_score, + cat_fit_duration, + cat_score_duration) = one_run(n_samples) + + for scores, score in ( + (sklearn_scores, sklearn_score), + (sklearn_fit_durations, sklearn_fit_duration), + (sklearn_score_durations, sklearn_score_duration), + (lightgbm_scores, lightgbm_score), + (lightgbm_fit_durations, lightgbm_fit_duration), + (lightgbm_score_durations, lightgbm_score_duration), + (xgb_scores, xgb_score), + (xgb_fit_durations, xgb_fit_duration), + (xgb_score_durations, xgb_score_duration), + (cat_scores, cat_score), + (cat_fit_durations, cat_fit_duration), + (cat_score_durations, cat_score_duration)): + scores.append(score) + +fig, axs = plt.subplots(3, sharex=True) + +axs[0].plot(n_samples_list, sklearn_scores, label='sklearn') +axs[1].plot(n_samples_list, sklearn_fit_durations, label='sklearn') +axs[2].plot(n_samples_list, sklearn_score_durations, label='sklearn') + +if args.lightgbm: + axs[0].plot(n_samples_list, lightgbm_scores, label='lightgbm') + axs[1].plot(n_samples_list, lightgbm_fit_durations, label='lightgbm') + axs[2].plot(n_samples_list, lightgbm_score_durations, label='lightgbm') + +if args.xgboost: + axs[0].plot(n_samples_list, xgb_scores, label='XGBoost') + axs[1].plot(n_samples_list, xgb_fit_durations, label='XGBoost') + axs[2].plot(n_samples_list, xgb_score_durations, label='XGBoost') + +if args.catboost: + axs[0].plot(n_samples_list, cat_scores, label='CatBoost') + axs[1].plot(n_samples_list, cat_fit_durations, label='CatBoost') + axs[2].plot(n_samples_list, cat_score_durations, label='CatBoost') + +for ax in axs: + ax.set_xscale('log') + ax.legend(loc='best') + ax.set_xlabel('n_samples') + +axs[0].set_title('scores') +axs[1].set_title('fit duration (s)') +axs[2].set_title('score duration (s)') + +title = args.problem +if args.problem == 'classification': + title += ' n_classes = {}'.format(args.n_classes) +fig.suptitle(title) + + +plt.tight_layout() +plt.show() diff --git a/benchmarks/bench_hist_gradient_boosting_higgsboson.py b/benchmarks/bench_hist_gradient_boosting_higgsboson.py new file mode 100644 index 0000000000000..ec75760cd39f7 --- /dev/null +++ b/benchmarks/bench_hist_gradient_boosting_higgsboson.py @@ -0,0 +1,123 @@ +from urllib.request import urlretrieve +import os +from gzip import GzipFile +from time import time +import argparse + +import numpy as np +import pandas as pd +from joblib import Memory +from sklearn.model_selection import train_test_split +from sklearn.metrics import accuracy_score, roc_auc_score +# To use this experimental feature, we need to explicitly ask for it: +from sklearn.experimental import enable_hist_gradient_boosting # noqa +from sklearn.ensemble import HistGradientBoostingClassifier +from sklearn.ensemble._hist_gradient_boosting.utils import ( + get_equivalent_estimator) + + +parser = argparse.ArgumentParser() +parser.add_argument('--n-leaf-nodes', type=int, default=31) +parser.add_argument('--n-trees', type=int, default=10) +parser.add_argument('--lightgbm', action="store_true", default=False) +parser.add_argument('--xgboost', action="store_true", default=False) +parser.add_argument('--catboost', action="store_true", default=False) +parser.add_argument('--learning-rate', type=float, default=1.) +parser.add_argument('--subsample', type=int, default=None) +parser.add_argument('--max-bins', type=int, default=255) +args = parser.parse_args() + +HERE = os.path.dirname(__file__) +URL = ("https://archive.ics.uci.edu/ml/machine-learning-databases/00280/" + "HIGGS.csv.gz") +m = Memory(location='/tmp', mmap_mode='r') + +n_leaf_nodes = args.n_leaf_nodes +n_trees = args.n_trees +subsample = args.subsample +lr = args.learning_rate +max_bins = args.max_bins + + +@m.cache +def load_data(): + filename = os.path.join(HERE, URL.rsplit('/', 1)[-1]) + if not os.path.exists(filename): + print(f"Downloading {URL} to {filename} (2.6 GB)...") + urlretrieve(URL, filename) + print("done.") + + print(f"Parsing {filename}...") + tic = time() + with GzipFile(filename) as f: + df = pd.read_csv(f, header=None, dtype=np.float32) + toc = time() + print(f"Loaded {df.values.nbytes / 1e9:0.3f} GB in {toc - tic:0.3f}s") + return df + + +df = load_data() +target = df.values[:, 0] +data = np.ascontiguousarray(df.values[:, 1:]) +data_train, data_test, target_train, target_test = train_test_split( + data, target, test_size=.2, random_state=0) + +if subsample is not None: + data_train, target_train = data_train[:subsample], target_train[:subsample] + +n_samples, n_features = data_train.shape +print(f"Training set with {n_samples} records with {n_features} features.") + +print("Fitting a sklearn model...") +tic = time() +est = HistGradientBoostingClassifier(loss='binary_crossentropy', + learning_rate=lr, + max_iter=n_trees, + max_bins=max_bins, + max_leaf_nodes=n_leaf_nodes, + n_iter_no_change=None, + random_state=0, + verbose=1) +est.fit(data_train, target_train) +toc = time() +predicted_test = est.predict(data_test) +predicted_proba_test = est.predict_proba(data_test) +roc_auc = roc_auc_score(target_test, predicted_proba_test[:, 1]) +acc = accuracy_score(target_test, predicted_test) +print(f"done in {toc - tic:.3f}s, ROC AUC: {roc_auc:.4f}, ACC: {acc :.4f}") + +if args.lightgbm: + print("Fitting a LightGBM model...") + tic = time() + lightgbm_est = get_equivalent_estimator(est, lib='lightgbm') + lightgbm_est.fit(data_train, target_train) + toc = time() + predicted_test = lightgbm_est.predict(data_test) + predicted_proba_test = lightgbm_est.predict_proba(data_test) + roc_auc = roc_auc_score(target_test, predicted_proba_test[:, 1]) + acc = accuracy_score(target_test, predicted_test) + print(f"done in {toc - tic:.3f}s, ROC AUC: {roc_auc:.4f}, ACC: {acc :.4f}") + +if args.xgboost: + print("Fitting an XGBoost model...") + tic = time() + xgboost_est = get_equivalent_estimator(est, lib='xgboost') + xgboost_est.fit(data_train, target_train) + toc = time() + predicted_test = xgboost_est.predict(data_test) + predicted_proba_test = xgboost_est.predict_proba(data_test) + roc_auc = roc_auc_score(target_test, predicted_proba_test[:, 1]) + acc = accuracy_score(target_test, predicted_test) + print(f"done in {toc - tic:.3f}s, ROC AUC: {roc_auc:.4f}, ACC: {acc :.4f}") + +if args.catboost: + print("Fitting a Catboost model...") + tic = time() + catboost_est = get_equivalent_estimator(est, lib='catboost') + catboost_est.fit(data_train, target_train) + toc = time() + predicted_test = catboost_est.predict(data_test) + predicted_proba_test = catboost_est.predict_proba(data_test) + roc_auc = roc_auc_score(target_test, predicted_proba_test[:, 1]) + acc = accuracy_score(target_test, predicted_test) + print(f"done in {toc - tic:.3f}s, ROC AUC: {roc_auc:.4f}, ACC: {acc :.4f}") diff --git a/build_tools/azure/test_pytest_soft_dependency.sh b/build_tools/azure/test_pytest_soft_dependency.sh index 7fd522cf4b1c5..28eacacc27d42 100755 --- a/build_tools/azure/test_pytest_soft_dependency.sh +++ b/build_tools/azure/test_pytest_soft_dependency.sh @@ -8,8 +8,10 @@ conda remove -y py pytest || pip uninstall -y py pytest if [[ "$COVERAGE" == "true" ]]; then # Need to append the coverage to the existing .coverage generated by - # running the tests - CMD="coverage run --append" + # running the tests. Make sure to reuse the same coverage + # configuration as the one used by the main pytest run to be + # able to combine the results. + CMD="coverage run --rcfile=$BUILD_SOURCESDIRECTORY/.coveragerc" else CMD="python" fi diff --git a/build_tools/azure/test_script.sh b/build_tools/azure/test_script.sh index c720f6e387c87..4fd3e70da7362 100755 --- a/build_tools/azure/test_script.sh +++ b/build_tools/azure/test_script.sh @@ -21,10 +21,11 @@ except ImportError: python -c "import multiprocessing as mp; print('%d CPUs' % mp.cpu_count())" pip list -TEST_CMD="python -m pytest --showlocals --durations=20 --junitxml=$JUNITXML --pyargs" +TEST_CMD="python -m pytest --showlocals --durations=20 --junitxml=$JUNITXML" if [[ "$COVERAGE" == "true" ]]; then - TEST_CMD="$TEST_CMD --cov sklearn" + export COVERAGE_PROCESS_START="$BUILD_SOURCESDIRECTORY/.coveragerc" + TEST_CMD="$TEST_CMD --cov-config=$COVERAGE_PROCESS_START --cov sklearn" fi if [[ -n "$CHECK_WARNINGS" ]]; then @@ -36,5 +37,5 @@ cp setup.cfg $TEST_DIR cd $TEST_DIR set -x -$TEST_CMD sklearn +$TEST_CMD --pyargs sklearn set +x diff --git a/build_tools/azure/upload_codecov.sh b/build_tools/azure/upload_codecov.sh index e9f801b3be5f5..ab6c14082ea7a 100755 --- a/build_tools/azure/upload_codecov.sh +++ b/build_tools/azure/upload_codecov.sh @@ -8,6 +8,9 @@ source activate $VIRTUALENV # Need to run codecov from a git checkout, so we copy .coverage # from TEST_DIR where pytest has been run +pushd $TEST_DIR +coverage combine +popd cp $TEST_DIR/.coverage $BUILD_REPOSITORY_LOCALPATH codecov --root $BUILD_REPOSITORY_LOCALPATH -t $CODECOV_TOKEN || echo "codecov upload failed" diff --git a/doc/conf.py b/doc/conf.py index 8fec9c3549f21..27a6bf2ee30c2 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -263,6 +263,11 @@ 'sphx_glr_plot_compare_methods_001.png': 349} +# enable experimental module so that the new GBDTs estimators can be +# discovered properly by sphinx +from sklearn.experimental import enable_hist_gradient_boosting # noqa + + def make_carousel_thumbs(app, exception): """produces the final resized carousel images""" if exception is not None: diff --git a/doc/modules/classes.rst b/doc/modules/classes.rst index 42e2ca5bc18f9..70a3c40fad0f1 100644 --- a/doc/modules/classes.rst +++ b/doc/modules/classes.rst @@ -422,6 +422,9 @@ Samples generator ensemble.RandomTreesEmbedding ensemble.VotingClassifier ensemble.VotingRegressor + ensemble.HistGradientBoostingRegressor + ensemble.HistGradientBoostingClassifier + .. autosummary:: :toctree: generated/ @@ -453,6 +456,22 @@ Samples generator exceptions.NonBLASDotWarning exceptions.UndefinedMetricWarning + +:mod:`sklearn.experimental`: Experimental +========================================= + +.. automodule:: sklearn.experimental + :no-members: + :no-inherited-members: + +.. currentmodule:: sklearn + +.. autosummary:: + :toctree: generated/ + + experimental.enable_hist_gradient_boosting + + .. _feature_extraction_ref: :mod:`sklearn.feature_extraction`: Feature Extraction @@ -1489,6 +1508,7 @@ Utilities from joblib: utils.parallel_backend utils.register_parallel_backend + Recently deprecated =================== diff --git a/doc/modules/ensemble.rst b/doc/modules/ensemble.rst index 7a16dc95098a1..5a38164cb4bd8 100644 --- a/doc/modules/ensemble.rst +++ b/doc/modules/ensemble.rst @@ -456,6 +456,39 @@ The module :mod:`sklearn.ensemble` provides methods for both classification and regression via gradient boosted regression trees. + +.. note:: + + Scikit-learn 0.21 introduces two new experimental implementation of + gradient boosting trees, namely :class:`HistGradientBoostingClassifier` + and :class:`HistGradientBoostingRegressor`, inspired by + `LightGBM `_. These fast estimators + first bin the input samples ``X`` into integer-valued bins (typically 256 + bins) which tremendously reduces the number of splitting points to + consider, and allow the algorithm to leverage integer-based data + structures (histograms) instead of relying on sorted continuous values. + + The new histogram-based estimators can be orders of magnitude faster than + their continuous counterparts when the number of samples is larger than + tens of thousands of samples. The API of these new estimators is slightly + different, and some of the features from :class:`GradientBoostingClassifier` + and :class:`GradientBoostingRegressor` are not yet supported. + + These new estimators are still **experimental** for now: their predictions + and their API might change without any deprecation cycle. To use them, you + need to explicitly import ``enable_hist_gradient_boosting``:: + + >>> # explicitly require this experimental feature + >>> from sklearn.experimental import enable_hist_gradient_boosting # noqa + >>> # now you can import normally from ensemble + >>> from sklearn.ensemble import HistGradientBoostingClassifier + + The following guide focuses on :class:`GradientBoostingClassifier` and + :class:`GradientBoostingRegressor` only, which might be preferred for small + sample sizes since binning may lead to split points that are too approximate + in this setting. + + Classification --------------- diff --git a/doc/whats_new/v0.21.rst b/doc/whats_new/v0.21.rst index ba095eedf1331..cc39d921724fe 100644 --- a/doc/whats_new/v0.21.rst +++ b/doc/whats_new/v0.21.rst @@ -274,6 +274,29 @@ Support for Python 3.4 and below has been officially dropped. :pr:`12513` by :user:`Ramil Nugmanov ` and :user:`Mohamed Ali Jamaoui `. +- |MajorFeature| Add two new implementations of + gradient boosting trees: :class:`ensemble.HistGradientBoostingClassifier` + and :class:`ensemble.HistGradientBoostingRegressor`. The implementation of + these estimators is inspired by + `LightGBM `_ and can be orders of + magnitude faster than :class:`ensemble.GradientBoostingRegressor` and + :class:`ensemble.GradientBoostingClassifier` when the number of samples is + larger than tens of thousands of samples. The API of these new estimators + is slightly different, and some of the features from + :class:`ensemble.GradientBoostingClassifier` and + :class:`ensemble.GradientBoostingRegressor` are not yet supported. + + These new estimators are experimental, which means that their results or + their API might change without any deprecation cycle. To use them, you + need to explicitly import ``enable_hist_gradient_boosting``:: + + >>> # explicitly require this experimental feature + >>> from sklearn.experimental import enable_hist_gradient_boosting # noqa + >>> # now you can import normally from ensemble + >>> from sklearn.ensemble import HistGradientBoostingClassifier + + :issue:`12807` by :user:`Nicolas Hug`. + :mod:`sklearn.externals` ........................ diff --git a/sklearn/__init__.py b/sklearn/__init__.py index 9a6f4e4e29deb..c52910ff2121a 100644 --- a/sklearn/__init__.py +++ b/sklearn/__init__.py @@ -80,14 +80,14 @@ __all__ = ['calibration', 'cluster', 'covariance', 'cross_decomposition', 'datasets', 'decomposition', 'dummy', 'ensemble', 'exceptions', - 'externals', 'feature_extraction', 'feature_selection', - 'gaussian_process', 'inspection', 'isotonic', - 'kernel_approximation', 'kernel_ridge', 'linear_model', - 'manifold', 'metrics', 'mixture', 'model_selection', - 'multiclass', 'multioutput', 'naive_bayes', 'neighbors', - 'neural_network', 'pipeline', 'preprocessing', - 'random_projection', 'semi_supervised', 'svm', 'tree', - 'discriminant_analysis', 'impute', 'compose', + 'experimental', 'externals', 'feature_extraction', + 'feature_selection', 'gaussian_process', 'inspection', + 'isotonic', 'kernel_approximation', 'kernel_ridge', + 'linear_model', 'manifold', 'metrics', 'mixture', + 'model_selection', 'multiclass', 'multioutput', + 'naive_bayes', 'neighbors', 'neural_network', 'pipeline', + 'preprocessing', 'random_projection', 'semi_supervised', + 'svm', 'tree', 'discriminant_analysis', 'impute', 'compose', # Non-modules: 'clone', 'get_config', 'set_config', 'config_context', 'show_versions'] diff --git a/sklearn/ensemble/_hist_gradient_boosting/__init__.py b/sklearn/ensemble/_hist_gradient_boosting/__init__.py new file mode 100644 index 0000000000000..879fae1189f87 --- /dev/null +++ b/sklearn/ensemble/_hist_gradient_boosting/__init__.py @@ -0,0 +1,5 @@ +"""This module implements histogram-based gradient boosting estimators. + +The implementation is a port from pygbm which is itself strongly inspired +from LightGBM. +""" diff --git a/sklearn/ensemble/_hist_gradient_boosting/_binning.pyx b/sklearn/ensemble/_hist_gradient_boosting/_binning.pyx new file mode 100644 index 0000000000000..be958948bec6a --- /dev/null +++ b/sklearn/ensemble/_hist_gradient_boosting/_binning.pyx @@ -0,0 +1,58 @@ +# cython: cdivision=True +# cython: boundscheck=False +# cython: wraparound=False +# cython: nonecheck=False +# cython: language_level=3 + +# Author: Nicolas Hug + +cimport cython + +import numpy as np +cimport numpy as np +from cython.parallel import prange + +from .types cimport X_DTYPE_C, X_BINNED_DTYPE_C + +cpdef _map_to_bins(const X_DTYPE_C [:, :] data, list binning_thresholds, + X_BINNED_DTYPE_C [::1, :] binned): + """Bin numerical values to discrete integer-coded levels. + + Parameters + ---------- + data : ndarray, shape (n_samples, n_features) + The numerical data to bin. + binning_thresholds : list of arrays + For each feature, stores the increasing numeric values that are + used to separate the bins. + binned : ndarray, shape (n_samples, n_features) + Output array, must be fortran aligned. + """ + cdef: + int feature_idx + + for feature_idx in range(data.shape[1]): + _map_num_col_to_bins(data[:, feature_idx], + binning_thresholds[feature_idx], + binned[:, feature_idx]) + + +cpdef void _map_num_col_to_bins(const X_DTYPE_C [:] data, + const X_DTYPE_C [:] binning_thresholds, + X_BINNED_DTYPE_C [:] binned): + """Binary search to find the bin index for each value in the data.""" + cdef: + int i + int left + int right + int middle + + for i in prange(data.shape[0], schedule='static', nogil=True): + left, right = 0, binning_thresholds.shape[0] + while left < right: + middle = (right + left - 1) // 2 + if data[i] <= binning_thresholds[middle]: + right = middle + else: + left = middle + 1 + binned[i] = left diff --git a/sklearn/ensemble/_hist_gradient_boosting/_gradient_boosting.pyx b/sklearn/ensemble/_hist_gradient_boosting/_gradient_boosting.pyx new file mode 100644 index 0000000000000..eb7517139beec --- /dev/null +++ b/sklearn/ensemble/_hist_gradient_boosting/_gradient_boosting.pyx @@ -0,0 +1,60 @@ +# cython: cdivision=True +# cython: boundscheck=False +# cython: wraparound=False +# cython: language_level=3 + +# Author: Nicolas Hug + +cimport cython +from cython.parallel import prange +import numpy as np +cimport numpy as np + +from .types import Y_DTYPE +from .types cimport Y_DTYPE_C + + +def _update_raw_predictions( + Y_DTYPE_C [::1] raw_predictions, # OUT + grower): + """Update raw_predictions with the predictions of the newest tree. + + This is equivalent to (and much faster than): + raw_predictions += last_estimator.predict(X_train) + + It's only possible for data X_train that is used to train the trees (it + isn't usable for e.g. X_val). + """ + cdef: + unsigned int [::1] starts # start of each leaf in partition + unsigned int [::1] stops # end of each leaf in partition + Y_DTYPE_C [::1] values # value of each leaf + const unsigned int [::1] partition = grower.splitter.partition + list leaves + + leaves = grower.finalized_leaves + starts = np.array([leaf.partition_start for leaf in leaves], + dtype=np.uint32) + stops = np.array([leaf.partition_stop for leaf in leaves], + dtype=np.uint32) + values = np.array([leaf.value for leaf in leaves], dtype=Y_DTYPE) + + _update_raw_predictions_helper(raw_predictions, starts, stops, partition, + values) + + +cdef inline void _update_raw_predictions_helper( + Y_DTYPE_C [::1] raw_predictions, # OUT + const unsigned int [::1] starts, + const unsigned int [::1] stops, + const unsigned int [::1] partition, + const Y_DTYPE_C [::1] values): + + cdef: + unsigned int position + int leaf_idx + int n_leaves = starts.shape[0] + + for leaf_idx in prange(n_leaves, nogil=True): + for position in range(starts[leaf_idx], stops[leaf_idx]): + raw_predictions[partition[position]] += values[leaf_idx] diff --git a/sklearn/ensemble/_hist_gradient_boosting/_loss.pyx b/sklearn/ensemble/_hist_gradient_boosting/_loss.pyx new file mode 100644 index 0000000000000..91c3e53101ed6 --- /dev/null +++ b/sklearn/ensemble/_hist_gradient_boosting/_loss.pyx @@ -0,0 +1,106 @@ +# cython: cdivision=True +# cython: boundscheck=False +# cython: wraparound=False +# cython: language_level=3 + +# Author: Nicolas Hug + +cimport cython +from cython.parallel import prange +import numpy as np +cimport numpy as np + +from libc.math cimport exp + +from .types cimport Y_DTYPE_C +from .types cimport G_H_DTYPE_C + + +def _update_gradients_least_squares( + G_H_DTYPE_C [::1] gradients, # OUT + const Y_DTYPE_C [::1] y_true, # IN + const Y_DTYPE_C [::1] raw_predictions): # IN + + cdef: + int n_samples + int i + + n_samples = raw_predictions.shape[0] + for i in prange(n_samples, schedule='static', nogil=True): + # Note: a more correct exp is 2 * (raw_predictions - y_true) but + # since we use 1 for the constant hessian value (and not 2) this + # is strictly equivalent for the leaves values. + gradients[i] = raw_predictions[i] - y_true[i] + + +def _update_gradients_hessians_binary_crossentropy( + G_H_DTYPE_C [::1] gradients, # OUT + G_H_DTYPE_C [::1] hessians, # OUT + const Y_DTYPE_C [::1] y_true, # IN + const Y_DTYPE_C [::1] raw_predictions): # IN + cdef: + int n_samples + Y_DTYPE_C p_i # proba that ith sample belongs to positive class + int i + + n_samples = raw_predictions.shape[0] + for i in prange(n_samples, schedule='static', nogil=True): + p_i = _cexpit(raw_predictions[i]) + gradients[i] = p_i - y_true[i] + hessians[i] = p_i * (1. - p_i) + + +def _update_gradients_hessians_categorical_crossentropy( + G_H_DTYPE_C [:, ::1] gradients, # OUT + G_H_DTYPE_C [:, ::1] hessians, # OUT + const Y_DTYPE_C [::1] y_true, # IN + const Y_DTYPE_C [:, ::1] raw_predictions): # IN + cdef: + int prediction_dim = raw_predictions.shape[0] + int n_samples = raw_predictions.shape[1] + int k # class index + int i # sample index + # p[i, k] is the probability that class(ith sample) == k. + # It's the softmax of the raw predictions + Y_DTYPE_C [:, ::1] p = np.empty(shape=(n_samples, prediction_dim)) + Y_DTYPE_C p_i_k + + for i in prange(n_samples, schedule='static', nogil=True): + # first compute softmaxes of sample i for each class + for k in range(prediction_dim): + p[i, k] = raw_predictions[k, i] # prepare softmax + _compute_softmax(p, i) + # then update gradients and hessians + for k in range(prediction_dim): + p_i_k = p[i, k] + gradients[k, i] = p_i_k - (y_true[i] == k) + hessians[k, i] = p_i_k * (1. - p_i_k) + + +cdef inline void _compute_softmax(Y_DTYPE_C [:, ::1] p, const int i) nogil: + """Compute softmaxes of values in p[i, :].""" + # i needs to be passed (and stays constant) because otherwise Cython does + # not generate optimal code + + cdef: + Y_DTYPE_C max_value = p[i, 0] + Y_DTYPE_C sum_exps = 0. + unsigned int k + unsigned prediction_dim = p.shape[1] + + # Compute max value of array for numerical stability + for k in range(1, prediction_dim): + if max_value < p[i, k]: + max_value = p[i, k] + + for k in range(prediction_dim): + p[i, k] = exp(p[i, k] - max_value) + sum_exps += p[i, k] + + for k in range(prediction_dim): + p[i, k] /= sum_exps + + +cdef inline Y_DTYPE_C _cexpit(const Y_DTYPE_C x) nogil: + """Custom expit (logistic sigmoid function)""" + return 1. / (1. + exp(-x)) diff --git a/sklearn/ensemble/_hist_gradient_boosting/_predictor.pyx b/sklearn/ensemble/_hist_gradient_boosting/_predictor.pyx new file mode 100644 index 0000000000000..45ba70095c3c7 --- /dev/null +++ b/sklearn/ensemble/_hist_gradient_boosting/_predictor.pyx @@ -0,0 +1,100 @@ +# cython: cdivision=True +# cython: boundscheck=False +# cython: wraparound=False +# cython: language_level=3 + +# Author: Nicolas Hug + +cimport cython +from cython.parallel import prange +import numpy as np +cimport numpy as np + +from .types cimport X_DTYPE_C +from .types cimport Y_DTYPE_C +from .types cimport X_BINNED_DTYPE_C + + +cdef packed struct node_struct: + # Equivalent struct to PREDICTOR_RECORD_DTYPE to use in memory views. It + # needs to be packed since by default numpy dtypes aren't aligned + Y_DTYPE_C value + unsigned int count + unsigned int feature_idx + X_DTYPE_C threshold + unsigned int left + unsigned int right + Y_DTYPE_C gain + unsigned int depth + unsigned char is_leaf + X_BINNED_DTYPE_C bin_threshold + + +def _predict_from_numeric_data(nodes, numeric_data, out): + _predict_from_numeric_data_parallel(nodes, numeric_data, out) + + +def _predict_from_binned_data(nodes, binned_data, out): + _predict_from_binned_data_parallel(nodes, binned_data, out) + + +cdef void _predict_from_numeric_data_parallel( + node_struct [:] nodes, + const X_DTYPE_C [:, :] numeric_data, + Y_DTYPE_C [:] out): + + cdef: + int i + + for i in prange(numeric_data.shape[0], schedule='static', nogil=True): + out[i] = _predict_one_from_numeric_data(nodes, numeric_data, i) + + +cdef inline Y_DTYPE_C _predict_one_from_numeric_data( + node_struct [:] nodes, + const X_DTYPE_C [:, :] numeric_data, + const int row) nogil: + # Need to pass the whole array and the row index, else prange won't work. + # See issue Cython #2798 + + cdef: + node_struct node = nodes[0] + + while True: + if node.is_leaf: + return node.value + if numeric_data[row, node.feature_idx] <= node.threshold: + node = nodes[node.left] + else: + node = nodes[node.right] + + +cdef void _predict_from_binned_data_parallel( + node_struct [:] nodes, + const X_BINNED_DTYPE_C [:, :] binned_data, + Y_DTYPE_C [:] out): + + cdef: + int i + + for i in prange(binned_data.shape[0], schedule='static', nogil=True): + out[i] = _predict_one_from_binned_data(nodes, binned_data, i) + + +cdef inline Y_DTYPE_C _predict_one_from_binned_data( + node_struct [:] nodes, + const X_BINNED_DTYPE_C [:, :] binned_data, + const int row) nogil: + # Need to pass the whole array and the row index, else prange won't work. + # See issue Cython #2798 + + cdef: + node_struct node = nodes[0] + + while True: + if node.is_leaf: + return node.value + if binned_data[row, node.feature_idx] <= node.bin_threshold: + node = nodes[node.left] + else: + node = nodes[node.right] diff --git a/sklearn/ensemble/_hist_gradient_boosting/binning.py b/sklearn/ensemble/_hist_gradient_boosting/binning.py new file mode 100644 index 0000000000000..075ed4f175ac3 --- /dev/null +++ b/sklearn/ensemble/_hist_gradient_boosting/binning.py @@ -0,0 +1,155 @@ +""" +This module contains the BinMapper class. + +BinMapper is used for mapping a real-valued dataset into integer-valued bins. +Bin thresholds are computed with the quantiles so that each bin contains +approximately the same number of samples. +""" +# Author: Nicolas Hug + +import numpy as np + +from ...utils import check_random_state, check_array +from ...base import BaseEstimator, TransformerMixin +from ...utils.validation import check_is_fitted +from ._binning import _map_to_bins +from .types import X_DTYPE, X_BINNED_DTYPE + + +def _find_binning_thresholds(data, max_bins, subsample, random_state): + """Extract feature-wise quantiles from numerical data. + + Parameters + ---------- + data : array-like, shape (n_samples, n_features) + The data to bin. + max_bins : int + The maximum number of bins to use. If for a given feature the number of + unique values is less than ``max_bins``, then those unique values + will be used to compute the bin thresholds, instead of the quantiles. + subsample : int or None + If ``n_samples > subsample``, then ``sub_samples`` samples will be + randomly choosen to compute the quantiles. If ``None``, the whole data + is used. + random_state: int or numpy.random.RandomState or None + Pseudo-random number generator to control the random sub-sampling. + See :term:`random_state`. + + Return + ------ + binning_thresholds: list of arrays + For each feature, stores the increasing numeric values that can + be used to separate the bins. Thus ``len(binning_thresholds) == + n_features``. + """ + if not (2 <= max_bins <= 256): + raise ValueError('max_bins={} should be no smaller than 2 ' + 'and no larger than 256.'.format(max_bins)) + rng = check_random_state(random_state) + if subsample is not None and data.shape[0] > subsample: + subset = rng.choice(np.arange(data.shape[0]), subsample, replace=False) + data = data.take(subset, axis=0) + + percentiles = np.linspace(0, 100, num=max_bins + 1) + percentiles = percentiles[1:-1] + binning_thresholds = [] + for f_idx in range(data.shape[1]): + col_data = np.ascontiguousarray(data[:, f_idx], dtype=X_DTYPE) + distinct_values = np.unique(col_data) + if len(distinct_values) <= max_bins: + midpoints = distinct_values[:-1] + distinct_values[1:] + midpoints *= .5 + else: + # We sort again the data in this case. We could compute + # approximate midpoint percentiles using the output of + # np.unique(col_data, return_counts) instead but this is more + # work and the performance benefit will be limited because we + # work on a fixed-size subsample of the full data. + midpoints = np.percentile(col_data, percentiles, + interpolation='midpoint').astype(X_DTYPE) + binning_thresholds.append(midpoints) + return binning_thresholds + + +class _BinMapper(BaseEstimator, TransformerMixin): + """Transformer that maps a dataset into integer-valued bins. + + The bins are created in a feature-wise fashion, using quantiles so that + each bins contains approximately the same number of samples. + + For large datasets, quantiles are computed on a subset of the data to + speed-up the binning, but the quantiles should remain stable. + + If the number of unique values for a given feature is less than + ``max_bins``, then the unique values of this feature are used instead of + the quantiles. + + Parameters + ---------- + max_bins : int, optional (default=256) + The maximum number of bins to use. If for a given feature the number of + unique values is less than ``max_bins``, then those unique values + will be used to compute the bin thresholds, instead of the quantiles. + subsample : int or None, optional (default=2e5) + If ``n_samples > subsample``, then ``sub_samples`` samples will be + randomly choosen to compute the quantiles. If ``None``, the whole data + is used. + random_state: int or numpy.random.RandomState or None, \ + optional (default=None) + Pseudo-random number generator to control the random sub-sampling. + See :term:`random_state`. + """ + def __init__(self, max_bins=256, subsample=int(2e5), random_state=None): + self.max_bins = max_bins + self.subsample = subsample + self.random_state = random_state + + def fit(self, X, y=None): + """Fit data X by computing the binning thresholds. + + Parameters + ---------- + X : array-like, shape (n_samples, n_features) + The data to bin. + y: None + Ignored. + + Returns + ------- + self : object + """ + X = check_array(X, dtype=[X_DTYPE]) + self.bin_thresholds_ = _find_binning_thresholds( + X, self.max_bins, subsample=self.subsample, + random_state=self.random_state) + + self.actual_n_bins_ = np.array( + [thresholds.shape[0] + 1 for thresholds in self.bin_thresholds_], + dtype=np.uint32) + + return self + + def transform(self, X): + """Bin data X. + + Parameters + ---------- + X : array-like, shape (n_samples, n_features) + The data to bin. + + Returns + ------- + X_binned : array-like, shape (n_samples, n_features) + The binned data. + """ + X = check_array(X, dtype=[X_DTYPE]) + check_is_fitted(self, ['bin_thresholds_', 'actual_n_bins_']) + if X.shape[1] != self.actual_n_bins_.shape[0]: + raise ValueError( + 'This estimator was fitted with {} features but {} got passed ' + 'to transform()'.format(self.actual_n_bins_.shape[0], + X.shape[1]) + ) + binned = np.zeros_like(X, dtype=X_BINNED_DTYPE, order='F') + _map_to_bins(X, self.bin_thresholds_, binned) + return binned diff --git a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py new file mode 100644 index 0000000000000..719756061f896 --- /dev/null +++ b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py @@ -0,0 +1,863 @@ +"""Fast Gradient Boosting decision trees for classification and regression.""" +# Author: Nicolas Hug + +from abc import ABC, abstractmethod + +import numpy as np +from timeit import default_timer as time +from sklearn.base import BaseEstimator, RegressorMixin, ClassifierMixin +from sklearn.utils import check_X_y, check_random_state, check_array +from sklearn.utils.validation import check_is_fitted +from sklearn.utils.multiclass import check_classification_targets +from sklearn.metrics import check_scoring +from sklearn.model_selection import train_test_split +from sklearn.preprocessing import LabelEncoder +from ._gradient_boosting import _update_raw_predictions +from .types import Y_DTYPE, X_DTYPE, X_BINNED_DTYPE + +from .binning import _BinMapper +from .grower import TreeGrower +from .loss import _LOSSES + + +class BaseHistGradientBoosting(BaseEstimator, ABC): + """Base class for histogram-based gradient boosting estimators.""" + + @abstractmethod + def __init__(self, loss, learning_rate, max_iter, max_leaf_nodes, + max_depth, min_samples_leaf, l2_regularization, max_bins, + scoring, validation_fraction, n_iter_no_change, tol, verbose, + random_state): + self.loss = loss + self.learning_rate = learning_rate + self.max_iter = max_iter + self.max_leaf_nodes = max_leaf_nodes + self.max_depth = max_depth + self.min_samples_leaf = min_samples_leaf + self.l2_regularization = l2_regularization + self.max_bins = max_bins + self.n_iter_no_change = n_iter_no_change + self.validation_fraction = validation_fraction + self.scoring = scoring + self.tol = tol + self.verbose = verbose + self.random_state = random_state + + def _validate_parameters(self): + """Validate parameters passed to __init__. + + The parameters that are directly passed to the grower are checked in + TreeGrower.""" + + if self.loss not in self._VALID_LOSSES: + raise ValueError( + "Loss {} is not supported for {}. Accepted losses: " + "{}.".format(self.loss, self.__class__.__name__, + ', '.join(self._VALID_LOSSES))) + + if self.learning_rate <= 0: + raise ValueError('learning_rate={} must ' + 'be strictly positive'.format(self.learning_rate)) + if self.max_iter < 1: + raise ValueError('max_iter={} must not be smaller ' + 'than 1.'.format(self.max_iter)) + if self.n_iter_no_change is not None and self.n_iter_no_change < 0: + raise ValueError('n_iter_no_change={} must be ' + 'positive.'.format(self.n_iter_no_change)) + if (self.validation_fraction is not None and + self.validation_fraction <= 0): + raise ValueError( + 'validation_fraction={} must be strictly ' + 'positive, or None.'.format(self.validation_fraction)) + if self.tol is not None and self.tol < 0: + raise ValueError('tol={} ' + 'must not be smaller than 0.'.format(self.tol)) + + def fit(self, X, y): + """Fit the gradient boosting model. + + Parameters + ---------- + X : array-like, shape=(n_samples, n_features) + The input samples. + + y : array-like, shape=(n_samples,) + Target values. + + Returns + ------- + self : object + """ + + fit_start_time = time() + acc_find_split_time = 0. # time spent finding the best splits + acc_apply_split_time = 0. # time spent splitting nodes + acc_compute_hist_time = 0. # time spent computing histograms + # time spent predicting X for gradient and hessians update + acc_prediction_time = 0. + X, y = check_X_y(X, y, dtype=[X_DTYPE]) + y = self._encode_y(y) + rng = check_random_state(self.random_state) + + self._validate_parameters() + self.n_features_ = X.shape[1] # used for validation in predict() + + # we need this stateful variable to tell raw_predict() that it was + # called from fit() (this current method), and that the data it has + # received is pre-binned. + # predicting is faster on pre-binned data, so we want early stopping + # predictions to be made on pre-binned data. Unfortunately the scorer_ + # can only call predict() or predict_proba(), not raw_predict(), and + # there's no way to tell the scorer that it needs to predict binned + # data. + self._in_fit = True + + # bin the data + if self.verbose: + print("Binning {:.3f} GB of data: ".format(X.nbytes / 1e9), end="", + flush=True) + tic = time() + self.bin_mapper_ = _BinMapper(max_bins=self.max_bins, random_state=rng) + X_binned = self.bin_mapper_.fit_transform(X) + toc = time() + if self.verbose: + duration = toc - tic + print("{:.3f} s".format(duration)) + + self.loss_ = self._get_loss() + + self.do_early_stopping_ = (self.n_iter_no_change is not None and + self.n_iter_no_change > 0) + + # create validation data if needed + self._use_validation_data = self.validation_fraction is not None + if self.do_early_stopping_ and self._use_validation_data: + # stratify for classification + stratify = y if hasattr(self.loss_, 'predict_proba') else None + + X_binned_train, X_binned_val, y_train, y_val = train_test_split( + X_binned, y, test_size=self.validation_fraction, + stratify=stratify, random_state=rng) + + # Predicting is faster of C-contiguous arrays, training is faster + # on Fortran arrays. + X_binned_val = np.ascontiguousarray(X_binned_val) + X_binned_train = np.asfortranarray(X_binned_train) + else: + X_binned_train, y_train = X_binned, y + X_binned_val, y_val = None, None + + if self.verbose: + print("Fitting gradient boosted rounds:") + + # initialize raw_predictions: those are the accumulated values + # predicted by the trees for the training data. raw_predictions has + # shape (n_trees_per_iteration, n_samples) where + # n_trees_per_iterations is n_classes in multiclass classification, + # else 1. + n_samples = X_binned_train.shape[0] + self._baseline_prediction = self.loss_.get_baseline_prediction( + y_train, self.n_trees_per_iteration_ + ) + raw_predictions = np.zeros( + shape=(self.n_trees_per_iteration_, n_samples), + dtype=self._baseline_prediction.dtype + ) + raw_predictions += self._baseline_prediction + + # initialize gradients and hessians (empty arrays). + # shape = (n_trees_per_iteration, n_samples). + gradients, hessians = self.loss_.init_gradients_and_hessians( + n_samples=n_samples, + prediction_dim=self.n_trees_per_iteration_ + ) + + # predictors is a matrix (list of lists) of TreePredictor objects + # with shape (n_iter_, n_trees_per_iteration) + self._predictors = predictors = [] + + # Initialize structures and attributes related to early stopping + self.scorer_ = None # set if scoring != loss + raw_predictions_val = None # set if scoring == loss and use val + self.train_score_ = [] + self.validation_score_ = [] + if self.do_early_stopping_: + # populate train_score and validation_score with the predictions + # of the initial model (before the first tree) + + if self.scoring == 'loss': + # we're going to compute scoring w.r.t the loss. As losses + # take raw predictions as input (unlike the scorers), we can + # optimize a bit and avoid repeating computing the predictions + # of the previous trees. We'll re-use raw_predictions (as it's + # needed for training anyway) for evaluating the training + # loss, and create raw_predictions_val for storing the + # raw predictions of the validation data. + + if self._use_validation_data: + raw_predictions_val = np.zeros( + shape=(self.n_trees_per_iteration_, + X_binned_val.shape[0]), + dtype=self._baseline_prediction.dtype + ) + + raw_predictions_val += self._baseline_prediction + + self._check_early_stopping_loss(raw_predictions, y_train, + raw_predictions_val, y_val) + else: + self.scorer_ = check_scoring(self, self.scoring) + # scorer_ is a callable with signature (est, X, y) and calls + # est.predict() or est.predict_proba() depending on its nature. + # Unfortunately, each call to scorer_() will compute + # the predictions of all the trees. So we use a subset of the + # training set to compute train scores. + subsample_size = 10000 # should we expose this parameter? + indices = np.arange(X_binned_train.shape[0]) + if X_binned_train.shape[0] > subsample_size: + # TODO: not critical but stratify using resample() + indices = rng.choice(indices, subsample_size, + replace=False) + X_binned_small_train = X_binned_train[indices] + y_small_train = y_train[indices] + # Predicting is faster on C-contiguous arrays. + X_binned_small_train = np.ascontiguousarray( + X_binned_small_train) + + self._check_early_stopping_scorer( + X_binned_small_train, y_small_train, + X_binned_val, y_val, + ) + + for iteration in range(self.max_iter): + + if self.verbose: + iteration_start_time = time() + print("[{}/{}] ".format(iteration + 1, self.max_iter), + end='', flush=True) + + # Update gradients and hessians, inplace + self.loss_.update_gradients_and_hessians(gradients, hessians, + y_train, raw_predictions) + + # Append a list since there may be more than 1 predictor per iter + predictors.append([]) + + # Build `n_trees_per_iteration` trees. + for k in range(self.n_trees_per_iteration_): + + grower = TreeGrower( + X_binned_train, gradients[k, :], hessians[k, :], + max_bins=self.max_bins, + actual_n_bins=self.bin_mapper_.actual_n_bins_, + max_leaf_nodes=self.max_leaf_nodes, + max_depth=self.max_depth, + min_samples_leaf=self.min_samples_leaf, + l2_regularization=self.l2_regularization, + shrinkage=self.learning_rate) + grower.grow() + + acc_apply_split_time += grower.total_apply_split_time + acc_find_split_time += grower.total_find_split_time + acc_compute_hist_time += grower.total_compute_hist_time + + predictor = grower.make_predictor( + bin_thresholds=self.bin_mapper_.bin_thresholds_ + ) + predictors[-1].append(predictor) + + # Update raw_predictions with the predictions of the newly + # created tree. + tic_pred = time() + _update_raw_predictions(raw_predictions[k, :], grower) + toc_pred = time() + acc_prediction_time += toc_pred - tic_pred + + should_early_stop = False + if self.do_early_stopping_: + if self.scoring == 'loss': + # Update raw_predictions_val with the newest tree(s) + if self._use_validation_data: + for k, pred in enumerate(self._predictors[-1]): + raw_predictions_val[k, :] += ( + pred.predict_binned(X_binned_val)) + + should_early_stop = self._check_early_stopping_loss( + raw_predictions, y_train, + raw_predictions_val, y_val + ) + + else: + should_early_stop = self._check_early_stopping_scorer( + X_binned_small_train, y_small_train, + X_binned_val, y_val, + ) + + if self.verbose: + self._print_iteration_stats(iteration_start_time) + + # maybe we could also early stop if all the trees are stumps? + if should_early_stop: + break + + if self.verbose: + duration = time() - fit_start_time + n_total_leaves = sum( + predictor.get_n_leaf_nodes() + for predictors_at_ith_iteration in self._predictors + for predictor in predictors_at_ith_iteration + ) + n_predictors = sum( + len(predictors_at_ith_iteration) + for predictors_at_ith_iteration in self._predictors) + print("Fit {} trees in {:.3f} s, ({} total leaves)".format( + n_predictors, duration, n_total_leaves)) + print("{:<32} {:.3f}s".format('Time spent computing histograms:', + acc_compute_hist_time)) + print("{:<32} {:.3f}s".format('Time spent finding best splits:', + acc_find_split_time)) + print("{:<32} {:.3f}s".format('Time spent applying splits:', + acc_apply_split_time)) + print("{:<32} {:.3f}s".format('Time spent predicting:', + acc_prediction_time)) + + self.train_score_ = np.asarray(self.train_score_) + self.validation_score_ = np.asarray(self.validation_score_) + del self._in_fit # hard delete so we're sure it can't be used anymore + return self + + def _check_early_stopping_scorer(self, X_binned_small_train, y_small_train, + X_binned_val, y_val): + """Check if fitting should be early-stopped based on scorer. + + Scores are computed on validation data or on training data. + """ + + self.train_score_.append( + self.scorer_(self, X_binned_small_train, y_small_train) + ) + + if self._use_validation_data: + self.validation_score_.append( + self.scorer_(self, X_binned_val, y_val) + ) + return self._should_stop(self.validation_score_) + else: + return self._should_stop(self.train_score_) + + def _check_early_stopping_loss(self, + raw_predictions, + y_train, + raw_predictions_val, + y_val): + """Check if fitting should be early-stopped based on loss. + + Scores are computed on validation data or on training data. + """ + + self.train_score_.append( + -self.loss_(y_train, raw_predictions) + ) + + if self._use_validation_data: + self.validation_score_.append( + -self.loss_(y_val, raw_predictions_val) + ) + return self._should_stop(self.validation_score_) + else: + return self._should_stop(self.train_score_) + + def _should_stop(self, scores): + """ + Return True (do early stopping) if the last n scores aren't better + than the (n-1)th-to-last score, up to some tolerance. + """ + reference_position = self.n_iter_no_change + 1 + if len(scores) < reference_position: + return False + + # A higher score is always better. Higher tol means that it will be + # harder for subsequent iteration to be considered an improvement upon + # the reference score, and therefore it is more likely to early stop + # because of the lack of significant improvement. + tol = 0 if self.tol is None else self.tol + reference_score = scores[-reference_position] + tol + recent_scores = scores[-reference_position + 1:] + recent_improvements = [score > reference_score + for score in recent_scores] + return not any(recent_improvements) + + def _print_iteration_stats(self, iteration_start_time): + """Print info about the current fitting iteration.""" + log_msg = '' + + predictors_of_ith_iteration = [ + predictors_list for predictors_list in self._predictors[-1] + if predictors_list + ] + n_trees = len(predictors_of_ith_iteration) + max_depth = max(predictor.get_max_depth() + for predictor in predictors_of_ith_iteration) + n_leaves = sum(predictor.get_n_leaf_nodes() + for predictor in predictors_of_ith_iteration) + + if n_trees == 1: + log_msg += ("{} tree, {} leaves, ".format(n_trees, n_leaves)) + else: + log_msg += ("{} trees, {} leaves ".format(n_trees, n_leaves)) + log_msg += ("({} on avg), ".format(int(n_leaves / n_trees))) + + log_msg += "max depth = {}, ".format(max_depth) + + if self.do_early_stopping_: + if self.scoring == 'loss': + factor = -1 # score_ arrays contain the negative loss + name = 'loss' + else: + factor = 1 + name = 'score' + log_msg += "train {}: {:.5f}, ".format(name, factor * + self.train_score_[-1]) + if self._use_validation_data: + log_msg += "val {}: {:.5f}, ".format( + name, factor * self.validation_score_[-1]) + + iteration_time = time() - iteration_start_time + log_msg += "in {:0.3f}s".format(iteration_time) + + print(log_msg) + + def _raw_predict(self, X): + """Return the sum of the leaves values over all predictors. + + Parameters + ---------- + X : array-like, shape=(n_samples, n_features) + The input samples. + + Returns + ------- + raw_predictions : array, shape (n_samples * n_trees_per_iteration,) + The raw predicted values. + """ + X = check_array(X, dtype=[X_DTYPE, X_BINNED_DTYPE]) + check_is_fitted(self, '_predictors') + if X.shape[1] != self.n_features_: + raise ValueError( + 'X has {} features but this estimator was trained with ' + '{} features.'.format(X.shape[1], self.n_features_) + ) + is_binned = getattr(self, '_in_fit', False) + n_samples = X.shape[0] + raw_predictions = np.zeros( + shape=(self.n_trees_per_iteration_, n_samples), + dtype=self._baseline_prediction.dtype + ) + raw_predictions += self._baseline_prediction + for predictors_of_ith_iteration in self._predictors: + for k, predictor in enumerate(predictors_of_ith_iteration): + predict = (predictor.predict_binned if is_binned + else predictor.predict) + raw_predictions[k, :] += predict(X) + + return raw_predictions + + @abstractmethod + def _get_loss(self): + pass + + @abstractmethod + def _encode_y(self, y=None): + pass + + @property + def n_iter_(self): + check_is_fitted(self, '_predictors') + return len(self._predictors) + + +class HistGradientBoostingRegressor(BaseHistGradientBoosting, RegressorMixin): + """Histogram-based Gradient Boosting Regression Tree. + + This estimator is much faster than + :class:`GradientBoostingRegressor` + for big datasets (n_samples >= 10 000). The input data ``X`` is pre-binned + into integer-valued bins, which considerably reduces the number of + splitting points to consider, and allows the algorithm to leverage + integer-based data structures. For small sample sizes, + :class:`GradientBoostingRegressor` + might be preferred since binning may lead to split points that are too + approximate in this setting. + + This implementation is inspired by + `LightGBM `_. + + .. note:: + + This estimator is still **experimental** for now: the predictions + and the API might change without any deprecation cycle. To use it, + you need to explicitly import ``enable_hist_gradient_boosting``:: + + >>> # explicitly require this experimental feature + >>> from sklearn.experimental import enable_hist_gradient_boosting # noqa + >>> # now you can import normally from ensemble + >>> from sklearn.ensemble import HistGradientBoostingClassifier + + + Parameters + ---------- + loss : {'least_squares'}, optional (default='least_squares') + The loss function to use in the boosting process. Note that the + "least squares" loss actually implements an "half least squares loss" + to simplify the computation of the gradient. + learning_rate : float, optional (default=0.1) + The learning rate, also known as *shrinkage*. This is used as a + multiplicative factor for the leaves values. Use ``1`` for no + shrinkage. + max_iter : int, optional (default=100) + The maximum number of iterations of the boosting process, i.e. the + maximum number of trees. + max_leaf_nodes : int or None, optional (default=31) + The maximum number of leaves for each tree. Must be strictly greater + than 1. If None, there is no maximum limit. + max_depth : int or None, optional (default=None) + The maximum depth of each tree. The depth of a tree is the number of + nodes to go from the root to the deepest leaf. Must be strictly greater + than 1. Depth isn't constrained by default. + min_samples_leaf : int, optional (default=20) + The minimum number of samples per leaf. For small datasets with less + than a few hundred samples, it is recommended to lower this value + since only very shallow trees would be built. + l2_regularization : float, optional (default=0) + The L2 regularization parameter. Use ``0`` for no regularization + (default). + max_bins : int, optional (default=256) + The maximum number of bins to use. Before training, each feature of + the input array ``X`` is binned into at most ``max_bins`` bins, which + allows for a much faster training stage. Features with a small + number of unique values may use less than ``max_bins`` bins. Must be no + larger than 256. + scoring : str or callable or None, optional (default=None) + Scoring parameter to use for early stopping. It can be a single + string (see :ref:`scoring_parameter`) or a callable (see + :ref:`scoring`). If None, the estimator's default scorer is used. If + ``scoring='loss'``, early stopping is checked w.r.t the loss value. + Only used if ``n_iter_no_change`` is not None. + validation_fraction : int or float or None, optional (default=0.1) + Proportion (or absolute size) of training data to set aside as + validation data for early stopping. If None, early stopping is done on + the training data. Only used if ``n_iter_no_change`` is not None. + n_iter_no_change : int or None, optional (default=None) + Used to determine when to "early stop". The fitting process is + stopped when none of the last ``n_iter_no_change`` scores are better + than the ``n_iter_no_change - 1``th-to-last one, up to some + tolerance. If None or 0, no early-stopping is done. + tol : float or None, optional (default=1e-7) + The absolute tolerance to use when comparing scores during early + stopping. The higher the tolerance, the more likely we are to early + stop: higher tolerance means that it will be harder for subsequent + iterations to be considered an improvement upon the reference score. + verbose: int, optional (default=0) + The verbosity level. If not zero, print some information about the + fitting process. + random_state : int, np.random.RandomStateInstance or None, \ + optional (default=None) + Pseudo-random number generator to control the subsampling in the + binning process, and the train/validation data split if early stopping + is enabled. See :term:`random_state`. + + Attributes + ---------- + n_iter_ : int + The number of iterations as selected by early stopping (if + n_iter_no_change is not None). Otherwise it corresponds to max_iter. + n_trees_per_iteration_ : int + The number of tree that are built at each iteration. For regressors, + this is always 1. + train_score_ : ndarray, shape (max_iter + 1,) + The scores at each iteration on the training data. The first entry + is the score of the ensemble before the first iteration. Scores are + computed according to the ``scoring`` parameter. If ``scoring`` is + not 'loss', scores are computed on a subset of at most 10 000 + samples. Empty if no early stopping. + validation_score_ : ndarray, shape (max_iter + 1,) + The scores at each iteration on the held-out validation data. The + first entry is the score of the ensemble before the first iteration. + Scores are computed according to the ``scoring`` parameter. Empty if + no early stopping or if ``validation_fraction`` is None. + + Examples + -------- + >>> # To use this experimental feature, we need to explicitly ask for it: + >>> from sklearn.experimental import enable_hist_gradient_boosting # noqa + >>> from sklearn.ensemble import HistGradientBoostingRegressor + >>> from sklearn.datasets import load_boston + >>> X, y = load_boston(return_X_y=True) + >>> est = HistGradientBoostingRegressor().fit(X, y) + >>> est.score(X, y) + 0.98... + """ + + _VALID_LOSSES = ('least_squares',) + + def __init__(self, loss='least_squares', learning_rate=0.1, + max_iter=100, max_leaf_nodes=31, max_depth=None, + min_samples_leaf=20, l2_regularization=0., max_bins=256, + scoring=None, validation_fraction=0.1, n_iter_no_change=None, + tol=1e-7, verbose=0, random_state=None): + super(HistGradientBoostingRegressor, self).__init__( + loss=loss, learning_rate=learning_rate, max_iter=max_iter, + max_leaf_nodes=max_leaf_nodes, max_depth=max_depth, + min_samples_leaf=min_samples_leaf, + l2_regularization=l2_regularization, max_bins=max_bins, + scoring=scoring, validation_fraction=validation_fraction, + n_iter_no_change=n_iter_no_change, tol=tol, verbose=verbose, + random_state=random_state) + + def predict(self, X): + """Predict values for X. + + Parameters + ---------- + X : array-like, shape (n_samples, n_features) + The input samples. + + Returns + ------- + y : ndarray, shape (n_samples,) + The predicted values. + """ + # Return raw predictions after converting shape + # (n_samples, 1) to (n_samples,) + return self._raw_predict(X).ravel() + + def _encode_y(self, y): + # Just convert y to the expected dtype + self.n_trees_per_iteration_ = 1 + y = y.astype(Y_DTYPE, copy=False) + return y + + def _get_loss(self): + return _LOSSES[self.loss]() + + +class HistGradientBoostingClassifier(BaseHistGradientBoosting, + ClassifierMixin): + """Histogram-based Gradient Boosting Classification Tree. + + This estimator is much faster than + :class:`GradientBoostingClassifier` + for big datasets (n_samples >= 10 000). The input data ``X`` is pre-binned + into integer-valued bins, which considerably reduces the number of + splitting points to consider, and allows the algorithm to leverage + integer-based data structures. For small sample sizes, + :class:`GradientBoostingClassifier` + might be preferred since binning may lead to split points that are too + approximate in this setting. + + This implementation is inspired by + `LightGBM `_. + + .. note:: + + This estimator is still **experimental** for now: the predictions + and the API might change without any deprecation cycle. To use it, + you need to explicitly import ``enable_hist_gradient_boosting``:: + + >>> # explicitly require this experimental feature + >>> from sklearn.experimental import enable_hist_gradient_boosting # noqa + >>> # now you can import normally from ensemble + >>> from sklearn.ensemble import HistGradientBoostingClassifier + + Parameters + ---------- + loss : {'auto', 'binary_crossentropy', 'categorical_crossentropy'}, \ + optional (default='auto') + The loss function to use in the boosting process. 'binary_crossentropy' + (also known as logistic loss) is used for binary classification and + generalizes to 'categorical_crossentropy' for multiclass + classification. 'auto' will automatically choose either loss depending + on the nature of the problem. + learning_rate : float, optional (default=1) + The learning rate, also known as *shrinkage*. This is used as a + multiplicative factor for the leaves values. Use ``1`` for no + shrinkage. + max_iter : int, optional (default=100) + The maximum number of iterations of the boosting process, i.e. the + maximum number of trees for binary classification. For multiclass + classification, `n_classes` trees per iteration are built. + max_leaf_nodes : int or None, optional (default=31) + The maximum number of leaves for each tree. Must be strictly greater + than 1. If None, there is no maximum limit. + max_depth : int or None, optional (default=None) + The maximum depth of each tree. The depth of a tree is the number of + nodes to go from the root to the deepest leaf. Must be strictly greater + than 1. Depth isn't constrained by default. + min_samples_leaf : int, optional (default=20) + The minimum number of samples per leaf. For small datasets with less + than a few hundred samples, it is recommended to lower this value + since only very shallow trees would be built. + l2_regularization : float, optional (default=0) + The L2 regularization parameter. Use 0 for no regularization. + max_bins : int, optional (default=256) + The maximum number of bins to use. Before training, each feature of + the input array ``X`` is binned into at most ``max_bins`` bins, which + allows for a much faster training stage. Features with a small + number of unique values may use less than ``max_bins`` bins. Must be no + larger than 256. + scoring : str or callable or None, optional (default=None) + Scoring parameter to use for early stopping. It can be a single + string (see :ref:`scoring_parameter`) or a callable (see + :ref:`scoring`). If None, the estimator's default scorer + is used. If ``scoring='loss'``, early stopping is checked + w.r.t the loss value. Only used if ``n_iter_no_change`` is not None. + validation_fraction : int or float or None, optional (default=0.1) + Proportion (or absolute size) of training data to set aside as + validation data for early stopping. If None, early stopping is done on + the training data. + n_iter_no_change : int or None, optional (default=None) + Used to determine when to "early stop". The fitting process is + stopped when none of the last ``n_iter_no_change`` scores are better + than the ``n_iter_no_change - 1``th-to-last one, up to some + tolerance. If None or 0, no early-stopping is done. + tol : float or None, optional (default=1e-7) + The absolute tolerance to use when comparing scores. The higher the + tolerance, the more likely we are to early stop: higher tolerance + means that it will be harder for subsequent iterations to be + considered an improvement upon the reference score. + verbose: int, optional (default=0) + The verbosity level. If not zero, print some information about the + fitting process. + random_state : int, np.random.RandomStateInstance or None, \ + optional (default=None) + Pseudo-random number generator to control the subsampling in the + binning process, and the train/validation data split if early stopping + is enabled. See :term:`random_state`. + + Attributes + ---------- + n_iter_ : int + The number of estimators as selected by early stopping (if + n_iter_no_change is not None). Otherwise it corresponds to max_iter. + n_trees_per_iteration_ : int + The number of tree that are built at each iteration. This is equal to 1 + for binary classification, and to ``n_classes`` for multiclass + classification. + train_score_ : ndarray, shape (max_iter + 1,) + The scores at each iteration on the training data. The first entry + is the score of the ensemble before the first iteration. Scores are + computed according to the ``scoring`` parameter. If ``scoring`` is + not 'loss', scores are computed on a subset of at most 10 000 + samples. Empty if no early stopping. + validation_score_ : ndarray, shape (max_iter + 1,) + The scores at each iteration on the held-out validation data. The + first entry is the score of the ensemble before the first iteration. + Scores are computed according to the ``scoring`` parameter. Empty if + no early stopping or if ``validation_fraction`` is None. + + Examples + -------- + >>> # To use this experimental feature, we need to explicitly ask for it: + >>> from sklearn.experimental import enable_hist_gradient_boosting # noqa + >>> from sklearn.ensemble import HistGradientBoostingRegressor + >>> from sklearn.datasets import load_iris + >>> X, y = load_iris(return_X_y=True) + >>> clf = HistGradientBoostingClassifier().fit(X, y) + >>> clf.score(X, y) + 1.0 + """ + + _VALID_LOSSES = ('binary_crossentropy', 'categorical_crossentropy', + 'auto') + + def __init__(self, loss='auto', learning_rate=0.1, max_iter=100, + max_leaf_nodes=31, max_depth=None, min_samples_leaf=20, + l2_regularization=0., max_bins=256, scoring=None, + validation_fraction=0.1, n_iter_no_change=None, tol=1e-7, + verbose=0, random_state=None): + super(HistGradientBoostingClassifier, self).__init__( + loss=loss, learning_rate=learning_rate, max_iter=max_iter, + max_leaf_nodes=max_leaf_nodes, max_depth=max_depth, + min_samples_leaf=min_samples_leaf, + l2_regularization=l2_regularization, max_bins=max_bins, + scoring=scoring, validation_fraction=validation_fraction, + n_iter_no_change=n_iter_no_change, tol=tol, verbose=verbose, + random_state=random_state) + + def predict(self, X): + """Predict classes for X. + + Parameters + ---------- + X : array-like, shape (n_samples, n_features) + The input samples. + + Returns + ------- + y : ndarray, shape (n_samples,) + The predicted classes. + """ + # TODO: This could be done in parallel + encoded_classes = np.argmax(self.predict_proba(X), axis=1) + return self.classes_[encoded_classes] + + def predict_proba(self, X): + """Predict class probabilities for X. + + Parameters + ---------- + X : array-like, shape (n_samples, n_features) + The input samples. + + Returns + ------- + p : ndarray, shape (n_samples, n_classes) + The class probabilities of the input samples. + """ + raw_predictions = self._raw_predict(X) + return self.loss_.predict_proba(raw_predictions) + + def decision_function(self, X): + """Compute the decision function of X. + + Parameters + ---------- + X : array-like, shape (n_samples, n_features) + The input samples. + + Returns + ------- + decision : ndarray, shape (n_samples,) or \ + (n_samples, n_trees_per_iteration) + The raw predicted values (i.e. the sum of the trees leaves) for + each sample. n_trees_per_iteration is equal to the number of + classes in multiclass classification. + """ + decision = self._raw_predict(X) + if decision.shape[0] == 1: + decision = decision.ravel() + return decision.T + + def _encode_y(self, y): + # encode classes into 0 ... n_classes - 1 and sets attributes classes_ + # and n_trees_per_iteration_ + check_classification_targets(y) + + label_encoder = LabelEncoder() + encoded_y = label_encoder.fit_transform(y) + self.classes_ = label_encoder.classes_ + n_classes = self.classes_.shape[0] + # only 1 tree for binary classification. For multiclass classification, + # we build 1 tree per class. + self.n_trees_per_iteration_ = 1 if n_classes <= 2 else n_classes + encoded_y = encoded_y.astype(Y_DTYPE, copy=False) + return encoded_y + + def _get_loss(self): + if self.loss == 'auto': + if self.n_trees_per_iteration_ == 1: + return _LOSSES['binary_crossentropy']() + else: + return _LOSSES['categorical_crossentropy']() + + return _LOSSES[self.loss]() diff --git a/sklearn/ensemble/_hist_gradient_boosting/grower.py b/sklearn/ensemble/_hist_gradient_boosting/grower.py new file mode 100644 index 0000000000000..ce7ac7116030a --- /dev/null +++ b/sklearn/ensemble/_hist_gradient_boosting/grower.py @@ -0,0 +1,465 @@ +""" +This module contains the TreeGrower class. + +TreeGrowee builds a regression tree fitting a Newton-Raphson step, based on +the gradients and hessians of the training data. +""" +# Author: Nicolas Hug + +from heapq import heappush, heappop +import numpy as np +from timeit import default_timer as time +import numbers + +from .splitting import Splitter +from .histogram import HistogramBuilder +from .predictor import TreePredictor, PREDICTOR_RECORD_DTYPE +from .utils import sum_parallel + + +class TreeNode: + """Tree Node class used in TreeGrower. + + This isn't used for prediction purposes, only for training (see + TreePredictor). + + Parameters + ---------- + depth : int + The depth of the node, i.e. its distance from the root. + sample_indices : ndarray of unsigned int, shape (n_samples_at_node,) + The indices of the samples at the node. + sum_gradients : float + The sum of the gradients of the samples at the node. + sum_hessians : float + The sum of the hessians of the samples at the node. + parent : TreeNode or None, optional (default=None) + The parent of the node. None for root. + + Attributes + ---------- + depth : int + The depth of the node, i.e. its distance from the root. + sample_indices : ndarray of unsigned int, shape (n_samples_at_node,) + The indices of the samples at the node. + sum_gradients : float + The sum of the gradients of the samples at the node. + sum_hessians : float + The sum of the hessians of the samples at the node. + parent : TreeNode or None + The parent of the node. None for root. + split_info : SplitInfo or None + The result of the split evaluation. + left_child : TreeNode or None + The left child of the node. None for leaves. + right_child : TreeNode or None + The right child of the node. None for leaves. + value : float or None + The value of the leaf, as computed in finalize_leaf(). None for + non-leaf nodes. + partition_start : int + start position of the node's sample_indices in splitter.partition. + partition_stop : int + stop position of the node's sample_indices in splitter.partition. + """ + + split_info = None + left_child = None + right_child = None + value = None + histograms = None + sibling = None + parent = None + + # start and stop indices of the node in the splitter.partition + # array. Concretely, + # self.sample_indices = view(self.splitter.partition[start:stop]) + # Please see the comments about splitter.partition and + # splitter.split_indices for more info about this design. + # These 2 attributes are only used in _update_raw_prediction, because we + # need to iterate over the leaves and I don't know how to efficiently + # store the sample_indices views because they're all of different sizes. + partition_start = 0 + partition_stop = 0 + + def __init__(self, depth, sample_indices, sum_gradients, + sum_hessians, parent=None): + self.depth = depth + self.sample_indices = sample_indices + self.n_samples = sample_indices.shape[0] + self.sum_gradients = sum_gradients + self.sum_hessians = sum_hessians + self.parent = parent + + def __lt__(self, other_node): + """Comparison for priority queue. + + Nodes with high gain are higher priority than nodes with low gain. + + heapq.heappush only need the '<' operator. + heapq.heappop take the smallest item first (smaller is higher + priority). + + Parameters + ----------- + other_node : TreeNode + The node to compare with. + """ + return self.split_info.gain > other_node.split_info.gain + + +class TreeGrower: + """Tree grower class used to build a tree. + + The tree is fitted to predict the values of a Newton-Raphson step. The + splits are considered in a best-first fashion, and the quality of a + split is defined in splitting._split_gain. + + Parameters + ---------- + X_binned : ndarray of int, shape (n_samples, n_features) + The binned input samples. Must be Fortran-aligned. + gradients : ndarray, shape (n_samples,) + The gradients of each training sample. Those are the gradients of the + loss w.r.t the predictions, evaluated at iteration ``i - 1``. + hessians : ndarray, shape (n_samples,) + The hessians of each training sample. Those are the hessians of the + loss w.r.t the predictions, evaluated at iteration ``i - 1``. + max_leaf_nodes : int or None, optional (default=None) + The maximum number of leaves for each tree. If None, there is no + maximum limit. + max_depth : int or None, optional (default=None) + The maximum depth of each tree. The depth of a tree is the number of + nodes to go from the root to the deepest leaf. + min_samples_leaf : int, optional (default=20) + The minimum number of samples per leaf. + min_gain_to_split : float, optional (default=0.) + The minimum gain needed to split a node. Splits with lower gain will + be ignored. + max_bins : int, optional (default=256) + The maximum number of bins. Used to define the shape of the + histograms. + actual_n_bins : ndarray of int or int, optional (default=None) + The actual number of bins needed for each feature, which is lower or + equal to ``max_bins``. If it's an int, all features are considered to + have the same number of bins. If None, all features are considered to + have ``max_bins`` bins. + l2_regularization : float, optional (default=0) + The L2 regularization parameter. + min_hessian_to_split : float, optional (default=1e-3) + The minimum sum of hessians needed in each node. Splits that result in + at least one child having a sum of hessians less than + ``min_hessian_to_split`` are discarded. + shrinkage : float, optional (default=1) + The shrinkage parameter to apply to the leaves values, also known as + learning rate. + """ + def __init__(self, X_binned, gradients, hessians, max_leaf_nodes=None, + max_depth=None, min_samples_leaf=20, min_gain_to_split=0., + max_bins=256, actual_n_bins=None, l2_regularization=0., + min_hessian_to_split=1e-3, shrinkage=1.): + + self._validate_parameters(X_binned, max_leaf_nodes, max_depth, + min_samples_leaf, min_gain_to_split, + l2_regularization, min_hessian_to_split) + + if actual_n_bins is None: + actual_n_bins = max_bins + + if isinstance(actual_n_bins, numbers.Integral): + actual_n_bins = np.array( + [actual_n_bins] * X_binned.shape[1], + dtype=np.uint32) + else: + actual_n_bins = np.asarray(actual_n_bins, dtype=np.uint32) + + hessians_are_constant = hessians.shape[0] == 1 + self.histogram_builder = HistogramBuilder( + X_binned, max_bins, gradients, hessians, hessians_are_constant) + self.splitter = Splitter( + X_binned, max_bins, actual_n_bins, l2_regularization, + min_hessian_to_split, min_samples_leaf, min_gain_to_split, + hessians_are_constant) + self.max_leaf_nodes = max_leaf_nodes + self.max_bins = max_bins + self.n_features = X_binned.shape[1] + self.max_depth = max_depth + self.min_samples_leaf = min_samples_leaf + self.X_binned = X_binned + self.min_gain_to_split = min_gain_to_split + self.shrinkage = shrinkage + self.splittable_nodes = [] + self.finalized_leaves = [] + self.total_find_split_time = 0. # time spent finding the best splits + self.total_compute_hist_time = 0. # time spent computing histograms + self.total_apply_split_time = 0. # time spent splitting nodes + self._intilialize_root(gradients, hessians, hessians_are_constant) + self.n_nodes = 1 + + def _validate_parameters(self, X_binned, max_leaf_nodes, max_depth, + min_samples_leaf, min_gain_to_split, + l2_regularization, min_hessian_to_split): + """Validate parameters passed to __init__. + + Also validate parameters passed to splitter. + """ + if X_binned.dtype != np.uint8: + raise NotImplementedError( + "X_binned must be of type uint8.") + if not X_binned.flags.f_contiguous: + raise ValueError( + "X_binned should be passed as Fortran contiguous " + "array for maximum efficiency.") + if max_leaf_nodes is not None and max_leaf_nodes <= 1: + raise ValueError('max_leaf_nodes={} should not be' + ' smaller than 2'.format(max_leaf_nodes)) + if max_depth is not None and max_depth <= 1: + raise ValueError('max_depth={} should not be' + ' smaller than 2'.format(max_depth)) + if min_samples_leaf < 1: + raise ValueError('min_samples_leaf={} should ' + 'not be smaller than 1'.format(min_samples_leaf)) + if min_gain_to_split < 0: + raise ValueError('min_gain_to_split={} ' + 'must be positive.'.format(min_gain_to_split)) + if l2_regularization < 0: + raise ValueError('l2_regularization={} must be ' + 'positive.'.format(l2_regularization)) + if min_hessian_to_split < 0: + raise ValueError('min_hessian_to_split={} ' + 'must be positive.'.format(min_hessian_to_split)) + + def grow(self): + """Grow the tree, from root to leaves.""" + while self.splittable_nodes: + self.split_next() + + def _intilialize_root(self, gradients, hessians, hessians_are_constant): + """Initialize root node and finalize it if needed.""" + n_samples = self.X_binned.shape[0] + depth = 0 + sum_gradients = sum_parallel(gradients) + if self.histogram_builder.hessians_are_constant: + sum_hessians = hessians[0] * n_samples + else: + sum_hessians = sum_parallel(hessians) + self.root = TreeNode( + depth=depth, + sample_indices=self.splitter.partition, + sum_gradients=sum_gradients, + sum_hessians=sum_hessians + ) + + self.root.partition_start = 0 + self.root.partition_stop = n_samples + + if self.root.n_samples < 2 * self.min_samples_leaf: + # Do not even bother computing any splitting statistics. + self._finalize_leaf(self.root) + return + if sum_hessians < self.splitter.min_hessian_to_split: + self._finalize_leaf(self.root) + return + + self.root.histograms = self.histogram_builder.compute_histograms_brute( + self.root.sample_indices) + self._compute_best_split_and_push(self.root) + + def _compute_best_split_and_push(self, node): + """Compute the best possible split (SplitInfo) of a given node. + + Also push it in the heap of splittable nodes if gain isn't zero. + The gain of a node is 0 if either all the leaves are pure + (best gain = 0), or if no split would satisfy the constraints, + (min_hessians_to_split, min_gain_to_split, min_samples_leaf) + """ + + node.split_info = self.splitter.find_node_split( + node.sample_indices, node.histograms, node.sum_gradients, + node.sum_hessians) + + if node.split_info.gain <= 0: # no valid split + self._finalize_leaf(node) + else: + heappush(self.splittable_nodes, node) + + def split_next(self): + """Split the node with highest potential gain. + + Returns + ------- + left : TreeNode + The resulting left child. + right : TreeNode + The resulting right child. + """ + # Consider the node with the highest loss reduction (a.k.a. gain) + node = heappop(self.splittable_nodes) + + tic = time() + (sample_indices_left, + sample_indices_right, + right_child_pos) = self.splitter.split_indices(node.split_info, + node.sample_indices) + self.total_apply_split_time += time() - tic + + depth = node.depth + 1 + n_leaf_nodes = len(self.finalized_leaves) + len(self.splittable_nodes) + n_leaf_nodes += 2 + + left_child_node = TreeNode(depth, + sample_indices_left, + node.split_info.sum_gradient_left, + node.split_info.sum_hessian_left, + parent=node) + right_child_node = TreeNode(depth, + sample_indices_right, + node.split_info.sum_gradient_right, + node.split_info.sum_hessian_right, + parent=node) + left_child_node.sibling = right_child_node + right_child_node.sibling = left_child_node + node.right_child = right_child_node + node.left_child = left_child_node + + # set start and stop indices + left_child_node.partition_start = node.partition_start + left_child_node.partition_stop = node.partition_start + right_child_pos + right_child_node.partition_start = left_child_node.partition_stop + right_child_node.partition_stop = node.partition_stop + + self.n_nodes += 2 + + if self.max_depth is not None and depth == self.max_depth: + self._finalize_leaf(left_child_node) + self._finalize_leaf(right_child_node) + return left_child_node, right_child_node + + if (self.max_leaf_nodes is not None + and n_leaf_nodes == self.max_leaf_nodes): + self._finalize_leaf(left_child_node) + self._finalize_leaf(right_child_node) + self._finalize_splittable_nodes() + return left_child_node, right_child_node + + if left_child_node.n_samples < self.min_samples_leaf * 2: + self._finalize_leaf(left_child_node) + if right_child_node.n_samples < self.min_samples_leaf * 2: + self._finalize_leaf(right_child_node) + + # Compute histograms of childs, and compute their best possible split + # (if needed) + should_split_left = left_child_node.value is None # node isn't a leaf + should_split_right = right_child_node.value is None + if should_split_left or should_split_right: + + # We will compute the histograms of both nodes even if one of them + # is a leaf, since computing the second histogram is very cheap + # (using histogram subtraction). + n_samples_left = left_child_node.sample_indices.shape[0] + n_samples_right = right_child_node.sample_indices.shape[0] + if n_samples_left < n_samples_right: + smallest_child = left_child_node + largest_child = right_child_node + else: + smallest_child = right_child_node + largest_child = left_child_node + + # We use the brute O(n_samples) method on the child that has the + # smallest number of samples, and the subtraction trick O(n_bins) + # on the other one. + tic = time() + smallest_child.histograms = \ + self.histogram_builder.compute_histograms_brute( + smallest_child.sample_indices) + largest_child.histograms = \ + self.histogram_builder.compute_histograms_subtraction( + node.histograms, smallest_child.histograms) + self.total_compute_hist_time += time() - tic + + tic = time() + if should_split_left: + self._compute_best_split_and_push(left_child_node) + if should_split_right: + self._compute_best_split_and_push(right_child_node) + self.total_find_split_time += time() - tic + + return left_child_node, right_child_node + + def _finalize_leaf(self, node): + """Compute the prediction value that minimizes the objective function. + + This sets the node.value attribute (node is a leaf iff node.value is + not None). + + See Equation 5 of: + XGBoost: A Scalable Tree Boosting System, T. Chen, C. Guestrin, 2016 + https://arxiv.org/abs/1603.02754 + """ + node.value = -self.shrinkage * node.sum_gradients / ( + node.sum_hessians + self.splitter.l2_regularization) + self.finalized_leaves.append(node) + + def _finalize_splittable_nodes(self): + """Transform all splittable nodes into leaves. + + Used when some constraint is met e.g. maximum number of leaves or + maximum depth.""" + while len(self.splittable_nodes) > 0: + node = self.splittable_nodes.pop() + self._finalize_leaf(node) + + def make_predictor(self, bin_thresholds=None): + """Make a TreePredictor object out of the current tree. + + Parameters + ---------- + bin_thresholds : array-like of floats, optional (default=None) + The actual thresholds values of each bin. + + Returns + ------- + A TreePredictor object. + """ + predictor_nodes = np.zeros(self.n_nodes, dtype=PREDICTOR_RECORD_DTYPE) + _fill_predictor_node_array(predictor_nodes, self.root, + bin_thresholds=bin_thresholds) + return TreePredictor(predictor_nodes) + + +def _fill_predictor_node_array(predictor_nodes, grower_node, + bin_thresholds, next_free_idx=0): + """Helper used in make_predictor to set the TreePredictor fields.""" + node = predictor_nodes[next_free_idx] + node['count'] = grower_node.n_samples + node['depth'] = grower_node.depth + if grower_node.split_info is not None: + node['gain'] = grower_node.split_info.gain + else: + node['gain'] = -1 + + if grower_node.value is not None: + # Leaf node + node['is_leaf'] = True + node['value'] = grower_node.value + return next_free_idx + 1 + else: + # Decision node + split_info = grower_node.split_info + feature_idx, bin_idx = split_info.feature_idx, split_info.bin_idx + node['feature_idx'] = feature_idx + node['bin_threshold'] = bin_idx + if bin_thresholds is not None: + threshold = bin_thresholds[feature_idx][bin_idx] + node['threshold'] = threshold + next_free_idx += 1 + + node['left'] = next_free_idx + next_free_idx = _fill_predictor_node_array( + predictor_nodes, grower_node.left_child, + bin_thresholds=bin_thresholds, next_free_idx=next_free_idx) + + node['right'] = next_free_idx + return _fill_predictor_node_array( + predictor_nodes, grower_node.right_child, + bin_thresholds=bin_thresholds, next_free_idx=next_free_idx) diff --git a/sklearn/ensemble/_hist_gradient_boosting/histogram.pyx b/sklearn/ensemble/_hist_gradient_boosting/histogram.pyx new file mode 100644 index 0000000000000..cf7d0fd7a7607 --- /dev/null +++ b/sklearn/ensemble/_hist_gradient_boosting/histogram.pyx @@ -0,0 +1,476 @@ +# cython: cdivision=True +# cython: boundscheck=False +# cython: wraparound=False +# cython: language_level=3 +"""This module contains routines for building histograms.""" + +# Author: Nicolas Hug + +cimport cython +from cython.parallel import prange + +import numpy as np +cimport numpy as np + +from .types import HISTOGRAM_DTYPE +from .types cimport hist_struct +from .types cimport X_BINNED_DTYPE_C +from .types cimport G_H_DTYPE_C +from .types cimport hist_struct + +# Notes: +# - IN views are read-only, OUT views are write-only +# - In a lot of functions here, we pass feature_idx and the whole 2d +# histograms arrays instead of just histograms[feature_idx]. This is because +# Cython generated C code will have strange Python interactions (likely +# related to the GIL release and the custom histogram dtype) when using 1d +# histogram arrays that come from 2d arrays. +# - The for loops are un-wrapped, for example: +# +# for i in range(n): +# array[i] = i +# +# will become +# +# for i in range(n // 4): +# array[i] = i +# array[i + 1] = i + 1 +# array[i + 2] = i + 2 +# array[i + 3] = i + 3 +# +# This is to hint gcc that it can auto-vectorize these 4 operations and +# perform them all at once. + + +@cython.final +cdef class HistogramBuilder: + """A Histogram builder... used to build histograms. + + A histogram is an array with n_bins entries of type HISTOGRAM_DTYPE. Each + feature has its own histogram. A histogram contains the sum of gradients + and hessians of all the samples belonging to each bin. + + There are different ways to build a histogram: + - by subtraction: hist(child) = hist(parent) - hist(sibling) + - from scratch. In this case we have rountines that update the hessians + or not (not useful when hessians are constant for some losses e.g. + least squares). Also, there's a special case for the root which + contains all the samples, leading to some possible optimizations. + Overall all the implementations look the same, and are optimized for + cache hit. + + Parameters + ---------- + X_binned : ndarray of int, shape (n_samples, n_features) + The binned input samples. Must be Fortran-aligned. + max_bins : int + The maximum number of bins. Used to define the shape of the + histograms. + gradients : ndarray, shape (n_samples,) + The gradients of each training sample. Those are the gradients of the + loss w.r.t the predictions, evaluated at iteration i - 1. + hessians : ndarray, shape (n_samples,) + The hessians of each training sample. Those are the hessians of the + loss w.r.t the predictions, evaluated at iteration i - 1. + hessians_are_constant : bool + Whether hessians are constant. + """ + cdef public: + const X_BINNED_DTYPE_C [::1, :] X_binned + unsigned int n_features + unsigned int max_bins + G_H_DTYPE_C [::1] gradients + G_H_DTYPE_C [::1] hessians + G_H_DTYPE_C [::1] ordered_gradients + G_H_DTYPE_C [::1] ordered_hessians + unsigned char hessians_are_constant + + def __init__(self, const X_BINNED_DTYPE_C [::1, :] X_binned, + unsigned int max_bins, G_H_DTYPE_C [::1] gradients, + G_H_DTYPE_C [::1] hessians, + unsigned char hessians_are_constant): + + self.X_binned = X_binned + self.n_features = X_binned.shape[1] + # Note: all histograms will have bins, but some of the + # last bins may be unused if actual_n_bins[f] < max_bins + self.max_bins = max_bins + self.gradients = gradients + self.hessians = hessians + # for root node, gradients and hessians are already ordered + self.ordered_gradients = gradients.copy() + self.ordered_hessians = hessians.copy() + self.hessians_are_constant = hessians_are_constant + + def compute_histograms_brute( + HistogramBuilder self, + const unsigned int [::1] sample_indices): # IN + """Compute the histograms of the node by scanning through all the data. + + For a given feature, the complexity is O(n_samples) + + Parameters + ---------- + sample_indices : array of int, shape (n_samples_at_node,) + The indices of the samples at the node to split. + + Returns + ------- + histograms : ndarray of HISTOGRAM_DTYPE, shape (n_features, max_bins) + The computed histograms of the current node. + """ + cdef: + int n_samples + int feature_idx + int i + # need local views to avoid python interactions + unsigned char hessians_are_constant = \ + self.hessians_are_constant + int n_features = self.n_features + G_H_DTYPE_C [::1] ordered_gradients = self.ordered_gradients + G_H_DTYPE_C [::1] gradients = self.gradients + G_H_DTYPE_C [::1] ordered_hessians = self.ordered_hessians + G_H_DTYPE_C [::1] hessians = self.hessians + hist_struct [:, ::1] histograms = np.zeros( + shape=(self.n_features, self.max_bins), + dtype=HISTOGRAM_DTYPE + ) + + with nogil: + n_samples = sample_indices.shape[0] + + # Populate ordered_gradients and ordered_hessians. (Already done + # for root) Ordering the gradients and hessians helps to improve + # cache hit. + if sample_indices.shape[0] != gradients.shape[0]: + if hessians_are_constant: + for i in prange(n_samples, schedule='static'): + ordered_gradients[i] = gradients[sample_indices[i]] + else: + for i in prange(n_samples, schedule='static'): + ordered_gradients[i] = gradients[sample_indices[i]] + ordered_hessians[i] = hessians[sample_indices[i]] + + for feature_idx in prange(n_features, schedule='static'): + # Compute histogram of each feature + self._compute_histogram_brute_single_feature( + feature_idx, sample_indices, histograms) + + return histograms + + cdef void _compute_histogram_brute_single_feature( + HistogramBuilder self, + const int feature_idx, + const unsigned int [::1] sample_indices, # IN + hist_struct [:, ::1] histograms) nogil: # OUT + """Compute the histogram for a given feature.""" + + cdef: + unsigned int n_samples = sample_indices.shape[0] + const X_BINNED_DTYPE_C [::1] X_binned = \ + self.X_binned[:, feature_idx] + unsigned int root_node = X_binned.shape[0] == n_samples + G_H_DTYPE_C [::1] ordered_gradients = \ + self.ordered_gradients[:n_samples] + G_H_DTYPE_C [::1] ordered_hessians = \ + self.ordered_hessians[:n_samples] + unsigned char hessians_are_constant = \ + self.hessians_are_constant + + if root_node: + if hessians_are_constant: + _build_histogram_root_no_hessian(feature_idx, X_binned, + ordered_gradients, + histograms) + else: + _build_histogram_root(feature_idx, X_binned, + ordered_gradients, ordered_hessians, + histograms) + else: + if hessians_are_constant: + _build_histogram_no_hessian(feature_idx, + sample_indices, X_binned, + ordered_gradients, histograms) + else: + _build_histogram(feature_idx, sample_indices, + X_binned, ordered_gradients, + ordered_hessians, histograms) + + def compute_histograms_subtraction( + HistogramBuilder self, + hist_struct [:, ::1] parent_histograms, # IN + hist_struct [:, ::1] sibling_histograms): # IN + """Compute the histograms of the node using the subtraction trick. + + hist(parent) = hist(left_child) + hist(right_child) + + For a given feature, the complexity is O(n_bins). This is much more + efficient than compute_histograms_brute, but it's only possible for one + of the siblings. + + Parameters + ---------- + parent_histograms : ndarray of HISTOGRAM_DTYPE, \ + shape (n_features, max_bins) + The histograms of the parent. + sibling_histograms : ndarray of HISTOGRAM_DTYPE, \ + shape (n_features, max_bins) + The histograms of the sibling. + + Returns + ------- + histograms : ndarray of HISTOGRAM_DTYPE, shape(n_features, max_bins) + The computed histograms of the current node. + """ + + cdef: + int feature_idx + int n_features = self.n_features + hist_struct [:, ::1] histograms = np.zeros( + shape=(self.n_features, self.max_bins), + dtype=HISTOGRAM_DTYPE + ) + + for feature_idx in prange(n_features, schedule='static', nogil=True): + # Compute histogram of each feature + _subtract_histograms(feature_idx, + self.max_bins, + parent_histograms, + sibling_histograms, + histograms) + return histograms + + +cpdef void _build_histogram_naive( + const int feature_idx, + unsigned int [:] sample_indices, # IN + X_BINNED_DTYPE_C [:] binned_feature, # IN + G_H_DTYPE_C [:] ordered_gradients, # IN + G_H_DTYPE_C [:] ordered_hessians, # IN + hist_struct [:, :] out) nogil: # OUT + """Build histogram in a naive way, without optimizing for cache hit. + + Used in tests to compare with the optimized version.""" + cdef: + unsigned int i + unsigned int n_samples = sample_indices.shape[0] + unsigned int sample_idx + unsigned int bin_idx + + for i in range(n_samples): + sample_idx = sample_indices[i] + bin_idx = binned_feature[sample_idx] + out[feature_idx, bin_idx].sum_gradients += ordered_gradients[i] + out[feature_idx, bin_idx].sum_hessians += ordered_hessians[i] + out[feature_idx, bin_idx].count += 1 + + +cpdef void _subtract_histograms( + const int feature_idx, + unsigned int n_bins, + hist_struct [:, ::1] hist_a, # IN + hist_struct [:, ::1] hist_b, # IN + hist_struct [:, ::1] out) nogil: # OUT + """compute (hist_a - hist_b) in out""" + cdef: + unsigned int i = 0 + for i in range(n_bins): + out[feature_idx, i].sum_gradients = ( + hist_a[feature_idx, i].sum_gradients - + hist_b[feature_idx, i].sum_gradients + ) + out[feature_idx, i].sum_hessians = ( + hist_a[feature_idx, i].sum_hessians - + hist_b[feature_idx, i].sum_hessians + ) + out[feature_idx, i].count = ( + hist_a[feature_idx, i].count - + hist_b[feature_idx, i].count + ) + + +cpdef void _build_histogram( + const int feature_idx, + const unsigned int [::1] sample_indices, # IN + const X_BINNED_DTYPE_C [::1] binned_feature, # IN + const G_H_DTYPE_C [::1] ordered_gradients, # IN + const G_H_DTYPE_C [::1] ordered_hessians, # IN + hist_struct [:, ::1] out) nogil: # OUT + """Return histogram for a given feature.""" + cdef: + unsigned int i = 0 + unsigned int n_node_samples = sample_indices.shape[0] + unsigned int unrolled_upper = (n_node_samples // 4) * 4 + + unsigned int bin_0 + unsigned int bin_1 + unsigned int bin_2 + unsigned int bin_3 + unsigned int bin_idx + + for i in range(0, unrolled_upper, 4): + bin_0 = binned_feature[sample_indices[i]] + bin_1 = binned_feature[sample_indices[i + 1]] + bin_2 = binned_feature[sample_indices[i + 2]] + bin_3 = binned_feature[sample_indices[i + 3]] + + out[feature_idx, bin_0].sum_gradients += ordered_gradients[i] + out[feature_idx, bin_1].sum_gradients += ordered_gradients[i + 1] + out[feature_idx, bin_2].sum_gradients += ordered_gradients[i + 2] + out[feature_idx, bin_3].sum_gradients += ordered_gradients[i + 3] + + out[feature_idx, bin_0].sum_hessians += ordered_hessians[i] + out[feature_idx, bin_1].sum_hessians += ordered_hessians[i + 1] + out[feature_idx, bin_2].sum_hessians += ordered_hessians[i + 2] + out[feature_idx, bin_3].sum_hessians += ordered_hessians[i + 3] + + out[feature_idx, bin_0].count += 1 + out[feature_idx, bin_1].count += 1 + out[feature_idx, bin_2].count += 1 + out[feature_idx, bin_3].count += 1 + + for i in range(unrolled_upper, n_node_samples): + bin_idx = binned_feature[sample_indices[i]] + out[feature_idx, bin_idx].sum_gradients += ordered_gradients[i] + out[feature_idx, bin_idx].sum_hessians += ordered_hessians[i] + out[feature_idx, bin_idx].count += 1 + + +cpdef void _build_histogram_no_hessian( + const int feature_idx, + const unsigned int [::1] sample_indices, # IN + const X_BINNED_DTYPE_C [::1] binned_feature, # IN + const G_H_DTYPE_C [::1] ordered_gradients, # IN + hist_struct [:, ::1] out) nogil: # OUT + """Return histogram for a given feature, not updating hessians. + + Used when the hessians of the loss are constant (typically LS loss). + """ + + cdef: + unsigned int i = 0 + unsigned int n_node_samples = sample_indices.shape[0] + unsigned int unrolled_upper = (n_node_samples // 4) * 4 + + unsigned int bin_0 + unsigned int bin_1 + unsigned int bin_2 + unsigned int bin_3 + unsigned int bin_idx + + for i in range(0, unrolled_upper, 4): + bin_0 = binned_feature[sample_indices[i]] + bin_1 = binned_feature[sample_indices[i + 1]] + bin_2 = binned_feature[sample_indices[i + 2]] + bin_3 = binned_feature[sample_indices[i + 3]] + + out[feature_idx, bin_0].sum_gradients += ordered_gradients[i] + out[feature_idx, bin_1].sum_gradients += ordered_gradients[i + 1] + out[feature_idx, bin_2].sum_gradients += ordered_gradients[i + 2] + out[feature_idx, bin_3].sum_gradients += ordered_gradients[i + 3] + + out[feature_idx, bin_0].count += 1 + out[feature_idx, bin_1].count += 1 + out[feature_idx, bin_2].count += 1 + out[feature_idx, bin_3].count += 1 + + for i in range(unrolled_upper, n_node_samples): + bin_idx = binned_feature[sample_indices[i]] + out[feature_idx, bin_idx].sum_gradients += ordered_gradients[i] + out[feature_idx, bin_idx].count += 1 + + +cpdef void _build_histogram_root( + const int feature_idx, + const X_BINNED_DTYPE_C [::1] binned_feature, # IN + const G_H_DTYPE_C [::1] all_gradients, # IN + const G_H_DTYPE_C [::1] all_hessians, # IN + hist_struct [:, ::1] out) nogil: # OUT + """Compute histogram of the root node. + + Unlike other nodes, the root node has to find the split among *all* the + samples from the training set. binned_feature and all_gradients / + all_hessians already have a consistent ordering. + """ + + cdef: + unsigned int i = 0 + unsigned int n_samples = binned_feature.shape[0] + unsigned int unrolled_upper = (n_samples // 4) * 4 + + unsigned int bin_0 + unsigned int bin_1 + unsigned int bin_2 + unsigned int bin_3 + unsigned int bin_idx + + for i in range(0, unrolled_upper, 4): + + bin_0 = binned_feature[i] + bin_1 = binned_feature[i + 1] + bin_2 = binned_feature[i + 2] + bin_3 = binned_feature[i + 3] + + out[feature_idx, bin_0].sum_gradients += all_gradients[i] + out[feature_idx, bin_1].sum_gradients += all_gradients[i + 1] + out[feature_idx, bin_2].sum_gradients += all_gradients[i + 2] + out[feature_idx, bin_3].sum_gradients += all_gradients[i + 3] + + out[feature_idx, bin_0].sum_hessians += all_hessians[i] + out[feature_idx, bin_1].sum_hessians += all_hessians[i + 1] + out[feature_idx, bin_2].sum_hessians += all_hessians[i + 2] + out[feature_idx, bin_3].sum_hessians += all_hessians[i + 3] + + out[feature_idx, bin_0].count += 1 + out[feature_idx, bin_1].count += 1 + out[feature_idx, bin_2].count += 1 + out[feature_idx, bin_3].count += 1 + + for i in range(unrolled_upper, n_samples): + bin_idx = binned_feature[i] + out[feature_idx, bin_idx].sum_gradients += all_gradients[i] + out[feature_idx, bin_idx].sum_hessians += all_hessians[i] + out[feature_idx, bin_idx].count += 1 + + +cpdef void _build_histogram_root_no_hessian( + const int feature_idx, + const X_BINNED_DTYPE_C [::1] binned_feature, # IN + const G_H_DTYPE_C [::1] all_gradients, # IN + hist_struct [:, ::1] out) nogil: # OUT + """Compute histogram of the root node, not updating hessians. + + Used when the hessians of the loss are constant (typically LS loss). + """ + + cdef: + unsigned int i = 0 + unsigned int n_samples = binned_feature.shape[0] + unsigned int unrolled_upper = (n_samples // 4) * 4 + + unsigned int bin_0 + unsigned int bin_1 + unsigned int bin_2 + unsigned int bin_3 + unsigned int bin_idx + + for i in range(0, unrolled_upper, 4): + bin_0 = binned_feature[i] + bin_1 = binned_feature[i + 1] + bin_2 = binned_feature[i + 2] + bin_3 = binned_feature[i + 3] + + out[feature_idx, bin_0].sum_gradients += all_gradients[i] + out[feature_idx, bin_1].sum_gradients += all_gradients[i + 1] + out[feature_idx, bin_2].sum_gradients += all_gradients[i + 2] + out[feature_idx, bin_3].sum_gradients += all_gradients[i + 3] + + out[feature_idx, bin_0].count += 1 + out[feature_idx, bin_1].count += 1 + out[feature_idx, bin_2].count += 1 + out[feature_idx, bin_3].count += 1 + + for i in range(unrolled_upper, n_samples): + bin_idx = binned_feature[i] + out[feature_idx, bin_idx].sum_gradients += all_gradients[i] + out[feature_idx, bin_idx].count += 1 diff --git a/sklearn/ensemble/_hist_gradient_boosting/loss.py b/sklearn/ensemble/_hist_gradient_boosting/loss.py new file mode 100644 index 0000000000000..5d7c68ea0b38f --- /dev/null +++ b/sklearn/ensemble/_hist_gradient_boosting/loss.py @@ -0,0 +1,247 @@ +""" +This module contains the loss classes. + +Specific losses are used for regression, binary classification or multiclass +classification. +""" +# Author: Nicolas Hug + +from abc import ABC, abstractmethod + +import numpy as np +from scipy.special import expit +try: # logsumexp was moved from mist to special in 0.19 + from scipy.special import logsumexp +except ImportError: + from scipy.misc import logsumexp + +from .types import Y_DTYPE +from .types import G_H_DTYPE +from ._loss import _update_gradients_least_squares +from ._loss import _update_gradients_hessians_binary_crossentropy +from ._loss import _update_gradients_hessians_categorical_crossentropy + + +class BaseLoss(ABC): + """Base class for a loss.""" + + def init_gradients_and_hessians(self, n_samples, prediction_dim): + """Return initial gradients and hessians. + + Unless hessians are constant, arrays are initialized with undefined + values. + + Parameters + ---------- + n_samples : int + The number of samples passed to `fit()`. + prediction_dim : int + The dimension of a raw prediction, i.e. the number of trees + built at each iteration. Equals 1 for regression and binary + classification, or K where K is the number of classes for + multiclass classification. + + Returns + ------- + gradients : ndarray, shape (prediction_dim, n_samples) + The initial gradients. The array is not initialized. + hessians : ndarray, shape (prediction_dim, n_samples) + If hessians are constant (e.g. for `LeastSquares` loss, the + array is initialized to ``1``. Otherwise, the array is allocated + without being initialized. + """ + shape = (prediction_dim, n_samples) + gradients = np.empty(shape=shape, dtype=G_H_DTYPE) + if self.hessians_are_constant: + # if the hessians are constant, we consider they are equal to 1. + # this is correct as long as we adjust the gradients. See e.g. LS + # loss + hessians = np.ones(shape=(1, 1), dtype=G_H_DTYPE) + else: + hessians = np.empty(shape=shape, dtype=G_H_DTYPE) + + return gradients, hessians + + @abstractmethod + def get_baseline_prediction(self, y_train, prediction_dim): + """Return initial predictions (before the first iteration). + + Parameters + ---------- + y_train : ndarray, shape (n_samples,) + The target training values. + prediction_dim : int + The dimension of one prediction: 1 for binary classification and + regression, n_classes for multiclass classification. + + Returns + ------- + baseline_prediction : float or ndarray, shape (1, prediction_dim) + The baseline prediction. + """ + + @abstractmethod + def update_gradients_and_hessians(self, gradients, hessians, y_true, + raw_predictions): + """Update gradients and hessians arrays, inplace. + + The gradients (resp. hessians) are the first (resp. second) order + derivatives of the loss for each sample with respect to the + predictions of model, evaluated at iteration ``i - 1``. + + Parameters + ---------- + gradients : ndarray, shape (prediction_dim, n_samples) + The gradients (treated as OUT array). + hessians : ndarray, shape (prediction_dim, n_samples) or \ + (1,) + The hessians (treated as OUT array). + y_true : ndarray, shape (n_samples,) + The true target values or each training sample. + raw_predictions : ndarray, shape (prediction_dim, n_samples) + The raw_predictions (i.e. values from the trees) of the tree + ensemble at iteration ``i - 1``. + """ + + +class LeastSquares(BaseLoss): + """Least squares loss, for regression. + + For a given sample x_i, least squares loss is defined as:: + + loss(x_i) = 0.5 * (y_true_i - raw_pred_i)**2 + + This actually computes the half least squares loss to optimize simplify + the computation of the gradients and get a unit hessian (and be consistent + with what is done in LightGBM). + """ + + hessians_are_constant = True + + def __call__(self, y_true, raw_predictions, average=True): + # shape (1, n_samples) --> (n_samples,). reshape(-1) is more likely to + # return a view. + raw_predictions = raw_predictions.reshape(-1) + loss = 0.5 * np.power(y_true - raw_predictions, 2) + return loss.mean() if average else loss + + def get_baseline_prediction(self, y_train, prediction_dim): + return np.mean(y_train) + + @staticmethod + def inverse_link_function(raw_predictions): + return raw_predictions + + def update_gradients_and_hessians(self, gradients, hessians, y_true, + raw_predictions): + # shape (1, n_samples) --> (n_samples,). reshape(-1) is more likely to + # return a view. + raw_predictions = raw_predictions.reshape(-1) + gradients = gradients.reshape(-1) + _update_gradients_least_squares(gradients, y_true, raw_predictions) + + +class BinaryCrossEntropy(BaseLoss): + """Binary cross-entropy loss, for binary classification. + + For a given sample x_i, the binary cross-entropy loss is defined as the + negative log-likelihood of the model which can be expressed as:: + + loss(x_i) = log(1 + exp(raw_pred_i)) - y_true_i * raw_pred_i + + See The Elements of Statistical Learning, by Hastie, Tibshirani, Friedman, + section 4.4.1 (about logistic regression). + """ + + hessians_are_constant = False + inverse_link_function = staticmethod(expit) + + def __call__(self, y_true, raw_predictions, average=True): + # shape (1, n_samples) --> (n_samples,). reshape(-1) is more likely to + # return a view. + raw_predictions = raw_predictions.reshape(-1) + # logaddexp(0, x) = log(1 + exp(x)) + loss = np.logaddexp(0, raw_predictions) - y_true * raw_predictions + return loss.mean() if average else loss + + def get_baseline_prediction(self, y_train, prediction_dim): + if prediction_dim > 2: + raise ValueError( + "loss='binary_crossentropy' is not defined for multiclass" + " classification with n_classes=%d, use" + " loss='categorical_crossentropy' instead" % prediction_dim) + proba_positive_class = np.mean(y_train) + eps = np.finfo(y_train.dtype).eps + proba_positive_class = np.clip(proba_positive_class, eps, 1 - eps) + # log(x / 1 - x) is the anti function of sigmoid, or the link function + # of the Binomial model. + return np.log(proba_positive_class / (1 - proba_positive_class)) + + def update_gradients_and_hessians(self, gradients, hessians, y_true, + raw_predictions): + # shape (1, n_samples) --> (n_samples,). reshape(-1) is more likely to + # return a view. + raw_predictions = raw_predictions.reshape(-1) + gradients = gradients.reshape(-1) + hessians = hessians.reshape(-1) + _update_gradients_hessians_binary_crossentropy( + gradients, hessians, y_true, raw_predictions) + + def predict_proba(self, raw_predictions): + # shape (1, n_samples) --> (n_samples,). reshape(-1) is more likely to + # return a view. + raw_predictions = raw_predictions.reshape(-1) + proba = np.empty((raw_predictions.shape[0], 2), dtype=Y_DTYPE) + proba[:, 1] = expit(raw_predictions) + proba[:, 0] = 1 - proba[:, 1] + return proba + + +class CategoricalCrossEntropy(BaseLoss): + """Categorical cross-entropy loss, for multiclass classification. + + For a given sample x_i, the categorical cross-entropy loss is defined as + the negative log-likelihood of the model and generalizes the binary + cross-entropy to more than 2 classes. + """ + + hessians_are_constant = False + + def __call__(self, y_true, raw_predictions, average=True): + one_hot_true = np.zeros_like(raw_predictions) + prediction_dim = raw_predictions.shape[0] + for k in range(prediction_dim): + one_hot_true[k, :] = (y_true == k) + + loss = (logsumexp(raw_predictions, axis=0) - + (one_hot_true * raw_predictions).sum(axis=0)) + return loss.mean() if average else loss + + def get_baseline_prediction(self, y_train, prediction_dim): + init_value = np.zeros(shape=(prediction_dim, 1), dtype=Y_DTYPE) + eps = np.finfo(y_train.dtype).eps + for k in range(prediction_dim): + proba_kth_class = np.mean(y_train == k) + proba_kth_class = np.clip(proba_kth_class, eps, 1 - eps) + init_value[k, :] += np.log(proba_kth_class) + + return init_value + + def update_gradients_and_hessians(self, gradients, hessians, y_true, + raw_predictions): + _update_gradients_hessians_categorical_crossentropy( + gradients, hessians, y_true, raw_predictions) + + def predict_proba(self, raw_predictions): + # TODO: This could be done in parallel + # compute softmax (using exp(log(softmax))) + proba = np.exp(raw_predictions - + logsumexp(raw_predictions, axis=0)[np.newaxis, :]) + return proba.T + + +_LOSSES = { + 'least_squares': LeastSquares, + 'binary_crossentropy': BinaryCrossEntropy, + 'categorical_crossentropy': CategoricalCrossEntropy +} diff --git a/sklearn/ensemble/_hist_gradient_boosting/predictor.py b/sklearn/ensemble/_hist_gradient_boosting/predictor.py new file mode 100644 index 0000000000000..5b18048cc24e2 --- /dev/null +++ b/sklearn/ensemble/_hist_gradient_boosting/predictor.py @@ -0,0 +1,80 @@ +""" +This module contains the TreePredictor class which is used for prediction. +""" +# Author: Nicolas Hug + +import numpy as np + +from .types import X_DTYPE +from .types import Y_DTYPE +from .types import X_BINNED_DTYPE +from ._predictor import _predict_from_numeric_data +from ._predictor import _predict_from_binned_data + + +PREDICTOR_RECORD_DTYPE = np.dtype([ + ('value', Y_DTYPE), + ('count', np.uint32), + ('feature_idx', np.uint32), + ('threshold', X_DTYPE), + ('left', np.uint32), + ('right', np.uint32), + ('gain', Y_DTYPE), + ('depth', np.uint32), + ('is_leaf', np.uint8), + ('bin_threshold', X_BINNED_DTYPE), +]) + + +class TreePredictor: + """Tree class used for predictions. + + Parameters + ---------- + nodes : list of PREDICTOR_RECORD_DTYPE + The nodes of the tree. + """ + def __init__(self, nodes): + self.nodes = nodes + + def get_n_leaf_nodes(self): + """Return number of leaves.""" + return int(self.nodes['is_leaf'].sum()) + + def get_max_depth(self): + """Return maximum depth among all leaves.""" + return int(self.nodes['depth'].max()) + + def predict(self, X): + """Predict raw values for non-binned data. + + Parameters + ---------- + X : ndarray, shape (n_samples, n_features) + The input samples. + + Returns + ------- + y : ndarray, shape (n_samples,) + The raw predicted values. + """ + out = np.empty(X.shape[0], dtype=Y_DTYPE) + _predict_from_numeric_data(self.nodes, X, out) + return out + + def predict_binned(self, X): + """Predict raw values for binned data. + + Parameters + ---------- + X : ndarray, shape (n_samples, n_features) + The input samples. + + Returns + ------- + y : ndarray, shape (n_samples,) + The raw predicted values. + """ + out = np.empty(X.shape[0], dtype=Y_DTYPE) + _predict_from_binned_data(self.nodes, X, out) + return out diff --git a/sklearn/ensemble/_hist_gradient_boosting/splitting.pyx b/sklearn/ensemble/_hist_gradient_boosting/splitting.pyx new file mode 100644 index 0000000000000..2f7c7d3453326 --- /dev/null +++ b/sklearn/ensemble/_hist_gradient_boosting/splitting.pyx @@ -0,0 +1,514 @@ +# cython: cdivision=True +# cython: boundscheck=False +# cython: wraparound=False +# cython: language_level=3 +"""This module contains routines and data structures to: + +- Find the best possible split of a node. For a given node, a split is + characterized by a feature and a bin. +- Apply a split to a node, i.e. split the indices of the samples at the node + into the newly created left and right childs. +""" +# Author: Nicolas Hug + +cimport cython +from cython.parallel import prange +import numpy as np +cimport numpy as np +IF SKLEARN_OPENMP_SUPPORTED: + from openmp cimport omp_get_max_threads +from libc.stdlib cimport malloc, free +from libc.string cimport memcpy + +from .types cimport X_BINNED_DTYPE_C +from .types cimport Y_DTYPE_C +from .types cimport hist_struct +from .types import HISTOGRAM_DTYPE + + +cdef struct split_info_struct: + # Same as the SplitInfo class, but we need a C struct to use it in the + # nogil sections and to use in arrays. + Y_DTYPE_C gain + int feature_idx + unsigned int bin_idx + Y_DTYPE_C sum_gradient_left + Y_DTYPE_C sum_gradient_right + Y_DTYPE_C sum_hessian_left + Y_DTYPE_C sum_hessian_right + unsigned int n_samples_left + unsigned int n_samples_right + + +class SplitInfo: + """Pure data class to store information about a potential split. + + Parameters + ---------- + gain : float + The gain of the split. + feature_idx : int + The index of the feature to be split. + bin_idx : int + The index of the bin on which the split is made. + sum_gradient_left : float + The sum of the gradients of all the samples in the left child. + sum_hessian_left : float + The sum of the hessians of all the samples in the left child. + sum_gradient_right : float + The sum of the gradients of all the samples in the right child. + sum_hessian_right : float + The sum of the hessians of all the samples in the right child. + n_samples_left : int, default=0 + The number of samples in the left child. + n_samples_right : int + The number of samples in the right child. + """ + def __init__(self, gain, feature_idx, bin_idx, sum_gradient_left, + sum_hessian_left, sum_gradient_right, sum_hessian_right, + n_samples_left, n_samples_right): + self.gain = gain + self.feature_idx = feature_idx + self.bin_idx = bin_idx + self.sum_gradient_left = sum_gradient_left + self.sum_hessian_left = sum_hessian_left + self.sum_gradient_right = sum_gradient_right + self.sum_hessian_right = sum_hessian_right + self.n_samples_left = n_samples_left + self.n_samples_right = n_samples_right + + +@cython.final +cdef class Splitter: + """Splitter used to find the best possible split at each node. + + A split (see SplitInfo) is characterized by a feature and a bin. + + The Splitter is also responsible for partitioning the samples among the + leaves of the tree (see split_indices() and the partition attribute). + + Parameters + ---------- + X_binned : ndarray of int, shape (n_samples, n_features) + The binned input samples. Must be Fortran-aligned. + max_bins : int + The maximum number of bins. Used to define the shape of the + histograms. + actual_n_bins : ndarray, shape (n_features,) + The actual number of bins needed for each feature, which is lower or + equal to max_bins. + l2_regularization : float + The L2 regularization parameter. + min_hessian_to_split : float, default=1e-3 + The minimum sum of hessians needed in each node. Splits that result in + at least one child having a sum of hessians less than + min_hessian_to_split are discarded. + min_samples_leaf : int, default=20 + The minimum number of samples per leaf. + min_gain_to_split : float, default=0.0 + The minimum gain needed to split a node. Splits with lower gain will + be ignored. + hessians_are_constant: bool, default is False + Whether hessians are constant. + """ + cdef public: + const X_BINNED_DTYPE_C [::1, :] X_binned + unsigned int n_features + unsigned int max_bins + unsigned int [::1] actual_n_bins + unsigned char hessians_are_constant + Y_DTYPE_C l2_regularization + Y_DTYPE_C min_hessian_to_split + unsigned int min_samples_leaf + Y_DTYPE_C min_gain_to_split + + unsigned int [::1] partition + unsigned int [::1] left_indices_buffer + unsigned int [::1] right_indices_buffer + + def __init__(self, const X_BINNED_DTYPE_C [::1, :] X_binned, unsigned int + max_bins, np.ndarray[np.uint32_t] actual_n_bins, + Y_DTYPE_C l2_regularization, Y_DTYPE_C + min_hessian_to_split=1e-3, unsigned int + min_samples_leaf=20, Y_DTYPE_C min_gain_to_split=0., + unsigned char hessians_are_constant=False): + + self.X_binned = X_binned + self.n_features = X_binned.shape[1] + # Note: all histograms will have bins, but some of the + # last bins may be unused if actual_n_bins[f] < max_bins + self.max_bins = max_bins + self.actual_n_bins = actual_n_bins + self.l2_regularization = l2_regularization + self.min_hessian_to_split = min_hessian_to_split + self.min_samples_leaf = min_samples_leaf + self.min_gain_to_split = min_gain_to_split + self.hessians_are_constant = hessians_are_constant + + # The partition array maps each sample index into the leaves of the + # tree (a leaf in this context is a node that isn't splitted yet, not + # necessarily a 'finalized' leaf). Initially, the root contains all + # the indices, e.g.: + # partition = [abcdefghijkl] + # After a call to split_indices, it may look e.g. like this: + # partition = [cef|abdghijkl] + # we have 2 leaves, the left one is at position 0 and the second one at + # position 3. The order of the samples is irrelevant. + self.partition = np.arange(X_binned.shape[0], dtype=np.uint32) + # buffers used in split_indices to support parallel splitting. + self.left_indices_buffer = np.empty_like(self.partition) + self.right_indices_buffer = np.empty_like(self.partition) + + def split_indices(Splitter self, split_info, unsigned int [::1] + sample_indices): + """Split samples into left and right arrays. + + The split is performed according to the best possible split + (split_info). + + Ultimately, this is nothing but a partition of the sample_indices + array with a given pivot, exactly like a quicksort subroutine. + + Parameters + ---------- + split_info : SplitInfo + The SplitInfo of the node to split. + sample_indices : ndarray of unsigned int, shape (n_samples_at_node,) + The indices of the samples at the node to split. This is a view + on self.partition, and it is modified inplace by placing the + indices of the left child at the beginning, and the indices of + the right child at the end. + + Returns + ------- + left_indices : ndarray of int, shape (n_left_samples,) + The indices of the samples in the left child. This is a view on + self.partition. + right_indices : ndarray of int, shape (n_right_samples,) + The indices of the samples in the right child. This is a view on + self.partition. + right_child_position : int + The position of the right child in ``sample_indices``. + """ + # This is a multi-threaded implementation inspired by lightgbm. Here + # is a quick break down. Let's suppose we want to split a node with 24 + # samples named from a to x. self.partition looks like this (the * are + # indices in other leaves that we don't care about): + # partition = [*************abcdefghijklmnopqrstuvwx****************] + # ^ ^ + # node_position node_position + node.n_samples + + # Ultimately, we want to reorder the samples inside the boundaries of + # the leaf (which becomes a node) to now represent the samples in its + # left and right child. For example: + # partition = [*************abefilmnopqrtuxcdghjksvw*****************] + # ^ ^ + # left_child_pos right_child_pos + # Note that left_child_pos always takes the value of node_position, + # and right_child_pos = left_child_pos + left_child.n_samples. The + # order of the samples inside a leaf is irrelevant. + + # 1. sample_indices is a view on this region a..x. We conceptually + # divide it into n_threads regions. Each thread will be responsible + # for its own region. Here is an example with 4 threads: + # sample_indices = [abcdef|ghijkl|mnopqr|stuvwx] + # 2. Each thread processes 6 = 24 // 4 entries and maps them into + # left_indices_buffer or right_indices_buffer. For example, we could + # have the following mapping ('.' denotes an undefined entry): + # - left_indices_buffer = [abef..|il....|mnopqr|tux...] + # - right_indices_buffer = [cd....|ghjk..|......|svw...] + # 3. We keep track of the start positions of the regions (the '|') in + # ``offset_in_buffers`` as well as the size of each region. We also + # keep track of the number of samples put into the left/right child + # by each thread. Concretely: + # - left_counts = [4, 2, 6, 3] + # - right_counts = [2, 4, 0, 3] + # 4. Finally, we put left/right_indices_buffer back into the + # sample_indices, without any undefined entries and the partition + # looks as expected + # partition = [*************abefilmnopqrtuxcdghjksvw***************] + + # Note: We here show left/right_indices_buffer as being the same size + # as sample_indices for simplicity, but in reality they are of the + # same size as partition. + + cdef: + int n_samples = sample_indices.shape[0] + X_BINNED_DTYPE_C bin_idx = split_info.bin_idx + int feature_idx = split_info.feature_idx + const X_BINNED_DTYPE_C [::1] X_binned = \ + self.X_binned[:, feature_idx] + unsigned int [::1] left_indices_buffer = self.left_indices_buffer + unsigned int [::1] right_indices_buffer = self.right_indices_buffer + + IF SKLEARN_OPENMP_SUPPORTED: + int n_threads = omp_get_max_threads() + ELSE: + int n_threads = 1 + + int [:] sizes = np.full(n_threads, n_samples // n_threads, + dtype=np.int32) + int [:] offset_in_buffers = np.zeros(n_threads, dtype=np.int32) + int [:] left_counts = np.empty(n_threads, dtype=np.int32) + int [:] right_counts = np.empty(n_threads, dtype=np.int32) + int left_count + int right_count + int start + int stop + int i + int thread_idx + int sample_idx + int right_child_position + int [:] left_offset = np.zeros(n_threads, dtype=np.int32) + int [:] right_offset = np.zeros(n_threads, dtype=np.int32) + + with nogil: + for thread_idx in range(n_samples % n_threads): + sizes[thread_idx] += 1 + + for thread_idx in range(1, n_threads): + offset_in_buffers[thread_idx] = \ + offset_in_buffers[thread_idx - 1] + sizes[thread_idx - 1] + + # map indices from sample_indices to left/right_indices_buffer + for thread_idx in prange(n_threads, schedule='static', + chunksize=1): + left_count = 0 + right_count = 0 + + start = offset_in_buffers[thread_idx] + stop = start + sizes[thread_idx] + for i in range(start, stop): + sample_idx = sample_indices[i] + if X_binned[sample_idx] <= bin_idx: + left_indices_buffer[start + left_count] = sample_idx + left_count = left_count + 1 + else: + right_indices_buffer[start + right_count] = sample_idx + right_count = right_count + 1 + + left_counts[thread_idx] = left_count + right_counts[thread_idx] = right_count + + # position of right child = just after the left child + right_child_position = 0 + for thread_idx in range(n_threads): + right_child_position += left_counts[thread_idx] + + # offset of each thread in sample_indices for left and right + # child, i.e. where each thread will start to write. + right_offset[0] = right_child_position + for thread_idx in range(1, n_threads): + left_offset[thread_idx] = \ + left_offset[thread_idx - 1] + left_counts[thread_idx - 1] + right_offset[thread_idx] = \ + right_offset[thread_idx - 1] + right_counts[thread_idx - 1] + + # map indices in left/right_indices_buffer back into + # sample_indices. This also updates self.partition since + # sample_indices is a view. + for thread_idx in prange(n_threads, schedule='static', + chunksize=1): + memcpy( + &sample_indices[left_offset[thread_idx]], + &left_indices_buffer[offset_in_buffers[thread_idx]], + sizeof(unsigned int) * left_counts[thread_idx] + ) + memcpy( + &sample_indices[right_offset[thread_idx]], + &right_indices_buffer[offset_in_buffers[thread_idx]], + sizeof(unsigned int) * right_counts[thread_idx] + ) + + return (sample_indices[:right_child_position], + sample_indices[right_child_position:], + right_child_position) + + def find_node_split( + Splitter self, + const unsigned int [::1] sample_indices, # IN + hist_struct [:, ::1] histograms, # IN + const Y_DTYPE_C sum_gradients, + const Y_DTYPE_C sum_hessians): + """For each feature, find the best bin to split on at a given node. + + Return the best split info among all features. + + Parameters + ---------- + sample_indices : ndarray of unsigned int, shape (n_samples_at_node,) + The indices of the samples at the node to split. + histograms : ndarray of HISTOGRAM_DTYPE of \ + shape (n_features, max_bins) + The histograms of the current node. + sum_gradients : float + The sum of the gradients for each sample at the node. + sum_hessians : float + The sum of the hessians for each sample at the node. + + Returns + ------- + best_split_info : SplitInfo + The info about the best possible split among all features. + """ + cdef: + int n_samples + int feature_idx + int best_feature_idx + int n_features = self.n_features + split_info_struct split_info + split_info_struct * split_infos + + with nogil: + n_samples = sample_indices.shape[0] + + split_infos = malloc( + self.n_features * sizeof(split_info_struct)) + + for feature_idx in prange(n_features, schedule='static'): + # For each feature, find best bin to split on + split_info = self._find_best_bin_to_split_helper( + feature_idx, histograms, n_samples, + sum_gradients, sum_hessians) + split_infos[feature_idx] = split_info + + # then compute best possible split among all features + best_feature_idx = self._find_best_feature_to_split_helper( + split_infos) + split_info = split_infos[best_feature_idx] + + out = SplitInfo( + split_info.gain, + split_info.feature_idx, + split_info.bin_idx, + split_info.sum_gradient_left, + split_info.sum_hessian_left, + split_info.sum_gradient_right, + split_info.sum_hessian_right, + split_info.n_samples_left, + split_info.n_samples_right, + ) + free(split_infos) + return out + + cdef int _find_best_feature_to_split_helper( + self, + split_info_struct * split_infos) nogil: # IN + """Returns the best feature among those in splits_infos.""" + cdef: + int feature_idx + int best_feature_idx = 0 + + for feature_idx in range(1, self.n_features): + if (split_infos[feature_idx].gain > + split_infos[best_feature_idx].gain): + best_feature_idx = feature_idx + return best_feature_idx + + cdef split_info_struct _find_best_bin_to_split_helper( + self, + unsigned int feature_idx, + const hist_struct [:, ::1] histograms, # IN + unsigned int n_samples, + Y_DTYPE_C sum_gradients, + Y_DTYPE_C sum_hessians) nogil: + """Find best bin to split on for a given feature. + + Splits that do not satisfy the splitting constraints + (min_gain_to_split, etc.) are discarded here. If no split can + satisfy the constraints, a SplitInfo with a gain of -1 is returned. + If for a given node the best SplitInfo has a gain of -1, it is + finalized into a leaf in the grower. + """ + cdef: + unsigned int bin_idx + unsigned int n_samples_left + unsigned int n_samples_right + unsigned int n_samples_ = n_samples + Y_DTYPE_C sum_hessian_left + Y_DTYPE_C sum_hessian_right + Y_DTYPE_C sum_gradient_left + Y_DTYPE_C sum_gradient_right + Y_DTYPE_C gain + split_info_struct best_split + + best_split.gain = -1. + sum_gradient_left, sum_hessian_left = 0., 0. + n_samples_left = 0 + + for bin_idx in range(self.actual_n_bins[feature_idx]): + n_samples_left += histograms[feature_idx, bin_idx].count + n_samples_right = n_samples_ - n_samples_left + + if self.hessians_are_constant: + sum_hessian_left += histograms[feature_idx, bin_idx].count + else: + sum_hessian_left += \ + histograms[feature_idx, bin_idx].sum_hessians + sum_hessian_right = sum_hessians - sum_hessian_left + + sum_gradient_left += histograms[feature_idx, bin_idx].sum_gradients + sum_gradient_right = sum_gradients - sum_gradient_left + + if n_samples_left < self.min_samples_leaf: + continue + if n_samples_right < self.min_samples_leaf: + # won't get any better + break + + if sum_hessian_left < self.min_hessian_to_split: + continue + if sum_hessian_right < self.min_hessian_to_split: + # won't get any better (hessians are > 0 since loss is convex) + break + + gain = _split_gain(sum_gradient_left, sum_hessian_left, + sum_gradient_right, sum_hessian_right, + sum_gradients, sum_hessians, + self.l2_regularization) + + if gain > best_split.gain and gain > self.min_gain_to_split: + best_split.gain = gain + best_split.feature_idx = feature_idx + best_split.bin_idx = bin_idx + best_split.sum_gradient_left = sum_gradient_left + best_split.sum_gradient_right = sum_gradient_right + best_split.sum_hessian_left = sum_hessian_left + best_split.sum_hessian_right = sum_hessian_right + best_split.n_samples_left = n_samples_left + best_split.n_samples_right = n_samples_right + + return best_split + + +cdef inline Y_DTYPE_C _split_gain( + Y_DTYPE_C sum_gradient_left, + Y_DTYPE_C sum_hessian_left, + Y_DTYPE_C sum_gradient_right, + Y_DTYPE_C sum_hessian_right, + Y_DTYPE_C sum_gradients, + Y_DTYPE_C sum_hessians, + Y_DTYPE_C l2_regularization) nogil: + """Loss reduction + + Compute the reduction in loss after taking a split, compared to keeping + the node a leaf of the tree. + + See Equation 7 of: + XGBoost: A Scalable Tree Boosting System, T. Chen, C. Guestrin, 2016 + https://arxiv.org/abs/1603.02754 + """ + cdef: + Y_DTYPE_C gain + gain = negative_loss(sum_gradient_left, sum_hessian_left, + l2_regularization) + gain += negative_loss(sum_gradient_right, sum_hessian_right, + l2_regularization) + gain -= negative_loss(sum_gradients, sum_hessians, l2_regularization) + return gain + +cdef inline Y_DTYPE_C negative_loss( + Y_DTYPE_C gradient, + Y_DTYPE_C hessian, + Y_DTYPE_C l2_regularization) nogil: + return (gradient * gradient) / (hessian + l2_regularization) diff --git a/sklearn/ensemble/_hist_gradient_boosting/tests/test_binning.py b/sklearn/ensemble/_hist_gradient_boosting/tests/test_binning.py new file mode 100644 index 0000000000000..4f4def6199411 --- /dev/null +++ b/sklearn/ensemble/_hist_gradient_boosting/tests/test_binning.py @@ -0,0 +1,242 @@ +import numpy as np +from numpy.testing import assert_array_equal, assert_allclose +import pytest + +from sklearn.ensemble._hist_gradient_boosting.binning import ( + _BinMapper, + _find_binning_thresholds as _find_binning_thresholds_orig, + _map_to_bins +) +from sklearn.ensemble._hist_gradient_boosting.types import X_DTYPE +from sklearn.ensemble._hist_gradient_boosting.types import X_BINNED_DTYPE + + +DATA = np.random.RandomState(42).normal( + loc=[0, 10], scale=[1, 0.01], size=(int(1e6), 2) +).astype(X_DTYPE) + + +def _find_binning_thresholds(data, max_bins=256, subsample=int(2e5), + random_state=None): + # Just a redef to avoid having to pass arguments all the time (as the + # function is private we don't use default values for parameters) + return _find_binning_thresholds_orig(data, max_bins, subsample, + random_state) + + +def test_find_binning_thresholds_regular_data(): + data = np.linspace(0, 10, 1001).reshape(-1, 1) + bin_thresholds = _find_binning_thresholds(data, max_bins=10) + assert_allclose(bin_thresholds[0], [1, 2, 3, 4, 5, 6, 7, 8, 9]) + assert len(bin_thresholds) == 1 + + bin_thresholds = _find_binning_thresholds(data, max_bins=5) + assert_allclose(bin_thresholds[0], [2, 4, 6, 8]) + assert len(bin_thresholds) == 1 + + +def test_find_binning_thresholds_small_regular_data(): + data = np.linspace(0, 10, 11).reshape(-1, 1) + + bin_thresholds = _find_binning_thresholds(data, max_bins=5) + assert_allclose(bin_thresholds[0], [2, 4, 6, 8]) + + bin_thresholds = _find_binning_thresholds(data, max_bins=10) + assert_allclose(bin_thresholds[0], [1, 2, 3, 4, 5, 6, 7, 8, 9]) + + bin_thresholds = _find_binning_thresholds(data, max_bins=11) + assert_allclose(bin_thresholds[0], np.arange(10) + .5) + + bin_thresholds = _find_binning_thresholds(data, max_bins=255) + assert_allclose(bin_thresholds[0], np.arange(10) + .5) + + +def test_find_binning_thresholds_random_data(): + bin_thresholds = _find_binning_thresholds(DATA, random_state=0) + assert len(bin_thresholds) == 2 + for i in range(len(bin_thresholds)): + assert bin_thresholds[i].shape == (255,) # 256 - 1 + assert bin_thresholds[i].dtype == DATA.dtype + + assert_allclose(bin_thresholds[0][[64, 128, 192]], + np.array([-0.7, 0.0, 0.7]), atol=1e-1) + + assert_allclose(bin_thresholds[1][[64, 128, 192]], + np.array([9.99, 10.00, 10.01]), atol=1e-2) + + +def test_find_binning_thresholds_low_n_bins(): + bin_thresholds = _find_binning_thresholds(DATA, max_bins=128, + random_state=0) + assert len(bin_thresholds) == 2 + for i in range(len(bin_thresholds)): + assert bin_thresholds[i].shape == (127,) # 128 - 1 + assert bin_thresholds[i].dtype == DATA.dtype + + +def test_find_binning_thresholds_invalid_n_bins(): + err_msg = 'no smaller than 2 and no larger than 256' + with pytest.raises(ValueError, match=err_msg): + _find_binning_thresholds(DATA, max_bins=1024) + + +def test_bin_mapper_n_features_transform(): + mapper = _BinMapper(max_bins=42, random_state=42).fit(DATA) + err_msg = 'This estimator was fitted with 2 features but 4 got passed' + with pytest.raises(ValueError, match=err_msg): + mapper.transform(np.repeat(DATA, 2, axis=1)) + + +@pytest.mark.parametrize('n_bins', [16, 128, 256]) +def test_map_to_bins(n_bins): + bin_thresholds = _find_binning_thresholds(DATA, max_bins=n_bins, + random_state=0) + binned = np.zeros_like(DATA, dtype=X_BINNED_DTYPE, order='F') + _map_to_bins(DATA, bin_thresholds, binned) + assert binned.shape == DATA.shape + assert binned.dtype == np.uint8 + assert binned.flags.f_contiguous + + min_indices = DATA.argmin(axis=0) + max_indices = DATA.argmax(axis=0) + + for feature_idx, min_idx in enumerate(min_indices): + assert binned[min_idx, feature_idx] == 0 + for feature_idx, max_idx in enumerate(max_indices): + assert binned[max_idx, feature_idx] == n_bins - 1 + + +@pytest.mark.parametrize("n_bins", [5, 10, 42]) +def test_bin_mapper_random_data(n_bins): + n_samples, n_features = DATA.shape + + expected_count_per_bin = n_samples // n_bins + tol = int(0.05 * expected_count_per_bin) + + mapper = _BinMapper(max_bins=n_bins, random_state=42).fit(DATA) + binned = mapper.transform(DATA) + + assert binned.shape == (n_samples, n_features) + assert binned.dtype == np.uint8 + assert_array_equal(binned.min(axis=0), np.array([0, 0])) + assert_array_equal(binned.max(axis=0), np.array([n_bins - 1, n_bins - 1])) + assert len(mapper.bin_thresholds_) == n_features + for bin_thresholds_feature in mapper.bin_thresholds_: + assert bin_thresholds_feature.shape == (n_bins - 1,) + assert bin_thresholds_feature.dtype == DATA.dtype + assert np.all(mapper.actual_n_bins_ == n_bins) + + # Check that the binned data is approximately balanced across bins. + for feature_idx in range(n_features): + for bin_idx in range(n_bins): + count = (binned[:, feature_idx] == bin_idx).sum() + assert abs(count - expected_count_per_bin) < tol + + +@pytest.mark.parametrize("n_samples, n_bins", [ + (5, 5), + (5, 10), + (5, 11), + (42, 255) +]) +def test_bin_mapper_small_random_data(n_samples, n_bins): + data = np.random.RandomState(42).normal(size=n_samples).reshape(-1, 1) + assert len(np.unique(data)) == n_samples + + mapper = _BinMapper(max_bins=n_bins, random_state=42) + binned = mapper.fit_transform(data) + + assert binned.shape == data.shape + assert binned.dtype == np.uint8 + assert_array_equal(binned.ravel()[np.argsort(data.ravel())], + np.arange(n_samples)) + + +@pytest.mark.parametrize("n_bins, n_distinct, multiplier", [ + (5, 5, 1), + (5, 5, 3), + (255, 12, 42), +]) +def test_bin_mapper_identity_repeated_values(n_bins, n_distinct, multiplier): + data = np.array(list(range(n_distinct)) * multiplier).reshape(-1, 1) + binned = _BinMapper(max_bins=n_bins).fit_transform(data) + assert_array_equal(data, binned) + + +@pytest.mark.parametrize('n_distinct', [2, 7, 42]) +def test_bin_mapper_repeated_values_invariance(n_distinct): + rng = np.random.RandomState(42) + distinct_values = rng.normal(size=n_distinct) + assert len(np.unique(distinct_values)) == n_distinct + + repeated_indices = rng.randint(low=0, high=n_distinct, size=1000) + data = distinct_values[repeated_indices] + rng.shuffle(data) + assert_array_equal(np.unique(data), np.sort(distinct_values)) + + data = data.reshape(-1, 1) + + mapper_1 = _BinMapper(max_bins=n_distinct) + binned_1 = mapper_1.fit_transform(data) + assert_array_equal(np.unique(binned_1[:, 0]), np.arange(n_distinct)) + + # Adding more bins to the mapper yields the same results (same thresholds) + mapper_2 = _BinMapper(max_bins=min(256, n_distinct * 3)) + binned_2 = mapper_2.fit_transform(data) + + assert_allclose(mapper_1.bin_thresholds_[0], mapper_2.bin_thresholds_[0]) + assert_array_equal(binned_1, binned_2) + + +@pytest.mark.parametrize("n_bins, scale, offset", [ + (3, 2, -1), + (42, 1, 0), + (256, 0.3, 42), +]) +def test_bin_mapper_identity_small(n_bins, scale, offset): + data = np.arange(n_bins).reshape(-1, 1) * scale + offset + binned = _BinMapper(max_bins=n_bins).fit_transform(data) + assert_array_equal(binned, np.arange(n_bins).reshape(-1, 1)) + + +@pytest.mark.parametrize('n_bins_small, n_bins_large', [ + (2, 2), + (3, 3), + (4, 4), + (42, 42), + (256, 256), + (5, 17), + (42, 256), +]) +def test_bin_mapper_idempotence(n_bins_small, n_bins_large): + assert n_bins_large >= n_bins_small + data = np.random.RandomState(42).normal(size=30000).reshape(-1, 1) + mapper_small = _BinMapper(max_bins=n_bins_small) + mapper_large = _BinMapper(max_bins=n_bins_large) + binned_small = mapper_small.fit_transform(data) + binned_large = mapper_large.fit_transform(binned_small) + assert_array_equal(binned_small, binned_large) + + +@pytest.mark.parametrize('max_bins', [10, 100, 256]) +@pytest.mark.parametrize('diff', [-5, 0, 5]) +def test_actual_n_bins(max_bins, diff): + # Check that actual_n_bins is n_unique_values when + # n_unique_values <= max_bins, else max_bins. + + n_unique_values = max_bins + diff + X = list(range(n_unique_values)) * 2 + X = np.array(X).reshape(-1, 1) + mapper = _BinMapper(max_bins=max_bins).fit(X) + assert np.all(mapper.actual_n_bins_ == min(max_bins, n_unique_values)) + + +def test_subsample(): + # Make sure bin thresholds are different when applying subsampling + mapper_no_subsample = _BinMapper(subsample=None, random_state=0).fit(DATA) + mapper_subsample = _BinMapper(subsample=256, random_state=0).fit(DATA) + + for feature in range(DATA.shape[1]): + assert not np.allclose(mapper_no_subsample.bin_thresholds_[feature], + mapper_subsample.bin_thresholds_[feature], + rtol=1e-4) diff --git a/sklearn/ensemble/_hist_gradient_boosting/tests/test_compare_lightgbm.py b/sklearn/ensemble/_hist_gradient_boosting/tests/test_compare_lightgbm.py new file mode 100644 index 0000000000000..95672a60e5c40 --- /dev/null +++ b/sklearn/ensemble/_hist_gradient_boosting/tests/test_compare_lightgbm.py @@ -0,0 +1,216 @@ +from sklearn.model_selection import train_test_split +from sklearn.metrics import accuracy_score +from sklearn.datasets import make_classification, make_regression +import numpy as np +import pytest + +# To use this experimental feature, we need to explicitly ask for it: +from sklearn.experimental import enable_hist_gradient_boosting # noqa +from sklearn.ensemble import HistGradientBoostingRegressor +from sklearn.ensemble import HistGradientBoostingClassifier +from sklearn.ensemble._hist_gradient_boosting.binning import _BinMapper +from sklearn.ensemble._hist_gradient_boosting.utils import ( + get_equivalent_estimator) + + +pytest.importorskip("lightgbm") + + +@pytest.mark.parametrize('seed', range(5)) +@pytest.mark.parametrize('min_samples_leaf', (1, 20)) +@pytest.mark.parametrize('n_samples, max_leaf_nodes', [ + (255, 4096), + (1000, 8), +]) +def test_same_predictions_regression(seed, min_samples_leaf, n_samples, + max_leaf_nodes): + # Make sure sklearn has the same predictions as lightgbm for easy targets. + # + # In particular when the size of the trees are bound and the number of + # samples is large enough, the structure of the prediction trees found by + # LightGBM and sklearn should be exactly identical. + # + # Notes: + # - Several candidate splits may have equal gains when the number of + # samples in a node is low (and because of float errors). Therefore the + # predictions on the test set might differ if the structure of the tree + # is not exactly the same. To avoid this issue we only compare the + # predictions on the test set when the number of samples is large enough + # and max_leaf_nodes is low enough. + # - To ignore discrepancies caused by small differences the binning + # strategy, data is pre-binned if n_samples > 255. + + rng = np.random.RandomState(seed=seed) + n_samples = n_samples + max_iter = 1 + max_bins = 256 + + X, y = make_regression(n_samples=n_samples, n_features=5, + n_informative=5, random_state=0) + + if n_samples > 255: + # bin data and convert it to float32 so that the estimator doesn't + # treat it as pre-binned + X = _BinMapper(max_bins=max_bins).fit_transform(X).astype(np.float32) + + X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=rng) + + est_sklearn = HistGradientBoostingRegressor( + max_iter=max_iter, + max_bins=max_bins, + learning_rate=1, + n_iter_no_change=None, + min_samples_leaf=min_samples_leaf, + max_leaf_nodes=max_leaf_nodes) + est_lightgbm = get_equivalent_estimator(est_sklearn, lib='lightgbm') + + est_lightgbm.fit(X_train, y_train) + est_sklearn.fit(X_train, y_train) + + # We need X to be treated an numerical data, not pre-binned data. + X_train, X_test = X_train.astype(np.float32), X_test.astype(np.float32) + + pred_lightgbm = est_lightgbm.predict(X_train) + pred_sklearn = est_sklearn.predict(X_train) + # less than 1% of the predictions are different up to the 3rd decimal + assert np.mean(abs(pred_lightgbm - pred_sklearn) > 1e-3) < .011 + + if max_leaf_nodes < 10 and n_samples >= 1000: + pred_lightgbm = est_lightgbm.predict(X_test) + pred_sklearn = est_sklearn.predict(X_test) + # less than 1% of the predictions are different up to the 4th decimal + assert np.mean(abs(pred_lightgbm - pred_sklearn) > 1e-4) < .01 + + +@pytest.mark.parametrize('seed', range(5)) +@pytest.mark.parametrize('min_samples_leaf', (1, 20)) +@pytest.mark.parametrize('n_samples, max_leaf_nodes', [ + (255, 4096), + (1000, 8), +]) +def test_same_predictions_classification(seed, min_samples_leaf, n_samples, + max_leaf_nodes): + # Same as test_same_predictions_regression but for classification + + rng = np.random.RandomState(seed=seed) + n_samples = n_samples + max_iter = 1 + max_bins = 256 + + X, y = make_classification(n_samples=n_samples, n_classes=2, n_features=5, + n_informative=5, n_redundant=0, random_state=0) + + if n_samples > 255: + # bin data and convert it to float32 so that the estimator doesn't + # treat it as pre-binned + X = _BinMapper(max_bins=max_bins).fit_transform(X).astype(np.float32) + + X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=rng) + + est_sklearn = HistGradientBoostingClassifier( + loss='binary_crossentropy', + max_iter=max_iter, + max_bins=max_bins, + learning_rate=1, + n_iter_no_change=None, + min_samples_leaf=min_samples_leaf, + max_leaf_nodes=max_leaf_nodes) + est_lightgbm = get_equivalent_estimator(est_sklearn, lib='lightgbm') + + est_lightgbm.fit(X_train, y_train) + est_sklearn.fit(X_train, y_train) + + # We need X to be treated an numerical data, not pre-binned data. + X_train, X_test = X_train.astype(np.float32), X_test.astype(np.float32) + + pred_lightgbm = est_lightgbm.predict(X_train) + pred_sklearn = est_sklearn.predict(X_train) + assert np.mean(pred_sklearn == pred_lightgbm) > .89 + + acc_lightgbm = accuracy_score(y_train, pred_lightgbm) + acc_sklearn = accuracy_score(y_train, pred_sklearn) + np.testing.assert_almost_equal(acc_lightgbm, acc_sklearn) + + if max_leaf_nodes < 10 and n_samples >= 1000: + + pred_lightgbm = est_lightgbm.predict(X_test) + pred_sklearn = est_sklearn.predict(X_test) + assert np.mean(pred_sklearn == pred_lightgbm) > .89 + + acc_lightgbm = accuracy_score(y_test, pred_lightgbm) + acc_sklearn = accuracy_score(y_test, pred_sklearn) + np.testing.assert_almost_equal(acc_lightgbm, acc_sklearn, decimal=2) + + +@pytest.mark.parametrize('seed', range(5)) +@pytest.mark.parametrize('min_samples_leaf', (1, 20)) +@pytest.mark.parametrize('n_samples, max_leaf_nodes', [ + (255, 4096), + (10000, 8), +]) +def test_same_predictions_multiclass_classification( + seed, min_samples_leaf, n_samples, max_leaf_nodes): + # Same as test_same_predictions_regression but for classification + + rng = np.random.RandomState(seed=seed) + n_samples = n_samples + max_iter = 1 + max_bins = 256 + lr = 1 + + X, y = make_classification(n_samples=n_samples, n_classes=3, n_features=5, + n_informative=5, n_redundant=0, + n_clusters_per_class=1, random_state=0) + + if n_samples > 255: + # bin data and convert it to float32 so that the estimator doesn't + # treat it as pre-binned + X = _BinMapper(max_bins=max_bins).fit_transform(X).astype(np.float32) + + X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=rng) + + est_sklearn = HistGradientBoostingClassifier( + loss='categorical_crossentropy', + max_iter=max_iter, + max_bins=max_bins, + learning_rate=lr, + n_iter_no_change=None, + min_samples_leaf=min_samples_leaf, + max_leaf_nodes=max_leaf_nodes) + est_lightgbm = get_equivalent_estimator(est_sklearn, lib='lightgbm') + + est_lightgbm.fit(X_train, y_train) + est_sklearn.fit(X_train, y_train) + + # We need X to be treated an numerical data, not pre-binned data. + X_train, X_test = X_train.astype(np.float32), X_test.astype(np.float32) + + pred_lightgbm = est_lightgbm.predict(X_train) + pred_sklearn = est_sklearn.predict(X_train) + assert np.mean(pred_sklearn == pred_lightgbm) > .89 + + proba_lightgbm = est_lightgbm.predict_proba(X_train) + proba_sklearn = est_sklearn.predict_proba(X_train) + # assert more than 75% of the predicted probabilities are the same up to + # the second decimal + assert np.mean(np.abs(proba_lightgbm - proba_sklearn) < 1e-2) > .75 + + acc_lightgbm = accuracy_score(y_train, pred_lightgbm) + acc_sklearn = accuracy_score(y_train, pred_sklearn) + np.testing.assert_almost_equal(acc_lightgbm, acc_sklearn, decimal=2) + + if max_leaf_nodes < 10 and n_samples >= 1000: + + pred_lightgbm = est_lightgbm.predict(X_test) + pred_sklearn = est_sklearn.predict(X_test) + assert np.mean(pred_sklearn == pred_lightgbm) > .89 + + proba_lightgbm = est_lightgbm.predict_proba(X_train) + proba_sklearn = est_sklearn.predict_proba(X_train) + # assert more than 75% of the predicted probabilities are the same up + # to the second decimal + assert np.mean(np.abs(proba_lightgbm - proba_sklearn) < 1e-2) > .75 + + acc_lightgbm = accuracy_score(y_test, pred_lightgbm) + acc_sklearn = accuracy_score(y_test, pred_sklearn) + np.testing.assert_almost_equal(acc_lightgbm, acc_sklearn, decimal=2) diff --git a/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py b/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py new file mode 100644 index 0000000000000..790597b07fa15 --- /dev/null +++ b/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py @@ -0,0 +1,147 @@ +import numpy as np +import pytest +from sklearn.datasets import make_classification, make_regression + +# To use this experimental feature, we need to explicitly ask for it: +from sklearn.experimental import enable_hist_gradient_boosting # noqa +from sklearn.ensemble import HistGradientBoostingRegressor +from sklearn.ensemble import HistGradientBoostingClassifier + + +X_classification, y_classification = make_classification(random_state=0) +X_regression, y_regression = make_regression(random_state=0) + + +@pytest.mark.parametrize('GradientBoosting, X, y', [ + (HistGradientBoostingClassifier, X_classification, y_classification), + (HistGradientBoostingRegressor, X_regression, y_regression) +]) +@pytest.mark.parametrize( + 'params, err_msg', + [({'loss': 'blah'}, 'Loss blah is not supported for'), + ({'learning_rate': 0}, 'learning_rate=0 must be strictly positive'), + ({'learning_rate': -1}, 'learning_rate=-1 must be strictly positive'), + ({'max_iter': 0}, 'max_iter=0 must not be smaller than 1'), + ({'max_leaf_nodes': 0}, 'max_leaf_nodes=0 should not be smaller than 2'), + ({'max_leaf_nodes': 1}, 'max_leaf_nodes=1 should not be smaller than 2'), + ({'max_depth': 0}, 'max_depth=0 should not be smaller than 2'), + ({'max_depth': 1}, 'max_depth=1 should not be smaller than 2'), + ({'min_samples_leaf': 0}, 'min_samples_leaf=0 should not be smaller'), + ({'l2_regularization': -1}, 'l2_regularization=-1 must be positive'), + ({'max_bins': 1}, 'max_bins=1 should be no smaller than 2 and no larger'), + ({'max_bins': 257}, 'max_bins=257 should be no smaller than 2 and no'), + ({'n_iter_no_change': -1}, 'n_iter_no_change=-1 must be positive'), + ({'validation_fraction': -1}, 'validation_fraction=-1 must be strictly'), + ({'validation_fraction': 0}, 'validation_fraction=0 must be strictly'), + ({'tol': -1}, 'tol=-1 must not be smaller than 0')] +) +def test_init_parameters_validation(GradientBoosting, X, y, params, err_msg): + + with pytest.raises(ValueError, match=err_msg): + GradientBoosting(**params).fit(X, y) + + +def test_invalid_classification_loss(): + binary_clf = HistGradientBoostingClassifier(loss="binary_crossentropy") + err_msg = ("loss='binary_crossentropy' is not defined for multiclass " + "classification with n_classes=3, use " + "loss='categorical_crossentropy' instead") + with pytest.raises(ValueError, match=err_msg): + binary_clf.fit(np.zeros(shape=(3, 2)), np.arange(3)) + + +@pytest.mark.parametrize( + 'scoring, validation_fraction, n_iter_no_change, tol', [ + ('neg_mean_squared_error', .1, 5, 1e-7), # use scorer + ('neg_mean_squared_error', None, 5, 1e-1), # use scorer on train data + (None, .1, 5, 1e-7), # same with default scorer + (None, None, 5, 1e-1), + ('loss', .1, 5, 1e-7), # use loss + ('loss', None, 5, 1e-1), # use loss on training data + (None, None, None, None), # no early stopping + ]) +def test_early_stopping_regression(scoring, validation_fraction, + n_iter_no_change, tol): + + max_iter = 200 + + X, y = make_regression(random_state=0) + + gb = HistGradientBoostingRegressor( + verbose=1, # just for coverage + min_samples_leaf=5, # easier to overfit fast + scoring=scoring, + tol=tol, + validation_fraction=validation_fraction, + max_iter=max_iter, + n_iter_no_change=n_iter_no_change, + random_state=0 + ) + gb.fit(X, y) + + if n_iter_no_change is not None: + assert n_iter_no_change <= gb.n_iter_ < max_iter + else: + assert gb.n_iter_ == max_iter + + +@pytest.mark.parametrize('data', ( + make_classification(random_state=0), + make_classification(n_classes=3, n_clusters_per_class=1, random_state=0) +)) +@pytest.mark.parametrize( + 'scoring, validation_fraction, n_iter_no_change, tol', [ + ('accuracy', .1, 5, 1e-7), # use scorer + ('accuracy', None, 5, 1e-1), # use scorer on training data + (None, .1, 5, 1e-7), # same with default scorerscor + (None, None, 5, 1e-1), + ('loss', .1, 5, 1e-7), # use loss + ('loss', None, 5, 1e-1), # use loss on training data + (None, None, None, None), # no early stopping + ]) +def test_early_stopping_classification(data, scoring, validation_fraction, + n_iter_no_change, tol): + + max_iter = 50 + + X, y = data + + gb = HistGradientBoostingClassifier( + verbose=1, # just for coverage + min_samples_leaf=5, # easier to overfit fast + scoring=scoring, + tol=tol, + validation_fraction=validation_fraction, + max_iter=max_iter, + n_iter_no_change=n_iter_no_change, + random_state=0 + ) + gb.fit(X, y) + + if n_iter_no_change is not None: + assert n_iter_no_change <= gb.n_iter_ < max_iter + else: + assert gb.n_iter_ == max_iter + + +@pytest.mark.parametrize( + 'scores, n_iter_no_change, tol, stopping', + [ + ([], 1, 0.001, False), # not enough iterations + ([1, 1, 1], 5, 0.001, False), # not enough iterations + ([1, 1, 1, 1, 1], 5, 0.001, False), # not enough iterations + ([1, 2, 3, 4, 5, 6], 5, 0.001, False), # significant improvement + ([1, 2, 3, 4, 5, 6], 5, 0., False), # significant improvement + ([1, 2, 3, 4, 5, 6], 5, 0.999, False), # significant improvement + ([1, 2, 3, 4, 5, 6], 5, 5 - 1e-5, False), # significant improvement + ([1] * 6, 5, 0., True), # no significant improvement + ([1] * 6, 5, 0.001, True), # no significant improvement + ([1] * 6, 5, 5, True), # no significant improvement + ] +) +def test_should_stop(scores, n_iter_no_change, tol, stopping): + + gbdt = HistGradientBoostingClassifier( + n_iter_no_change=n_iter_no_change, tol=tol + ) + assert gbdt._should_stop(scores) == stopping diff --git a/sklearn/ensemble/_hist_gradient_boosting/tests/test_grower.py b/sklearn/ensemble/_hist_gradient_boosting/tests/test_grower.py new file mode 100644 index 0000000000000..49b19ce2778dd --- /dev/null +++ b/sklearn/ensemble/_hist_gradient_boosting/tests/test_grower.py @@ -0,0 +1,309 @@ +import numpy as np +import pytest +from pytest import approx + +from sklearn.ensemble._hist_gradient_boosting.grower import TreeGrower +from sklearn.ensemble._hist_gradient_boosting.binning import _BinMapper +from sklearn.ensemble._hist_gradient_boosting.types import X_BINNED_DTYPE +from sklearn.ensemble._hist_gradient_boosting.types import Y_DTYPE +from sklearn.ensemble._hist_gradient_boosting.types import G_H_DTYPE + + +def _make_training_data(n_bins=256, constant_hessian=True): + rng = np.random.RandomState(42) + n_samples = 10000 + + # Generate some test data directly binned so as to test the grower code + # independently of the binning logic. + X_binned = rng.randint(0, n_bins - 1, size=(n_samples, 2), + dtype=X_BINNED_DTYPE) + X_binned = np.asfortranarray(X_binned) + + def true_decision_function(input_features): + """Ground truth decision function + + This is a very simple yet asymmetric decision tree. Therefore the + grower code should have no trouble recovering the decision function + from 10000 training samples. + """ + if input_features[0] <= n_bins // 2: + return -1 + else: + return -1 if input_features[1] <= n_bins // 3 else 1 + + target = np.array([true_decision_function(x) for x in X_binned], + dtype=Y_DTYPE) + + # Assume a square loss applied to an initial model that always predicts 0 + # (hardcoded for this test): + all_gradients = target.astype(G_H_DTYPE) + shape_hessians = 1 if constant_hessian else all_gradients.shape + all_hessians = np.ones(shape=shape_hessians, dtype=G_H_DTYPE) + + return X_binned, all_gradients, all_hessians + + +def _check_children_consistency(parent, left, right): + # Make sure the samples are correctly dispatched from a parent to its + # children + assert parent.left_child is left + assert parent.right_child is right + + # each sample from the parent is propagated to one of the two children + assert (len(left.sample_indices) + len(right.sample_indices) + == len(parent.sample_indices)) + + assert (set(left.sample_indices).union(set(right.sample_indices)) + == set(parent.sample_indices)) + + # samples are sent either to the left or the right node, never to both + assert (set(left.sample_indices).intersection(set(right.sample_indices)) + == set()) + + +@pytest.mark.parametrize( + 'n_bins, constant_hessian, stopping_param, shrinkage', + [ + (11, True, "min_gain_to_split", 0.5), + (11, False, "min_gain_to_split", 1.), + (11, True, "max_leaf_nodes", 1.), + (11, False, "max_leaf_nodes", 0.1), + (42, True, "max_leaf_nodes", 0.01), + (42, False, "max_leaf_nodes", 1.), + (256, True, "min_gain_to_split", 1.), + (256, True, "max_leaf_nodes", 0.1), + ] +) +def test_grow_tree(n_bins, constant_hessian, stopping_param, shrinkage): + X_binned, all_gradients, all_hessians = _make_training_data( + n_bins=n_bins, constant_hessian=constant_hessian) + n_samples = X_binned.shape[0] + + if stopping_param == "max_leaf_nodes": + stopping_param = {"max_leaf_nodes": 3} + else: + stopping_param = {"min_gain_to_split": 0.01} + + grower = TreeGrower(X_binned, all_gradients, all_hessians, + max_bins=n_bins, shrinkage=shrinkage, + min_samples_leaf=1, **stopping_param) + + # The root node is not yet splitted, but the best possible split has + # already been evaluated: + assert grower.root.left_child is None + assert grower.root.right_child is None + + root_split = grower.root.split_info + assert root_split.feature_idx == 0 + assert root_split.bin_idx == n_bins // 2 + assert len(grower.splittable_nodes) == 1 + + # Calling split next applies the next split and computes the best split + # for each of the two newly introduced children nodes. + left_node, right_node = grower.split_next() + + # All training samples have ben splitted in the two nodes, approximately + # 50%/50% + _check_children_consistency(grower.root, left_node, right_node) + assert len(left_node.sample_indices) > 0.4 * n_samples + assert len(left_node.sample_indices) < 0.6 * n_samples + + if grower.min_gain_to_split > 0: + # The left node is too pure: there is no gain to split it further. + assert left_node.split_info.gain < grower.min_gain_to_split + assert left_node in grower.finalized_leaves + + # The right node can still be splitted further, this time on feature #1 + split_info = right_node.split_info + assert split_info.gain > 1. + assert split_info.feature_idx == 1 + assert split_info.bin_idx == n_bins // 3 + assert right_node.left_child is None + assert right_node.right_child is None + + # The right split has not been applied yet. Let's do it now: + assert len(grower.splittable_nodes) == 1 + right_left_node, right_right_node = grower.split_next() + _check_children_consistency(right_node, right_left_node, right_right_node) + assert len(right_left_node.sample_indices) > 0.1 * n_samples + assert len(right_left_node.sample_indices) < 0.2 * n_samples + + assert len(right_right_node.sample_indices) > 0.2 * n_samples + assert len(right_right_node.sample_indices) < 0.4 * n_samples + + # All the leafs are pure, it is not possible to split any further: + assert not grower.splittable_nodes + + # Check the values of the leaves: + assert grower.root.left_child.value == approx(shrinkage) + assert grower.root.right_child.left_child.value == approx(shrinkage) + assert grower.root.right_child.right_child.value == approx(-shrinkage, + rel=1e-3) + + +def test_predictor_from_grower(): + # Build a tree on the toy 3-leaf dataset to extract the predictor. + n_bins = 256 + X_binned, all_gradients, all_hessians = _make_training_data( + n_bins=n_bins) + grower = TreeGrower(X_binned, all_gradients, all_hessians, + max_bins=n_bins, shrinkage=1., + max_leaf_nodes=3, min_samples_leaf=5) + grower.grow() + assert grower.n_nodes == 5 # (2 decision nodes + 3 leaves) + + # Check that the node structure can be converted into a predictor + # object to perform predictions at scale + predictor = grower.make_predictor() + assert predictor.nodes.shape[0] == 5 + assert predictor.nodes['is_leaf'].sum() == 3 + + # Probe some predictions for each leaf of the tree + # each group of 3 samples corresponds to a condition in _make_training_data + input_data = np.array([ + [0, 0], + [42, 99], + [128, 255], + + [129, 0], + [129, 85], + [255, 85], + + [129, 86], + [129, 255], + [242, 100], + ], dtype=np.uint8) + predictions = predictor.predict_binned(input_data) + expected_targets = [1, 1, 1, 1, 1, 1, -1, -1, -1] + assert np.allclose(predictions, expected_targets) + + # Check that training set can be recovered exactly: + predictions = predictor.predict_binned(X_binned) + assert np.allclose(predictions, -all_gradients) + + +@pytest.mark.parametrize( + 'n_samples, min_samples_leaf, n_bins, constant_hessian, noise', + [ + (11, 10, 7, True, 0), + (13, 10, 42, False, 0), + (56, 10, 255, True, 0.1), + (101, 3, 7, True, 0), + (200, 42, 42, False, 0), + (300, 55, 255, True, 0.1), + (300, 301, 255, True, 0.1), + ] +) +def test_min_samples_leaf(n_samples, min_samples_leaf, n_bins, + constant_hessian, noise): + rng = np.random.RandomState(seed=0) + # data = linear target, 3 features, 1 irrelevant. + X = rng.normal(size=(n_samples, 3)) + y = X[:, 0] - X[:, 1] + if noise: + y_scale = y.std() + y += rng.normal(scale=noise, size=n_samples) * y_scale + mapper = _BinMapper(max_bins=n_bins) + X = mapper.fit_transform(X) + + all_gradients = y.astype(G_H_DTYPE) + shape_hessian = 1 if constant_hessian else all_gradients.shape + all_hessians = np.ones(shape=shape_hessian, dtype=G_H_DTYPE) + grower = TreeGrower(X, all_gradients, all_hessians, + max_bins=n_bins, shrinkage=1., + min_samples_leaf=min_samples_leaf, + max_leaf_nodes=n_samples) + grower.grow() + predictor = grower.make_predictor( + bin_thresholds=mapper.bin_thresholds_) + + if n_samples >= min_samples_leaf: + for node in predictor.nodes: + if node['is_leaf']: + assert node['count'] >= min_samples_leaf + else: + assert predictor.nodes.shape[0] == 1 + assert predictor.nodes[0]['is_leaf'] + assert predictor.nodes[0]['count'] == n_samples + + +@pytest.mark.parametrize('n_samples, min_samples_leaf', [ + (99, 50), + (100, 50)]) +def test_min_samples_leaf_root(n_samples, min_samples_leaf): + # Make sure root node isn't split if n_samples is not at least twice + # min_samples_leaf + rng = np.random.RandomState(seed=0) + + max_bins = 255 + + # data = linear target, 3 features, 1 irrelevant. + X = rng.normal(size=(n_samples, 3)) + y = X[:, 0] - X[:, 1] + mapper = _BinMapper(max_bins=max_bins) + X = mapper.fit_transform(X) + + all_gradients = y.astype(G_H_DTYPE) + all_hessians = np.ones(shape=1, dtype=G_H_DTYPE) + grower = TreeGrower(X, all_gradients, all_hessians, + max_bins=max_bins, shrinkage=1., + min_samples_leaf=min_samples_leaf, + max_leaf_nodes=n_samples) + grower.grow() + if n_samples >= min_samples_leaf * 2: + assert len(grower.finalized_leaves) >= 2 + else: + assert len(grower.finalized_leaves) == 1 + + +@pytest.mark.parametrize('max_depth', [2, 3]) +def test_max_depth(max_depth): + # Make sure max_depth parameter works as expected + rng = np.random.RandomState(seed=0) + + max_bins = 255 + n_samples = 1000 + + # data = linear target, 3 features, 1 irrelevant. + X = rng.normal(size=(n_samples, 3)) + y = X[:, 0] - X[:, 1] + mapper = _BinMapper(max_bins=max_bins) + X = mapper.fit_transform(X) + + all_gradients = y.astype(G_H_DTYPE) + all_hessians = np.ones(shape=1, dtype=G_H_DTYPE) + grower = TreeGrower(X, all_gradients, all_hessians, max_depth=max_depth) + grower.grow() + + depth = max(leaf.depth for leaf in grower.finalized_leaves) + assert depth == max_depth + + +def test_input_validation(): + + X_binned, all_gradients, all_hessians = _make_training_data() + + X_binned_float = X_binned.astype(np.float32) + with pytest.raises(NotImplementedError, + match="X_binned must be of type uint8"): + TreeGrower(X_binned_float, all_gradients, all_hessians) + + X_binned_C_array = np.ascontiguousarray(X_binned) + with pytest.raises( + ValueError, + match="X_binned should be passed as Fortran contiguous array"): + TreeGrower(X_binned_C_array, all_gradients, all_hessians) + + +def test_init_parameters_validation(): + X_binned, all_gradients, all_hessians = _make_training_data() + with pytest.raises(ValueError, + match="min_gain_to_split=-1 must be positive"): + + TreeGrower(X_binned, all_gradients, all_hessians, + min_gain_to_split=-1) + + with pytest.raises(ValueError, + match="min_hessian_to_split=-1 must be positive"): + TreeGrower(X_binned, all_gradients, all_hessians, + min_hessian_to_split=-1) diff --git a/sklearn/ensemble/_hist_gradient_boosting/tests/test_histogram.py b/sklearn/ensemble/_hist_gradient_boosting/tests/test_histogram.py new file mode 100644 index 0000000000000..c425a0389a789 --- /dev/null +++ b/sklearn/ensemble/_hist_gradient_boosting/tests/test_histogram.py @@ -0,0 +1,202 @@ +import numpy as np +import pytest + +from numpy.testing import assert_allclose +from numpy.testing import assert_array_equal + +from sklearn.ensemble._hist_gradient_boosting.histogram import ( + _build_histogram_naive, + _build_histogram, + _build_histogram_no_hessian, + _build_histogram_root_no_hessian, + _build_histogram_root, + _subtract_histograms +) +from sklearn.ensemble._hist_gradient_boosting.types import HISTOGRAM_DTYPE +from sklearn.ensemble._hist_gradient_boosting.types import G_H_DTYPE +from sklearn.ensemble._hist_gradient_boosting.types import X_BINNED_DTYPE + + +@pytest.mark.parametrize( + 'build_func', [_build_histogram_naive, _build_histogram]) +def test_build_histogram(build_func): + binned_feature = np.array([0, 2, 0, 1, 2, 0, 2, 1], dtype=X_BINNED_DTYPE) + + # Small sample_indices (below unrolling threshold) + ordered_gradients = np.array([0, 1, 3], dtype=G_H_DTYPE) + ordered_hessians = np.array([1, 1, 2], dtype=G_H_DTYPE) + + sample_indices = np.array([0, 2, 3], dtype=np.uint32) + hist = np.zeros((1, 3), dtype=HISTOGRAM_DTYPE) + build_func(0, sample_indices, binned_feature, ordered_gradients, + ordered_hessians, hist) + hist = hist[0] + assert_array_equal(hist['count'], [2, 1, 0]) + assert_allclose(hist['sum_gradients'], [1, 3, 0]) + assert_allclose(hist['sum_hessians'], [2, 2, 0]) + + # Larger sample_indices (above unrolling threshold) + sample_indices = np.array([0, 2, 3, 6, 7], dtype=np.uint32) + ordered_gradients = np.array([0, 1, 3, 0, 1], dtype=G_H_DTYPE) + ordered_hessians = np.array([1, 1, 2, 1, 0], dtype=G_H_DTYPE) + + hist = np.zeros((1, 3), dtype=HISTOGRAM_DTYPE) + build_func(0, sample_indices, binned_feature, ordered_gradients, + ordered_hessians, hist) + hist = hist[0] + assert_array_equal(hist['count'], [2, 2, 1]) + assert_allclose(hist['sum_gradients'], [1, 4, 0]) + assert_allclose(hist['sum_hessians'], [2, 2, 1]) + + +def test_histogram_sample_order_independence(): + # Make sure the order of the samples has no impact on the histogram + # computations + rng = np.random.RandomState(42) + n_sub_samples = 100 + n_samples = 1000 + n_bins = 256 + + binned_feature = rng.randint(0, n_bins - 1, size=n_samples, + dtype=X_BINNED_DTYPE) + sample_indices = rng.choice(np.arange(n_samples, dtype=np.uint32), + n_sub_samples, replace=False) + ordered_gradients = rng.randn(n_sub_samples).astype(G_H_DTYPE) + hist_gc = np.zeros((1, n_bins), dtype=HISTOGRAM_DTYPE) + _build_histogram_no_hessian(0, sample_indices, binned_feature, + ordered_gradients, hist_gc) + + ordered_hessians = rng.exponential(size=n_sub_samples).astype(G_H_DTYPE) + hist_ghc = np.zeros((1, n_bins), dtype=HISTOGRAM_DTYPE) + _build_histogram(0, sample_indices, binned_feature, + ordered_gradients, ordered_hessians, hist_ghc) + + permutation = rng.permutation(n_sub_samples) + hist_gc_perm = np.zeros((1, n_bins), dtype=HISTOGRAM_DTYPE) + _build_histogram_no_hessian(0, sample_indices[permutation], + binned_feature, ordered_gradients[permutation], + hist_gc_perm) + + hist_ghc_perm = np.zeros((1, n_bins), dtype=HISTOGRAM_DTYPE) + _build_histogram(0, sample_indices[permutation], binned_feature, + ordered_gradients[permutation], + ordered_hessians[permutation], hist_ghc_perm) + + hist_gc = hist_gc[0] + hist_ghc = hist_ghc[0] + hist_gc_perm = hist_gc_perm[0] + hist_ghc_perm = hist_ghc_perm[0] + + assert_allclose(hist_gc['sum_gradients'], hist_gc_perm['sum_gradients']) + assert_array_equal(hist_gc['count'], hist_gc_perm['count']) + + assert_allclose(hist_ghc['sum_gradients'], hist_ghc_perm['sum_gradients']) + assert_allclose(hist_ghc['sum_hessians'], hist_ghc_perm['sum_hessians']) + assert_array_equal(hist_ghc['count'], hist_ghc_perm['count']) + + +@pytest.mark.parametrize("constant_hessian", [True, False]) +def test_unrolled_equivalent_to_naive(constant_hessian): + # Make sure the different unrolled histogram computations give the same + # results as the naive one. + rng = np.random.RandomState(42) + n_samples = 10 + n_bins = 5 + sample_indices = np.arange(n_samples).astype(np.uint32) + binned_feature = rng.randint(0, n_bins - 1, size=n_samples, dtype=np.uint8) + ordered_gradients = rng.randn(n_samples).astype(G_H_DTYPE) + if constant_hessian: + ordered_hessians = np.ones(n_samples, dtype=G_H_DTYPE) + else: + ordered_hessians = rng.lognormal(size=n_samples).astype(G_H_DTYPE) + + hist_gc_root = np.zeros((1, n_bins), dtype=HISTOGRAM_DTYPE) + hist_ghc_root = np.zeros((1, n_bins), dtype=HISTOGRAM_DTYPE) + hist_gc = np.zeros((1, n_bins), dtype=HISTOGRAM_DTYPE) + hist_ghc = np.zeros((1, n_bins), dtype=HISTOGRAM_DTYPE) + hist_naive = np.zeros((1, n_bins), dtype=HISTOGRAM_DTYPE) + + _build_histogram_root_no_hessian(0, binned_feature, + ordered_gradients, hist_gc_root) + _build_histogram_root(0, binned_feature, ordered_gradients, + ordered_hessians, hist_ghc_root) + _build_histogram_no_hessian(0, sample_indices, binned_feature, + ordered_gradients, hist_gc) + _build_histogram(0, sample_indices, binned_feature, + ordered_gradients, ordered_hessians, hist_ghc) + _build_histogram_naive(0, sample_indices, binned_feature, + ordered_gradients, ordered_hessians, hist_naive) + + hist_naive = hist_naive[0] + hist_gc_root = hist_gc_root[0] + hist_ghc_root = hist_ghc_root[0] + hist_gc = hist_gc[0] + hist_ghc = hist_ghc[0] + for hist in (hist_gc_root, hist_ghc_root, hist_gc, hist_ghc): + assert_array_equal(hist['count'], hist_naive['count']) + assert_allclose(hist['sum_gradients'], hist_naive['sum_gradients']) + for hist in (hist_ghc_root, hist_ghc): + assert_allclose(hist['sum_hessians'], hist_naive['sum_hessians']) + for hist in (hist_gc_root, hist_gc): + assert_array_equal(hist['sum_hessians'], np.zeros(n_bins)) + + +@pytest.mark.parametrize("constant_hessian", [True, False]) +def test_hist_subtraction(constant_hessian): + # Make sure the histogram subtraction trick gives the same result as the + # classical method. + rng = np.random.RandomState(42) + n_samples = 10 + n_bins = 5 + sample_indices = np.arange(n_samples).astype(np.uint32) + binned_feature = rng.randint(0, n_bins - 1, size=n_samples, dtype=np.uint8) + ordered_gradients = rng.randn(n_samples).astype(G_H_DTYPE) + if constant_hessian: + ordered_hessians = np.ones(n_samples, dtype=G_H_DTYPE) + else: + ordered_hessians = rng.lognormal(size=n_samples).astype(G_H_DTYPE) + + hist_parent = np.zeros((1, n_bins), dtype=HISTOGRAM_DTYPE) + if constant_hessian: + _build_histogram_no_hessian(0, sample_indices, binned_feature, + ordered_gradients, hist_parent) + else: + _build_histogram(0, sample_indices, binned_feature, + ordered_gradients, ordered_hessians, hist_parent) + + mask = rng.randint(0, 2, n_samples).astype(np.bool) + + sample_indices_left = sample_indices[mask] + ordered_gradients_left = ordered_gradients[mask] + ordered_hessians_left = ordered_hessians[mask] + hist_left = np.zeros((1, n_bins), dtype=HISTOGRAM_DTYPE) + if constant_hessian: + _build_histogram_no_hessian(0, sample_indices_left, + binned_feature, ordered_gradients_left, + hist_left) + else: + _build_histogram(0, sample_indices_left, binned_feature, + ordered_gradients_left, ordered_hessians_left, + hist_left) + + sample_indices_right = sample_indices[~mask] + ordered_gradients_right = ordered_gradients[~mask] + ordered_hessians_right = ordered_hessians[~mask] + hist_right = np.zeros((1, n_bins), dtype=HISTOGRAM_DTYPE) + if constant_hessian: + _build_histogram_no_hessian(0, sample_indices_right, + binned_feature, ordered_gradients_right, + hist_right) + else: + _build_histogram(0, sample_indices_right, binned_feature, + ordered_gradients_right, ordered_hessians_right, + hist_right) + + hist_left_sub = np.zeros((1, n_bins), dtype=HISTOGRAM_DTYPE) + hist_right_sub = np.zeros((1, n_bins), dtype=HISTOGRAM_DTYPE) + _subtract_histograms(0, n_bins, hist_parent, hist_right, hist_left_sub) + _subtract_histograms(0, n_bins, hist_parent, hist_left, hist_right_sub) + + for key in ('count', 'sum_hessians', 'sum_gradients'): + assert_allclose(hist_left[key], hist_left_sub[key], rtol=1e-6) + assert_allclose(hist_right[key], hist_right_sub[key], rtol=1e-6) diff --git a/sklearn/ensemble/_hist_gradient_boosting/tests/test_loss.py b/sklearn/ensemble/_hist_gradient_boosting/tests/test_loss.py new file mode 100644 index 0000000000000..575095beb4883 --- /dev/null +++ b/sklearn/ensemble/_hist_gradient_boosting/tests/test_loss.py @@ -0,0 +1,192 @@ +import numpy as np +from numpy.testing import assert_almost_equal +from scipy.optimize import newton +from sklearn.utils import assert_all_finite +from sklearn.utils.fixes import sp_version +import pytest + +from sklearn.ensemble._hist_gradient_boosting.loss import _LOSSES +from sklearn.ensemble._hist_gradient_boosting.types import Y_DTYPE +from sklearn.ensemble._hist_gradient_boosting.types import G_H_DTYPE + + +def get_derivatives_helper(loss): + """Return get_gradients() and get_hessians() functions for a given loss. + """ + + def get_gradients(y_true, raw_predictions): + # create gradients and hessians array, update inplace, and return + gradients = np.empty_like(raw_predictions, dtype=G_H_DTYPE) + hessians = np.empty_like(raw_predictions, dtype=G_H_DTYPE) + loss.update_gradients_and_hessians(gradients, hessians, y_true, + raw_predictions) + return gradients + + def get_hessians(y_true, raw_predictions): + # create gradients and hessians array, update inplace, and return + gradients = np.empty_like(raw_predictions, dtype=G_H_DTYPE) + hessians = np.empty_like(raw_predictions, dtype=G_H_DTYPE) + loss.update_gradients_and_hessians(gradients, hessians, y_true, + raw_predictions) + + if loss.__class__.__name__ == 'LeastSquares': + # hessians aren't updated because they're constant: + # the value is 1 because the loss is actually an half + # least squares loss. + hessians = np.full_like(raw_predictions, fill_value=1) + + return hessians + + return get_gradients, get_hessians + + +@pytest.mark.parametrize('loss, x0, y_true', [ + ('least_squares', -2., 42), + ('least_squares', 117., 1.05), + ('least_squares', 0., 0.), + # I don't understand why but y_true == 0 fails :/ + # ('binary_crossentropy', 0.3, 0), + ('binary_crossentropy', -12, 1), + ('binary_crossentropy', 30, 1), +]) +@pytest.mark.skipif(sp_version == (1, 2, 0), + reason='bug in scipy 1.2.0, see scipy issue #9608') +@pytest.mark.skipif(Y_DTYPE != np.float64, + reason='Newton internally uses float64 != Y_DTYPE') +def test_derivatives(loss, x0, y_true): + # Check that gradients are zero when the loss is minimized on 1D array + # using Halley's method with the first and second order derivatives + # computed by the Loss instance. + + loss = _LOSSES[loss]() + y_true = np.array([y_true], dtype=Y_DTYPE) + x0 = np.array([x0], dtype=Y_DTYPE).reshape(1, 1) + get_gradients, get_hessians = get_derivatives_helper(loss) + + def func(x): + return loss(y_true, x) + + def fprime(x): + return get_gradients(y_true, x) + + def fprime2(x): + return get_hessians(y_true, x) + + optimum = newton(func, x0=x0, fprime=fprime, fprime2=fprime2) + assert np.allclose(loss.inverse_link_function(optimum), y_true) + assert np.allclose(loss(y_true, optimum), 0) + assert np.allclose(get_gradients(y_true, optimum), 0) + + +@pytest.mark.parametrize('loss, n_classes, prediction_dim', [ + ('least_squares', 0, 1), + ('binary_crossentropy', 2, 1), + ('categorical_crossentropy', 3, 3), +]) +@pytest.mark.skipif(Y_DTYPE != np.float64, + reason='Need 64 bits float precision for numerical checks') +def test_numerical_gradients(loss, n_classes, prediction_dim): + # Make sure gradients and hessians computed in the loss are correct, by + # comparing with their approximations computed with finite central + # differences. + # See https://en.wikipedia.org/wiki/Finite_difference. + + rng = np.random.RandomState(0) + n_samples = 100 + if loss == 'least_squares': + y_true = rng.normal(size=n_samples).astype(Y_DTYPE) + else: + y_true = rng.randint(0, n_classes, size=n_samples).astype(Y_DTYPE) + raw_predictions = rng.normal( + size=(prediction_dim, n_samples) + ).astype(Y_DTYPE) + loss = _LOSSES[loss]() + get_gradients, get_hessians = get_derivatives_helper(loss) + + # only take gradients and hessians of first tree / class. + gradients = get_gradients(y_true, raw_predictions)[0, :].ravel() + hessians = get_hessians(y_true, raw_predictions)[0, :].ravel() + + # Approximate gradients + # For multiclass loss, we should only change the predictions of one tree + # (here the first), hence the use of offset[:, 0] += eps + # As a softmax is computed, offsetting the whole array by a constant would + # have no effect on the probabilities, and thus on the loss + eps = 1e-9 + offset = np.zeros_like(raw_predictions) + offset[0, :] = eps + f_plus_eps = loss(y_true, raw_predictions + offset / 2, average=False) + f_minus_eps = loss(y_true, raw_predictions - offset / 2, average=False) + numerical_gradients = (f_plus_eps - f_minus_eps) / eps + + # Approximate hessians + eps = 1e-4 # need big enough eps as we divide by its square + offset[0, :] = eps + f_plus_eps = loss(y_true, raw_predictions + offset, average=False) + f_minus_eps = loss(y_true, raw_predictions - offset, average=False) + f = loss(y_true, raw_predictions, average=False) + numerical_hessians = (f_plus_eps + f_minus_eps - 2 * f) / eps**2 + + def relative_error(a, b): + return np.abs(a - b) / np.maximum(np.abs(a), np.abs(b)) + + assert np.allclose(numerical_gradients, gradients, rtol=1e-5) + assert np.allclose(numerical_hessians, hessians, rtol=1e-5) + + +def test_baseline_least_squares(): + rng = np.random.RandomState(0) + + loss = _LOSSES['least_squares']() + y_train = rng.normal(size=100) + baseline_prediction = loss.get_baseline_prediction(y_train, 1) + assert baseline_prediction.shape == tuple() # scalar + assert baseline_prediction.dtype == y_train.dtype + # Make sure baseline prediction is the mean of all targets + assert_almost_equal(baseline_prediction, y_train.mean()) + + +def test_baseline_binary_crossentropy(): + rng = np.random.RandomState(0) + + loss = _LOSSES['binary_crossentropy']() + for y_train in (np.zeros(shape=100), np.ones(shape=100)): + y_train = y_train.astype(np.float64) + baseline_prediction = loss.get_baseline_prediction(y_train, 1) + assert_all_finite(baseline_prediction) + assert np.allclose(loss.inverse_link_function(baseline_prediction), + y_train[0]) + + # Make sure baseline prediction is equal to link_function(p), where p + # is the proba of the positive class. We want predict_proba() to return p, + # and by definition + # p = inverse_link_function(raw_prediction) = sigmoid(raw_prediction) + # So we want raw_prediction = link_function(p) = log(p / (1 - p)) + y_train = rng.randint(0, 2, size=100).astype(np.float64) + baseline_prediction = loss.get_baseline_prediction(y_train, 1) + assert baseline_prediction.shape == tuple() # scalar + assert baseline_prediction.dtype == y_train.dtype + p = y_train.mean() + assert np.allclose(baseline_prediction, np.log(p / (1 - p))) + + +def test_baseline_categorical_crossentropy(): + rng = np.random.RandomState(0) + + prediction_dim = 4 + loss = _LOSSES['categorical_crossentropy']() + for y_train in (np.zeros(shape=100), np.ones(shape=100)): + y_train = y_train.astype(np.float64) + baseline_prediction = loss.get_baseline_prediction(y_train, + prediction_dim) + assert baseline_prediction.dtype == y_train.dtype + assert_all_finite(baseline_prediction) + + # Same logic as for above test. Here inverse_link_function = softmax and + # link_function = log + y_train = rng.randint(0, prediction_dim + 1, size=100).astype(np.float32) + baseline_prediction = loss.get_baseline_prediction(y_train, prediction_dim) + assert baseline_prediction.shape == (prediction_dim, 1) + for k in range(prediction_dim): + p = (y_train == k).mean() + assert np.allclose(baseline_prediction[k, :], np.log(p)) diff --git a/sklearn/ensemble/_hist_gradient_boosting/tests/test_predictor.py b/sklearn/ensemble/_hist_gradient_boosting/tests/test_predictor.py new file mode 100644 index 0000000000000..80a56bfe78ded --- /dev/null +++ b/sklearn/ensemble/_hist_gradient_boosting/tests/test_predictor.py @@ -0,0 +1,36 @@ +import numpy as np +from sklearn.datasets import load_boston +from sklearn.model_selection import train_test_split +from sklearn.metrics import r2_score +import pytest + +from sklearn.ensemble._hist_gradient_boosting.binning import _BinMapper +from sklearn.ensemble._hist_gradient_boosting.grower import TreeGrower +from sklearn.ensemble._hist_gradient_boosting.types import G_H_DTYPE + + +@pytest.mark.parametrize('max_bins', [200, 256]) +def test_boston_dataset(max_bins): + boston = load_boston() + X_train, X_test, y_train, y_test = train_test_split( + boston.data, boston.target, random_state=42) + + mapper = _BinMapper(max_bins=max_bins, random_state=42) + X_train_binned = mapper.fit_transform(X_train) + + # Init gradients and hessians to that of least squares loss + gradients = -y_train.astype(G_H_DTYPE) + hessians = np.ones(1, dtype=G_H_DTYPE) + + min_samples_leaf = 8 + max_leaf_nodes = 31 + grower = TreeGrower(X_train_binned, gradients, hessians, + min_samples_leaf=min_samples_leaf, + max_leaf_nodes=max_leaf_nodes, max_bins=max_bins, + actual_n_bins=mapper.actual_n_bins_) + grower.grow() + + predictor = grower.make_predictor(bin_thresholds=mapper.bin_thresholds_) + + assert r2_score(y_train, predictor.predict(X_train)) > 0.85 + assert r2_score(y_test, predictor.predict(X_test)) > 0.70 diff --git a/sklearn/ensemble/_hist_gradient_boosting/tests/test_splitting.py b/sklearn/ensemble/_hist_gradient_boosting/tests/test_splitting.py new file mode 100644 index 0000000000000..d34f5ef064137 --- /dev/null +++ b/sklearn/ensemble/_hist_gradient_boosting/tests/test_splitting.py @@ -0,0 +1,262 @@ +import numpy as np +import pytest + +from sklearn.ensemble._hist_gradient_boosting.types import HISTOGRAM_DTYPE +from sklearn.ensemble._hist_gradient_boosting.types import G_H_DTYPE +from sklearn.ensemble._hist_gradient_boosting.types import X_BINNED_DTYPE +from sklearn.ensemble._hist_gradient_boosting.splitting import Splitter +from sklearn.ensemble._hist_gradient_boosting.histogram import HistogramBuilder + + +@pytest.mark.parametrize('n_bins', [3, 32, 256]) +def test_histogram_split(n_bins): + rng = np.random.RandomState(42) + feature_idx = 0 + l2_regularization = 0 + min_hessian_to_split = 1e-3 + min_samples_leaf = 1 + min_gain_to_split = 0. + X_binned = np.asfortranarray( + rng.randint(0, n_bins, size=(int(1e4), 1)), dtype=X_BINNED_DTYPE) + binned_feature = X_binned.T[feature_idx] + sample_indices = np.arange(binned_feature.shape[0], dtype=np.uint32) + ordered_hessians = np.ones_like(binned_feature, dtype=G_H_DTYPE) + all_hessians = ordered_hessians + sum_hessians = all_hessians.sum() + hessians_are_constant = False + + for true_bin in range(1, n_bins - 1): + for sign in [-1, 1]: + ordered_gradients = np.full_like(binned_feature, sign, + dtype=G_H_DTYPE) + ordered_gradients[binned_feature <= true_bin] *= -1 + all_gradients = ordered_gradients + sum_gradients = all_gradients.sum() + + actual_n_bins = np.array([n_bins] * X_binned.shape[1], + dtype=np.uint32) + builder = HistogramBuilder(X_binned, + n_bins, + all_gradients, + all_hessians, + hessians_are_constant) + splitter = Splitter(X_binned, + n_bins, + actual_n_bins, + l2_regularization, + min_hessian_to_split, + min_samples_leaf, min_gain_to_split, + hessians_are_constant) + + histograms = builder.compute_histograms_brute(sample_indices) + split_info = splitter.find_node_split( + sample_indices, histograms, sum_gradients, + sum_hessians) + + assert split_info.bin_idx == true_bin + assert split_info.gain >= 0 + assert split_info.feature_idx == feature_idx + assert (split_info.n_samples_left + split_info.n_samples_right + == sample_indices.shape[0]) + # Constant hessian: 1. per sample. + assert split_info.n_samples_left == split_info.sum_hessian_left + + +@pytest.mark.parametrize('constant_hessian', [True, False]) +def test_gradient_and_hessian_sanity(constant_hessian): + # This test checks that the values of gradients and hessians are + # consistent in different places: + # - in split_info: si.sum_gradient_left + si.sum_gradient_right must be + # equal to the gradient at the node. Same for hessians. + # - in the histograms: summing 'sum_gradients' over the bins must be + # constant across all features, and those sums must be equal to the + # node's gradient. Same for hessians. + + rng = np.random.RandomState(42) + + n_bins = 10 + n_features = 20 + n_samples = 500 + l2_regularization = 0. + min_hessian_to_split = 1e-3 + min_samples_leaf = 1 + min_gain_to_split = 0. + + X_binned = rng.randint(0, n_bins, size=(n_samples, n_features), + dtype=X_BINNED_DTYPE) + X_binned = np.asfortranarray(X_binned) + sample_indices = np.arange(n_samples, dtype=np.uint32) + all_gradients = rng.randn(n_samples).astype(G_H_DTYPE) + sum_gradients = all_gradients.sum() + if constant_hessian: + all_hessians = np.ones(1, dtype=G_H_DTYPE) + sum_hessians = 1 * n_samples + else: + all_hessians = rng.lognormal(size=n_samples).astype(G_H_DTYPE) + sum_hessians = all_hessians.sum() + + actual_n_bins = np.array([n_bins] * X_binned.shape[1], + dtype=np.uint32) + builder = HistogramBuilder(X_binned, n_bins, all_gradients, + all_hessians, constant_hessian) + splitter = Splitter(X_binned, n_bins, actual_n_bins, + l2_regularization, min_hessian_to_split, + min_samples_leaf, min_gain_to_split, constant_hessian) + + hists_parent = builder.compute_histograms_brute(sample_indices) + si_parent = splitter.find_node_split(sample_indices, hists_parent, + sum_gradients, sum_hessians) + sample_indices_left, sample_indices_right, _ = splitter.split_indices( + si_parent, sample_indices) + + hists_left = builder.compute_histograms_brute(sample_indices_left) + hists_right = builder.compute_histograms_brute(sample_indices_right) + si_left = splitter.find_node_split(sample_indices_left, hists_left, + si_parent.sum_gradient_left, + si_parent.sum_hessian_left) + si_right = splitter.find_node_split(sample_indices_right, hists_right, + si_parent.sum_gradient_right, + si_parent.sum_hessian_right) + + # make sure that si.sum_gradient_left + si.sum_gradient_right have their + # expected value, same for hessians + for si, indices in ( + (si_parent, sample_indices), + (si_left, sample_indices_left), + (si_right, sample_indices_right)): + gradient = si.sum_gradient_right + si.sum_gradient_left + expected_gradient = all_gradients[indices].sum() + hessian = si.sum_hessian_right + si.sum_hessian_left + if constant_hessian: + expected_hessian = indices.shape[0] * all_hessians[0] + else: + expected_hessian = all_hessians[indices].sum() + + assert np.isclose(gradient, expected_gradient) + assert np.isclose(hessian, expected_hessian) + + # make sure sum of gradients in histograms are the same for all features, + # and make sure they're equal to their expected value + hists_parent = np.asarray(hists_parent, dtype=HISTOGRAM_DTYPE) + hists_left = np.asarray(hists_left, dtype=HISTOGRAM_DTYPE) + hists_right = np.asarray(hists_right, dtype=HISTOGRAM_DTYPE) + for hists, indices in ( + (hists_parent, sample_indices), + (hists_left, sample_indices_left), + (hists_right, sample_indices_right)): + # note: gradients and hessians have shape (n_features,), + # we're comparing them to *scalars*. This has the benefit of also + # making sure that all the entries are equal across features. + gradients = hists['sum_gradients'].sum(axis=1) # shape = (n_features,) + expected_gradient = all_gradients[indices].sum() # scalar + hessians = hists['sum_hessians'].sum(axis=1) + if constant_hessian: + # 0 is not the actual hessian, but it's not computed in this case + expected_hessian = 0. + else: + expected_hessian = all_hessians[indices].sum() + + assert np.allclose(gradients, expected_gradient) + assert np.allclose(hessians, expected_hessian) + + +def test_split_indices(): + # Check that split_indices returns the correct splits and that + # splitter.partition is consistent with what is returned. + rng = np.random.RandomState(421) + + n_bins = 5 + n_samples = 10 + l2_regularization = 0. + min_hessian_to_split = 1e-3 + min_samples_leaf = 1 + min_gain_to_split = 0. + + # split will happen on feature 1 and on bin 3 + X_binned = [[0, 0], + [0, 3], + [0, 4], + [0, 0], + [0, 0], + [0, 0], + [0, 0], + [0, 4], + [0, 0], + [0, 4]] + X_binned = np.asfortranarray(X_binned, dtype=X_BINNED_DTYPE) + sample_indices = np.arange(n_samples, dtype=np.uint32) + all_gradients = rng.randn(n_samples).astype(G_H_DTYPE) + all_hessians = np.ones(1, dtype=G_H_DTYPE) + sum_gradients = all_gradients.sum() + sum_hessians = 1 * n_samples + hessians_are_constant = True + + actual_n_bins = np.array([n_bins] * X_binned.shape[1], + dtype=np.uint32) + builder = HistogramBuilder(X_binned, n_bins, + all_gradients, all_hessians, + hessians_are_constant) + splitter = Splitter(X_binned, n_bins, actual_n_bins, + l2_regularization, min_hessian_to_split, + min_samples_leaf, min_gain_to_split, + hessians_are_constant) + + assert np.all(sample_indices == splitter.partition) + + histograms = builder.compute_histograms_brute(sample_indices) + si_root = splitter.find_node_split(sample_indices, histograms, + sum_gradients, sum_hessians) + + # sanity checks for best split + assert si_root.feature_idx == 1 + assert si_root.bin_idx == 3 + + samples_left, samples_right, position_right = splitter.split_indices( + si_root, splitter.partition) + assert set(samples_left) == set([0, 1, 3, 4, 5, 6, 8]) + assert set(samples_right) == set([2, 7, 9]) + + assert list(samples_left) == list(splitter.partition[:position_right]) + assert list(samples_right) == list(splitter.partition[position_right:]) + + # Check that the resulting split indices sizes are consistent with the + # count statistics anticipated when looking for the best split. + assert samples_left.shape[0] == si_root.n_samples_left + assert samples_right.shape[0] == si_root.n_samples_right + + +def test_min_gain_to_split(): + # Try to split a pure node (all gradients are equal, same for hessians) + # with min_gain_to_split = 0 and make sure that the node is not split (best + # possible gain = -1). Note: before the strict inequality comparison, this + # test would fail because the node would be split with a gain of 0. + rng = np.random.RandomState(42) + l2_regularization = 0 + min_hessian_to_split = 0 + min_samples_leaf = 1 + min_gain_to_split = 0. + n_bins = 255 + n_samples = 100 + X_binned = np.asfortranarray( + rng.randint(0, n_bins, size=(n_samples, 1)), dtype=X_BINNED_DTYPE) + binned_feature = X_binned[:, 0] + sample_indices = np.arange(n_samples, dtype=np.uint32) + all_hessians = np.ones_like(binned_feature, dtype=G_H_DTYPE) + all_gradients = np.ones_like(binned_feature, dtype=G_H_DTYPE) + sum_gradients = all_gradients.sum() + sum_hessians = all_hessians.sum() + hessians_are_constant = False + + actual_n_bins = np.array([n_bins] * X_binned.shape[1], + dtype=np.uint32) + builder = HistogramBuilder(X_binned, n_bins, all_gradients, + all_hessians, hessians_are_constant) + splitter = Splitter(X_binned, n_bins, actual_n_bins, + l2_regularization, min_hessian_to_split, + min_samples_leaf, min_gain_to_split, + hessians_are_constant) + + histograms = builder.compute_histograms_brute(sample_indices) + split_info = splitter.find_node_split(sample_indices, histograms, + sum_gradients, sum_hessians) + assert split_info.gain == -1 diff --git a/sklearn/ensemble/_hist_gradient_boosting/types.pxd b/sklearn/ensemble/_hist_gradient_boosting/types.pxd new file mode 100644 index 0000000000000..1dd1fbee4273c --- /dev/null +++ b/sklearn/ensemble/_hist_gradient_boosting/types.pxd @@ -0,0 +1,16 @@ +# cython: language_level=3 +import numpy as np +cimport numpy as np + + +ctypedef np.npy_float64 X_DTYPE_C +ctypedef np.npy_uint8 X_BINNED_DTYPE_C +ctypedef np.npy_float64 Y_DTYPE_C +ctypedef np.npy_float32 G_H_DTYPE_C + +cdef packed struct hist_struct: + # Same as histogram dtype but we need a struct to declare views. It needs + # to be packed since by default numpy dtypes aren't aligned + Y_DTYPE_C sum_gradients + Y_DTYPE_C sum_hessians + unsigned int count diff --git a/sklearn/ensemble/_hist_gradient_boosting/types.pyx b/sklearn/ensemble/_hist_gradient_boosting/types.pyx new file mode 100644 index 0000000000000..e13b5320bad32 --- /dev/null +++ b/sklearn/ensemble/_hist_gradient_boosting/types.pyx @@ -0,0 +1,16 @@ +import numpy as np + +# Y_DYTPE is the dtype to which the targets y are converted to. This is also +# dtype for leaf values, gains, and sums of gradients / hessians. The gradients +# and hessians arrays are stored as floats to avoid using too much memory. +Y_DTYPE = np.float64 +X_DTYPE = np.float64 +X_BINNED_DTYPE = np.uint8 # hence max_bins == 256 +# dtypes for gradients and hessians arrays +G_H_DTYPE = np.float32 + +HISTOGRAM_DTYPE = np.dtype([ + ('sum_gradients', Y_DTYPE), # sum of sample gradients in bin + ('sum_hessians', Y_DTYPE), # sum of sample hessians in bin + ('count', np.uint32), # number of samples in bin +]) diff --git a/sklearn/ensemble/_hist_gradient_boosting/utils.pyx b/sklearn/ensemble/_hist_gradient_boosting/utils.pyx new file mode 100644 index 0000000000000..fa9556ef9efb5 --- /dev/null +++ b/sklearn/ensemble/_hist_gradient_boosting/utils.pyx @@ -0,0 +1,151 @@ +# cython: cdivision=True +# cython: boundscheck=False +# cython: wraparound=False +# cython: language_level=3 +"""This module contains utility routines.""" +# Author: Nicolas Hug + +from cython.parallel import prange + +from ...base import is_classifier +from .binning import _BinMapper +from .types cimport G_H_DTYPE_C +from .types cimport Y_DTYPE_C + + +def get_equivalent_estimator(estimator, lib='lightgbm'): + """Return an unfitted estimator from another lib with matching hyperparams. + + This utility function takes care of renaming the sklearn parameters into + their LightGBM, XGBoost or CatBoost equivalent parameters. + + # unmapped XGB parameters: + # - min_samples_leaf + # - min_data_in_bin + # - min_split_gain (there is min_split_loss though?) + + # unmapped Catboost parameters: + # max_leaves + # min_* + """ + + if lib not in ('lightgbm', 'xgboost', 'catboost'): + raise ValueError('accepted libs are lightgbm, xgboost, and catboost. ' + ' got {}'.format(lib)) + + sklearn_params = estimator.get_params() + + if sklearn_params['loss'] == 'auto': + raise ValueError('auto loss is not accepted. We need to know if ' + 'the problem is binary or multiclass classification.') + if sklearn_params['n_iter_no_change'] is not None: + raise NotImplementedError('Early stopping should be deactivated.') + + lightgbm_loss_mapping = { + 'least_squares': 'regression_l2', + 'binary_crossentropy': 'binary', + 'categorical_crossentropy': 'multiclass' + } + + lightgbm_params = { + 'objective': lightgbm_loss_mapping[sklearn_params['loss']], + 'learning_rate': sklearn_params['learning_rate'], + 'n_estimators': sklearn_params['max_iter'], + 'num_leaves': sklearn_params['max_leaf_nodes'], + 'max_depth': sklearn_params['max_depth'], + 'min_child_samples': sklearn_params['min_samples_leaf'], + 'reg_lambda': sklearn_params['l2_regularization'], + 'max_bin': sklearn_params['max_bins'], + 'min_data_in_bin': 1, + 'min_child_weight': 1e-3, + 'min_sum_hessian_in_leaf': 1e-3, + 'min_split_gain': 0, + 'verbosity': 10 if sklearn_params['verbose'] else -10, + 'boost_from_average': True, + 'enable_bundle': False, # also makes feature order consistent + 'min_data_in_bin': 1, + 'subsample_for_bin': _BinMapper().subsample, + } + + if sklearn_params['loss'] == 'categorical_crossentropy': + # LightGBM multiplies hessians by 2 in multiclass loss. + lightgbm_params['min_sum_hessian_in_leaf'] *= 2 + lightgbm_params['learning_rate'] *= 2 + + # XGB + xgboost_loss_mapping = { + 'least_squares': 'reg:linear', + 'binary_crossentropy': 'reg:logistic', + 'categorical_crossentropy': 'multi:softmax' + } + + xgboost_params = { + 'tree_method': 'hist', + 'grow_policy': 'lossguide', # so that we can set max_leaves + 'objective': xgboost_loss_mapping[sklearn_params['loss']], + 'learning_rate': sklearn_params['learning_rate'], + 'n_estimators': sklearn_params['max_iter'], + 'max_leaves': sklearn_params['max_leaf_nodes'], + 'max_depth': sklearn_params['max_depth'] or 0, + 'lambda': sklearn_params['l2_regularization'], + 'max_bin': sklearn_params['max_bins'], + 'min_child_weight': 1e-3, + 'verbosity': 2 if sklearn_params['verbose'] else 0, + 'silent': sklearn_params['verbose'] == 0, + 'n_jobs': -1, + } + + # Catboost + catboost_loss_mapping = { + 'least_squares': 'RMSE', + 'binary_crossentropy': 'Logloss', + 'categorical_crossentropy': 'MultiClass' + } + + catboost_params = { + 'loss_function': catboost_loss_mapping[sklearn_params['loss']], + 'learning_rate': sklearn_params['learning_rate'], + 'iterations': sklearn_params['max_iter'], + 'depth': sklearn_params['max_depth'], + 'reg_lambda': sklearn_params['l2_regularization'], + 'max_bin': sklearn_params['max_bins'], + 'feature_border_type': 'Median', + 'leaf_estimation_method': 'Newton', + 'verbose': bool(sklearn_params['verbose']), + } + + if lib == 'lightgbm': + from lightgbm import LGBMRegressor + from lightgbm import LGBMClassifier + if is_classifier(estimator): + return LGBMClassifier(**lightgbm_params) + else: + return LGBMRegressor(**lightgbm_params) + + elif lib == 'xgboost': + from xgboost import XGBRegressor + from xgboost import XGBClassifier + if is_classifier(estimator): + return XGBClassifier(**xgboost_params) + else: + return XGBRegressor(**xgboost_params) + + else: + from catboost import CatBoostRegressor + from catboost import CatBoostClassifier + if is_classifier(estimator): + return CatBoostClassifier(**catboost_params) + else: + return CatBoostRegressor(**catboost_params) + + +def sum_parallel(G_H_DTYPE_C [:] array): + + cdef: + Y_DTYPE_C out = 0. + int i = 0 + + for i in prange(array.shape[0], schedule='static', nogil=True): + out += array[i] + + return out diff --git a/sklearn/ensemble/gradient_boosting.py b/sklearn/ensemble/gradient_boosting.py index b4199526bb4e3..3ce0eb7f456da 100644 --- a/sklearn/ensemble/gradient_boosting.py +++ b/sklearn/ensemble/gradient_boosting.py @@ -2012,6 +2012,7 @@ class GradientBoostingClassifier(BaseGradientBoosting, ClassifierMixin): See also -------- + sklearn.ensemble.HistGradientBoostingClassifier, sklearn.tree.DecisionTreeClassifier, RandomForestClassifier AdaBoostClassifier @@ -2472,7 +2473,8 @@ class GradientBoostingRegressor(BaseGradientBoosting, RegressorMixin): See also -------- - DecisionTreeRegressor, RandomForestRegressor + sklearn.ensemble.HistGradientBoostingRegressor, + sklearn.tree.DecisionTreeRegressor, RandomForestRegressor References ---------- diff --git a/sklearn/ensemble/setup.py b/sklearn/ensemble/setup.py index 34fb63b906d0a..88e1b2e32d98d 100644 --- a/sklearn/ensemble/setup.py +++ b/sklearn/ensemble/setup.py @@ -4,12 +4,49 @@ def configuration(parent_package="", top_path=None): config = Configuration("ensemble", parent_package, top_path) + config.add_extension("_gradient_boosting", sources=["_gradient_boosting.pyx"], include_dirs=[numpy.get_include()]) config.add_subpackage("tests") + # Histogram-based gradient boosting files + config.add_extension( + "_hist_gradient_boosting._gradient_boosting", + sources=["_hist_gradient_boosting/_gradient_boosting.pyx"], + include_dirs=[numpy.get_include()]) + + config.add_extension("_hist_gradient_boosting.histogram", + sources=["_hist_gradient_boosting/histogram.pyx"], + include_dirs=[numpy.get_include()]) + + config.add_extension("_hist_gradient_boosting.splitting", + sources=["_hist_gradient_boosting/splitting.pyx"], + include_dirs=[numpy.get_include()]) + + config.add_extension("_hist_gradient_boosting._binning", + sources=["_hist_gradient_boosting/_binning.pyx"], + include_dirs=[numpy.get_include()]) + + config.add_extension("_hist_gradient_boosting._predictor", + sources=["_hist_gradient_boosting/_predictor.pyx"], + include_dirs=[numpy.get_include()]) + + config.add_extension("_hist_gradient_boosting._loss", + sources=["_hist_gradient_boosting/_loss.pyx"], + include_dirs=[numpy.get_include()]) + + config.add_extension("_hist_gradient_boosting.types", + sources=["_hist_gradient_boosting/types.pyx"], + include_dirs=[numpy.get_include()]) + + config.add_extension("_hist_gradient_boosting.utils", + sources=["_hist_gradient_boosting/utils.pyx"], + include_dirs=[numpy.get_include()]) + + config.add_subpackage("_hist_gradient_boosting.tests") + return config if __name__ == "__main__": diff --git a/sklearn/experimental/__init__.py b/sklearn/experimental/__init__.py new file mode 100644 index 0000000000000..0effaf5b05fa0 --- /dev/null +++ b/sklearn/experimental/__init__.py @@ -0,0 +1,7 @@ +""" +The :mod:`sklearn.experimental` module provides importable modules that enable +the use of experimental features or estimators. + +The features and estimators that are experimental aren't subject to +deprecation cycles. Use them at your own risks! +""" diff --git a/sklearn/experimental/enable_hist_gradient_boosting.py b/sklearn/experimental/enable_hist_gradient_boosting.py new file mode 100644 index 0000000000000..6b0a6ad8a28bb --- /dev/null +++ b/sklearn/experimental/enable_hist_gradient_boosting.py @@ -0,0 +1,32 @@ +"""Enables histogram-based gradient boosting estimators. + +The API and results of these estimators might change without any deprecation +cycle. + +Importing this file dynamically sets the +:class:`sklearn.ensemble.HistGradientBoostingClassifier` and +:class:`sklearn.ensemble.HistGradientBoostingRegressor` as attributes of the +ensemble module:: + + >>> # explicitly require this experimental feature + >>> from sklearn.experimental import enable_hist_gradient_boosting # noqa + >>> # now you can import normally from ensemble + >>> from sklearn.ensemble import HistGradientBoostingClassifier + >>> from sklearn.ensemble import HistGradientBoostingRegressor + + +The ``# noqa`` comment comment can be removed: it just tells linters like +flake8 to ignore the import, which appears as unused. +""" + +from ..ensemble._hist_gradient_boosting.gradient_boosting import ( + HistGradientBoostingClassifier, + HistGradientBoostingRegressor +) + +from .. import ensemble + +ensemble.HistGradientBoostingClassifier = HistGradientBoostingClassifier +ensemble.HistGradientBoostingRegressor = HistGradientBoostingRegressor +ensemble.__all__ += ['HistGradientBoostingClassifier', + 'HistGradientBoostingRegressor'] diff --git a/sklearn/experimental/tests/__init__.py b/sklearn/experimental/tests/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sklearn/experimental/tests/test_enable_hist_gradient_boosting.py b/sklearn/experimental/tests/test_enable_hist_gradient_boosting.py new file mode 100644 index 0000000000000..eff4f53d810a9 --- /dev/null +++ b/sklearn/experimental/tests/test_enable_hist_gradient_boosting.py @@ -0,0 +1,45 @@ +"""Tests for making sure experimental imports work as expected.""" + +import textwrap + +from sklearn.utils.testing import assert_run_python_script + + +def test_imports_strategies(): + # Make sure different import strategies work or fail as expected. + + # Since Python caches the imported modules, we need to run a child process + # for every test case. Else, the tests would not be independent + # (manually removing the imports from the cache (sys.modules) is not + # recommended and can lead to many complications). + + good_import = """ + from sklearn.experimental import enable_hist_gradient_boosting + from sklearn.ensemble import GradientBoostingClassifier + from sklearn.ensemble import GradientBoostingRegressor + """ + assert_run_python_script(textwrap.dedent(good_import)) + + good_import_with_ensemble_first = """ + import sklearn.ensemble + from sklearn.experimental import enable_hist_gradient_boosting + from sklearn.ensemble import GradientBoostingClassifier + from sklearn.ensemble import GradientBoostingRegressor + """ + assert_run_python_script(textwrap.dedent(good_import_with_ensemble_first)) + + bad_imports = """ + import pytest + + with pytest.raises(ImportError): + from sklearn.ensemble import HistGradientBoostingClassifier + + with pytest.raises(ImportError): + from sklearn.ensemble._hist_gradient_boosting import ( + HistGradientBoostingClassifier) + + import sklearn.experimental + with pytest.raises(ImportError): + from sklearn.ensemble import HistGradientBoostingClassifier + """ + assert_run_python_script(textwrap.dedent(bad_imports)) diff --git a/sklearn/setup.py b/sklearn/setup.py index 6ea9fecf83a76..e6f10cad77d9f 100644 --- a/sklearn/setup.py +++ b/sklearn/setup.py @@ -45,6 +45,10 @@ def configuration(parent_package='', top_path=None): config.add_subpackage('preprocessing/tests') config.add_subpackage('semi_supervised') config.add_subpackage('semi_supervised/tests') + config.add_subpackage('experimental') + config.add_subpackage('experimental/tests') + config.add_subpackage('ensemble/_hist_gradient_boosting') + config.add_subpackage('ensemble/_hist_gradient_boosting/tests') # submodules which have their own setup.py config.add_subpackage('cluster') diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 26360c1ef07c1..d5d59a041fdf4 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -396,6 +396,12 @@ def set_checking_parameters(estimator): # which is more feature than we have in most case. estimator.set_params(k=1) + if name in ('HistGradientBoostingClassifier', + 'HistGradientBoostingRegressor'): + # The default min_samples_leaf (20) isn't appropriate for small + # datasets (only very shallow trees are built) that the checks use. + estimator.set_params(min_samples_leaf=5) + class NotAnArray: """An object that is convertible to an array @@ -2462,6 +2468,7 @@ def check_fit_idempotent(name, estimator_orig): if hasattr(estimator, method)} # Fit again + set_random_state(estimator) estimator.fit(X_train, y_train) for method in check_methods: diff --git a/sklearn/utils/testing.py b/sklearn/utils/testing.py index 1662294189690..65bed4c7ecef8 100644 --- a/sklearn/utils/testing.py +++ b/sklearn/utils/testing.py @@ -11,11 +11,16 @@ # Thierry Guillemot # License: BSD 3 clause import os +import os.path as op import inspect import pkgutil import warnings import sys import functools +import tempfile +from subprocess import check_output, STDOUT, CalledProcessError +from subprocess import TimeoutExpired + import scipy as sp import scipy.io @@ -27,7 +32,6 @@ import tempfile import shutil -import os.path as op import atexit import unittest @@ -82,7 +86,8 @@ "assert_array_almost_equal", "assert_array_less", "assert_less", "assert_less_equal", "assert_greater", "assert_greater_equal", - "assert_approx_equal", "assert_allclose", "SkipTest"] + "assert_approx_equal", "assert_allclose", + "assert_run_python_script", "SkipTest"] __all__.extend(additional_names_in_all) _dummy = TestCase('__init__') @@ -970,3 +975,70 @@ def check_docstring_parameters(func, doc=None, ignore=None, class_name=None): if n1 != n2: incorrect += [func_name + ' ' + n1 + ' != ' + n2] return incorrect + + +def assert_run_python_script(source_code, timeout=60): + """Utility to check assertions in an independent Python subprocess. + + The script provided in the source code should return 0 and not print + anything on stderr or stdout. + + This is a port from cloudpickle https://github.com/cloudpipe/cloudpickle + + Parameters + ---------- + source_code : str + The Python source code to execute. + timeout : int + Time in seconds before timeout. + """ + fd, source_file = tempfile.mkstemp(suffix='_src_test_sklearn.py') + os.close(fd) + try: + with open(source_file, 'wb') as f: + f.write(source_code.encode('utf-8')) + cmd = [sys.executable, source_file] + cwd = op.normpath(op.join(op.dirname(sklearn.__file__), '..')) + env = os.environ.copy() + kwargs = { + 'cwd': cwd, + 'stderr': STDOUT, + 'env': env, + } + # If coverage is running, pass the config file to the subprocess + coverage_rc = os.environ.get("COVERAGE_PROCESS_START") + if coverage_rc: + kwargs['env']['COVERAGE_PROCESS_START'] = coverage_rc + + kwargs['timeout'] = timeout + try: + try: + out = check_output(cmd, **kwargs) + except CalledProcessError as e: + raise RuntimeError(u"script errored with output:\n%s" + % e.output.decode('utf-8')) + if out != b"": + raise AssertionError(out.decode('utf-8')) + except TimeoutExpired as e: + raise RuntimeError(u"script timeout, output so far:\n%s" + % e.output.decode('utf-8')) + finally: + os.unlink(source_file) + + +def close_figure(fig=None): + """Close a matplotlibt figure. + + Parameters + ---------- + fig : int or str or Figure, optional (default=None) + The figure, figure number or figure name to close. If ``None``, all + current figures are closed. + """ + from matplotlib.pyplot import get_fignums, close as _close # noqa + + if fig is None: + for fig in get_fignums(): + _close(fig) + else: + _close(fig)