Skip to content

ENH Adds Categorical Support to Histogram Gradient Boosting #16909

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 89 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
89 commits
Select commit Hold shift + click to select a range
02d89d7
ENH Adds categorical support
thomasjpfan Apr 13, 2020
8472f60
DOC Improves english
thomasjpfan Apr 13, 2020
1198340
REV Less diffs
thomasjpfan Apr 13, 2020
63f56fd
DOC Adds comment
thomasjpfan Apr 13, 2020
f34087e
DOC Adds performance comment
thomasjpfan Apr 13, 2020
0b2ed9c
DOC Adds performance comment
thomasjpfan Apr 13, 2020
5eaf099
STY Fix
thomasjpfan Apr 13, 2020
43822ab
ENH Much faster bin mapping when transforming categories
thomasjpfan Apr 13, 2020
0d6012a
CLN Removes uneeded code
thomasjpfan Apr 13, 2020
b22151f
BUG Code actually is being used
thomasjpfan Apr 13, 2020
8432bac
Merge remote-tracking branch 'upstream/master' into cat_hgbt_rb
thomasjpfan Apr 29, 2020
ae9be56
CLN Address comments
thomasjpfan Apr 30, 2020
d0557a5
Merge remote-tracking branch 'upstream/master' into cat_hgbt_rb
thomasjpfan May 4, 2020
7692325
WIP Address more comments
thomasjpfan May 5, 2020
590d95f
CLN Address comments
thomasjpfan May 6, 2020
95e79f2
CLN Address comments
thomasjpfan May 6, 2020
e62479b
STY Linting
thomasjpfan May 6, 2020
9086fad
ENH Adds new method to binner
thomasjpfan May 7, 2020
3e323b2
CLN Binner refactor once again
thomasjpfan May 7, 2020
197fac0
CLN Address comments
thomasjpfan May 7, 2020
63af0d5
CLN More comments lol
thomasjpfan May 7, 2020
7ef6a8d
CLN Adds test for predict
thomasjpfan May 7, 2020
e6a03c6
ENH Adds categorical indicies support
thomasjpfan May 7, 2020
eabcfae
ENH Fix qsort
thomasjpfan May 7, 2020
2abe579
CLN Move missing_go_left code into grower
thomasjpfan May 7, 2020
9a5a3f4
Merge remote-tracking branch 'upstream/master' into cat_hgbt_rb
thomasjpfan May 8, 2020
ebb68e5
BUG Fix
thomasjpfan May 8, 2020
470c146
DOC More comments
thomasjpfan May 8, 2020
0fc4c24
DOC Update failing example
thomasjpfan May 8, 2020
cebd6c0
BUG Fixes
thomasjpfan May 8, 2020
d1478ba
DOC Update comment
thomasjpfan May 8, 2020
ba00644
BUG Fix
thomasjpfan May 8, 2020
1806c2b
BUG Fix
thomasjpfan May 8, 2020
95919e3
DOC Fix
thomasjpfan May 8, 2020
38966d5
WIP Try 32 bit
thomasjpfan May 8, 2020
c4869ba
WIP Fix bug
thomasjpfan May 8, 2020
f63ad6a
WIP Fix bug
thomasjpfan May 8, 2020
26d0796
WIP Fix bug
thomasjpfan May 8, 2020
17afb0f
WIP Fix bug
thomasjpfan May 8, 2020
5246cc1
REV Revert
thomasjpfan May 8, 2020
60523a3
DOC Fix
thomasjpfan May 8, 2020
b014d6e
Merge remote-tracking branch 'upstream/master' into cat_hgbt_rb
thomasjpfan May 25, 2020
96d0687
CLN Address comments
thomasjpfan May 28, 2020
dc0a3a4
WIP Updates binning
thomasjpfan May 28, 2020
e10b346
WIP Address more comments
thomasjpfan May 28, 2020
af58498
DOC Fix
thomasjpfan May 29, 2020
3c2f672
WIP moving to a predictor for bitset
thomasjpfan May 30, 2020
c8f31f9
ENH Fix binning tests
thomasjpfan May 30, 2020
fe16b42
ENH Removes binning in predict
thomasjpfan May 30, 2020
cf5bb6d
WIP Do not look still iterating lol
thomasjpfan May 31, 2020
3d9e449
WIP Do not look still iterating lol
thomasjpfan May 31, 2020
3615dc2
WIP Do not look still iterating lol
thomasjpfan May 31, 2020
6608715
WIP
thomasjpfan Jun 1, 2020
2c384e6
WIP Adds unknown category encoding
thomasjpfan Jun 1, 2020
8c6e985
WIP Removes binning during predict
thomasjpfan Jun 1, 2020
2357ae9
STY Update
thomasjpfan Jun 1, 2020
52048af
CLN Clean up commets
thomasjpfan Jun 1, 2020
f70416e
DOC Improve doc
thomasjpfan Jun 1, 2020
a4159cf
BUG Fix test
thomasjpfan Jun 1, 2020
a398786
ENH Only go in one direction when finding best split
thomasjpfan Jun 1, 2020
8ea46cc
ENH Do not include bitset if the split is not categorical
thomasjpfan Jun 1, 2020
3dcbd31
Fix
thomasjpfan Jun 2, 2020
c3b5eef
WIP Test seg fault
thomasjpfan Jun 2, 2020
280784a
BUG Fix
thomasjpfan Jun 2, 2020
019de8a
DOC Update doc
thomasjpfan Jun 5, 2020
24d0711
DOC Address comments
thomasjpfan Jun 5, 2020
2d0e79d
ENH Enables openmp
thomasjpfan Jun 5, 2020
966379c
BUG Fix
thomasjpfan Jun 5, 2020
1c920f7
Some comments
NicolasHug Jul 18, 2020
3966432
Merge branch 'master' of github.com:scikit-learn/scikit-learn into ca…
NicolasHug Jul 20, 2020
6c1af62
pep8
NicolasHug Jul 20, 2020
f535c33
Merge remote-tracking branch 'upstream/master' into cat_hgbt_rb
thomasjpfan Jul 31, 2020
2afca55
CLN Apply suggestions
thomasjpfan Aug 1, 2020
9b44d82
CLN More comments
thomasjpfan Aug 8, 2020
9f3fa46
categorical => categorical_features
ogrisel Aug 10, 2020
1054754
Merge branch 'master' of github.com:scikit-learn/scikit-learn into ca…
NicolasHug Aug 17, 2020
bbae955
Added grower test for OHE equivalent
NicolasHug Aug 17, 2020
c003b76
ENH Change sorting ordre to match lightgbm
thomasjpfan Aug 24, 2020
40c3f9b
CLN Less than equal
thomasjpfan Aug 24, 2020
bb0e899
CLN Adds splitting in both directions
thomasjpfan Aug 24, 2020
f47da15
Merge remote-tracking branch 'upstream/master' into cat_hgbt_rb
thomasjpfan Aug 24, 2020
bb5877d
CLN Fixes merge conflicts
thomasjpfan Aug 25, 2020
c3061b5
ENH Uses mask instead of pandas features in benchmark
thomasjpfan Aug 25, 2020
8762e88
DOC Remove reference to pandas in user guide
thomasjpfan Aug 25, 2020
6d7ec60
DOC Benchmark update
thomasjpfan Aug 25, 2020
69f3f9a
ENH Update benchmark parameters
thomasjpfan Aug 25, 2020
f9f837c
Merge branch 'master' of github.com:scikit-learn/scikit-learn into ca…
NicolasHug Aug 27, 2020
b913ff1
Remove pandas support for categorical features
NicolasHug Aug 27, 2020
730d69f
Merge branch 'master' of github.com:scikit-learn/scikit-learn into ca…
NicolasHug Sep 4, 2020
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions benchmarks/bench_hist_gradient_boosting_adult.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import argparse
from time import time

import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.datasets import fetch_openml
from sklearn.metrics import accuracy_score, roc_auc_score
from sklearn.experimental import enable_hist_gradient_boosting # noqa
from sklearn.ensemble import HistGradientBoostingClassifier
from sklearn.ensemble._hist_gradient_boosting.utils import (
get_equivalent_estimator)


parser = argparse.ArgumentParser()
parser.add_argument('--n-leaf-nodes', type=int, default=31)
parser.add_argument('--n-trees', type=int, default=40)
parser.add_argument('--lightgbm', action="store_true", default=False)
parser.add_argument('--learning-rate', type=float, default=1.)
parser.add_argument('--max-bins', type=int, default=255)
parser.add_argument('--no-predict', action="store_true", default=False)
args = parser.parse_args()

n_leaf_nodes = args.n_leaf_nodes
n_trees = args.n_trees
lr = args.learning_rate
max_bins = args.max_bins


def fit(est, data_train, target_train, libname, **fit_params):
print(f"Fitting a {libname} model...")
tic = time()
est.fit(data_train, target_train, **fit_params)
toc = time()
print(f"fitted in {toc - tic:.3f}s")


def predict(est, data_test, target_test):
if args.no_predict:
return
tic = time()
predicted_test = est.predict(data_test)
predicted_proba_test = est.predict_proba(data_test)
toc = time()
roc_auc = roc_auc_score(target_test, predicted_proba_test[:, 1])
acc = accuracy_score(target_test, predicted_test)
print(f"predicted in {toc - tic:.3f}s, "
f"ROC AUC: {roc_auc:.4f}, ACC: {acc :.4f}")


data, target = fetch_openml(data_id=179, as_frame=True, return_X_y=True)

# does not support categories in encoding y yet
target = target.cat.codes

n_features = data.shape[1]
is_categorical = data.dtypes == 'category'
n_categorical_features = is_categorical.sum()
n_numerical_features = (data.dtypes == 'float').sum()
print(f"Number of features: {data.shape[1]}")
print(f"Number of categorical features: {n_categorical_features}")
print(f"Number of numerical features: {n_numerical_features}")

categorical_features = np.flatnonzero(is_categorical)
for i in categorical_features:
data.iloc[:, i] = data.iloc[:, i].cat.codes

data_train, data_test, target_train, target_test = train_test_split(
data, target, test_size=.2, random_state=0)

est = HistGradientBoostingClassifier(loss='binary_crossentropy',
learning_rate=lr,
max_iter=n_trees,
max_bins=max_bins,
categorical_features=categorical_features,
max_leaf_nodes=n_leaf_nodes,
early_stopping=False,
random_state=0,
verbose=1)

fit(est, data_train, target_train, 'sklearn')
predict(est, data_test, target_test)

# lightgbm infers the categories from the dtype
if args.lightgbm:
est = get_equivalent_estimator(est, lib='lightgbm')
fit(est, data_train, target_train, 'lightgbm',
categorical_feature=is_categorical[is_categorical].index.tolist())
predict(est, data_test, target_test)
39 changes: 39 additions & 0 deletions doc/modules/ensemble.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1051,6 +1051,43 @@ multiplying the gradients (and the hessians) by the sample weights. Note that
the binning stage (specifically the quantiles computation) does not take the
weights into account.

.. _categorical_support_gbdt:

Categorical Features Support
----------------------------

For datasets with categorical data, :class:`HistGradientBoostingClassifier`
and :class:`HistGradientBoostingRegressor` have native support for splitting
on categorical features. This is often better than one hot encoding because
it leads to faster training times and trees with less depth. The canonical way
of considering categorical splits is to consider all of the :math:`2^{K - 1} -
1` partitions where `K` is the number of categories. This can quickly become
prohibitive when `K` is large. Fortunately, since gradient boosting trees are
always regression trees (even for classification problems), there exist a
faster strategy that can yield equivalent splits. First, the categories of a
feature are sorted according to the ratio `sum_gradient_k / sum_hessians_k` of
each category `k`. Once the categories are sorted, one can consider *continuous
partitions*, i.e. treat the categories as if they were ordered continuous
values (see Fisher [Fisher1958]_ for a formal proof). As a result, only `K - 1`
splits need to be considered instead of :math:`2^{K - 1} - 1`.

If there are missing values during training, the missing values will be
considered as a single category. When predicting, categories that were unknown
during fit time, will be consider missing. If the cardinality of a categorical
feature is greater than `max_bins`, then the top `max_bins` categories based on
cardinality will be kept, and the less frequent categories will be considered
as missing.

To enable categorical support, a boolean mask can be passed to the
`categorical_features` parameter. In the following, the first feature will be
treated as categorical and the second feature as nummerical::
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
treated as categorical and the second feature as nummerical::
treated as categorical and the second feature as numerical::


>>> gbdt = HistGradientBoostingClassifier(categorical_features=[True, False])

.. topic:: Examples:

* :ref:`sphx_glr_auto_examples_ensemble_plot_gradient_boosting_categorical.py`

.. _monotonic_cst_gbdt:

Monotonic Constraints
Expand Down Expand Up @@ -1158,6 +1195,8 @@ Finally, many parts of the implementation of
.. [LightGBM] Ke et. al. `"LightGBM: A Highly Efficient Gradient
BoostingDecision Tree" <https://papers.nips.cc/paper/
6907-lightgbm-a-highly-efficient-gradient-boosting-decision-tree>`_
.. [Fisher1958] Walter D. Fisher. `"On Grouping for Maximum Homogeneity"
<http://www.csiss.org/SPACE/workshops/2004/SAC/files/fisher.pdf>`_

.. _voting_classifier:

Expand Down
91 changes: 91 additions & 0 deletions examples/ensemble/plot_gradient_boosting_categorical.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
"""
========================================
Categorical Support in Gradient Boosting
========================================

.. currentmodule:: sklearn

In this example, we will compare the performance of
:class:`~ensemble.HistGradientBoostingRegressor` using one hot encoding
and with native categorical support.

We will work with the Ames Lowa Housing dataset which consists of numerical
and categorical features, where the houses' sales prices is the target.
Comment on lines +12 to +13
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are the categorical features useful for this classification task? It may be worth it to add another example where the categorical features are dropped, training should be faster but predictive performance should be worse. Dropping categorical features is another way to deal with them (in a dummy way).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this dataset, the categories do not matter as much. So I will be on the lookout for a nicer dataset.

"""
##############################################################################
# Load Ames Housing dataset
# -------------------------
# First, we load the ames housing data as a pandas dataframe. The features
# are either categorical or numerical:
print(__doc__)

from sklearn.datasets import fetch_openml

X, y = fetch_openml(data_id=41211, as_frame=True, return_X_y=True)

n_features = X.shape[1]
n_categorical_features = (X.dtypes == 'category').sum()
n_numerical_features = (X.dtypes == 'float').sum()
print(f"Number of features: {X.shape[1]}")
print(f"Number of categorical featuers: {n_categorical_features}")
print(f"Number of numerical featuers: {n_numerical_features}")

##############################################################################
# Create gradient boosting estimator with one hot encoding
# --------------------------------------------------------
# Next, we create a pipeline that will one hot encode the categorical features
# and let rest of the numerical data to passthrough:

from sklearn.experimental import enable_hist_gradient_boosting # noqa
from sklearn.ensemble import HistGradientBoostingRegressor
from sklearn.pipeline import make_pipeline
from sklearn.compose import make_column_transformer
from sklearn.compose import make_column_selector
from sklearn.preprocessing import OneHotEncoder

preprocessor = make_column_transformer(
(OneHotEncoder(sparse=False, handle_unknown='ignore'),
make_column_selector(dtype_include='category')),
remainder='passthrough')

hist_one_hot = make_pipeline(preprocessor,
HistGradientBoostingRegressor(random_state=42))

##############################################################################
# Create gradient boosting estimator with native categorical support
# ------------------------------------------------------------------
# The :class:`~ensemble.HistGradientBoostingRegressor` has native support
# for categorical features using the `categorical_features` parameter:

hist_native = HistGradientBoostingRegressor(categorical_features='pandas',
random_state=42)

##############################################################################
# Train the models with cross-validation
# --------------------------------
# Finally, we train the models using cross validation. Here we compare the
# models performance in terms of :func:`~metrics.r2_score` and fit times. We
# show that fit times are faster with native categorical support and that the
# test scores and scores times are comparable:

from sklearn.model_selection import cross_validate
import matplotlib.pyplot as plt
import numpy as np

one_hot_result = cross_validate(hist_one_hot, X, y)
native_result = cross_validate(hist_native, X, y)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 8))

plot_info = [('fit_time', 'Fit times (s)', ax1),
('test_score', 'Test Scores (r2 score)', ax2)]

x, width = np.arange(2), 0.9
for key, title, ax in plot_info:
items = [native_result[key], one_hot_result[key]]
ax.bar(x, [np.mean(item) for item in items],
width, yerr=[np.std(item) for item in items],
color=['b', 'r'])
ax.set(xlabel='Split number', title=title, xticks=[0, 1],
xticklabels=['Native', "One Hot"])
plt.show()
30 changes: 26 additions & 4 deletions sklearn/ensemble/_hist_gradient_boosting/_binning.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,19 @@ from cython.parallel import prange
from libc.math cimport isnan

from .common cimport X_DTYPE_C, X_BINNED_DTYPE_C
from ._cat_mapper cimport CategoryMapper

np.import_array()


def _map_to_bins(const X_DTYPE_C [:, :] data,
list binning_thresholds,
const unsigned char missing_values_bin_idx,
CategoryMapper category_mapper,
const unsigned char[::1] is_categorical,
X_BINNED_DTYPE_C [::1, :] binned):
"""Bin numerical values to discrete integer-coded levels.
TODO docstring needs update

Parameters
----------
Expand All @@ -32,17 +36,24 @@ def _map_to_bins(const X_DTYPE_C [:, :] data,
binning_thresholds : list of arrays
For each feature, stores the increasing numeric values that are
used to separate the bins.
is_categorical : ndarray, shape (n_features,)
Indicates categorical features.
binned : ndarray, shape (n_samples, n_features)
Output array, must be fortran aligned.
"""
cdef:
int feature_idx

for feature_idx in range(data.shape[1]):
_map_num_col_to_bins(data[:, feature_idx],
binning_thresholds[feature_idx],
missing_values_bin_idx,
binned[:, feature_idx])
if is_categorical[feature_idx]:
_map_cat_col_to_bins(data[:, feature_idx], feature_idx,
category_mapper, missing_values_bin_idx,
binned[:, feature_idx])
else:
_map_num_col_to_bins(data[:, feature_idx],
binning_thresholds[feature_idx],
missing_values_bin_idx,
binned[:, feature_idx])


cdef void _map_num_col_to_bins(const X_DTYPE_C [:] data,
Expand Down Expand Up @@ -71,3 +82,14 @@ cdef void _map_num_col_to_bins(const X_DTYPE_C [:] data,
else:
left = middle + 1
binned[i] = left


cdef void _map_cat_col_to_bins(const X_DTYPE_C [:] data,
int feature_idx,
CategoryMapper category_mapper,
const unsigned char missing_values_bin_idx,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this parameter seems to be unused

X_BINNED_DTYPE_C [:] binned):
"""Map form raw categories to bin"""
cdef int i
for i in prange(data.shape[0], schedule='static', nogil=True):
binned[i] = category_mapper.map_to_bin(feature_idx, data[i])
9 changes: 9 additions & 0 deletions sklearn/ensemble/_hist_gradient_boosting/_bitset.pxd
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# cython: language_level=3
from .common cimport X_BINNED_DTYPE_C
from .common cimport BITSET_DTYPE_C

cdef void init_bitset(BITSET_DTYPE_C bitset) nogil

cdef void set_bitset(BITSET_DTYPE_C bitset, X_BINNED_DTYPE_C val) nogil

cdef unsigned char in_bitset(BITSET_DTYPE_C bitset, X_BINNED_DTYPE_C val) nogil
38 changes: 38 additions & 0 deletions sklearn/ensemble/_hist_gradient_boosting/_bitset.pyx
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# cython: cdivision=True
# cython: boundscheck=False
# cython: wraparound=False
# cython: language_level=3
from .common cimport BITSET_INNER_DTYPE_C

cdef inline void init_bitset(BITSET_DTYPE_C bitset) nogil: # OUT
cdef:
unsigned int i

for i in range(8):
bitset[i] = 0

cdef inline void set_bitset(BITSET_DTYPE_C bitset, # OUT
X_BINNED_DTYPE_C val) nogil:
cdef:
unsigned int i1 = val // 32
unsigned int i2 = val % 32

# It is assumed that val < 256 so that i1 < 8
bitset[i1] |= (1 << i2)

cdef inline unsigned char in_bitset(BITSET_DTYPE_C bitset,
X_BINNED_DTYPE_C val) nogil:
cdef:
unsigned int i1 = val // 32
unsigned int i2 = val % 32

return (bitset[i1] >> i2) & 1


def set_bitset_py(BITSET_INNER_DTYPE_C[:] bitset, X_BINNED_DTYPE_C val):
cdef:
unsigned int i1 = val // 32
unsigned int i2 = val % 32

# It is assumed that val < 256 so that i1 < 8
bitset[i1] |= (1 << i2)
19 changes: 19 additions & 0 deletions sklearn/ensemble/_hist_gradient_boosting/_cat_mapper.pxd
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# cython: cdivision=True
# cython: boundscheck=False
# cython: wraparound=False
# cython: language_level=3
# cython: nonecheck=False
# distutils: language=c++

from libcpp.map cimport map
from libcpp.vector cimport vector
from .common cimport X_DTYPE_C
from .common cimport X_BINNED_DTYPE_C

cdef class CategoryMapper:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This object stores the mapping from raw category to bin

cdef:
map[int, map[int, X_BINNED_DTYPE_C]] raw_category_to_bin
X_BINNED_DTYPE_C missing_values_bin_idx

cdef X_BINNED_DTYPE_C map_to_bin(self, int feature_idx,
X_DTYPE_C raw_category) nogil
Loading