Skip to content

[MRG+1] Trees as feature transformers #5037

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 31, 2015
Merged

[MRG+1] Trees as feature transformers #5037

merged 2 commits into from
Aug 31, 2015

Conversation

betatim
Copy link
Member

@betatim betatim commented Jul 27, 2015

This is an example using ensembles of trees to show how to use them to transform your samples into a high dimensional, sparse feature space and then train a linear model on that. In particular how to use the apply method of a DecisionTree. It came about from a discussion in #4488 and #4549. It is loosely based on @ogrisel's notebook here: http://nbviewer.ipython.org/github/ogrisel/notebooks/blob/master/sklearn_demos/Income%20classification.ipynb

In #4549 there is talk of showing the difference between a linear model and PCA. Not quite sure I got what you meant.

This could interest @amueller, @ogrisel and @vene.

from sklearn.cross_validation import train_test_split
from sklearn.metrics import roc_curve

Nest = 10
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you rename this n_estimators?

@betatim
Copy link
Member Author

betatim commented Jul 28, 2015

Content related: do you think as the example is now people can understand what it is meant to show (if they didn't already know)? I was thinking of putting a comment by the gradient boosting part to say "this is the interesting bit, using a tree's apply method".

Build related:
Any ideas why a lot of the gaussian process tests are failing in the py3.4 build?

In a previous build a test in test_ridge.py failed (https://ci.appveyor.com/project/sklearn-ci/scikit-learn/build/1.0.1287/job/wdqjboo3dvnw6lk4) which looks like a Heisenbug?

Should I file an issue for those or are they known to be strange?

@amueller
Copy link
Member

I think what I meant in #4549 was show a PCA on these vs on the original data? I'm not entirely following my wording, though.

@@ -0,0 +1,103 @@
"""Use trees to transform your features
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need to have a title like in the other plot examples which will be used in the example gallery.

@amueller
Copy link
Member

Thanks for working on this!

Can you post the plot here?
Also, it would be good to reference this in the user guide.

@betatim
Copy link
Member Author

betatim commented Jul 28, 2015

roc-curve

Undecided if zooming in on the top left (interesting) part is more helpful or confusing for others.

edit: saving the image while the figure is still open helps ...

@amueller
Copy link
Member

the image looks white to me...

@betatim
Copy link
Member Author

betatim commented Jul 28, 2015

Updated the plot, and added a reference in the ensemble section where RandomTreesEmbedding is discussed. I looked in "Dataset transformations" but couldn't find an existing topic where this would fit there. Somehow expected to see tree based transformations there as well.

If you can't quite remember the PCA comment, and I can't work it out either should we skip it/make a second example?

@amueller
Copy link
Member

looks good. Maybe it would be interesting to compare training the lr on the same training set vs a hold-out set? or at least mention that?

@betatim
Copy link
Member Author

betatim commented Jul 29, 2015

You mean like this (just showing the RF part here but I changed it for all of the models):

X, y = make_classification(n_samples=80000)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5)
X_train, X_train_lr, y_train, y_train_lr = train_test_split(X, y, test_size=0.5)

...
# Supervised transformation based on random forests
rf = RandomForestClassifier(max_depth=3, n_estimators=n_estimator)
rf_enc = OneHotEncoder()
rf_lm = LogisticRegression()
rf.fit(X_train, y_train)
rf_enc.fit(rf.apply(X_train))
rf_lm.fit(rf_enc.transform(rf.apply(X_train_lr)), y_train_lr)

...

roc-curve

Hard to tell if it makes a difference. What is the idea behind using a different dataset for fitting the lr instead of using the same as for fitting the trees? We can change it or add it, but you'll have to provide the sentence to motivate it as I don't know enough ;)

@amueller
Copy link
Member

well the idea would be to be closer to "stacking". Intuitively training both on the same set should lead to crazy overfitting. but maybe not. let's keep it simple.

@betatim
Copy link
Member Author

betatim commented Jul 30, 2015

Ok. I looked at 8.8 of ESL but I'll have to read it a few more times before it sinks in. Changed the example to use different subsets for fitting of the trees and the LR model.

@amueller
Copy link
Member

Well, I'm not entirely sure what the "industry standard" is.

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5)
# It is important to train the ensemble of trees on a different subset
# of the training data than the linear regression model to avoid overfitting
X_train, X_train_lr, y_train, y_train_lr = train_test_split(X, y, test_size=0.5)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldnt you do train_test_split(X_train, y_train, test_size=0.5)?

@betatim
Copy link
Member Author

betatim commented Jul 31, 2015

roc-curve

Updated ROC curves. After discussing with @glouppe this morning I think it makes sense to split the training samples again. Updated the comment a bit to explain why/when this is important.


Each sample goes through the decisions of each tree of the ensemble
and ends up in one leaf per tree. The sample is encoded by setting
feature values for these leafs to 1 and the other feature values to 0.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

leafs -> leaves (?)

@ndawe
Copy link
Member

ndawe commented Jul 31, 2015

Showing inverse false positive rate on a log scale vs true positive rate will "enhance" the difference:

figure_1

Also changed n_samples to 10k, n_estimators to 20, and random seed to 10.

plt.plot(tpr_rt_lm, 1 / fpr_rt_lm, label='RT + LR')
plt.plot(tpr_rf, 1 / fpr_rf, label='RF')
plt.plot(tpr_rf_lm, 1 / fpr_rf_lm, label='RF + LR')
plt.plot(tpr_grd, 1 / fpr_grd, label='GBT')
plt.plot(tpr_grd_lm, 1 / fpr_grd_lm, label='GBT + LR')
plt.yscale('log')
plt.ylabel('Inverse false positive rate')
plt.xlabel('True positive rate')
plt.title('ROC curve')
plt.legend(loc='best')
plt.show()

@amueller
Copy link
Member

With "enhanced" differences I would be a bit afraid of sending a message like "this is always better than that" when this is a very zoomed-in view of a single run on a particular synthetic dataset.

@amueller
Copy link
Member

(and currently it compares against a random forest using only part of the data, right?)

@betatim
Copy link
Member Author

betatim commented Aug 4, 2015

I have a slight preference for the plain ROC curve. The original intent of the example was to show how to use the individual tree's apply method. (Though @glouppe was surprised that GradientBoostedClassifier does not have a public apply)

Right now it compares RandomTreesEmbedding, RandomForestClassifier and GradientBoostedClassifier. Each of these is used as input for a LogisticRegression which is then used to classify. For the last two it also shows the performance without the LR step. The trees and LR are trained on different subsets of the data. For the comparison without the LR we could retrain the tree ensembles on the whole training set...though that makes it feel like a lesson on whether or not you should use this technique, which was not the original intent.

@betatim betatim changed the title Trees as feature transformers [MRG] Trees as feature transformers Aug 9, 2015
This example trains several tree based ensemble methods and uses
them to transform the data into a high dimensional, sparse space.
The trains a linear model on this new feature space. The idea is
taken from:

Practical Lessons from Predicting Clicks on Ads at Facebook Junfeng Pan,
He Xinran, Ou Jin, Tianbing XU, Bo Liu, Tao Xu, Yanxin Shi, Antoine
Atallah, Ralf Herbrich, Stuart Bowers, Joaquin Quiñonero Candela
International Workshop on Data Mining for Online Advertising (ADKDD)

https://www.facebook.com/publications/329190253909587/
@larsmans
Copy link
Member

Shouldn't this at least link to the RandomTreesEmbedding docs?

@amueller
Copy link
Member

Yeah, probably. I don't think we usually link from examples to docs. You can always click on the class to get to the api doc.

@larsmans
Copy link
Member

By link I meant at least mention it :) This seems to be doing the same thing that estimator does, except with more control over the type of trees.

@betatim
Copy link
Member Author

betatim commented Aug 14, 2015

It is so similar that we even use RandomTreesEmbedding in the example (RT+LR in the legend) :) What do you think of this modification?

@glouppe
Copy link
Contributor

glouppe commented Aug 30, 2015

I think the example is good enough to be merged. It is a nice demonstration of the apply method of tree-based methods. +1

@glouppe glouppe changed the title [MRG] Trees as feature transformers [MRG+1] Trees as feature transformers Aug 30, 2015
amueller added a commit that referenced this pull request Aug 31, 2015
[MRG+1] Trees as feature transformers
@amueller amueller merged commit 96c329f into scikit-learn:master Aug 31, 2015
@amueller
Copy link
Member

Thanks!

@betatim betatim deleted the tree-feature-transform branch August 31, 2015 21:10
@betatim
Copy link
Member Author

betatim commented Aug 31, 2015

Thanks! Now back to walking in the swiss alps: 🗻

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants