diff --git a/doc/whats_new.rst b/doc/whats_new.rst index f70761ccb0470..3805b1d708b1a 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -156,6 +156,11 @@ Enhancements visible with extra trees and on datasets with categorical or sparse features. By `Arnaud Joly`_. + - :class:`ensemble.GradientBoostingRegressor` and + :class:`ensemble.GradientBoostingClassifier` now expose an ``apply`` + method for retrieving the leaf indices each sample ends up in under + each try. By `Jacob Schreiber`_. + Bug fixes ......... diff --git a/examples/ensemble/plot_feature_transformation.py b/examples/ensemble/plot_feature_transformation.py index 71114a64818f6..6faf222cf9875 100644 --- a/examples/ensemble/plot_feature_transformation.py +++ b/examples/ensemble/plot_feature_transformation.py @@ -70,24 +70,15 @@ y_pred_rf_lm = rf_lm.predict_proba(rf_enc.transform(rf.apply(X_test)))[:, 1] fpr_rf_lm, tpr_rf_lm, _ = roc_curve(y_test, y_pred_rf_lm) - -# Supervised transformation based on gradient boosted trees. Demonstrates -# the use of each tree's apply() method. -def gradient_apply(clf, X): - X_trans = [] - for tree in clf.estimators_.ravel(): - X_trans.append(tree.apply(X)) - return np.array(X_trans).T - grd = GradientBoostingClassifier(n_estimators=n_estimator) grd_enc = OneHotEncoder() grd_lm = LogisticRegression() grd.fit(X_train, y_train) -grd_enc.fit(gradient_apply(grd, X_train)) -grd_lm.fit(grd_enc.transform(gradient_apply(grd, X_train_lr)), y_train_lr) +grd_enc.fit(grd.apply(X_train)[:, :, 0]) +grd_lm.fit(grd_enc.transform(grd.apply(X_train_lr)[:, :, 0]), y_train_lr) y_pred_grd_lm = grd_lm.predict_proba( - grd_enc.transform(gradient_apply(grd, X_test)))[:, 1] + grd_enc.transform(grd.apply(X_test)[:, :, 0]))[:, 1] fpr_grd_lm, tpr_grd_lm, _ = roc_curve(y_test, y_pred_grd_lm) @@ -100,7 +91,7 @@ def gradient_apply(clf, X): y_pred_rf = rf.predict_proba(X_test)[:, 1] fpr_rf, tpr_rf, _ = roc_curve(y_test, y_pred_rf) - +plt.figure(1) plt.plot([0, 1], [0, 1], 'k--') plt.plot(fpr_rt_lm, tpr_rt_lm, label='RT + LR') plt.plot(fpr_rf, tpr_rf, label='RF') @@ -112,3 +103,18 @@ def gradient_apply(clf, X): plt.title('ROC curve') plt.legend(loc='best') plt.show() + +plt.figure(2) +plt.xlim(0, 0.2) +plt.ylim(0.8, 1) +plt.plot([0, 1], [0, 1], 'k--') +plt.plot(fpr_rt_lm, tpr_rt_lm, label='RT + LR') +plt.plot(fpr_rf, tpr_rf, label='RF') +plt.plot(fpr_rf_lm, tpr_rf_lm, label='RF + LR') +plt.plot(fpr_grd, tpr_grd, label='GBT') +plt.plot(fpr_grd_lm, tpr_grd_lm, label='GBT + LR') +plt.xlabel('False positive rate') +plt.ylabel('True positive rate') +plt.title('ROC curve (zoomed in at top left)') +plt.legend(loc='best') +plt.show() \ No newline at end of file diff --git a/sklearn/ensemble/gradient_boosting.py b/sklearn/ensemble/gradient_boosting.py index 0cb432995e15f..3ceceae7b0ea9 100644 --- a/sklearn/ensemble/gradient_boosting.py +++ b/sklearn/ensemble/gradient_boosting.py @@ -17,7 +17,7 @@ """ # Authors: Peter Prettenhofer, Scott White, Gilles Louppe, Emanuele Olivetti, -# Arnaud Joly +# Arnaud Joly, Jacob Schreiber # License: BSD 3 clause from __future__ import print_function @@ -898,6 +898,13 @@ def _resize_state(self): def _is_initialized(self): return len(getattr(self, 'estimators_', [])) > 0 + def _check_initialized(self): + """Check that the estimator is initialized, raising an error if not.""" + if self.estimators_ is None or len(self.estimators_) == 0: + raise NotFittedError("Estimator not fitted, call `fit`" + " before making predictions`.") + + def fit(self, X, y, sample_weight=None, monitor=None): """Fit the gradient boosting model. @@ -1067,9 +1074,7 @@ def _make_estimator(self, append=True): def _init_decision_function(self, X): """Check input and compute prediction of ``init``. """ - if self.estimators_ is None or len(self.estimators_) == 0: - raise NotFittedError("Estimator not fitted, call `fit`" - " before making predictions`.") + self._check_initialized() if X.shape[1] != self.n_features: raise ValueError("X.shape[1] should be {0:d}, not {1:d}.".format( self.n_features, X.shape[1])) @@ -1164,9 +1169,7 @@ def feature_importances_(self): ------- feature_importances_ : array, shape = [n_features] """ - if self.estimators_ is None or len(self.estimators_) == 0: - raise NotFittedError("Estimator not fitted, call `fit` before" - " `feature_importances_`.") + self._check_initialized() total_sum = np.zeros((self.n_features, ), dtype=np.float64) for stage in self.estimators_: @@ -1184,6 +1187,38 @@ def _validate_y(self, y): # Default implementation return y + def apply(self, X): + """Apply trees in the ensemble to X, return leaf indices. + + Parameters + ---------- + X : array-like or sparse matrix, shape = [n_samples, n_features] + The input samples. Internally, it will be converted to + ``dtype=np.float32`` and if a sparse matrix is provided + to a sparse ``csr_matrix``. + + Returns + ------- + X_leaves : array_like, shape = [n_samples, n_estimators, n_classes] + For each datapoint x in X and for each tree in the ensemble, + return the index of the leaf x ends up in in each estimator. + In the case of binary classification n_classes is 1. + """ + + self._check_initialized() + X = self.estimators_[0, 0]._validate_X_predict(X, check_input=True) + + # n_classes will be equal to 1 in the binary classification or the + # regression case. + n_estimators, n_classes = self.estimators_.shape + leaves = np.zeros((X.shape[0], n_estimators, n_classes)) + + for i in range(n_estimators): + for j in range(n_classes): + estimator = self.estimators_[i, j] + leaves[:, i, j] = estimator.apply(X, check_input=False) + + return leaves class GradientBoostingClassifier(BaseGradientBoosting, ClassifierMixin): """Gradient Boosting for classification. @@ -1704,3 +1739,25 @@ def staged_predict(self, X): """ for y in self._staged_decision_function(X): yield y.ravel() + + def apply(self, X): + """Apply trees in the ensemble to X, return leaf indices. + + Parameters + ---------- + X : array-like or sparse matrix, shape = [n_samples, n_features] + The input samples. Internally, it will be converted to + ``dtype=np.float32`` and if a sparse matrix is provided + to a sparse ``csr_matrix``. + + Returns + ------- + X_leaves : array_like, shape = [n_samples, n_estimators] + For each datapoint x in X and for each tree in the ensemble, + return the index of the leaf x ends up in in each estimator. + """ + + leaves = super(GradientBoostingRegressor, self).apply(X) + leaves = leaves.reshape(X.shape[0], self.estimators_.shape[0]) + return leaves + \ No newline at end of file diff --git a/sklearn/ensemble/tests/test_gradient_boosting.py b/sklearn/ensemble/tests/test_gradient_boosting.py index c1043e8da482f..8b579f8938eb2 100644 --- a/sklearn/ensemble/tests/test_gradient_boosting.py +++ b/sklearn/ensemble/tests/test_gradient_boosting.py @@ -62,6 +62,9 @@ def test_classification_toy(): assert np.any(deviance_decrease >= 0.0), \ "Train deviance does not monotonically decrease." + leaves = clf.apply(X) + assert_equal(leaves.shape, (6, 10, 1)) + def test_parameter_checks(): # Check input parameter validation. @@ -182,6 +185,9 @@ def test_boston(): assert_raises(ValueError, clf.predict, boston.data) clf.fit(boston.data, boston.target, sample_weight=sample_weight) + leaves = clf.apply(boston.data) + assert_equal(leaves.shape, (506, 100)) + y_pred = clf.predict(boston.data) mse = mean_squared_error(boston.target, y_pred) assert mse < 6.0, "Failed with loss %s and " \ @@ -207,6 +213,9 @@ def test_iris(): assert score > 0.9, "Failed with subsample %.1f " \ "and score = %f" % (subsample, score) + leaves = clf.apply(iris.data) + assert_equal(leaves.shape, (150, 100, 3)) + def test_regression_synthetic(): # Test on synthetic regression datasets used in Leo Breiman, @@ -1012,3 +1021,4 @@ def test_non_uniform_weights_toy_edge_case_clf(): gb = GradientBoostingClassifier(n_estimators=5) gb.fit(X, y, sample_weight=sample_weight) assert_array_equal(gb.predict([[1, 0]]), [1]) +