-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
ENH Adds Categorical Support to Histogram Gradient Boosting #16909
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
02d89d7
8472f60
1198340
63f56fd
f34087e
0b2ed9c
5eaf099
43822ab
0d6012a
b22151f
8432bac
ae9be56
d0557a5
7692325
590d95f
95e79f2
e62479b
9086fad
3e323b2
197fac0
63af0d5
7ef6a8d
e6a03c6
eabcfae
2abe579
9a5a3f4
ebb68e5
470c146
0fc4c24
cebd6c0
d1478ba
ba00644
1806c2b
95919e3
38966d5
c4869ba
f63ad6a
26d0796
17afb0f
5246cc1
60523a3
b014d6e
96d0687
dc0a3a4
e10b346
af58498
3c2f672
c8f31f9
fe16b42
cf5bb6d
3d9e449
3615dc2
6608715
2c384e6
8c6e985
2357ae9
52048af
f70416e
a4159cf
a398786
8ea46cc
3dcbd31
c3b5eef
280784a
019de8a
24d0711
2d0e79d
966379c
1c920f7
3966432
6c1af62
f535c33
2afca55
9b44d82
9f3fa46
1054754
bbae955
c003b76
40c3f9b
bb0e899
f47da15
bb5877d
c3061b5
8762e88
6d7ec60
69f3f9a
f9f837c
b913ff1
730d69f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
import argparse | ||
from time import time | ||
|
||
import numpy as np | ||
from sklearn.model_selection import train_test_split | ||
from sklearn.datasets import fetch_openml | ||
from sklearn.metrics import accuracy_score, roc_auc_score | ||
from sklearn.experimental import enable_hist_gradient_boosting # noqa | ||
from sklearn.ensemble import HistGradientBoostingClassifier | ||
from sklearn.ensemble._hist_gradient_boosting.utils import ( | ||
get_equivalent_estimator) | ||
|
||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--n-leaf-nodes', type=int, default=31) | ||
parser.add_argument('--n-trees', type=int, default=40) | ||
parser.add_argument('--lightgbm', action="store_true", default=False) | ||
parser.add_argument('--learning-rate', type=float, default=1.) | ||
parser.add_argument('--max-bins', type=int, default=255) | ||
parser.add_argument('--no-predict', action="store_true", default=False) | ||
args = parser.parse_args() | ||
|
||
n_leaf_nodes = args.n_leaf_nodes | ||
n_trees = args.n_trees | ||
lr = args.learning_rate | ||
max_bins = args.max_bins | ||
|
||
|
||
def fit(est, data_train, target_train, libname, **fit_params): | ||
print(f"Fitting a {libname} model...") | ||
tic = time() | ||
est.fit(data_train, target_train, **fit_params) | ||
toc = time() | ||
print(f"fitted in {toc - tic:.3f}s") | ||
|
||
|
||
def predict(est, data_test, target_test): | ||
if args.no_predict: | ||
return | ||
tic = time() | ||
predicted_test = est.predict(data_test) | ||
predicted_proba_test = est.predict_proba(data_test) | ||
toc = time() | ||
roc_auc = roc_auc_score(target_test, predicted_proba_test[:, 1]) | ||
acc = accuracy_score(target_test, predicted_test) | ||
print(f"predicted in {toc - tic:.3f}s, " | ||
f"ROC AUC: {roc_auc:.4f}, ACC: {acc :.4f}") | ||
|
||
|
||
data, target = fetch_openml(data_id=179, as_frame=True, return_X_y=True) | ||
|
||
# does not support categories in encoding y yet | ||
target = target.cat.codes | ||
|
||
n_features = data.shape[1] | ||
is_categorical = data.dtypes == 'category' | ||
n_categorical_features = is_categorical.sum() | ||
n_numerical_features = (data.dtypes == 'float').sum() | ||
print(f"Number of features: {data.shape[1]}") | ||
print(f"Number of categorical features: {n_categorical_features}") | ||
print(f"Number of numerical features: {n_numerical_features}") | ||
|
||
categorical_features = np.flatnonzero(is_categorical) | ||
for i in categorical_features: | ||
data.iloc[:, i] = data.iloc[:, i].cat.codes | ||
|
||
data_train, data_test, target_train, target_test = train_test_split( | ||
data, target, test_size=.2, random_state=0) | ||
|
||
est = HistGradientBoostingClassifier(loss='binary_crossentropy', | ||
learning_rate=lr, | ||
max_iter=n_trees, | ||
max_bins=max_bins, | ||
categorical_features=categorical_features, | ||
max_leaf_nodes=n_leaf_nodes, | ||
early_stopping=False, | ||
random_state=0, | ||
verbose=1) | ||
|
||
fit(est, data_train, target_train, 'sklearn') | ||
predict(est, data_test, target_test) | ||
|
||
# lightgbm infers the categories from the dtype | ||
if args.lightgbm: | ||
est = get_equivalent_estimator(est, lib='lightgbm') | ||
fit(est, data_train, target_train, 'lightgbm', | ||
categorical_feature=is_categorical[is_categorical].index.tolist()) | ||
predict(est, data_test, target_test) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
""" | ||
======================================== | ||
Categorical Support in Gradient Boosting | ||
======================================== | ||
|
||
.. currentmodule:: sklearn | ||
|
||
In this example, we will compare the performance of | ||
:class:`~ensemble.HistGradientBoostingRegressor` using one hot encoding | ||
and with native categorical support. | ||
|
||
We will work with the Ames Lowa Housing dataset which consists of numerical | ||
and categorical features, where the houses' sales prices is the target. | ||
Comment on lines
+12
to
+13
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are the categorical features useful for this classification task? It may be worth it to add another example where the categorical features are dropped, training should be faster but predictive performance should be worse. Dropping categorical features is another way to deal with them (in a dummy way). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For this dataset, the categories do not matter as much. So I will be on the lookout for a nicer dataset. |
||
""" | ||
############################################################################## | ||
# Load Ames Housing dataset | ||
# ------------------------- | ||
# First, we load the ames housing data as a pandas dataframe. The features | ||
# are either categorical or numerical: | ||
print(__doc__) | ||
|
||
from sklearn.datasets import fetch_openml | ||
|
||
X, y = fetch_openml(data_id=41211, as_frame=True, return_X_y=True) | ||
|
||
n_features = X.shape[1] | ||
n_categorical_features = (X.dtypes == 'category').sum() | ||
n_numerical_features = (X.dtypes == 'float').sum() | ||
print(f"Number of features: {X.shape[1]}") | ||
print(f"Number of categorical featuers: {n_categorical_features}") | ||
print(f"Number of numerical featuers: {n_numerical_features}") | ||
|
||
############################################################################## | ||
# Create gradient boosting estimator with one hot encoding | ||
# -------------------------------------------------------- | ||
# Next, we create a pipeline that will one hot encode the categorical features | ||
# and let rest of the numerical data to passthrough: | ||
|
||
from sklearn.experimental import enable_hist_gradient_boosting # noqa | ||
from sklearn.ensemble import HistGradientBoostingRegressor | ||
from sklearn.pipeline import make_pipeline | ||
from sklearn.compose import make_column_transformer | ||
from sklearn.compose import make_column_selector | ||
from sklearn.preprocessing import OneHotEncoder | ||
|
||
preprocessor = make_column_transformer( | ||
(OneHotEncoder(sparse=False, handle_unknown='ignore'), | ||
make_column_selector(dtype_include='category')), | ||
remainder='passthrough') | ||
|
||
hist_one_hot = make_pipeline(preprocessor, | ||
HistGradientBoostingRegressor(random_state=42)) | ||
|
||
############################################################################## | ||
# Create gradient boosting estimator with native categorical support | ||
# ------------------------------------------------------------------ | ||
# The :class:`~ensemble.HistGradientBoostingRegressor` has native support | ||
# for categorical features using the `categorical_features` parameter: | ||
|
||
hist_native = HistGradientBoostingRegressor(categorical_features='pandas', | ||
random_state=42) | ||
|
||
############################################################################## | ||
# Train the models with cross-validation | ||
# -------------------------------- | ||
# Finally, we train the models using cross validation. Here we compare the | ||
# models performance in terms of :func:`~metrics.r2_score` and fit times. We | ||
# show that fit times are faster with native categorical support and that the | ||
# test scores and scores times are comparable: | ||
|
||
from sklearn.model_selection import cross_validate | ||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
|
||
one_hot_result = cross_validate(hist_one_hot, X, y) | ||
native_result = cross_validate(hist_native, X, y) | ||
|
||
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 8)) | ||
|
||
plot_info = [('fit_time', 'Fit times (s)', ax1), | ||
('test_score', 'Test Scores (r2 score)', ax2)] | ||
|
||
x, width = np.arange(2), 0.9 | ||
for key, title, ax in plot_info: | ||
items = [native_result[key], one_hot_result[key]] | ||
ax.bar(x, [np.mean(item) for item in items], | ||
width, yerr=[np.std(item) for item in items], | ||
color=['b', 'r']) | ||
ax.set(xlabel='Split number', title=title, xticks=[0, 1], | ||
xticklabels=['Native', "One Hot"]) | ||
plt.show() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,15 +15,19 @@ from cython.parallel import prange | |
from libc.math cimport isnan | ||
|
||
from .common cimport X_DTYPE_C, X_BINNED_DTYPE_C | ||
from ._cat_mapper cimport CategoryMapper | ||
|
||
np.import_array() | ||
|
||
|
||
def _map_to_bins(const X_DTYPE_C [:, :] data, | ||
list binning_thresholds, | ||
const unsigned char missing_values_bin_idx, | ||
CategoryMapper category_mapper, | ||
const unsigned char[::1] is_categorical, | ||
X_BINNED_DTYPE_C [::1, :] binned): | ||
"""Bin numerical values to discrete integer-coded levels. | ||
TODO docstring needs update | ||
|
||
Parameters | ||
---------- | ||
|
@@ -32,17 +36,24 @@ def _map_to_bins(const X_DTYPE_C [:, :] data, | |
binning_thresholds : list of arrays | ||
For each feature, stores the increasing numeric values that are | ||
used to separate the bins. | ||
is_categorical : ndarray, shape (n_features,) | ||
Indicates categorical features. | ||
binned : ndarray, shape (n_samples, n_features) | ||
Output array, must be fortran aligned. | ||
""" | ||
cdef: | ||
int feature_idx | ||
|
||
for feature_idx in range(data.shape[1]): | ||
_map_num_col_to_bins(data[:, feature_idx], | ||
binning_thresholds[feature_idx], | ||
missing_values_bin_idx, | ||
binned[:, feature_idx]) | ||
if is_categorical[feature_idx]: | ||
_map_cat_col_to_bins(data[:, feature_idx], feature_idx, | ||
category_mapper, missing_values_bin_idx, | ||
binned[:, feature_idx]) | ||
else: | ||
_map_num_col_to_bins(data[:, feature_idx], | ||
binning_thresholds[feature_idx], | ||
missing_values_bin_idx, | ||
binned[:, feature_idx]) | ||
|
||
|
||
cdef void _map_num_col_to_bins(const X_DTYPE_C [:] data, | ||
|
@@ -71,3 +82,14 @@ cdef void _map_num_col_to_bins(const X_DTYPE_C [:] data, | |
else: | ||
left = middle + 1 | ||
binned[i] = left | ||
|
||
|
||
cdef void _map_cat_col_to_bins(const X_DTYPE_C [:] data, | ||
NicolasHug marked this conversation as resolved.
Show resolved
Hide resolved
|
||
int feature_idx, | ||
CategoryMapper category_mapper, | ||
const unsigned char missing_values_bin_idx, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this parameter seems to be unused |
||
X_BINNED_DTYPE_C [:] binned): | ||
"""Map form raw categories to bin""" | ||
cdef int i | ||
for i in prange(data.shape[0], schedule='static', nogil=True): | ||
binned[i] = category_mapper.map_to_bin(feature_idx, data[i]) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
# cython: language_level=3 | ||
from .common cimport X_BINNED_DTYPE_C | ||
from .common cimport BITSET_DTYPE_C | ||
|
||
cdef void init_bitset(BITSET_DTYPE_C bitset) nogil | ||
|
||
cdef void set_bitset(BITSET_DTYPE_C bitset, X_BINNED_DTYPE_C val) nogil | ||
|
||
cdef unsigned char in_bitset(BITSET_DTYPE_C bitset, X_BINNED_DTYPE_C val) nogil |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
# cython: cdivision=True | ||
# cython: boundscheck=False | ||
# cython: wraparound=False | ||
# cython: language_level=3 | ||
from .common cimport BITSET_INNER_DTYPE_C | ||
|
||
cdef inline void init_bitset(BITSET_DTYPE_C bitset) nogil: # OUT | ||
cdef: | ||
unsigned int i | ||
|
||
for i in range(8): | ||
bitset[i] = 0 | ||
|
||
cdef inline void set_bitset(BITSET_DTYPE_C bitset, # OUT | ||
X_BINNED_DTYPE_C val) nogil: | ||
cdef: | ||
unsigned int i1 = val // 32 | ||
unsigned int i2 = val % 32 | ||
|
||
# It is assumed that val < 256 so that i1 < 8 | ||
bitset[i1] |= (1 << i2) | ||
|
||
cdef inline unsigned char in_bitset(BITSET_DTYPE_C bitset, | ||
X_BINNED_DTYPE_C val) nogil: | ||
cdef: | ||
unsigned int i1 = val // 32 | ||
unsigned int i2 = val % 32 | ||
|
||
return (bitset[i1] >> i2) & 1 | ||
|
||
|
||
def set_bitset_py(BITSET_INNER_DTYPE_C[:] bitset, X_BINNED_DTYPE_C val): | ||
cdef: | ||
unsigned int i1 = val // 32 | ||
unsigned int i2 = val % 32 | ||
|
||
# It is assumed that val < 256 so that i1 < 8 | ||
bitset[i1] |= (1 << i2) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# cython: cdivision=True | ||
# cython: boundscheck=False | ||
# cython: wraparound=False | ||
# cython: language_level=3 | ||
# cython: nonecheck=False | ||
# distutils: language=c++ | ||
|
||
from libcpp.map cimport map | ||
from libcpp.vector cimport vector | ||
from .common cimport X_DTYPE_C | ||
from .common cimport X_BINNED_DTYPE_C | ||
|
||
cdef class CategoryMapper: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This object stores the mapping from raw category to bin |
||
cdef: | ||
map[int, map[int, X_BINNED_DTYPE_C]] raw_category_to_bin | ||
X_BINNED_DTYPE_C missing_values_bin_idx | ||
|
||
cdef X_BINNED_DTYPE_C map_to_bin(self, int feature_idx, | ||
X_DTYPE_C raw_category) nogil |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.