Skip to content

[MRG] MNT Initialize histograms in parallel and don't call np.zero in Hist-GBDT #18341

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Sep 4, 2020

Conversation

NicolasHug
Copy link
Member

@NicolasHug NicolasHug commented Sep 4, 2020

To keep things more unitary, this PR extracts out the histogram initialization from #18242 which was already proposed in #14392

Memory usage is the same as in master but runs in 40s instead of 45s (across many runs) on the following benchmarks on my laptop (4 threads):

from sklearn.datasets import make_classification
from sklearn.experimental import enable_hist_gradient_boosting
from sklearn.ensemble import HistGradientBoostingClassifier
from memory_profiler import memory_usage

X, y = make_classification(n_classes=5,
                           n_samples=3_000,
                           n_features=400,
                           random_state=0,
                           n_clusters_per_class=1,
                           n_informative=5)

hgb = HistGradientBoostingClassifier(
    max_iter=100,
    max_leaf_nodes=256,
    learning_rate=.1,
    random_state=0,
    verbose=1,
)

mems = memory_usage((hgb.fit, (X, y)))
print(f"{max(mems):.2f}, {max(mems) - min(mems):.2f} MB")

Other benchmarks welcome, especially one with 88 hyperthreads as in #14392 (comment).

CC @ogrisel @thomasjpfan

@ogrisel
Copy link
Member

ogrisel commented Sep 4, 2020

Ok let me re-run a quick benchmark with many threads on this PR.

@NicolasHug NicolasHug changed the title [MRG] Initialize histograms in parallel [MRG] MNT Initialize histograms in parallel and don't call np.zero in Hist-GBDT Sep 4, 2020
@ogrisel
Copy link
Member

ogrisel commented Sep 4, 2020

On a machine with 36 physical cores (72 HyperThreads):

  • commandline:
$ OMP_NUM_THREADS=36 nice  python benchmarks/bench_hist_gradient_boosting_higgsboson.py --n-trees 100 --n-leaf-nodes 255
  • master with
Fit 100 trees in 39.437 s, (25500 total leaves)                                                                                                                               
Time spent computing histograms: 17.777s                                                                                                                                      
Time spent finding best splits:  1.128s                                                                                                                                       
Time spent applying splits:      4.035s                                                                                                                                       
Time spent predicting:           1.743s                                                                                                                                       
fitted in 39.451s                                                                                                                                                             
predicted in 2.488s, ROC AUC: 0.8280, ACC: 0.7466
  • this branch (hist_parallel):
Fit 100 trees in 39.656 s, (25500 total leaves)
Fit 100 trees in 38.755 s, (25500 total leaves)
Time spent computing histograms: 16.972s
Time spent finding best splits:  1.147s
Time spent applying splits:      4.027s
Time spent predicting:           1.753s
fitted in 38.773s
predicted in 2.950s, ROC AUC: 0.8278, ACC: 0.7462

so no significant improvement (I re-rerun the benchmark several times and observe approximately +/- 0.5s in total fit time for both branches

I was almost alone on the machine when running the benchmark.

For information lightgbm yields with 36 threads on the same machine:

fitted in 17.750s
predicted in 1.265s, ROC AUC: 0.8268, ACC: 0.7453

@ogrisel
Copy link
Member

ogrisel commented Sep 4, 2020

Without limiting the number of threads to the number of physical cores (no OMP_NUM_THREADS):

  • master
Fit 100 trees in 51.788 s, (25500 total leaves) 
Time spent computing histograms: 23.774s
Time spent finding best splits:  2.760s
Time spent applying splits:      6.218s
Time spent predicting:           1.838s
fitted in 51.808s
predicted in 2.669s, ROC AUC: 0.8283, ACC: 0.7468
  • this branch:
Fit 100 trees in 46.983 s, (25500 total leaves)
Time spent computing histograms: 21.686s
Time spent finding best splits:  2.015s
Time spent applying splits:      5.202s
Time spent predicting:           1.947s
fitted in 47.001s
predicted in 2.621s, ROC AUC: 0.8279, ACC: 0.7463
  • lightgbm
fitted in 20.194s
predicted in 1.118s, ROC AUC: 0.8274, ACC: 0.7459

so maybe when the number of threads is larger than the number of physical CPU cores (oversubscription), this branch make it possible to mitigate some of the perf degradation cause by over-subscription. But this is probably in the noise because I get +/- 2s variations in that regime.

So in conclusion, no significant performance impact on many threads machines.

@NicolasHug
Copy link
Member Author

Thanks for the benchmarks @ogrisel.

Just to make sure, did you use the snippet provided here? It's different from the other ones, as I designed it to build lots of histograms.

I tried on my desktop computer with 16 threads and the execution time drops from 28s (master) to 10s (this PR). So that's quite a significant improvement in this case.

Since you did not observe a regression in your benchmarks, I'd say we still have a strong incentive to merge this as a stand-alone.

@NicolasHug
Copy link
Member Author

My CPU with 16 threads has only 8 cores. Using OMP_NUM_THREADS=8, across multiple runs:

  • 22s on master
  • 10s on this PR

So in my case I observe a significant improvement not just in the case of over-subscription.

@ogrisel
Copy link
Member

ogrisel commented Sep 4, 2020

For the code snippet this yields a huge improvement:

@NicolasHug
Copy link
Member Author

fantastic!
Thanks for retrying

Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's merge this. It can either be very beneficial or the same speed but never slower. And very simple code change.

@lorentzenchr
Copy link
Member

On macOS without openMP, no regression as well.

  • master
    Fit 500 trees in 42.757 s, (47303 total leaves)
    Time spent computing histograms: 17.263s
    Time spent finding best splits:  19.739s
    Time spent applying splits:      1.075s
    Time spent predicting:           0.029s
    164.78, 34.19 MB
    
  • this PR
    Fit 500 trees in 43.890 s, (47303 total leaves)
    Time spent computing histograms: 16.605s
    Time spent finding best splits:  21.601s
    Time spent applying splits:      1.024s
    Time spent predicting:           0.029s
    164.91, 35.30 MB
    

@ogrisel
Copy link
Member

ogrisel commented Sep 4, 2020

I have updated #18341 (comment) to also add bench results for #18242's branch with parallel zero init + histogrampool.

The histogrampool is performance neutral compared to this branch.

@NicolasHug
Copy link
Member Author

@ogrisel I believe this PR may close #14306, WDYT?

@ogrisel
Copy link
Member

ogrisel commented Sep 4, 2020

Thanks for crediting Egor indeed.

:class:`ensemble.HistGradientBoostingClassifier` which results in speed
improvement for problems that build a lot of nodes on multicore machines.
:pr:`18341` by `Olivier Grisel`_, `Nicolas Hug`_, `Thomas Fan`_, and
:user:`Egor Smirnov <SmirnovEgorRu>`.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Crediting @SmirnovEgorRu here since you opened a similar PR a while ago #14380

@ogrisel
Copy link
Member

ogrisel commented Sep 4, 2020

@ogrisel I believe this PR may close #14306, WDYT?

I am ambivalent. But the memory usage problem observed by @thomasjpfan on his mac might be too rare to justify the added code complexity. Let's wait for his opinion. In the mean time we can already merge this one.

@ogrisel ogrisel merged commit 4a2de5b into scikit-learn:master Sep 4, 2020
@NicolasHug
Copy link
Member Author

I believe #14306 isn't about memory but about thread scalability. In particular @SmirnovEgorRu had identified that np.zeros might be a problem

@ogrisel
Copy link
Member

ogrisel commented Sep 4, 2020

Oops I merged without waiting for the CI to be green but I am confident this will not crash the build on master. Fingers crossed.

@ogrisel
Copy link
Member

ogrisel commented Sep 4, 2020

I believe #14306 isn't about memory but about thread scalability. In particular @SmirnovEgorRu had identified that np.zeros might be a problem

Indeed, I was mistaken. It's an improvement but it still not as good as lightgbm in terms of scalability. But I agree we should update the benchmark results in that issue to reflect the new master.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants