Skip to content

Balanced Random Forest #5181

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 20 commits into from

Conversation

potash
Copy link

@potash potash commented Aug 28, 2015

I have implemented balanced random forest as described in Chen, C., Liaw, A., Breiman, L. (2004) "Using Random Forest to Learn Imbalanced Data", Tech. Rep. 666, 2004. It is enabled using the balanced=True parameter to RandomForestClassifier.

This is related to the class_weight='subsample' feature already available but instead of down-weighting majority class(es) it undersamples them. According to the referenced paper (and personal experience) balanced random forests perform well for very imbalanced data.

In order to do the balanced sampling we need some class summary data (distribution of classes, etc.). For efficiency, this is precomputed in fit() by the _get_balance_class_data() function and then passed to _parallel_build_trees() which, when specified, calls generate_balanced_sample_indices() instead of the default _generate_sample_indices().

If there is interest in this feature, I'd be happy to write some tests for it and discuss code style, etc. Thanks!

@amueller
Copy link
Member

Can you provide an example where this works better then class_weights? And is that generally your experience? cc @arjoly @glouppe

@potash
Copy link
Author

potash commented Aug 31, 2015

The linked paper above has several examples. See tables 3-7 for comparison of performance including Balanced Random Forest (BRF) and Weighted Random Forest (WRF), which is class_weight='auto'. The authors (n.b. Brieman is the creator of Random Forests) conclude:

Between WRF and BRF, however, there is no clear winner. By the construction of BRF and WRF, we found that BRF is computationally more efficient with large imbalanced data, since each tree only uses a small portion of the training set to grow, while WRF needs to use the entire training set. WRF assigns a weight to the minority class, possibly making it more vulnerable to noise (mis-labeled class) than BRF. A majority case that is mislabled as belonging to the minority class may have a larger effect on the prediction accuracy of the majority class in WRF than in BRF.

In my experience another benefit of WRF is improved precision at the top. I.e. take the top N predictions (as determined by their predicted probability) and assume them to be True, and then calculate precision. This is a common metric for resource allocation problems. BRF does very well here.

Unfortunately my own examples are not using public data. Would you like me to build a public example using sklearn.datasets or is the above paper sufficient?

@amueller
Copy link
Member

It would be great to have an example using a public dataset.

@ogrisel
Copy link
Member

ogrisel commented Sep 1, 2015

+1 for an example. It would be great to include fit times for WRF vs BRF (to emphasize the computational advantage).

Could you give an example to explain the resource allocation / precision at the top case? I am not sure I understand. Maybe you could include that in your example as well.

For the example it would be great to have it run fast enough (e.g. less than 60s, ideally less than 10s).

@potash
Copy link
Author

potash commented Sep 4, 2015

Hi all, sorry for the delay. I've added an example in examples/ensemble/balanced_random_forest.py. The dataset is KDD Cup '99:

The competition task was to build a network intrusion detector, a predictive model capable of distinguishing between bad'' connections, called intrusions or attacks, andgood'' normal connections. This database contains a standard set of data to be audited, which includes a wide variety of intrusions simulated in a military network environment.

The original data has a dozen or so classes but I've squashed all of the minorities into one class which makes up 1.7% of the training set. The script downloads the compressed example data (3mb).

"Precision at k" is defined as follows: given a list of examples and a class of interest we look at the predicted probabilities. We take the k highest probabilities and mark them as members of the class and then compute the precision of this prediction. This metric is useful when one has limited resources to investigate a particular class.

For example if we are predicting a disease and we only have 1000 vaccines, precision at 1000 would tell us what proportion of the vaccines we distribute according to our model actually went to good use. Examples of this type are abundant.

Anyway, my example will print something like:

$ python balanced_random_forest.py
baseline: 0.0176308651472

forest parameters: {}
time elapsed: 66.118927
precision at 28318: 0.557560562187
auc: 0.776372816647

forest parameters: {'class_weight': 'auto'}
time elapsed: 54.718798
precision at 28318: 0.563705063917
auc: 0.7491969778

forest parameters: {'class_weight': 'balanced_subsample'}
time elapsed: 68.085352
precision at 28318: 0.731407585281
auc: 0.853682452421

forest parameters: {'balanced': True}
time elapsed: 10.367698
precision at 28318: 0.673988276008
auc: 0.906080228057

In this case I have chosen k to be the number of positive cases in the test set and have printed out both precision at k as well as auc. As you can see the existing 'balanced subsample' as well as my feature 'balanced' (in this particular run the former did better but generally they are similar) do much better than either the default or class_weight='auto'. Additionally balanced random forest runs in 1/5 of the time of the others.

Thoughts? If there is interest in the merging, I would suggest:

  • including a kind of warning or documentation saying that when balanced=True, setting class_weight to 'auto' or 'balanced_subsample' is redundant.
  • consider moving the the balanced sampling to fit() when the trees are constructed so that the entire X does not need to be copied to each subprocess when running in parallel. this might speed up balanced=True over the competition even more.

@jmschrei
Copy link
Member

jmschrei commented Sep 9, 2015

I'm not sure how I feel about modifying the dataset after it's passed to a classifier, especially when this is the major upgrade. Is it possible that a better approach would be to have a balanced dataset transformer which could be fed into the existing random forest code for the same performance?

@potash
Copy link
Author

potash commented Sep 9, 2015

It's not "modifying" the dataset any more than bootstrap sampling does in the random forest already. It's just a variation on bootstrap to ensure that classes are equally represented. Unfortunately it's not possible to achieve this by passing a transformed dataset into RF.

@arjoly
Copy link
Member

arjoly commented Sep 9, 2015

This looks interesting. Maybe @trevorstephens might want to comment about this pull request.

@amueller
Copy link
Member

amueller commented Sep 9, 2015

This does look very interesting. @glouppe any opinions? It would be nice to have a couple more datasets for comparison.

@trevorstephens
Copy link
Contributor

Thx for the ping @arjoly . I'll try to take a deeper look soon. But some initial thoughts...

Can this method easily be modified to support multi-output? Looks like only single output is currently coded in. At first thought, you could duplicate rows that are chosen for both outputs, or use the union of any generated subsets. Doubtful there's much literature covering such cases, but we do need to do something to support it I suppose.

Can you explain the rationale behind the "precision at 28318"? It would be nice to see a threshold/precision plot with each method overlaid on it I think.

I am very surprised by the difference between "auto" (I think we're using "balanced" now btw) and "balanced subsample". In my trials way back when, I saw very little difference between the two. How stable are these results with different random seeds?

I generally don't use the pre-written "preset" class weights in practice, but grid search for a better weighting scheme. Thus, I think it would be good to also compare this method to grid searching over a range of manually set class_weights.

Will take a look at the code and example more later.

@potash
Copy link
Author

potash commented Sep 9, 2015

Can this method easily be modified to support multi-output? Looks like only single output is currently coded in. At first thought, you could duplicate rows that are chosen for both outputs, or use the union of any generated subsets. Doubtful there's much literature covering such cases, but we do need to do something to support it I suppose.

By multi-output do you mean multi-class? If so, then the natural generalization of the definition in the paper is simply to take bootstrap samples of size n of each class, where n = the size of the smallest class. That is already included in my implementation, see the loop in _generate_balanced_sample_indices(), line 105. If by multi-output you mean something else, pardon me and please explain.

Can you explain the rationale behind the "precision at 28318"? It would be nice to see a threshold/precision plot with each method overlaid on it I think.

28318 was chosen because that was the number of positive cases in the test set. Of course that is cheating because in practice you don't know that number but maybe you have some prior on the proportion of cases. A plot with k on the x-axis and precision or count on the y-axis is what I use in practice.

I am very surprised by the difference between "auto" (I think we're using "balanced" now btw) and "balanced subsample". In my trials way back when, I saw very little difference between the two. How stable are these results with different random seeds?

There is some variation and maybe it decreases by increaseing n_estimators. But with such a huge class imbalance it is not surprising that a simple auto weighting does not suffice. The reason is that for very imbalanced data the class distribution passed to any tree (remember its a bootstrap sample) in the forest will vary. The difference between 'auto' and 'balanced_subsample' is that the latter takes into account this variability.

E.g. in my KDDCup example above, the training set has 500k examples with an incidence (proportion of positive examples) of 1%. The distribution of positives for each of the bootstrap samples is then binomial with n=500k and p=.01. Here's a typical histogram of 1000 such samples.
binom
The problem is that on the low and high ends we are training trees on ~4800, 5200 positive examples respectively. This corresponds to a range of incidences of 9.6% to 1.04%. When class_weight='auto' the weights are 100:1 always, but class_weight='balanced' will adjust for this with a range of class weights from 104:1 to 96:1.

Balanced Random Forests, on the other hand, avoid this variablility by ensuring that the samples are... balanced :)

I generally don't use the pre-written "preset" class weights in practice, but grid search for a better weighting scheme. Thus, I think it would be good to also compare this method to grid searching over a range of manually set class_weights.

Makes sense. Tomorrow I can add grid search and a plot to my example. Thanks for the comments.

@trevorstephens
Copy link
Contributor

By multi-output do you mean multi-class? If so, then the natural generalization of the definition in the paper is simply to take bootstrap samples of size n of each class, where n = the size of the smallest class. That is already included in my implementation, see the loop in _generate_balanced_sample_indices(), line 105. If by multi-output you mean something else, pardon me and please explain.

http://scikit-learn.org/stable/modules/tree.html#multi-output-problems

Basically, y is 2D and you predict multiple targets at once. It is supported by individual trees and random forests.

@trevorstephens
Copy link
Contributor

I ran your example code this evening (the difference between the auto and balanced_subsample weights really surprised me), and this particular problem appears to be extremely unstable on a mere 100 trees.

Over 20 trials:

forest parameters: {}
precision: 0.593090966876 +/- 0.0710233412747
auc: 0.782164324556 +/- 0.0525929649579
time: 307.525359154

forest parameters: {'class_weight': 'auto'}
precision: 0.613754502437 +/- 0.0903029879824
auc: 0.795725857085 +/- 0.0590804050987
time: 262.033653975

forest parameters: {'class_weight': 'balanced_subsample'}
precision: 0.600702733244 +/- 0.0761696971482
auc: 0.793003608462 +/- 0.0563624712942
time: 310.007385969

forest parameters: {'balanced': True}
precision: 0.630664241825 +/- 0.067713466876
auc: 0.798414825638 +/- 0.0432061839716
time: 92.456414938

Where the reported results are mean +/- stdev and the time taken is in seconds for all 20 trials to complete.

I do think there's value here. But the metric results don't appear to be as significant as your results above, at least on this problem. The time is indeed an improvement, but might potentially be worse than the existing solutions on more balanced and/or larger datasets due to recreating the dataset at each tree.

As another thought on API, this could probably be incorporated into the existing class_weight param as "undersample" or something (the existing ones are mostly akin to oversampling as far as I'm concerned). Would make for less clutter and easier grid searching too.

Some solution for multi-output behavior is still required as I mentioned.

@trevorstephens
Copy link
Contributor

Above results are with no random_state set, so just different seeds on each trial BTW

@glouppe
Copy link
Contributor

glouppe commented Sep 11, 2015

Coming a bit late to the party...

Am I correctly understanding that this is effectively implementing class balancing (in terms of number of samples) with bootstrap for each class? It should be very similar to class_weight="auto" with bootstrap=True right?

In any case, given this was proposed by Breiman, I am not against adding such a balancing strategy.

@glouppe
Copy link
Contributor

glouppe commented Sep 11, 2015

As another thought on API, this could probably be incorporated into the existing class_weight param as "undersample" or something (the existing ones are mostly akin to oversampling as far as I'm concerned). Would make for less clutter and easier grid searching too

I would prefer something like this too!

@glouppe
Copy link
Contributor

glouppe commented Sep 11, 2015

Am I correctly understanding that this is effectively implementing class balancing (in terms of number of samples) with bootstrap for each class? It should be very similar to class_weight="auto" with bootstrap=True right?

Hmm actually one of the reasons I believe this PR behaves differently than class_weight="auto" + bootstrap=True is because bootstrap replicates are drawn in master without taking sample weights into account. I think we should fix that, but this will not be backward compatible... :/ @arjoly @trevorstephens what do you think?

@arjoly
Copy link
Member

arjoly commented Sep 11, 2015

Arf, could it be consider as a bugfix? There are probably weird situation where the current behavior is sub-optimal. (For instance in the presence of mostly zero sample weight?)

@trevorstephens
Copy link
Contributor

@glouppe @arjoly The 'balanced_subsample' option calculates class weights after the bootstrap is drawn, so this is already an option to users for the "presets". That does not take into account any sample_weight passed through fit though :-/ ... For user defined class weights through a dict (more useful in my experience) it would make the bootstrap biased with, or without, sample weights.

In terms of differences between under/over sampling, and this is purely arm-waving, from deep trees I've plotted with massive imbalance and heavy class weighting, the ones I've grown tend to go way off to one side, shedding negative cases on the way down until right at the bottom you start to get positive classifications. This PR's method of undersampling the majority class may create more balanced trees, it probably has better regularization at first thought, but you're throwing away a ton of information for each tree. I do think it's a good option to have available to try each way.

It would be pretty easy to create a bootstrap from the weighted samples, just draw floats (up to the sum of weights) instead of ints. Might be more expensive on the searchsorted though and would have to be done for each tree. For deprecating, maybe a temporary bootstrap='weighted' option, or something, would work to bring that online.

@potash
Copy link
Author

potash commented Sep 15, 2015

@glouppe @arjoly I agree with @trevorstephens that this should not be
considered a bugfix because the definition of bootstrap sample does not
include class weights and that is what everyone expects from a random forest.

@trevorstephens while it would be nice to grid search through 'auto',
'balanced_subsample' and 'balanced', this feature is not strictly speaking a
class_weight because it does not assign equal weight to all instances in a
given class. It would not fit very nicely into the existing
utils.compute_sample_weight call that currently handles class_weight. What
do you think about making it accessible via bootstrap='balanced'. That
logically makes the most sense to me because it is really changing the
bootstrapping algorithm. But if you insist, I can make class_weight='balanced'
(or 'balanced_bootstrap' since as you noted 'balanced'=='auto' now) work.

@trevorstephens regarding multi-output, I do not see a great way of handling
it so I have simply raised an error when bootstrap is called that way.

@trevorstephens Yes, the output is unstable with just 100 trees. I chose that
number because above someone expressed a desire for the example to run
quickly. I have updated the example to do 10 runs and print the mean and
standard deviation as you did. That parameter is n and can be changed. Note
that while variation in auc is probably insignificant across models, the
variation in precision at k really is significantly better for balanced=True.
And of course the more efficient fit() is hard to ignore.

I have also changed the example to, as requested, generate a precision chart
for a range of k for each of the previous models as well as a few
different class weights. The custom class weights are chosen to be near
class_weight='auto'.

fig

And here is the printed output:

$ python balanced_random_forest.py
baseline: 0.0176308651472

forest parameters: {}
time elapsed: 71.4973599
precision at 10000: 0.80003 +/- 0.00610459662877
auc: 0.778167629683 +/- 0.0491199994664

forest parameters: {'class_weight': 'auto'}
time elapsed: 59.3670258
precision at 10000: 0.79339 +/- 0.0206917592292
auc: 0.789811530892 +/- 0.0573386601125

forest parameters: {'class_weight': 'balanced_subsample'}
time elapsed: 72.2302383
precision at 10000: 0.78824 +/- 0.015790516141
auc: 0.795884241264 +/- 0.0738124202678

forest parameters: {'balanced': True}
time elapsed: 14.3122445
precision at 10000: 0.82642 +/- 0.00962390773023
auc: 0.79435629339 +/- 0.0367644512771

forest parameters: {'class_weight': {False: 8710, True: 242655}}
time elapsed: 60.1485844
precision at 10000: 0.79488 +/- 0.00511699130349
auc: 0.811902051944 +/- 0.0471212316519

forest parameters: {'class_weight': {False: 8710, True: 363982}}
time elapsed: 59.5174899
precision at 10000: 0.79436 +/- 0.00675206635038
auc: 0.784691168168 +/- 0.0732038976236

forest parameters: {'class_weight': {False: 8710, True: 606637}}
time elapsed: 58.2950391
precision at 10000: 0.79058 +/- 0.018542534886
auc: 0.789464259055 +/- 0.0696026062929

forest parameters: {'class_weight': {False: 8710, True: 727965}}
time elapsed: 58.3999256
precision at 10000: 0.79317 +/- 0.0108844889637
auc: 0.803815016221 +/- 0.0599962680204

I also moved the example to it's own repository that includes
the data because the server was not happy with me repeatedly grabbing the
data. Find it here: https://github.com/potash/brf-example/

I would love to get this merged soon because I am getting busy with other projects which are causing my responses to get increasingly delayed. (Thanks for your patience.) I think the only remaining item is the API. What are the thoughts on that?

@trevorstephens
Copy link
Contributor

I think that this is a really interesting proposition. Training time aside, in my work I do a lot of highly imbalanced classification stuffs, and alternative strategies are very welcome to explore different solutions.

Support for multi-output should be incorporated though, even if it is sub-optimal. Having some parameter options available and other absent based on input shape doesn't make for a great user experience. While the multi-output case is a royal pain to support, it exists, and any PR that merely errors at its existence is likely to not merge. I would suggest two paths forward for that:

  1. have a set union of the randomized minor classes
  2. repeat the randomized minor classes if the multi-output indicates it should be so

I think # 2 is more appropriate in this case, though I have not put a ton of thought into it and would defer to the resident tree huggers @pprett and @glouppe for their insight.

On the API side. RF is getting pretty busy on the param side. While, perhaps, this solution is not necessarily an explicit class weighting as you mention, it is very much related to it, and makes for a cleaner parameter listing IMO if this was moved to the existing class_weight parameter. I do not believe this fits into the bootstrap param well.

Would love to hear some core dev's opinions on the above as well.

@potash
Copy link
Author

potash commented Oct 7, 2015

@trevorstephens I finally got around to implementing multi-output balanced random forest!

I didn't quite understand the ideas you listed above. What I did is to just consider for a given output i and label j the class as the pair (i,j) and then find the minority class in that sense. Then I do a bootstrap sample of each class of size (i_min, j_min).

The result won't be completely balanced as there could be a correlation between classes of different outputs. So if the minority class of the first output is correlated with a majority class of the second output, that majority class will be still be majority in the balanced sample.

Does this work for you? I included an example in in brf_multioutput.py. If you prefer something else could you expand on your algorithm above?

@trevorstephens
Copy link
Contributor

I will try to find time to look it over soon @potash

In the meantime you may like to pip install pep8 pyflakes and run both on the modified file as your code looks to have a few missing whitespaces near commas and extended line lengths in places. Adding some documentation on the new parameter in the constructor docstring would help other contributors and core devs weigh in as well.

Adding tests will be required for any eventual merge. There's some you could butcher for class_weight in there. Testing the indices generating function independently would be good for a sanity check on a small unbalanced y for each of binary classification, multi-output, multi-label, and the combination of those as well.

I still feel that this should be incorporated into the class_weight parameter somehow. It is definitely something I would want to grid-search over vs the existing solutions.

@trevorstephens
Copy link
Contributor

As for the "algorithm" I mentioned. Pretty simple idea:

idx: 0 1 2 3 4 5 6 7 8 9
y0:  1 0 0 0 0 1 0 1 0 0
y1:  0 0 0 0 1 1 0 0 0 0

So the independent balanced samples for y0 might be [0, 2, 5, 7, 8, 9] and for y1 might be [4, 5, 8, 9]. The multi-output balanced sample could then be either the set of the two: [0, 2, 4, 5, 7, 8, 9] or, the combination [0, 2, 5, 7, 8, 9, 4, 5, 8, 9]. The latter is more akin to the way I did it in the multi-output class_weight logic as the weights are multiplied and repeating indices is kinda sorta the same. While there probably isn't much literature to lean on to determine what the "correct" behavior should be for multi-output, IMO (I don't come across many use-cases) it just needs to be sensible, and not break.

I'll have to look at your example and think about potential corner-cases for your solution to the multi-output case. It sounds rational at first glance though.

@potash
Copy link
Author

potash commented Oct 8, 2015

@trevorstephens Ah, so my method is similar to your second one except instead of samples of size 3 and 2 for y0, y1 respectively it will just use size 2 (the global min) for all y_i. Let me know what you think. I'll get to the documentation and tests soon.

@potash potash force-pushed the feature/balanced-random-forest branch from 0c2a0ed to 8a5a645 Compare April 11, 2017 17:21
@potash potash force-pushed the feature/balanced-random-forest branch from 8a5a645 to 59f7c85 Compare April 11, 2017 18:10
@massich massich mentioned this pull request Apr 12, 2017
4 tasks
@chkoar chkoar mentioned this pull request Feb 22, 2019
@amueller amueller added the Superseded PR has been replace by a newer PR label Aug 5, 2019
Base automatically changed from master to main January 22, 2021 10:48
@lorentzenchr
Copy link
Member

Succeeded by #13227.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module:ensemble Superseded PR has been replace by a newer PR
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants