From 89ec72ea2fca97d2e5f7921fdf5e345bd308819f Mon Sep 17 00:00:00 2001 From: Tim Head Date: Mon, 27 Jul 2015 16:38:50 +0200 Subject: [PATCH 1/2] Example to demonstrate use of tree.apply() method MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This example trains several tree based ensemble methods and uses them to transform the data into a high dimensional, sparse space. The trains a linear model on this new feature space. The idea is taken from: Practical Lessons from Predicting Clicks on Ads at Facebook Junfeng Pan, He Xinran, Ou Jin, Tianbing XU, Bo Liu, Tao Xu, Yanxin Shi, Antoine Atallah, Ralf Herbrich, Stuart Bowers, Joaquin QuiƱonero Candela International Workshop on Data Mining for Online Advertising (ADKDD) https://www.facebook.com/publications/329190253909587/ --- doc/modules/ensemble.rst | 3 + .../ensemble/plot_feature_transformation.py | 113 ++++++++++++++++++ 2 files changed, 116 insertions(+) create mode 100644 examples/ensemble/plot_feature_transformation.py diff --git a/doc/modules/ensemble.rst b/doc/modules/ensemble.rst index 315b23486a1b0..374a2db42ad3f 100644 --- a/doc/modules/ensemble.rst +++ b/doc/modules/ensemble.rst @@ -306,6 +306,9 @@ the transformation performs an implicit, non-parametric density estimation. * :ref:`example_manifold_plot_lle_digits.py` compares non-linear dimensionality reduction techniques on handwritten digits. + * :ref:`example_ensemble_plot_feature_transformation.py` compares + supervised and unsupervised tree based feature transformations. + .. seealso:: :ref:`manifold` techniques can also be useful to derive non-linear diff --git a/examples/ensemble/plot_feature_transformation.py b/examples/ensemble/plot_feature_transformation.py new file mode 100644 index 0000000000000..b208a34aa466a --- /dev/null +++ b/examples/ensemble/plot_feature_transformation.py @@ -0,0 +1,113 @@ +""" +=============================================== +Feature transformations with ensembles of trees +=============================================== + +Transform your features into a higher dimensional, sparse space. Then +train a linear model on these features. + +First fit an ensemble of trees (e.g. gradient boosted trees or a +random forest) on the training set. Then each leaf of each tree in the +ensemble is assigned a fixed arbitrary feature index in a new feature +space. These leaf indices are then encoded in a one-hot fashion. + +Each sample goes through the decisions of each tree of the ensemble +and ends up in one leaf per tree. The sample is encoded by setting +feature values for these leaves to 1 and the other feature values to 0. + +The resulting transformer has then learned a supervised, sparse, +high-dimensional categorical embedding of the data. +""" + +# Author: Tim Head +# +# License: BSD 3 clause + +import numpy as np +np.random.seed(10) + +import matplotlib.pyplot as plt + +from sklearn.datasets import make_classification +from sklearn.linear_model import LogisticRegression +from sklearn.ensemble import (RandomTreesEmbedding, RandomForestClassifier, + GradientBoostingClassifier) +from sklearn.preprocessing import OneHotEncoder +from sklearn.pipeline import Pipeline +from sklearn.cross_validation import train_test_split +from sklearn.metrics import roc_curve + +n_estimator = 10 +X, y = make_classification(n_samples=80000) +X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5) +# It is important to train the ensemble of trees on a different subset +# of the training data than the linear regression model to avoid +# overfitting, in particular if the total number of leaves is +# similar to the number of training samples +X_train, X_train_lr, y_train, y_train_lr = train_test_split(X_train, + y_train, + test_size=0.5) + +# Unsupervised transformation based on totally random trees +rt = RandomTreesEmbedding(max_depth=3, n_estimators=n_estimator) +rt_lm = LogisticRegression() +rt.fit(X_train, y_train) +rt_lm.fit(rt.transform(X_train_lr), y_train_lr) + +y_pred_rt = rt_lm.predict_proba(rt.transform(X_test))[:, 1] +fpr_rt_lm, tpr_rt_lm, _ = roc_curve(y_test, y_pred_rt) + + +# Supervised transformation based on random forests +rf = RandomForestClassifier(max_depth=3, n_estimators=n_estimator) +rf_enc = OneHotEncoder() +rf_lm = LogisticRegression() +rf.fit(X_train, y_train) +rf_enc.fit(rf.apply(X_train)) +rf_lm.fit(rf_enc.transform(rf.apply(X_train_lr)), y_train_lr) + +y_pred_rf_lm = rf_lm.predict_proba(rf_enc.transform(rf.apply(X_test)))[:, 1] +fpr_rf_lm, tpr_rf_lm, _ = roc_curve(y_test, y_pred_rf_lm) + + +# Supervised transformation based on gradient boosted trees. Demonstrates +# the use of each tree's apply() method. +def gradient_apply(clf, X): + X_trans = [] + for tree in clf.estimators_.ravel(): + X_trans.append(tree.apply(X)) + return np.array(X_trans).T + +grd = GradientBoostingClassifier(n_estimators=n_estimator) +grd_enc = OneHotEncoder() +grd_lm = LogisticRegression() +grd.fit(X_train, y_train) +grd_enc.fit(gradient_apply(grd, X_train)) +grd_lm.fit(grd_enc.transform(gradient_apply(grd, X_train_lr)), y_train_lr) + +y_pred_grd_lm = grd_lm.predict_proba( + grd_enc.transform(gradient_apply(grd, X_test)))[:, 1] +fpr_grd_lm, tpr_grd_lm, _ = roc_curve(y_test, y_pred_grd_lm) + + +# The gradient boosted model by itself +y_pred_grd = grd.predict_proba(X_test)[:, 1] +fpr_grd, tpr_grd, _ = roc_curve(y_test, y_pred_grd) + + +# The random forest model by itself +y_pred_rf = rf.predict_proba(X_test)[:, 1] +fpr_rf, tpr_rf, _ = roc_curve(y_test, y_pred_rf) + + +plt.plot([0, 1], [0, 1], 'k--') +plt.plot(fpr_rt_lm, tpr_rt_lm, label='RT + LR') +plt.plot(fpr_rf, tpr_rf, label='RF') +plt.plot(fpr_rf_lm, tpr_rf_lm, label='RF + LR') +plt.plot(fpr_grd, tpr_grd, label='GBT') +plt.plot(fpr_grd_lm, tpr_grd_lm, label='GBT + LR') +plt.xlabel('False positive rate') +plt.ylabel('True positive rate') +plt.title('ROC curve') +plt.legend(loc='best') +plt.show() From ea8d092f2c42c8cdab392ae67a39242458296510 Mon Sep 17 00:00:00 2001 From: Tim Head Date: Fri, 14 Aug 2015 20:04:46 +0200 Subject: [PATCH 2/2] Explicitly mention RandomTreesEmbedding in the text --- examples/ensemble/plot_feature_transformation.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/examples/ensemble/plot_feature_transformation.py b/examples/ensemble/plot_feature_transformation.py index b208a34aa466a..2271c51daaf75 100644 --- a/examples/ensemble/plot_feature_transformation.py +++ b/examples/ensemble/plot_feature_transformation.py @@ -6,10 +6,11 @@ Transform your features into a higher dimensional, sparse space. Then train a linear model on these features. -First fit an ensemble of trees (e.g. gradient boosted trees or a -random forest) on the training set. Then each leaf of each tree in the -ensemble is assigned a fixed arbitrary feature index in a new feature -space. These leaf indices are then encoded in a one-hot fashion. +First fit an ensemble of trees (totally random trees, a random +forest, or gradient boosted trees) on the training set. Then each leaf +of each tree in the ensemble is assigned a fixed arbitrary feature +index in a new feature space. These leaf indices are then encoded in a +one-hot fashion. Each sample goes through the decisions of each tree of the ensemble and ends up in one leaf per tree. The sample is encoded by setting @@ -17,6 +18,7 @@ The resulting transformer has then learned a supervised, sparse, high-dimensional categorical embedding of the data. + """ # Author: Tim Head