Skip to content

ENH Add Categorical support for HistGradientBoosting #18394

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 85 commits into from
Nov 16, 2020

Conversation

thomasjpfan
Copy link
Member

@thomasjpfan thomasjpfan commented Sep 14, 2020

Reference Issues/PRs

Alternative to #16909

What does this implement/fix? Explain your changes.

This versions future restricts the number of input categories to 256 thus allowing the bitset in the predictors to be defined without using c++.

CC @NicolasHug

Copy link
Member

@NicolasHug NicolasHug left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @thomasjpfan,

I can already appreciate the impact of the simplication ;)

I pushed a few minor things and I also have a few comments regarding the binning: I think we can simplify the code a bit and treat categorical and continuous features uniformly. I'm happy to give these changes a try if that helps.

Re splitting: I'm a bit concerned about adding some extra complexity (MAX_CAT_THRESHOLDS + scanning in both directions) unless we can identify strong incentives to do so.

@amueller
Copy link
Member

Is the decision between this and #16909 still open?
I think the description up top has a typo btw ;)

@NicolasHug
Copy link
Member

NicolasHug commented Sep 16, 2020

Is the decision between this and #16909 still open?

I think @thomasjpfan and I agreed that working on this simplified version is the simplest and fastest way towards merging a more complete implementation like #16909 in the future. An important point is that this simplified version will allow backward compatibility of the next version. (but estimators are experimental anyway).

@NicolasHug
Copy link
Member

Another approach worth considering is to internally call an OrdinalEncoder within the GB estimator, which comes with another set of complications.

Copy link
Member

@lorentzenchr lorentzenchr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It remains to review tests and examples.
There should be rewards for reviews of +1000 loc:smirk:

# Reduces the effect of noises in categorical features,
# especially for categoires with few data. Called cat_smooth in
# LightGBM. TODO: Make this user adjustable?
Y_DTYPE_C MIN_CAT_SUPPORT = 10.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume that MIN_CAT_SUPPORT is one reason why native categorical support is not strictly equivalent to OHE (tree depth is another). In this case, giving the user control over this is desirable. Maybe a future PR.

Comment on lines 879 to 885
if sum_hessians_bin * support_factor >= MIN_CAT_SUPPORT:
cat_infos[n_used_bins].bin_idx = bin_idx
sum_gradients_bin = feature_hist[bin_idx].sum_gradients

cat_infos[n_used_bins].value = \
sum_gradients_bin / (sum_hessians_bin + MIN_CAT_SUPPORT)
n_used_bins += 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is very hard to find any explaining documentation for this.

Looking into lightgbm and googling, is this a form of additive smoothing?
If so, I do not fully understand the formula for cat_infos[n_used_bins].value, i.e. why do we add MIN_CAT_SUPPORT to every sum_hessians_bin?

MIN_CAT_SUPPORT is used as a cut-off and seems to play at the same time a role similar to self.l2_regularization.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I couldn't find doc on this either.

I just see it as some form of shrinkage. sum_hessian_bin is basically n_samples_bins

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After a long walk, I'd be in favor of excluding this functionality here and introduce it in a separate PR showing its usefulness.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would that be OK to keep it with the option to remove it in the future depending on empirical results?

While I don't have a specific ref for the specific case of categorical splits in GBDTs, this kind of shrinkage is quite common (ref below). Also, "LightGBM does it" has been a pretty strong argument for this part of the code so far lol.

Categories are ordered by their mean y, and the main idea is that we don't want to trust too much a mean that is computed with too few samples. So we add this small constant in the denominator to "shrink" it and limit its strength (in other words, we regularize).

Such shrinkage is often used in an RS context as a form of regularization of similarities: in this paper they shrink similarities that don't have enough support, and they justify from a Bayesian point of view.
(There's also the original ref "Regression Shrinkage and Selection via the Lasso" from Tibshirani...)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm...
We already have min_samples_leaf to control too little samples and l2_regularization for some regularization. I' still prefer to take this out and to do it later properly.

What I'm missing in the implemented formula is that smoothing should be size dependent: the more samples we have the more should we trust the empirical averages. Something like:

prior_mean = 0
n = samples of category
penalty = 10
alpha = n / (n + penalty)
value = alpha * sum_gradients_bin / sum_hessians_bin + (1-alpha) * prior_mean

There would even be theory in how to choose an optimal penalty or an empirical estimation of prior_mean, cf. credibility theory or linear mixed models with random intercept.

Conclusion: As HGBT is still experimental, I'll also approve this PR without removal of MIN_CAT_SUPPORT. In this case, I plead for the responsibility of the implementors of this PR to investigate this in a future PR. 🙏

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already have min_samples_leaf to control too little samples and l2_regularization for some regularization

This is similar, but unrelated. min_sample_leaf controls tree tree size, l2_reg is used for the leaves values. The shrinkage here only affects the values by which we will sort the categories.

What I'm missing in the implemented formula is that smoothing should be size dependent: the more samples we have the more should we trust the empirical averages

This is what happens here I believe. The denominator is n_samples + CAT_SMOOTH. When n_samples is small, CAT_SMOOTH takes over and the final value is shrinked,. When it's high, CAT_SMOOTH has almost no effect.

For ref this is briefly discussed in microsoft/LightGBM#699 (comment) . They didn't published benchmarks, but this yielded the best performance according to the comments.

@scikit-learn scikit-learn deleted a comment from lorentzenchr Nov 15, 2020
Copy link
Member

@lorentzenchr lorentzenchr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And here comes another round for half of the tests.

Copy link
Member

@lorentzenchr lorentzenchr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Finally! Review completed.

Copy link
Member

@lorentzenchr lorentzenchr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Incredibly nice work, @NicolasHug @thomasjpfan, thank you so much:rocket:. Categoricals are met very often in practice!

@lorentzenchr lorentzenchr changed the title [MRG] ENH Adds Categorical support for HistGradientBoosting ENH Add Categorical support for HistGradientBoosting Nov 16, 2020
@lorentzenchr lorentzenchr merged commit b4453f1 into scikit-learn:master Nov 16, 2020
@NicolasHug
Copy link
Member

Thanks a lot for the review @lorentzenchr !

@ogrisel I believe you still had comments / concerns, maybe you can still comment here and we can address in a subsequent PR?

It seems that we're all a bit concerned in particular about the ease of use with the ColumnTransformer + OE, and how to specify the categorical features. I think we should address that before we start communicating on the categorical support?

It seems to me that we could implement a solution that resembles #18394 (comment) by first merging #18393 (or a variant), and by letting categorical_features accept callables as suggested by @thomasjpfan above. I'm not sure how close we are from releasing already but hopefully we can squeeze this in for 0.24?

@NicolasHug
Copy link
Member

It seems to me that we could implement a solution that resembles #18394 (comment) by first merging #18393 (or a variant), and by letting categorical_features accept callables as suggested by @thomasjpfan above. I'm not sure how close we are from releasing already but hopefully we can squeeze this in for 0.24?

Sadly, this doesn't work. It breaks when the pipeline is used in e.g. cross_val_score because the estimators will be cloned there, and thus the callable refers to an unfitted CT:

from sklearn.datasets import fetch_openml
from sklearn.experimental import enable_hist_gradient_boosting  # noqa
from sklearn.ensemble import HistGradientBoostingRegressor
from sklearn.pipeline import make_pipeline
from sklearn.compose import make_column_transformer
from sklearn.compose import make_column_selector
from sklearn.preprocessing import OrdinalEncoder
from sklearn.model_selection import cross_val_score
import numpy as np


X, y = fetch_openml(data_id=41211, as_frame=True, return_X_y=True)

ct = make_column_transformer(
    (OrdinalEncoder(handle_unknown='use_encoded_value', unknown_value=np.nan),
     make_column_selector(dtype_include='category')),
    remainder='passthrough')

cat_features_getter = lambda: ct.output_indices_['ordinalencoder']  # see PR #18393
hist_native = make_pipeline(
    ct,
    HistGradientBoostingRegressor(random_state=42,
                                  categorical_features=cat_features_getter)
)

cross_val_score(hist_native, X, y)

breaks with # this breaks with AttributeError: 'ColumnTransformer' object has no attribute 'output_indices_'

@lorentzenchr
Copy link
Member

@NicolasHug @thomasjpfan Shall we continue this discussion in #4196? Or in a dedicated new issue?

@NicolasHug
Copy link
Member

#4196 is too general IMO. We can open an issue if there are any additional input

@NicolasHug
Copy link
Member

If you do, please summarize/reference the attempt above ;)

@lorentzenchr
Copy link
Member

lorentzenchr commented Nov 21, 2020

I'll open a new issue.
Edit: #18894

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants