diff --git a/doc/whats_new.rst b/doc/whats_new.rst index 4c970d4b64326..709b72fc38955 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -142,6 +142,10 @@ Enhancements such that the split that maximizes this value also maximizes the impurity improvement. By `Arnaud Joly`_, `Jacob Schreiber`_ and `Gilles Louppe`_ + - :class:'ensemble.GradientBoostingRegressor` and + :class:`ensemble.GradientBoostingClassifier' now expose an ``apply`` + method for retrieving the leaf indices samples are predicted as. By + `Jacob Schreiber`_. Bug fixes ......... diff --git a/sklearn/ensemble/gradient_boosting.py b/sklearn/ensemble/gradient_boosting.py index 4d741299b2cfe..701a86d6bab6d 100644 --- a/sklearn/ensemble/gradient_boosting.py +++ b/sklearn/ensemble/gradient_boosting.py @@ -17,7 +17,7 @@ """ # Authors: Peter Prettenhofer, Scott White, Gilles Louppe, Emanuele Olivetti, -# Arnaud Joly +# Arnaud Joly, Jacob Schreiber # License: BSD 3 clause from __future__ import print_function @@ -34,6 +34,7 @@ from ..base import BaseEstimator from ..base import ClassifierMixin from ..base import RegressorMixin +from ..base import is_classifier from ..utils import check_random_state, check_array, check_X_y, column_or_1d from ..utils import check_consistent_length, deprecated from ..utils.extmath import logsumexp @@ -948,7 +949,7 @@ def fit(self, X, y, sample_weight=None, monitor=None): check_consistent_length(X, y, sample_weight) y = self._validate_y(y) - + random_state = check_random_state(self.random_state) self._check_params() @@ -959,8 +960,41 @@ def fit(self, X, y, sample_weight=None, monitor=None): # fit initial model - FIXME make sample_weight optional self.init_.fit(X, y, sample_weight) - # init predictions - y_pred = self.init_.predict(X) + if is_classifier(self.init_): + n_classes = np.unique(y).shape[0] + else: + n_classes = 1 + + # If the initialization estimator has a predict_proba method, + # either use those, or collapse to a single vector if there + # are only two classes + if hasattr(self.init_, 'predict_proba'): + eps = np.finfo(X.dtype).eps + y_pred = self.init_.predict_proba(X) + eps + if n_classes == 2: + y_pred = np.log(y_pred[:,1] / y_pred[:,0]) + y_pred = y_pred.reshape(n_samples, 1) + + # Otherwise, it can be a naive estimator defined above, in which + # case don't do anything, or a classifier whose estimates will be + # a vector that should be hot encoded, or a regressor whose + # estimates still need to be reshaped from (n_samples,) to + # (n_samples,1) + else: + pred = self.init_.predict(X) + + if len(pred.shape) < 2: + if is_classifier(self.init_): + y_pred = np.zeros((n_samples, n_classes)) + y_pred[:, pred] = 1.0 + if n_classes == 2: + y_pred = np.log(y_pred[:,1] / y_pred[:,0]) + y_pred = y_pred.reshape(n_samples, 1) + else: + y_pred = pred.reshape(n_samples, 1) + else: + y_pred = pred + begin_at_stage = 0 else: # add more estimators to fitted model @@ -975,6 +1009,13 @@ def fit(self, X, y, sample_weight=None, monitor=None): y_pred = self._decision_function(X) self._resize_state() + if is_classifier(self.init_): + n_classes = np.unique(y).shape[0] + else: + n_classes = 1 + + self.n_classes = n_classes + # fit the boosting stages n_stages = self._fit_stages(X, y, y_pred, sample_weight, random_state, begin_at_stage, monitor) @@ -1073,7 +1114,31 @@ def _init_decision_function(self, X): if X.shape[1] != self.n_features: raise ValueError("X.shape[1] should be {0:d}, not {1:d}.".format( self.n_features, X.shape[1])) - score = self.init_.predict(X).astype(np.float64) + # init predictions + + if hasattr(self.init_, 'predict_proba'): + eps = np.finfo(X.dtype).eps + score = self.init_.predict_proba(X) + eps + if self.n_classes == 2: + score = np.log(score[:,1] / score[:,0]) + score = score.reshape(X.shape[0], 1) + else: + pred = self.init_.predict(X) + + if len(pred.shape) < 2: + if is_classifier(self.init_): + score = np.zeros((X.shape[0], self.n_classes)) + score[:, pred] = 1.0 + if self.n_classes == 2: + score = np.log(y_pred[:,1] / y_pred[:,0]) + score = y_pred.reshape(X.shape[0], 1) + else: + score = pred.reshape(X.shape[0], 1) + else: + score = pred + + score = score.astype(np.float64) + return score def _decision_function(self, X): @@ -1107,7 +1172,7 @@ def decision_function(self, X): return score def _staged_decision_function(self, X): - """Compute decision function of ``X`` for each iteration. + """Compute decision function of ``, X`` for each iteration. This method allows monitoring (i.e. determine error on testing set) after each stage. @@ -1184,6 +1249,37 @@ def _validate_y(self, y): # Default implementation return y + def apply(self, X): + """Apply trees in the ensemble to X, return leaf indices. + + Parameters + ---------- + X : array-like or sparse matrix, shape = [n_samples, n_features] + The input samples. Internally, it will be converted to + ``dtype=np.float32`` and if a sparse matrix is provided + to a sparse ``csr_matrix``. + + Returns + ------- + X_leaves : array_like, shape = [n_samples, n_estimators, n_classes] + For each datapoint x in X and for each tree in the ensemble, + return the index of the leaf x ends up in in each estimator. + """ + + if self.estimators_ is None or len(self.estimators_) == 0: + raise NotFittedError("Estimator not fitted, " + "call `fit` before exploiting the model.") + + X = self.estimators_[0, 0]._validate_X_predict(X, check_input=True) + + n_estimators, n_classes = self.estimators_.shape + leaves = np.zeros((X.shape[0], n_estimators, n_classes)) + + for i in range(n_estimators): + for j in range(n_classes): + leaves[:, i, j] = self.estimators_[i, j].apply(X) + + return leaves class GradientBoostingClassifier(BaseGradientBoosting, ClassifierMixin): """Gradient Boosting for classification. @@ -1508,7 +1604,6 @@ def staged_predict_proba(self, X): raise AttributeError('loss=%r does not support predict_proba' % self.loss) - class GradientBoostingRegressor(BaseGradientBoosting, RegressorMixin): """Gradient Boosting for regression. diff --git a/sklearn/ensemble/tests/test_gradient_boosting.py b/sklearn/ensemble/tests/test_gradient_boosting.py index c1043e8da482f..804279ee82558 100644 --- a/sklearn/ensemble/tests/test_gradient_boosting.py +++ b/sklearn/ensemble/tests/test_gradient_boosting.py @@ -6,10 +6,14 @@ from sklearn import datasets from sklearn.base import clone -from sklearn.ensemble import GradientBoostingClassifier -from sklearn.ensemble import GradientBoostingRegressor +from sklearn.cross_validation import train_test_split +from sklearn.ensemble import GradientBoostingClassifier, RandomForestClassifier +from sklearn.ensemble import GradientBoostingRegressor, RandomForestRegressor from sklearn.ensemble.gradient_boosting import ZeroEstimator +from sklearn.linear_model import Ridge from sklearn.metrics import mean_squared_error +from sklearn.svm import SVC, SVR +from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor from sklearn.utils import check_random_state, tosequence from sklearn.utils.testing import assert_almost_equal from sklearn.utils.testing import assert_array_almost_equal @@ -62,6 +66,9 @@ def test_classification_toy(): assert np.any(deviance_decrease >= 0.0), \ "Train deviance does not monotonically decrease." + leaves = clf.apply(X) + assert_equal(leaves.shape, (6, 10, 1)) + def test_parameter_checks(): # Check input parameter validation. @@ -1012,3 +1019,49 @@ def test_non_uniform_weights_toy_edge_case_clf(): gb = GradientBoostingClassifier(n_estimators=5) gb.fit(X, y, sample_weight=sample_weight) assert_array_equal(gb.predict([[1, 0]]), [1]) + + +def test_classification_w_init(): + # Test that gradient boosting a previously learned model will improve + # the performance of that model. + iris = datasets.load_digits() + X, y = iris.data, iris.target + X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, + random_state=0) + + for clf in [DecisionTreeClassifier(random_state=0), + RandomForestClassifier(random_state=0, n_estimators=3), + SVC(random_state=0)]: + + clf.fit(X_train, y_train) + acc1 = clf.score(X_test, y_test) + + clf = GradientBoostingClassifier(random_state=0, + n_estimators=1, + init=clf) + clf.fit(X_train, y_train) + acc2 = clf.score(X_test, y_test) + assert acc2 >= acc1 + + +def test_regression_w_init(): + # Test that gradient boosting a previously learned model will improve + # the performance of that model. + boston = datasets.load_boston() + X, y = boston.data, boston.target + X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, + random_state=0) + + for clf in [DecisionTreeRegressor(random_state=0), + RandomForestRegressor(random_state=0, n_estimators=3), + SVR(), Ridge()]: + + clf.fit(X_train, y_train) + acc1 = clf.score(X_test, y_test) + + clf = GradientBoostingRegressor(random_state=0, + n_estimators=1, + init=clf) + clf.fit(X_train, y_train) + acc2 = clf.score(X_test, y_test) + assert acc2 >= acc1