Skip to content

[MRG] Expose an apply method for gradient boosters #5222

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 10 commits into from
4 changes: 4 additions & 0 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,10 @@ Enhancements
such that the split that maximizes this value also maximizes the impurity
improvement. By `Arnaud Joly`_, `Jacob Schreiber`_ and `Gilles Louppe`_

- :class:'ensemble.GradientBoostingRegressor` and
:class:`ensemble.GradientBoostingClassifier' now expose an ``apply``
method for retrieving the leaf indices samples are predicted as. By
`Jacob Schreiber`_.

Bug fixes
.........
Expand Down
109 changes: 102 additions & 7 deletions sklearn/ensemble/gradient_boosting.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"""

# Authors: Peter Prettenhofer, Scott White, Gilles Louppe, Emanuele Olivetti,
# Arnaud Joly
# Arnaud Joly, Jacob Schreiber
# License: BSD 3 clause

from __future__ import print_function
Expand All @@ -34,6 +34,7 @@
from ..base import BaseEstimator
from ..base import ClassifierMixin
from ..base import RegressorMixin
from ..base import is_classifier
from ..utils import check_random_state, check_array, check_X_y, column_or_1d
from ..utils import check_consistent_length, deprecated
from ..utils.extmath import logsumexp
Expand Down Expand Up @@ -948,7 +949,7 @@ def fit(self, X, y, sample_weight=None, monitor=None):
check_consistent_length(X, y, sample_weight)

y = self._validate_y(y)

random_state = check_random_state(self.random_state)
self._check_params()

Expand All @@ -959,8 +960,41 @@ def fit(self, X, y, sample_weight=None, monitor=None):
# fit initial model - FIXME make sample_weight optional
self.init_.fit(X, y, sample_weight)

# init predictions
y_pred = self.init_.predict(X)
if is_classifier(self.init_):
n_classes = np.unique(y).shape[0]
else:
n_classes = 1

# If the initialization estimator has a predict_proba method,
# either use those, or collapse to a single vector if there
# are only two classes
if hasattr(self.init_, 'predict_proba'):
eps = np.finfo(X.dtype).eps
y_pred = self.init_.predict_proba(X) + eps
if n_classes == 2:
y_pred = np.log(y_pred[:,1] / y_pred[:,0])
y_pred = y_pred.reshape(n_samples, 1)

# Otherwise, it can be a naive estimator defined above, in which
# case don't do anything, or a classifier whose estimates will be
# a vector that should be hot encoded, or a regressor whose
# estimates still need to be reshaped from (n_samples,) to
# (n_samples,1)
else:
pred = self.init_.predict(X)

if len(pred.shape) < 2:
if is_classifier(self.init_):
y_pred = np.zeros((n_samples, n_classes))
y_pred[:, pred] = 1.0
if n_classes == 2:
y_pred = np.log(y_pred[:,1] / y_pred[:,0])
y_pred = y_pred.reshape(n_samples, 1)
else:
y_pred = pred.reshape(n_samples, 1)
else:
y_pred = pred

begin_at_stage = 0
else:
# add more estimators to fitted model
Expand All @@ -975,6 +1009,13 @@ def fit(self, X, y, sample_weight=None, monitor=None):
y_pred = self._decision_function(X)
self._resize_state()

if is_classifier(self.init_):
n_classes = np.unique(y).shape[0]
else:
n_classes = 1

self.n_classes = n_classes

# fit the boosting stages
n_stages = self._fit_stages(X, y, y_pred, sample_weight, random_state,
begin_at_stage, monitor)
Expand Down Expand Up @@ -1073,7 +1114,31 @@ def _init_decision_function(self, X):
if X.shape[1] != self.n_features:
raise ValueError("X.shape[1] should be {0:d}, not {1:d}.".format(
self.n_features, X.shape[1]))
score = self.init_.predict(X).astype(np.float64)
# init predictions

if hasattr(self.init_, 'predict_proba'):
eps = np.finfo(X.dtype).eps
score = self.init_.predict_proba(X) + eps
if self.n_classes == 2:
score = np.log(score[:,1] / score[:,0])
score = score.reshape(X.shape[0], 1)
else:
pred = self.init_.predict(X)

if len(pred.shape) < 2:
if is_classifier(self.init_):
score = np.zeros((X.shape[0], self.n_classes))
score[:, pred] = 1.0
if self.n_classes == 2:
score = np.log(y_pred[:,1] / y_pred[:,0])
score = y_pred.reshape(X.shape[0], 1)
else:
score = pred.reshape(X.shape[0], 1)
else:
score = pred

score = score.astype(np.float64)

return score

def _decision_function(self, X):
Expand Down Expand Up @@ -1107,7 +1172,7 @@ def decision_function(self, X):
return score

def _staged_decision_function(self, X):
"""Compute decision function of ``X`` for each iteration.
"""Compute decision function of ``, X`` for each iteration.

This method allows monitoring (i.e. determine error on testing set)
after each stage.
Expand Down Expand Up @@ -1184,6 +1249,37 @@ def _validate_y(self, y):
# Default implementation
return y

def apply(self, X):
"""Apply trees in the ensemble to X, return leaf indices.

Parameters
----------
X : array-like or sparse matrix, shape = [n_samples, n_features]
The input samples. Internally, it will be converted to
``dtype=np.float32`` and if a sparse matrix is provided
to a sparse ``csr_matrix``.

Returns
-------
X_leaves : array_like, shape = [n_samples, n_estimators, n_classes]
For each datapoint x in X and for each tree in the ensemble,
return the index of the leaf x ends up in in each estimator.
"""

if self.estimators_ is None or len(self.estimators_) == 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe some code from https://github.com/jmschrei/scikit-learn/blob/gbt_apply/sklearn/ensemble/gradient_boosting.py#L1068 could be factored out in a _validate_X_predict method?

raise NotFittedError("Estimator not fitted, "
"call `fit` before exploiting the model.")

X = self.estimators_[0, 0]._validate_X_predict(X, check_input=True)

n_estimators, n_classes = self.estimators_.shape
leaves = np.zeros((X.shape[0], n_estimators, n_classes))

for i in range(n_estimators):
for j in range(n_classes):
leaves[:, i, j] = self.estimators_[i, j].apply(X)

return leaves

class GradientBoostingClassifier(BaseGradientBoosting, ClassifierMixin):
"""Gradient Boosting for classification.
Expand Down Expand Up @@ -1508,7 +1604,6 @@ def staged_predict_proba(self, X):
raise AttributeError('loss=%r does not support predict_proba' %
self.loss)


class GradientBoostingRegressor(BaseGradientBoosting, RegressorMixin):
"""Gradient Boosting for regression.

Expand Down
57 changes: 55 additions & 2 deletions sklearn/ensemble/tests/test_gradient_boosting.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,14 @@

from sklearn import datasets
from sklearn.base import clone
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.ensemble import GradientBoostingRegressor
from sklearn.cross_validation import train_test_split
from sklearn.ensemble import GradientBoostingClassifier, RandomForestClassifier
from sklearn.ensemble import GradientBoostingRegressor, RandomForestRegressor
from sklearn.ensemble.gradient_boosting import ZeroEstimator
from sklearn.linear_model import Ridge
from sklearn.metrics import mean_squared_error
from sklearn.svm import SVC, SVR
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
from sklearn.utils import check_random_state, tosequence
from sklearn.utils.testing import assert_almost_equal
from sklearn.utils.testing import assert_array_almost_equal
Expand Down Expand Up @@ -62,6 +66,9 @@ def test_classification_toy():
assert np.any(deviance_decrease >= 0.0), \
"Train deviance does not monotonically decrease."

leaves = clf.apply(X)
assert_equal(leaves.shape, (6, 10, 1))


def test_parameter_checks():
# Check input parameter validation.
Expand Down Expand Up @@ -1012,3 +1019,49 @@ def test_non_uniform_weights_toy_edge_case_clf():
gb = GradientBoostingClassifier(n_estimators=5)
gb.fit(X, y, sample_weight=sample_weight)
assert_array_equal(gb.predict([[1, 0]]), [1])


def test_classification_w_init():
# Test that gradient boosting a previously learned model will improve
# the performance of that model.
iris = datasets.load_digits()
X, y = iris.data, iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1,
random_state=0)

for clf in [DecisionTreeClassifier(random_state=0),
RandomForestClassifier(random_state=0, n_estimators=3),
SVC(random_state=0)]:

clf.fit(X_train, y_train)
acc1 = clf.score(X_test, y_test)

clf = GradientBoostingClassifier(random_state=0,
n_estimators=1,
init=clf)
clf.fit(X_train, y_train)
acc2 = clf.score(X_test, y_test)
assert acc2 >= acc1


def test_regression_w_init():
# Test that gradient boosting a previously learned model will improve
# the performance of that model.
boston = datasets.load_boston()
X, y = boston.data, boston.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1,
random_state=0)

for clf in [DecisionTreeRegressor(random_state=0),
RandomForestRegressor(random_state=0, n_estimators=3),
SVR(), Ridge()]:

clf.fit(X_train, y_train)
acc1 = clf.score(X_test, y_test)

clf = GradientBoostingRegressor(random_state=0,
n_estimators=1,
init=clf)
clf.fit(X_train, y_train)
acc2 = clf.score(X_test, y_test)
assert acc2 >= acc1