Skip to content

[MRG+1] Merge PresortBestSplitter and BestSplitter #5252

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 29, 2015

Conversation

jmschrei
Copy link
Member

This pull request merges PresortBestSplitter into BestSplitter, without loss of functionality. All tree based estimators (Decision Trees, Random Forests, Extra Trees, and Gradient Boosting) have an optional parameter "presort," which is default false for all except Gradient Boosting.

In addition to allowing other estimators to turn on presorting on smaller data sets, it allows Gradient Boosting to turn off presorting on large datasets. A good default for this switch needs to be found, as it currently is done manually. Gradient Boosting can now also work on sparse data by turning off the presorting option.

Here is a checklist of what needs to be done (much more manageable than my last one):

  • Merge PresortBestSplitter into BestSplitter
  • Write unit tests to test sparse data, presorting options for RF and DT, and non-presorting for GBTs
  • Find a good default switch for GBT to switch from presorting to non-presorting

As a side note, I'm getting some pretty high variance running this benchmark, with sometimes my branch being faster and sometimes slower, by up to 50s. Can someone else run the benchmarks a few times and see what they get?

@arjoly @pprett @glouppe

MASTER
Classification performance:
===========================
Classifier   train-time test-time error-rate
--------------------------------------------
RandomForest  32.1754s   0.4057s     0.0330  
ExtraTrees    39.5981s   0.5889s     0.0372  
CART          12.1820s   0.0209s     0.0424  
GBRT         584.7982s   0.4412s     0.1777

BRANCH
Classification performance:
===========================
Classifier   train-time test-time error-rate
--------------------------------------------
RandomForest  36.0459s   0.3898s     0.0330  
ExtraTrees    35.2664s   0.5764s     0.0372  
CART          15.3547s   0.2031s     0.0424  
GBRT         557.0054s   0.3247s     0.1777  

@jmschrei jmschrei mentioned this pull request Sep 11, 2015
12 tasks
@@ -251,6 +255,11 @@ def fit(self, X, y, sample_weight=None):

random_state = check_random_state(self.random_state)

X_idx_sorted = None
if self.presort == True:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if self.presort: is more pythonic

@jmschrei
Copy link
Member Author

Ran some time tests with and without presorting for GradientBoostingClassifier.

max_depth = 3

with presorting
spambase   0.91     0.11
Gaussian   0.853    8.51
mnist      0.837  191.26
covtypes   0.708  111.45

without presorting
spambase   0.91     0.14
Gaussian   0.853   20.29
mnist      0.837  263.88
covtypes   0.708  134.32

max_depth = 6

with presorting
spambase   0.925    0.29
Gaussian   0.911   22.29
mnist      0.912  718.27
covtypes   0.774  306.24

without presorting
spambase   0.926    0.28
Gaussian   0.911   38.73
mnist      0.913  549.19
covtypes   0.774  249.28

So it looks like some function of max_depth, n_samples, and n_features corresponds to training time. Do you think it's worth finding a good switch, or mentioning it in the docs and leaving it to users to determine if they want to presort or not?

@glouppe
Copy link
Contributor

glouppe commented Sep 14, 2015

So it looks like some function of max_depth, n_samples, and n_features corresponds to training time. Do you think it's worth finding a good switch, or mentioning it in the docs and leaving it to users to determine if they want to presort or not?

I would leave that to the user and only enable presort for boosting. Let's keep things simple :)

@jmschrei
Copy link
Member Author

I have added unit tests for Gradient Boosting and Forests to ensure that presorting returns the same results as not using presorting. I have also added unit tests for Gradient Boosting testing sparse inputs, and ensuring that they give the same results as dense inputs. Please review when you have the time.

@jmschrei jmschrei force-pushed the presort branch 3 times, most recently from 8fd9f2a to ba674b6 Compare September 16, 2015 12:45
@jmschrei jmschrei changed the title [WIP] Merge PresortBestSplitter and BestSplitter [MRG] Merge PresortBestSplitter and BestSplitter Sep 16, 2015
@arjoly
Copy link
Member

arjoly commented Sep 17, 2015

First thanks @jmschrei to give love to those codes. Here a few thoughts about this pull request.

I find great that you add sparsity support to gbrt. (Though I would have done that in another pr). As pointed by some people in the @PPRET pull request, it is needed to store the sparse X matrix in both csr and csc format to handle sparse matrix efficiently. It allows to avoid costly format conversions between the csr and the csc format.

I understand the need and the benefits to merge more of the logic between pre-sort best splitter and the best spliter. In this pr, it is proposed to interleave the codes with if else. If we go to more rationalization such as merging the random splitter and the sparse splitters, this approach might start leading to spaghetti code. What do you think of having inline methods which would be overriden by each concrete implementation? The advantage is that it would add code re-use with some semantics.

I am not fan of adding a presort parameter to the random forest estimators. Since that most (all?) people builds totally developed trees, it might not be used much and add some complexify in terms of usage and maintenance. In the gradient boosting case, I wouldn't allow the user to choose and switch presort or not depending on the matrix sparsity.

@jmschrei
Copy link
Member Author

Thanks for the comments @arjoly.

To be clear, I haven't added in sparsity support for presorting. I added an option to turn presorting off for gradient boosting, and when you do, you get the same sparsity support that decision trees/random forests get. Thus not having a separate PR for it.

I can understand the concern about spaghetti code. Presorting and best splitting are so similar that merging them makes much more sense than any other merger, because it's not an algorithmic difference in how the split is calculated, just in how Xf is prepared for the split. This requires only the addition of 31 lines of code (15 of which are initializing and declaring variables) versus having an entirely new object and ~250 lines of repeated code.

Having inline methods may be the correct way to merge more code in the future. I was thinking about factorizing the node_split method into several methods, such as "sort_array", "sort_array_w_presorting", and "sort_sparse_array", which would return Xf. Then have a "best_split" method which takes in Xf and finds the best split, or a "random_split" method which finds the best random split. Lastly, a "partition_samples" method which is shared. I haven't fully formed the idea yet, but if you have ideas I'd love to see them.

I like the addition of presorting for these other methods, but am not tied to it--it seems like a "why not" case. It's easy to do add in, the default behaviour is not to use it so users won't notice a difference unless they know what they're doing, and should it become difficult to maintain it can be easily removed. Maybe @glouppe will have an opinion.

@glouppe
Copy link
Contributor

glouppe commented Sep 21, 2015

I am not fan of adding a presort parameter to the random forest estimators. Since that most (all?) people builds totally developed trees, it might not be used much and add some complexify in terms of usage and maintenance. In the gradient boosting case, I wouldn't allow the user to choose and switch presort or not depending on the matrix sparsity.

It is true that presorting is not very useful outside of GBRT. +1 for not adding this parameter in forests and keep the public API as it is.

@jmschrei
Copy link
Member Author

Presorting has been removed as an option for tests. Code, unit tests, and documentation have been updated to reflect this.

@jmschrei jmschrei force-pushed the presort branch 2 times, most recently from 48e826f to ec28f15 Compare September 21, 2015 07:58
@@ -934,6 +935,13 @@ def fit(self, X, y, sample_weight=None, monitor=None):
computing held-out estimates, early stopping, model introspect, and
snapshoting.

presort : bool, optional (default=False)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this empty line

@glouppe
Copy link
Contributor

glouppe commented Sep 21, 2015

You should test for presorting in test_tree.py. You can do that I believe by changing

"Presort-DecisionTreeClassifier": partial(DecisionTreeClassifier, splitter="presort-best"),

to

"Presort-DecisionTreeClassifier": partial(DecisionTreeClassifier, presort=True),

(and the same for the regressor)

@jmschrei
Copy link
Member Author

Thanks for the review @glouppe! Good catch, I forgot to change that over. I've incorporated all your comments. It looks like a bunch of TSNE unit tests are failing. :(

@glouppe
Copy link
Contributor

glouppe commented Sep 22, 2015

Great, thanks for the changes! +1 for merge on my side.

@glouppe glouppe changed the title [MRG] Merge PresortBestSplitter and BestSplitter [MRG+1] Merge PresortBestSplitter and BestSplitter Sep 22, 2015
@glouppe
Copy link
Contributor

glouppe commented Sep 22, 2015

@jmschrei It seems there are conflicts in your branch? (I imagine this comes from the what's new)

@arjoly Any more comments?

@jmschrei
Copy link
Member Author

Yep, rebased this branch onto master.

@jmschrei
Copy link
Member Author

I've addressed comments, added an 'auto' default option, and added some more unit tests. Since this PR is getting rather hefty now, I am going to submit another PR in the future applying the same style changes to _splitter.pyx as _criterion.pyx, instead of handling that here.

@arjoly
Copy link
Member

arjoly commented Sep 24, 2015

To grow a tree on sparse, you need the data to be in sparse csc format. However to make a prediction, you need a sparse csr data.This means that whenever you build a boosting models at each step you:
(1) convert the X matrix to csc if needed (2) build the tree (3) convert the X matrix to csr (4) predict with the current tree and (5) repeat (1) until we have done n_estimators. This transformation from csc and csr induce computational costs that should be avoided.

@@ -1318,6 +1317,12 @@ class GradientBoostingClassifier(BaseGradientBoosting, ClassifierMixin):
If None, the random number generator is the RandomState instance used
by `np.random`.

presort : bool, optional (default='auto')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is presort : bool or 'auto', optional (default='auto')

@jmschrei
Copy link
Member Author

To grow a tree on sparse, you need the data to be in sparse csc format. However to make a prediction, you need a sparse csr data.This means that whenever you build a boosting models at each step you:
(1) convert the X matrix to csc if needed (2) build the tree (3) convert the X matrix to csr (4) predict with the current tree and (5) repeat (1) until we have done n_estimators. This transformation from csc and csr induce computational costs that should be avoided.

Got it. Apologies for the confusion.

@jmschrei
Copy link
Member Author

Comments have been incorporated. For sparse data, csc matrices are passed into fitting functions, and csr matrices are passed into prediction functions. Thanks @arjoly for the thorough review.

@glouppe
Copy link
Contributor

glouppe commented Sep 28, 2015

Any more comments @arjoly ? If not I believe we can merge this one.


self.estimators_ = np.empty((0, 0), dtype=np.object)

def _fit_stage(self, i, X, y, y_pred, sample_weight, sample_mask,
criterion, splitter, random_state):
random_state, X_idx_sorted, X_csc=None, X_csr=None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here it would be better to have X_fit and X_predict instead of X, X_csc and X_csr.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I disagree. X_csc and X_csr is explicit about what those are. If I saw X_fit and X_predict in the wild, I'd assume that they were different datasets. Here they are the same dataset.

@arjoly
Copy link
Member

arjoly commented Sep 28, 2015

I made a last round of review.

@jmschrei
Copy link
Member Author

Comments have been into account, and commits have been squashed.

@arjoly
Copy link
Member

arjoly commented Sep 29, 2015

LGTM

arjoly added a commit that referenced this pull request Sep 29, 2015
[MRG+1] Merge PresortBestSplitter and BestSplitter
@arjoly arjoly merged commit 2c758f3 into scikit-learn:master Sep 29, 2015
@arjoly
Copy link
Member

arjoly commented Sep 29, 2015

thanks @jmschrei

@GaelVaroquaux
Copy link
Member

GaelVaroquaux commented Sep 29, 2015 via email

@jmschrei
Copy link
Member Author

🎺

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants