diff --git a/doc/whats_new.rst b/doc/whats_new.rst index 722ba5f6f8be3..cfe55da0f293a 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -67,6 +67,9 @@ Changelog - Fixed API inconsistency: :meth:`SGDClassifier.predict_proba` now returns 2d array. + - Fixed feature importance computation in + :ref:`ensemble.gradient_boosting`. + API changes summary ------------------- @@ -97,6 +100,10 @@ API changes summary ``min_n`` and ``max_n`` were joined to the parameter ``n_gram_range`` to enable grid-searching both at once. + - :class:`ensemble.GradientBoostingClassifier` now supports + :meth:`ensemble.GradientBoostingClassifier.staged_predict_proba`, and + :meth:`ensemble.GradientBoostingClassifier.staged_predict`. + .. _changes_0_11: 0.11 diff --git a/sklearn/ensemble/gradient_boosting.py b/sklearn/ensemble/gradient_boosting.py index 6fdb765ce7d59..3d93a0c98a821 100644 --- a/sklearn/ensemble/gradient_boosting.py +++ b/sklearn/ensemble/gradient_boosting.py @@ -1,4 +1,4 @@ -"""Gradient Boosting methods +"""Gradient Boosted Regression Trees This module contains methods for fitting gradient boosted regression trees for both classification and regression. @@ -13,7 +13,7 @@ classification problems. - ``GradientBoostingRegressor`` implements gradient boosting for - classification problems. + regression problems. """ # Authors: Peter Prettenhofer, Scott White, Gilles Louppe @@ -44,7 +44,8 @@ class QuantileEstimator(BaseEstimator): """An estimator predicting the alpha-quantile of the training targets.""" def __init__(self, alpha=0.9): - assert 0 < alpha < 1.0 + if not 0 < alpha < 1.0: + raise ValueError("`alpha` must be in (0, 1.0)") self.alpha = alpha def fit(self, X, y): @@ -56,17 +57,6 @@ def predict(self, X): return y -class MedianEstimator(BaseEstimator): - """An estimator predicting the median of the training targets.""" - def fit(self, X, y): - self.median = np.median(y) - - def predict(self, X): - y = np.empty((X.shape[0], 1), dtype=np.float64) - y.fill(self.median) - return y - - class MeanEstimator(BaseEstimator): """An estimator predicting the mean of the training targets.""" def fit(self, X, y): @@ -174,7 +164,7 @@ def update_terminal_regions(self, tree, X, y, residual, y_pred, # update predictions (both in-bag and out-of-bag) y_pred[:, k] += learn_rate * tree.value[:, 0, 0].take(terminal_regions, - axis=0) + axis=0) @abstractmethod def _update_terminal_region(self, tree, terminal_regions, leaf, X, y, @@ -221,7 +211,7 @@ def _update_terminal_region(self, tree, terminal_regions, leaf, X, y, class LeastAbsoluteError(RegressionLossFunction): """Loss function for least absolute deviation (LAD) regression. """ def init_estimator(self): - return MedianEstimator() + return QuantileEstimator(alpha=0.5) def __call__(self, y, pred): return np.abs(y - pred.ravel()).mean() @@ -247,7 +237,7 @@ def __init__(self, n_classes, alpha=0.9): self.alpha = alpha def init_estimator(self): - return MedianEstimator() + return QuantileEstimator(alpha=0.5) def __call__(self, y, pred): pred = pred.ravel() @@ -532,7 +522,7 @@ def fit(self, X, y): n_samples, n_features = X.shape self.n_features = n_features - if self.max_features == None: + if self.max_features is None: self.max_features = n_features if not (0 < self.max_features <= n_features): @@ -597,22 +587,40 @@ def _make_estimator(self, append=True): # we don't need _make_estimator raise NotImplementedError() - @property - def feature_importances_(self): + def _init_decision_function(self, X): + """Check input and compute prediction of ``init``. """ if self.estimators_ is None or len(self.estimators_) == 0: - raise ValueError("Estimator not fitted, " \ - "call `fit` before `feature_importances_`.") - total_sum = np.zeros((self.n_features, ), dtype=np.float64) - for stage in self.estimators_: - stage_sum = sum(tree.compute_feature_importances(method='squared') - for tree in stage) / len(stage) - total_sum += stage_sum + raise ValueError("Estimator not fitted, call `fit` " \ + "before making predictions`.") + if X.shape[1] != self.n_features: + raise ValueError("X.shape[1] should be %d, not %d." % + (self.n_features, X.shape[1])) + score = self.init.predict(X).astype(np.float64) + return score - importances = total_sum / len(self.estimators_) - return importances + def decision_function(self, X): + """Compute the decision function of ``X``. + + Parameters + ---------- + X : array-like of shape = [n_samples, n_features] + The input samples. + + Returns + ------- + score : array, shape = [n_samples, k] + The decision function of the input samples. Classes are + ordered by arithmetical order. Regression and binary + classification are special cases with ``k == 1``, + otherwise ``k==n_classes``. + """ + X = array2d(X, dtype=DTYPE, order='C') + score = self._init_decision_function(X) + predict_stages(self.estimators_, X, self.learn_rate, score) + return score def staged_decision_function(self, X): - """Compute decision function for X. + """Compute decision function of ``X`` for each iteration. This method allows monitoring (i.e. determine error on testing set) after each stage. @@ -624,26 +632,32 @@ def staged_decision_function(self, X): Returns ------- - f : array of shape = [n_samples, n_classes] + score : generator of array, shape = [n_samples, k] The decision function of the input samples. Classes are ordered by arithmetical order. Regression and binary - classification are special cases with ``n_classes == 1``. + classification are special cases with ``k == 1``, + otherwise ``k==n_classes``. """ X = array2d(X, dtype=DTYPE, order='C') - - if self.estimators_ is None or len(self.estimators_) == 0: - raise ValueError("Estimator not fitted, call `fit` " \ - "before `staged_decision_function`.") - if X.shape[1] != self.n_features: - raise ValueError("X.shape[1] should be %d, not %d." % - (self.n_features, X.shape[1])) - - score = self.init.predict(X).astype(np.float64) - + score = self._init_decision_function(X) for i in range(self.n_estimators): predict_stage(self.estimators_, i, X, self.learn_rate, score) yield score + @property + def feature_importances_(self): + if self.estimators_ is None or len(self.estimators_) == 0: + raise ValueError("Estimator not fitted, " \ + "call `fit` before `feature_importances_`.") + total_sum = np.zeros((self.n_features, ), dtype=np.float64) + for stage in self.estimators_: + stage_sum = sum(tree.compute_feature_importances(method='gini') + for tree in stage) / len(stage) + total_sum += stage_sum + + importances = total_sum / len(self.estimators_) + return importances + class GradientBoostingClassifier(BaseGradientBoosting, ClassifierMixin): """Gradient Boosting for classification. @@ -657,11 +671,10 @@ class GradientBoostingClassifier(BaseGradientBoosting, ClassifierMixin): Parameters ---------- - loss : {'deviance', 'ls'}, optional (default='deviance') + loss : {'deviance'}, optional (default='deviance') loss function to be optimized. 'deviance' refers to deviance (= logistic regression) for classification - with probabilistic outputs. 'ls' refers to least squares - regression. + with probabilistic outputs. learn_rate : float, optional (default=0.1) learning rate shrinks the contribution of each tree by `learn_rate`. @@ -698,6 +711,21 @@ class GradientBoostingClassifier(BaseGradientBoosting, ClassifierMixin): `max_features < n_features` leads to a reduction of variance and an increase in bias. + Attributes + ---------- + `feature_importances_` : array, shape = [n_features] + The feature importances (the higher, the more important the feature). + + `oob_score_` : array, shape = [n_estimators] + Score of the training dataset obtained using an out-of-bag estimate. + The i-th score ``oob_score_[i]`` is the deviance (= loss) of the + model at iteration ``i`` on the out-of-bag sample. + + `train_score_` : array, shape = [n_estimators] + The i-th score ``train_score_[i]`` is the deviance (= loss) of the + model at iteration ``i`` on the in-bag sample. + If ``subsample == 1`` this is the deviance on the training data. + Examples -------- >>> samples = [[0, 0, 2], [1, 0, 0]] @@ -761,8 +789,38 @@ def fit(self, X, y): return super(GradientBoostingClassifier, self).fit(X, y) - def predict(self, X): - """Predict class for X. + def _score_to_proba(self, score): + """Compute class probability estimates from decision scores. """ + proba = np.ones((score.shape[0], self.n_classes_), dtype=np.float64) + if not self.loss_.is_multi_class: + proba[:, 1] = 1.0 / (1.0 + np.exp(-score.ravel())) + proba[:, 0] -= proba[:, 1] + else: + proba = np.exp(score) / np.sum(np.exp(score), axis=1)[:, np.newaxis] + return proba + + def predict_proba(self, X): + """Predict class probabilities for X. + + Parameters + ---------- + X : array-like of shape = [n_samples, n_features] + The input samples. + + Returns + ------- + p : array of shape = [n_samples] + The class probabilities of the input samples. Classes are + ordered by arithmetical order. + """ + score = self.decision_function(X) + return self._score_to_proba(score) + + def staged_predict_proba(self, X): + """Predict class probabilities at each stage for X. + + This method allows monitoring (i.e. determine error on testing set) + after each stage. Parameters ---------- @@ -772,13 +830,13 @@ def predict(self, X): Returns ------- y : array of shape = [n_samples] - The predicted classes. + The predicted value of the input samples. """ - probas = self.predict_proba(X) - return self.classes_.take(np.argmax(probas, axis=1), axis=0) + for score in self.staged_decision_function(X): + yield self._score_to_proba(X) - def predict_proba(self, X): - """Predict class probabilities for X. + def predict(self, X): + """Predict class for X. Parameters ---------- @@ -787,30 +845,30 @@ def predict_proba(self, X): Returns ------- - p : array of shape = [n_samples] - The class probabilities of the input samples. Classes are - ordered by arithmetical order. + y : array of shape = [n_samples] + The predicted classes. """ - X = array2d(X, dtype=DTYPE, order='C') + proba = self.predict_proba(X) + return self.classes_.take(np.argmax(proba, axis=1), axis=0) - if self.estimators_ is None or len(self.estimators_) == 0: - raise ValueError("Estimator not fitted, " \ - "call `fit` before `predict_proba`.") - if X.shape[1] != self.n_features: - raise ValueError("X.shape[1] should be %d, not %d." % - (self.n_features, X.shape[1])) + def staged_predict(self, X): + """Predict class probabilities at each stage for X. - proba = np.ones((X.shape[0], self.n_classes_), dtype=np.float64) + This method allows monitoring (i.e. determine error on testing set) + after each stage. - score = self.init.predict(X).astype(np.float64) - predict_stages(self.estimators_, X, self.learn_rate, score) + Parameters + ---------- + X : array-like of shape = [n_samples, n_features] + The input samples. - if not self.loss_.is_multi_class: - proba[:, 1] = 1.0 / (1.0 + np.exp(-score.ravel())) - proba[:, 0] -= proba[:, 1] - else: - proba = np.exp(score) / np.sum(np.exp(score), axis=1)[:, np.newaxis] - return proba + Returns + ------- + y : array of shape = [n_samples] + The predicted value of the input samples. + """ + for proba in self.staged_predict_proba(X): + yield self.classes_.take(np.argmax(proba, axis=1), axis=0) class GradientBoostingRegressor(BaseGradientBoosting, RegressorMixin): @@ -823,11 +881,12 @@ class GradientBoostingRegressor(BaseGradientBoosting, RegressorMixin): Parameters ---------- - loss : {'ls', 'lad'}, optional (default='ls') + loss : {'ls', 'lad', 'huber', 'quantile'}, optional (default='ls') loss function to be optimized. 'ls' refers to least squares regression. 'lad' (least absolute deviation) is a highly robust loss function soley based on order information of the input - variables. + variables. 'huber' is a combination of the two. 'quantile' + allows quantile regression (use `alpha` to specify the quantile). learn_rate : float, optional (default=0.1) learning rate shrinks the contribution of each tree by `learn_rate`. @@ -883,21 +942,6 @@ class GradientBoostingRegressor(BaseGradientBoosting, RegressorMixin): model at iteration ``i`` on the in-bag sample. If ``subsample == 1`` this is the deviance on the training data. - Attributes - ---------- - `feature_importances_` : array, shape = [n_features] - The feature importances (the higher, the more important the feature). - - `oob_score_` : array, shape = [n_estimators] - Score of the training dataset obtained using an out-of-bag estimate. - The i-th score ``oob_score_[i]`` is the deviance (= loss) of the - model at iteration ``i`` on the out-of-bag sample. - - `train_score_` : array, shape = [n_estimators] - The i-th score ``train_score_[i]`` is the deviance (= loss) of the - model at iteration ``i`` on the in-bag sample. - If ``subsample == 1`` this is the deviance on the training data. - Examples -------- >>> samples = [[0, 0, 2], [1, 0, 0]] @@ -909,7 +953,7 @@ class GradientBoostingRegressor(BaseGradientBoosting, RegressorMixin): See also -------- - sklearn.tree.DecisionTreeRegressor, RandomForestRegressor + DecisionTreeRegressor, RandomForestRegressor References ---------- @@ -969,18 +1013,7 @@ def predict(self, X): y: array of shape = [n_samples] The predicted values. """ - X = array2d(X, dtype=DTYPE, order='C') - - if self.estimators_ is None or len(self.estimators_) == 0: - raise ValueError("Estimator not fitted, " \ - "call `fit` before `predict`.") - if X.shape[1] != self.n_features: - raise ValueError("X.shape[1] should be %d, not %d." % - (self.n_features, X.shape[1])) - - y = self.init.predict(X).astype(np.float64) - predict_stages(self.estimators_, X, self.learn_rate, y) - return y.ravel() + return self.decision_function(X).ravel() def staged_predict(self, X): """Predict regression target at each stage for X. @@ -998,6 +1031,5 @@ def staged_predict(self, X): y : array of shape = [n_samples] The predicted value of the input samples. """ - X = array2d(X, dtype=DTYPE, order='C') for y in self.staged_decision_function(X): yield y.ravel()