Skip to content

ENH compute histograms only for allowed features in HGBT #24856

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Nov 22, 2022

Conversation

lorentzenchr
Copy link
Member

@lorentzenchr lorentzenchr commented Nov 7, 2022

Reference Issues/PRs

Follow-up of #21020.

What does this implement/fix? Explain your changes.

This PR restricts the computation of histograms in HistGradientBoostingRegressor and HistGradientBoostingClassifier to features that are allowed to be split on. This gives a boost in performance (fit time).

Any other comments?

Copy link
Member

@thomasjpfan thomasjpfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Codewise, I suspect this will improve performance. May you run a quick benchmark to verify?

@lorentzenchr
Copy link
Member Author

lorentzenchr commented Nov 11, 2022

Summary: This PR clearly reduces fit time with interaction constraints (35% on Higgs benchmark) and no performance penalty without interactions.
Note: Numbers vary a lot from run to run.

With Interaction Constraints

Not interactions allowed.

MAIN with commit 85a8aa6 with interaction constraints

% python scikit-learn/benchmarks/bench_hist_gradient_boosting_higgsboson.py --n-trees 100 --no-interactions 1
Training set with 8800000 records with 28 features.
Fitting a sklearn model...
Binning 1.971 GB of training data: 3.593 s
Fitting gradient boosted rounds:
...
Fit 100 trees in 56.812 s, (2484 total leaves)
Time spent computing histograms: 27.538s
Time spent finding best splits:  0.160s
Time spent applying splits:      6.929s
Time spent predicting:           1.833s
fitted in 56.990s
predicted in 7.205s, ROC AUC: 0.7755, ACC: 0.7028

this PR with interaction constraints

% python scikit-learn/benchmarks/bench_hist_gradient_boosting_higgsboson.py --n-trees 100 --no-interactions 1
Training set with 8800000 records with 28 features.
Fitting a sklearn model...
Binning 1.971 GB of training data: 3.816 s
Fitting gradient boosted rounds:
...
Fit 100 trees in 36.839 s, (2484 total leaves)
Time spent computing histograms: 6.840s
Time spent finding best splits:  0.194s
Time spent applying splits:      7.018s
Time spent predicting:           1.828s
fitted in 37.044s
predicted in 7.586s, ROC AUC: 0.7755, ACC: 0.7028

Without interaction constraints

MAIN

% python scikit-learn/benchmarks/bench_hist_gradient_boosting_higgsboson.py --n-trees 100                   
Training set with 8800000 records with 28 features.
Fitting a sklearn model...
Binning 1.971 GB of training data: 3.901 s
Fitting gradient boosted rounds:
...
Fit 100 trees in 63.442 s, (3100 total leaves)
Time spent computing histograms: 29.477s
Time spent finding best splits:  0.772s
Time spent applying splits:      8.813s
Time spent predicting:           1.730s
fitted in 63.638s
predicted in 7.028s, ROC AUC: 0.8228, ACC: 0.7415

This PR

% python scikit-learn/benchmarks/bench_hist_gradient_boosting_higgsboson.py --n-trees 100                   
Training set with 8800000 records with 28 features.
Fitting a sklearn model...
Binning 1.971 GB of training data: 3.847 s
Fitting gradient boosted rounds:
...
Fit 100 trees in 58.557 s, (3100 total leaves)
Time spent computing histograms: 26.800s
Time spent finding best splits:  0.407s
Time spent applying splits:      7.132s
Time spent predicting:           1.794s
fitted in 58.794s
predicted in 6.642s, ROC AUC: 0.8228, ACC: 0.7415

Copy link
Member

@jjerphan jjerphan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM modulo a few suggestions.

has_interaction_cst = allowed_features is not None
if has_interaction_cst:
n_allowed_features = allowed_features.shape[0]

with nogil:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This allows reusing thread, eventually reducing OpenMP's overhead.

Suggested change
with nogil:
with nogil, parallel(num_threads=n_threads:

Copy link
Member Author

@lorentzenchr lorentzenchr Nov 18, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we try this out in another PR?
I did not touch those implementation parts and would like to keep it that way.

Comment on lines +170 to +172
for f_idx in prange(
n_allowed_features, schedule='static', num_threads=n_threads
):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the here-above suggested nogil, parallel-context is used, this num_threads must be removed here and in previous prange loops I can't suggest on.

Suggested change
for f_idx in prange(
n_allowed_features, schedule='static', num_threads=n_threads
):
for f_idx in prange(n_allowed_features, schedule='static'):

Note that this partern

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In another PR, maybe. See comment above.

Comment on lines +173 to +176
if has_interaction_cst:
feature_idx = allowed_features[f_idx]
else:
feature_idx = f_idx
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there performances benefit of making the branching outside the prange loops and have one prange loop per branch?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried that, but I did not observe any runtime difference. Usually, the number of features is small (less than 10, 100, maybe 1000), but sample size is large. The inner loop is then over n_samples and there branching should be avoided.

@lorentzenchr lorentzenchr added this to the 1.2 milestone Nov 18, 2022
@lorentzenchr
Copy link
Member Author

@jjerphan I hope it's ready now. All issues/merge conflicts are fixed. I also added a whatsnew entry.

@jjerphan jjerphan added the Waiting for Second Reviewer First reviewer is done, need a second one! label Nov 18, 2022
@jjerphan
Copy link
Member

Still LGTM, yes. I just have labelled this PR as "Waiting for a Second Reviewer".

@ogrisel ogrisel merged commit 2da7428 into scikit-learn:main Nov 22, 2022
@ogrisel ogrisel deleted the hgbt_seed_up_allowed_features branch November 22, 2022 09:41
@ogrisel
Copy link
Member

ogrisel commented Nov 22, 2022

Thanks @lorentzenchr!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants