From 8cba989ec0e80df046a7e0c7a89e8181efc6f815 Mon Sep 17 00:00:00 2001 From: Peter Prettenhofer Date: Thu, 16 Aug 2012 22:10:39 +0200 Subject: [PATCH 01/11] work on BaseGradientBoostingCV --- sklearn/ensemble/gradient_boosting.py | 56 ++++++++++++++++++++++++++- 1 file changed, 54 insertions(+), 2 deletions(-) diff --git a/sklearn/ensemble/gradient_boosting.py b/sklearn/ensemble/gradient_boosting.py index 6fdb765ce7d59..9f61964746814 100644 --- a/sklearn/ensemble/gradient_boosting.py +++ b/sklearn/ensemble/gradient_boosting.py @@ -31,6 +31,7 @@ from ..base import ClassifierMixin from ..base import RegressorMixin from ..utils import check_random_state, array2d, check_arrays +from ..cross_validation import KFold from ..tree._tree import Tree from ..tree._tree import _random_sample_mask @@ -174,7 +175,7 @@ def update_terminal_regions(self, tree, X, y, residual, y_pred, # update predictions (both in-bag and out-of-bag) y_pred[:, k] += learn_rate * tree.value[:, 0, 0].take(terminal_regions, - axis=0) + axis=0) @abstractmethod def _update_terminal_region(self, tree, terminal_regions, leaf, X, y, @@ -532,7 +533,7 @@ def fit(self, X, y): n_samples, n_features = X.shape self.n_features = n_features - if self.max_features == None: + if self.max_features is None: self.max_features = n_features if not (0 < self.max_features <= n_features): @@ -1001,3 +1002,54 @@ def staged_predict(self, X): X = array2d(X, dtype=DTYPE, order='C') for y in self.staged_decision_function(X): yield y.ravel() + + +class BaseGradientBoostingCV(BaseGradientBoosting): + + def __init__(self, loss, learn_rate, max_estimators, min_samples_split, + min_samples_leaf, max_depth, init, subsample, + max_features, random_state, alpha=0.9, cv=None): + super(BaseGradientBoostingCV, self).__init__( + loss, learn_rate, max_estimators, min_samples_split, + min_samples_leaf, max_depth, init, subsample, max_features, + random_state, alpha=alpha) + + self.max_estimators = max_estimators + self.cv = cv + + def fit(self, X, y): + if self.cv is None: + self.cv = KFold(y.shape[0], k=5) + + cv_deviance = np.zeros((self.cv.k, self.max_estimators), + dtype=np.float64) + for k, (train, test) in enumerate(self.cv): + super(BaseGradientBoostingCV, self).fit(X[train], y[train]) + for i, pred in enumerate(self.staged_predict(X[test])): + cv_deviance[k, i] = self.loss_(y[test], pred) + + mean_deviance = cv_deviance.mean(axis=1) + best_estimators = mean_deviance.argmin() + 1 + print("Best number of estimators=%d - error %.4f" % + (best_estimators, mean_deviance[best_estimators - 1])) + + self.n_estimators = best_estimators + self.cv_deviance = cv_deviance + super(BaseGradientBoostingCV, self).fit(X, y) + return self + + +class GradientBoostingClassifierCV(GradientBoostingClassifier): + + ##__metaclass__ = BaseGradientBoostingCV + + ## def __init__(self, loss='deviance', learn_rate=0.1, max_estimators=1000, + ## subsample=1.0, min_samples_split=1, min_samples_leaf=1, + ## max_depth=3, init=None, random_state=None, + ## max_features=None, cv=None): + ## super(GradientBoostingClassifierCV, self).__init__( + ## loss, learn_rate, max_estimators, min_samples_split, + ## min_samples_leaf, max_depth, init, subsample, + ## max_features, random_state, cv=cv) + + From a280e5df7b2ea45ab719d7874d626db11f6d18a3 Mon Sep 17 00:00:00 2001 From: Peter Prettenhofer Date: Fri, 17 Aug 2012 14:23:02 +0200 Subject: [PATCH 02/11] refactored prediction and decision_function (rm duplicate code) --- sklearn/ensemble/gradient_boosting.py | 175 +++++++++++++++----------- 1 file changed, 99 insertions(+), 76 deletions(-) diff --git a/sklearn/ensemble/gradient_boosting.py b/sklearn/ensemble/gradient_boosting.py index 9f61964746814..089093ade455a 100644 --- a/sklearn/ensemble/gradient_boosting.py +++ b/sklearn/ensemble/gradient_boosting.py @@ -30,6 +30,7 @@ from ..base import BaseEstimator from ..base import ClassifierMixin from ..base import RegressorMixin +from ..base import clone from ..utils import check_random_state, array2d, check_arrays from ..cross_validation import KFold @@ -598,22 +599,40 @@ def _make_estimator(self, append=True): # we don't need _make_estimator raise NotImplementedError() - @property - def feature_importances_(self): + def _init_decision_function(self, X): + """Check input and compute prediction of ``init``. """ if self.estimators_ is None or len(self.estimators_) == 0: - raise ValueError("Estimator not fitted, " \ - "call `fit` before `feature_importances_`.") - total_sum = np.zeros((self.n_features, ), dtype=np.float64) - for stage in self.estimators_: - stage_sum = sum(tree.compute_feature_importances(method='squared') - for tree in stage) / len(stage) - total_sum += stage_sum + raise ValueError("Estimator not fitted, call `fit` " \ + "before making predictions`.") + if X.shape[1] != self.n_features: + raise ValueError("X.shape[1] should be %d, not %d." % + (self.n_features, X.shape[1])) + score = self.init.predict(X).astype(np.float64) + return score - importances = total_sum / len(self.estimators_) - return importances + def decision_function(self, X): + """Compute the decision function of ``X``. + + Parameters + ---------- + X : array-like of shape = [n_samples, n_features] + The input samples. + + Returns + ------- + score : array, shape = [n_samples, k] + The decision function of the input samples. Classes are + ordered by arithmetical order. Regression and binary + classification are special cases with ``k == 1``, + otherwise ``k==n_classes``. + """ + X = array2d(X, dtype=DTYPE, order='C') + score = self._init_decision_function(X) + predict_stages(self.estimators_, X, self.learn_rate, score) + return score def staged_decision_function(self, X): - """Compute decision function for X. + """Compute decision function of ``X`` for each iteration. This method allows monitoring (i.e. determine error on testing set) after each stage. @@ -625,25 +644,32 @@ def staged_decision_function(self, X): Returns ------- - f : array of shape = [n_samples, n_classes] + score : generator of array, shape = [n_samples, k] The decision function of the input samples. Classes are ordered by arithmetical order. Regression and binary - classification are special cases with ``n_classes == 1``. + classification are special cases with ``k == 1``, + otherwise ``k==n_classes``. """ X = array2d(X, dtype=DTYPE, order='C') + score = self._init_decision_function(X) + for i in range(self.n_estimators): + predict_stage(self.estimators_, i, X, self.learn_rate, score) + yield score + @property + def feature_importances_(self): if self.estimators_ is None or len(self.estimators_) == 0: - raise ValueError("Estimator not fitted, call `fit` " \ - "before `staged_decision_function`.") - if X.shape[1] != self.n_features: - raise ValueError("X.shape[1] should be %d, not %d." % - (self.n_features, X.shape[1])) + raise ValueError("Estimator not fitted, " \ + "call `fit` before `feature_importances_`.") + total_sum = np.zeros((self.n_features, ), dtype=np.float64) + for stage in self.estimators_: + stage_sum = sum(tree.compute_feature_importances(method='squared') + for tree in stage) / len(stage) + total_sum += stage_sum - score = self.init.predict(X).astype(np.float64) + importances = total_sum / len(self.estimators_) + return importances - for i in range(self.n_estimators): - predict_stage(self.estimators_, i, X, self.learn_rate, score) - yield score class GradientBoostingClassifier(BaseGradientBoosting, ClassifierMixin): @@ -775,8 +801,8 @@ def predict(self, X): y : array of shape = [n_samples] The predicted classes. """ - probas = self.predict_proba(X) - return self.classes_.take(np.argmax(probas, axis=1), axis=0) + proba = self.predict_proba(X) + return self.classes_.take(np.argmax(proba, axis=1), axis=0) def predict_proba(self, X): """Predict class probabilities for X. @@ -792,20 +818,8 @@ def predict_proba(self, X): The class probabilities of the input samples. Classes are ordered by arithmetical order. """ - X = array2d(X, dtype=DTYPE, order='C') - - if self.estimators_ is None or len(self.estimators_) == 0: - raise ValueError("Estimator not fitted, " \ - "call `fit` before `predict_proba`.") - if X.shape[1] != self.n_features: - raise ValueError("X.shape[1] should be %d, not %d." % - (self.n_features, X.shape[1])) - - proba = np.ones((X.shape[0], self.n_classes_), dtype=np.float64) - - score = self.init.predict(X).astype(np.float64) - predict_stages(self.estimators_, X, self.learn_rate, score) - + score = self.decision_function(X) + proba = np.ones((score.shape[0], self.n_classes_), dtype=np.float64) if not self.loss_.is_multi_class: proba[:, 1] = 1.0 / (1.0 + np.exp(-score.ravel())) proba[:, 0] -= proba[:, 1] @@ -970,18 +984,7 @@ def predict(self, X): y: array of shape = [n_samples] The predicted values. """ - X = array2d(X, dtype=DTYPE, order='C') - - if self.estimators_ is None or len(self.estimators_) == 0: - raise ValueError("Estimator not fitted, " \ - "call `fit` before `predict`.") - if X.shape[1] != self.n_features: - raise ValueError("X.shape[1] should be %d, not %d." % - (self.n_features, X.shape[1])) - - y = self.init.predict(X).astype(np.float64) - predict_stages(self.estimators_, X, self.learn_rate, y) - return y.ravel() + return self.decision_function(X).ravel() def staged_predict(self, X): """Predict regression target at each stage for X. @@ -999,34 +1002,37 @@ def staged_predict(self, X): y : array of shape = [n_samples] The predicted value of the input samples. """ - X = array2d(X, dtype=DTYPE, order='C') for y in self.staged_decision_function(X): yield y.ravel() -class BaseGradientBoostingCV(BaseGradientBoosting): +class BaseGradientBoostingCV(object): - def __init__(self, loss, learn_rate, max_estimators, min_samples_split, - min_samples_leaf, max_depth, init, subsample, - max_features, random_state, alpha=0.9, cv=None): - super(BaseGradientBoostingCV, self).__init__( - loss, learn_rate, max_estimators, min_samples_split, - min_samples_leaf, max_depth, init, subsample, max_features, - random_state, alpha=alpha) + def __init__(self, **kwargs): + self.cv = kwargs.pop('cv', None) + self.max_estimators = kwargs.pop('max_estimators', 1000) - self.max_estimators = max_estimators - self.cv = cv + kwargs['n_estimators'] = self.max_estimators + self._model = self._model_class(**kwargs) def fit(self, X, y): + """Pick best ``n_estimators`` based on cross-validation ``cv``. + + Fits model on entire dataset using ``n_estimators`` found by + cross-validation ``cv``. + + Cross-validation scores are stored in ``self.cv_deviance``. + """ if self.cv is None: self.cv = KFold(y.shape[0], k=5) cv_deviance = np.zeros((self.cv.k, self.max_estimators), dtype=np.float64) for k, (train, test) in enumerate(self.cv): - super(BaseGradientBoostingCV, self).fit(X[train], y[train]) - for i, pred in enumerate(self.staged_predict(X[test])): - cv_deviance[k, i] = self.loss_(y[test], pred) + model = clone(self._model) + model.fit(X[train], y[train]) + for i, score in enumerate(model.staged_decision_function(X[test])): + cv_deviance[k, i] = model.loss_(y[test], score) mean_deviance = cv_deviance.mean(axis=1) best_estimators = mean_deviance.argmin() + 1 @@ -1035,21 +1041,38 @@ def fit(self, X, y): self.n_estimators = best_estimators self.cv_deviance = cv_deviance - super(BaseGradientBoostingCV, self).fit(X, y) + self._model.fit(X, y) return self + def __getattr__(self, name): + return getattr(self._model, name) + + +class GradientBoostingClassifierCV(BaseGradientBoostingCV): -class GradientBoostingClassifierCV(GradientBoostingClassifier): + _model_class = GradientBoostingClassifier + + def __init__(self, loss='deviance', learn_rate=0.1, max_estimators=1000, + subsample=1.0, min_samples_split=1, min_samples_leaf=1, + max_depth=3, init=None, random_state=None, + max_features=None, cv=None): + super(GradientBoostingClassifierCV, self).__init__( + loss=loss, learn_rate=learn_rate, max_estimators=max_estimators, + min_samples_split=min_samples_split, min_samples_leaf=min_samples_leaf, + max_depth=max_depth, init=init, subsample=subsample, + max_features=max_features, random_state=random_state, cv=cv) - ##__metaclass__ = BaseGradientBoostingCV - ## def __init__(self, loss='deviance', learn_rate=0.1, max_estimators=1000, - ## subsample=1.0, min_samples_split=1, min_samples_leaf=1, - ## max_depth=3, init=None, random_state=None, - ## max_features=None, cv=None): - ## super(GradientBoostingClassifierCV, self).__init__( - ## loss, learn_rate, max_estimators, min_samples_split, - ## min_samples_leaf, max_depth, init, subsample, - ## max_features, random_state, cv=cv) +class GradientBoostingRegressorCV(BaseGradientBoostingCV): + _model_class = GradientBoostingRegressor + def __init__(self, loss='ls', learn_rate=0.1, max_estimators=1000, + subsample=1.0, min_samples_split=1, min_samples_leaf=1, + max_depth=3, init=None, random_state=None, + max_features=None, alpha=0.9, cv=None): + super(GradientBoostingClassifierCV, self).__init__( + loss=loss, learn_rate=learn_rate, max_estimators=max_estimators, + min_samples_split=min_samples_split, min_samples_leaf=min_samples_leaf, + max_depth=max_depth, init=init, subsample=subsample, + max_features=max_features, random_state=random_state, alpha=alpha, cv=cv) From 3abe97a59a49acbe381a5b8e8cf83a1def0d06a0 Mon Sep 17 00:00:00 2001 From: Peter Prettenhofer Date: Fri, 17 Aug 2012 14:23:36 +0200 Subject: [PATCH 03/11] ENH: use gini for feature importance --- sklearn/ensemble/gradient_boosting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/ensemble/gradient_boosting.py b/sklearn/ensemble/gradient_boosting.py index 089093ade455a..3e85757c1b4c9 100644 --- a/sklearn/ensemble/gradient_boosting.py +++ b/sklearn/ensemble/gradient_boosting.py @@ -663,7 +663,7 @@ def feature_importances_(self): "call `fit` before `feature_importances_`.") total_sum = np.zeros((self.n_features, ), dtype=np.float64) for stage in self.estimators_: - stage_sum = sum(tree.compute_feature_importances(method='squared') + stage_sum = sum(tree.compute_feature_importances(method='gini') for tree in stage) / len(stage) total_sum += stage_sum From 7d204af423e0b0195aa489eae56a8a8e2381dc96 Mon Sep 17 00:00:00 2001 From: Peter Prettenhofer Date: Fri, 17 Aug 2012 16:24:00 +0200 Subject: [PATCH 04/11] GradientBoosting classes with built in cross-validation; implemented via Decorator pattern --- sklearn/ensemble/__init__.py | 2 + sklearn/ensemble/gradient_boosting.py | 332 ++++++++++++++---- .../ensemble/tests/test_gradient_boosting.py | 51 +++ 3 files changed, 325 insertions(+), 60 deletions(-) diff --git a/sklearn/ensemble/__init__.py b/sklearn/ensemble/__init__.py index d76405cf0e827..2134a54abb456 100644 --- a/sklearn/ensemble/__init__.py +++ b/sklearn/ensemble/__init__.py @@ -10,3 +10,5 @@ from .forest import ExtraTreesRegressor from .gradient_boosting import GradientBoostingClassifier from .gradient_boosting import GradientBoostingRegressor +from .gradient_boosting import GradientBoostingClassifierCV +from .gradient_boosting import GradientBoostingRegressorCV diff --git a/sklearn/ensemble/gradient_boosting.py b/sklearn/ensemble/gradient_boosting.py index 3e85757c1b4c9..6ed3a81dc81c9 100644 --- a/sklearn/ensemble/gradient_boosting.py +++ b/sklearn/ensemble/gradient_boosting.py @@ -1,4 +1,4 @@ -"""Gradient Boosting methods +"""Gradient Boosted Regression Trees This module contains methods for fitting gradient boosted regression trees for both classification and regression. @@ -13,7 +13,16 @@ classification problems. - ``GradientBoostingRegressor`` implements gradient boosting for + regression problems. + +- The ``BaseGradientBoostingCV`` base class implements a ``fit`` method + to choose the best ``n_estimators`` based on cross-validation. + +- ``GradientBoostingClassifierCV`` implements ``BaseGradientBoostingCV`` for classification problems. + +- ``GradientBoostingRegressor`` implements ``BaseGradientBoostingCV`` for + regression problems. """ # Authors: Peter Prettenhofer, Scott White, Gilles Louppe @@ -46,7 +55,8 @@ class QuantileEstimator(BaseEstimator): """An estimator predicting the alpha-quantile of the training targets.""" def __init__(self, alpha=0.9): - assert 0 < alpha < 1.0 + if not 0 < alpha < 1.0: + raise ValueError("`alpha` must be in (0, 1.0)") self.alpha = alpha def fit(self, X, y): @@ -58,17 +68,6 @@ def predict(self, X): return y -class MedianEstimator(BaseEstimator): - """An estimator predicting the median of the training targets.""" - def fit(self, X, y): - self.median = np.median(y) - - def predict(self, X): - y = np.empty((X.shape[0], 1), dtype=np.float64) - y.fill(self.median) - return y - - class MeanEstimator(BaseEstimator): """An estimator predicting the mean of the training targets.""" def fit(self, X, y): @@ -223,7 +222,7 @@ def _update_terminal_region(self, tree, terminal_regions, leaf, X, y, class LeastAbsoluteError(RegressionLossFunction): """Loss function for least absolute deviation (LAD) regression. """ def init_estimator(self): - return MedianEstimator() + return QuantileEstimator(alpha=0.5) def __call__(self, y, pred): return np.abs(y - pred.ravel()).mean() @@ -249,7 +248,7 @@ def __init__(self, n_classes, alpha=0.9): self.alpha = alpha def init_estimator(self): - return MedianEstimator() + return QuantileEstimator(alpha=0.5) def __call__(self, y, pred): pred = pred.ravel() @@ -671,7 +670,6 @@ def feature_importances_(self): return importances - class GradientBoostingClassifier(BaseGradientBoosting, ClassifierMixin): """Gradient Boosting for classification. @@ -684,11 +682,10 @@ class GradientBoostingClassifier(BaseGradientBoosting, ClassifierMixin): Parameters ---------- - loss : {'deviance', 'ls'}, optional (default='deviance') + loss : {'deviance'}, optional (default='deviance') loss function to be optimized. 'deviance' refers to deviance (= logistic regression) for classification - with probabilistic outputs. 'ls' refers to least squares - regression. + with probabilistic outputs. learn_rate : float, optional (default=0.1) learning rate shrinks the contribution of each tree by `learn_rate`. @@ -788,6 +785,52 @@ def fit(self, X, y): return super(GradientBoostingClassifier, self).fit(X, y) + def _score_to_proba(self, score): + """Compute class probability estimates from decision scores. """ + proba = np.ones((score.shape[0], self.n_classes_), dtype=np.float64) + if not self.loss_.is_multi_class: + proba[:, 1] = 1.0 / (1.0 + np.exp(-score.ravel())) + proba[:, 0] -= proba[:, 1] + else: + proba = np.exp(score) / np.sum(np.exp(score), axis=1)[:, np.newaxis] + return proba + + def predict_proba(self, X): + """Predict class probabilities for X. + + Parameters + ---------- + X : array-like of shape = [n_samples, n_features] + The input samples. + + Returns + ------- + p : array of shape = [n_samples] + The class probabilities of the input samples. Classes are + ordered by arithmetical order. + """ + score = self.decision_function(X) + return self._score_to_proba(score) + + def staged_predict_proba(self, X): + """Predict class probabilities at each stage for X. + + This method allows monitoring (i.e. determine error on testing set) + after each stage. + + Parameters + ---------- + X : array-like of shape = [n_samples, n_features] + The input samples. + + Returns + ------- + y : array of shape = [n_samples] + The predicted value of the input samples. + """ + for score in self.staged_decision_function(X): + yield self._score_to_proba(X) + def predict(self, X): """Predict class for X. @@ -804,8 +847,11 @@ def predict(self, X): proba = self.predict_proba(X) return self.classes_.take(np.argmax(proba, axis=1), axis=0) - def predict_proba(self, X): - """Predict class probabilities for X. + def staged_predict(self, X): + """Predict class probabilities at each stage for X. + + This method allows monitoring (i.e. determine error on testing set) + after each stage. Parameters ---------- @@ -814,18 +860,11 @@ def predict_proba(self, X): Returns ------- - p : array of shape = [n_samples] - The class probabilities of the input samples. Classes are - ordered by arithmetical order. + y : array of shape = [n_samples] + The predicted value of the input samples. """ - score = self.decision_function(X) - proba = np.ones((score.shape[0], self.n_classes_), dtype=np.float64) - if not self.loss_.is_multi_class: - proba[:, 1] = 1.0 / (1.0 + np.exp(-score.ravel())) - proba[:, 0] -= proba[:, 1] - else: - proba = np.exp(score) / np.sum(np.exp(score), axis=1)[:, np.newaxis] - return proba + for proba in self.staged_predict_proba(X): + yield self.classes_.take(np.argmax(proba, axis=1), axis=0) class GradientBoostingRegressor(BaseGradientBoosting, RegressorMixin): @@ -838,11 +877,12 @@ class GradientBoostingRegressor(BaseGradientBoosting, RegressorMixin): Parameters ---------- - loss : {'ls', 'lad'}, optional (default='ls') + loss : {'ls', 'lad', 'huber', 'quantile'}, optional (default='ls') loss function to be optimized. 'ls' refers to least squares regression. 'lad' (least absolute deviation) is a highly robust loss function soley based on order information of the input - variables. + variables. 'huber' is a combination of the two. 'quantile' + allows quantile regression (use `alpha` to specify the quantile). learn_rate : float, optional (default=0.1) learning rate shrinks the contribution of each tree by `learn_rate`. @@ -924,7 +964,7 @@ class GradientBoostingRegressor(BaseGradientBoosting, RegressorMixin): See also -------- - sklearn.tree.DecisionTreeRegressor, RandomForestRegressor + DecisionTreeRegressor, RandomForestRegressor, GradientBoostingRegressorCV References ---------- @@ -1006,73 +1046,245 @@ def staged_predict(self, X): yield y.ravel() -class BaseGradientBoostingCV(object): +class BaseGradientBoostingCV(BaseEstimator): + """Abstract base class for GB with built-in cross-validation. + This class implements the Decorator design pattern; it wraps + a concrete ``_model_class`` object and delegates attribute + access to the object. Soley the arguments ``cv``, ``max_estimators``, + and ``_model`` are stored in the decorator object. + """ + __metaclass__ = ABCMeta + + @abstractmethod def __init__(self, **kwargs): - self.cv = kwargs.pop('cv', None) - self.max_estimators = kwargs.pop('max_estimators', 1000) + # verbose syntax needed to avoid recursive __setattr__ invokation + BaseEstimator.__setattr__(self, 'cv', kwargs.pop('cv', None)) + BaseEstimator.__setattr__(self, 'max_estimators', + kwargs.pop('max_estimators', 1000)) kwargs['n_estimators'] = self.max_estimators - self._model = self._model_class(**kwargs) + BaseEstimator.__setattr__(self, '_model', self._model_class(**kwargs)) def fit(self, X, y): """Pick best ``n_estimators`` based on cross-validation ``cv``. - Fits model on entire dataset using ``n_estimators`` found by - cross-validation ``cv``. - + Finally, fits model on entire dataset using ``n_estimators``. Cross-validation scores are stored in ``self.cv_deviance``. """ + X, y = check_arrays(X, y, sparse_format='dense') + X = np.asfortranarray(X, dtype=DTYPE) + y = np.ravel(y, order='C') + if self.cv is None: - self.cv = KFold(y.shape[0], k=5) + cv = KFold(y.shape[0], k=5) + if isinstance(self.cv, int): + cv = KFold(y.shape[0], k=self.cv) + else: + cv = self.cv - cv_deviance = np.zeros((self.cv.k, self.max_estimators), + cv_score = np.zeros((cv.k, self.max_estimators), dtype=np.float64) - for k, (train, test) in enumerate(self.cv): + for k, (train, test) in enumerate(cv): model = clone(self._model) model.fit(X[train], y[train]) for i, score in enumerate(model.staged_decision_function(X[test])): - cv_deviance[k, i] = model.loss_(y[test], score) + cv_score[k, i] = model.loss_(y[test], score) - mean_deviance = cv_deviance.mean(axis=1) - best_estimators = mean_deviance.argmin() + 1 - print("Best number of estimators=%d - error %.4f" % - (best_estimators, mean_deviance[best_estimators - 1])) + mean_score = cv_score.mean(axis=0) + best_estimators = mean_score.argmin() + 1 - self.n_estimators = best_estimators - self.cv_deviance = cv_deviance + self._model.set_params(n_estimators=best_estimators) + self.cv_score_ = cv_score + BaseEstimator.__setattr__(self, 'cv_score_', cv_score) self._model.fit(X, y) return self def __getattr__(self, name): return getattr(self._model, name) + def __setattr__(self, name, value): + setattr(self._model, name, value) + class GradientBoostingClassifierCV(BaseGradientBoostingCV): + """GB classifier with built-in cross-validation. + + A ``GradientBoostingClassifier`` that optimizes ``n_estimators`` + via cross-validation. + + Parameters + ---------- + loss : {'deviance'}, optional (default='deviance') + loss function to be optimized. 'deviance' refers to + deviance (= logistic regression) for classification + with probabilistic outputs. + + learn_rate : float, optional (default=0.1) + learning rate shrinks the contribution of each tree by `learn_rate`. + There is a trade-off between learn_rate and n_estimators. + + max_estimators : int (default=1000) + The maximum number of boosting stages to perform. The best number + of estimators ``n_estimators`` is picked based on deviance on + held-out data. + + max_depth : integer, optional (default=3) + maximum depth of the individual regression estimators. The maximum + depth limits the number of nodes in the tree. Tune this parameter + for best performance; the best value depends on the interaction + of the input variables. + + min_samples_split : integer, optional (default=1) + The minimum number of samples required to split an internal node. + + min_samples_leaf : integer, optional (default=1) + The minimum number of samples required to be at a leaf node. + + subsample : float, optional (default=1.0) + The fraction of samples to be used for fitting the individual base + learners. If smaller than 1.0 this results in Stochastic Gradient + Boosting. `subsample` interacts with the parameter `n_estimators`. + Choosing `subsample < 1.0` leads to a reduction of variance + and an increase in bias. + + max_features : int, None, optional (default=None) + The number of features to consider when looking for the best split. + Features are choosen randomly at each split point. + If None, then `max_features=n_features`. Choosing + `max_features < n_features` leads to a reduction of variance + and an increase in bias. + + cv : cross-validation generator or int (default=5) + If int, ``cv``-fold cross-valdiation will be used. + + Attributes + ---------- + `feature_importances_` : array, shape = [n_features] + The feature importances (the higher, the more important the feature). + + `oob_score_` : array, shape = [n_estimators] + Score of the training dataset obtained using an out-of-bag estimate. + The i-th score ``oob_score_[i]`` is the deviance (= loss) of the + model at iteration ``i`` on the out-of-bag sample. + + `train_score_` : array, shape = [n_estimators] + The i-th score ``train_score_[i]`` is the deviance (= loss) of the + model at iteration ``i`` on the in-bag sample. + If ``subsample == 1`` this is the deviance on the training data. + + `cv_score_` : array, shape = [cv.k, max_estimators] + The deviance scores for each fold and boosting iteration. + + See also + -------- + GradientBoostingClassifier + """ _model_class = GradientBoostingClassifier def __init__(self, loss='deviance', learn_rate=0.1, max_estimators=1000, subsample=1.0, min_samples_split=1, min_samples_leaf=1, max_depth=3, init=None, random_state=None, - max_features=None, cv=None): + max_features=None, cv=5): super(GradientBoostingClassifierCV, self).__init__( loss=loss, learn_rate=learn_rate, max_estimators=max_estimators, - min_samples_split=min_samples_split, min_samples_leaf=min_samples_leaf, - max_depth=max_depth, init=init, subsample=subsample, - max_features=max_features, random_state=random_state, cv=cv) + min_samples_split=min_samples_split, max_depth=max_depth, + min_samples_leaf=min_samples_leaf, init=init, + subsample=subsample, max_features=max_features, + random_state=random_state, cv=cv) class GradientBoostingRegressorCV(BaseGradientBoostingCV): + """GB regressor with built-in cross-validation. + + A ``GradientBoostingRegressor`` that optimizes ``n_estimators`` + via cross-validation. + + Parameters + ---------- + loss : {'ls', 'lad', 'huber', 'quantile'}, optional (default='ls') + loss function to be optimized. 'ls' refers to least squares + regression. 'lad' (least absolute deviation) is a highly robust + loss function soley based on order information of the input + variables. 'huber' is a combination of the two. 'quantile' + allows quantile regression (use `alpha` to specify the quantile). + + learn_rate : float, optional (default=0.1) + learning rate shrinks the contribution of each tree by `learn_rate`. + There is a trade-off between learn_rate and n_estimators. + + max_estimators : int (default=1000) + The maximum number of boosting stages to perform. The best number + of estimators ``n_estimators`` is picked based on deviance on + held-out data. + + max_depth : integer, optional (default=3) + maximum depth of the individual regression estimators. The maximum + depth limits the number of nodes in the tree. Tune this parameter + for best performance; the best value depends on the interaction + of the input variables. + + min_samples_split : integer, optional (default=1) + The minimum number of samples required to split an internal node. + + min_samples_leaf : integer, optional (default=1) + The minimum number of samples required to be at a leaf node. + + subsample : float, optional (default=1.0) + The fraction of samples to be used for fitting the individual base + learners. If smaller than 1.0 this results in Stochastic Gradient + Boosting. `subsample` interacts with the parameter `n_estimators`. + Choosing `subsample < 1.0` leads to a reduction of variance + and an increase in bias. + + max_features : int, None, optional (default=None) + The number of features to consider when looking for the best split. + Features are choosen randomly at each split point. + If None, then `max_features=n_features`. Choosing + `max_features < n_features` leads to a reduction of variance + and an increase in bias. + + alpha : float (default=0.9) + The alpha-quantile of the huber loss function and the quantile + loss function. Only if ``loss='huber'`` or ``loss='quantile'``. + + cv : cross-validation generator or int (default=5) + If int, ``cv``-fold cross-valdiation will be used. + + Attributes + ---------- + `feature_importances_` : array, shape = [n_features] + The feature importances (the higher, the more important the feature). + + `oob_score_` : array, shape = [n_estimators] + Score of the training dataset obtained using an out-of-bag estimate. + The i-th score ``oob_score_[i]`` is the deviance (= loss) of the + model at iteration ``i`` on the out-of-bag sample. + + `train_score_` : array, shape = [n_estimators] + The i-th score ``train_score_[i]`` is the deviance (= loss) of the + model at iteration ``i`` on the in-bag sample. + If ``subsample == 1`` this is the deviance on the training data. + + `cv_score_` : array, shape = [cv.k, max_estimators] + The deviance scores for each fold and boosting iteration. + + See also + -------- + GradientBoostingRegressor + """ _model_class = GradientBoostingRegressor def __init__(self, loss='ls', learn_rate=0.1, max_estimators=1000, subsample=1.0, min_samples_split=1, min_samples_leaf=1, max_depth=3, init=None, random_state=None, - max_features=None, alpha=0.9, cv=None): - super(GradientBoostingClassifierCV, self).__init__( + max_features=None, alpha=0.9, cv=5): + super(GradientBoostingRegressorCV, self).__init__( loss=loss, learn_rate=learn_rate, max_estimators=max_estimators, - min_samples_split=min_samples_split, min_samples_leaf=min_samples_leaf, + min_samples_split=min_samples_split, + min_samples_leaf=min_samples_leaf, max_depth=max_depth, init=init, subsample=subsample, - max_features=max_features, random_state=random_state, alpha=alpha, cv=cv) + max_features=max_features, random_state=random_state, + alpha=alpha, cv=cv) diff --git a/sklearn/ensemble/tests/test_gradient_boosting.py b/sklearn/ensemble/tests/test_gradient_boosting.py index 85115ec0b0fe9..c5fd9d46e4377 100644 --- a/sklearn/ensemble/tests/test_gradient_boosting.py +++ b/sklearn/ensemble/tests/test_gradient_boosting.py @@ -14,6 +14,8 @@ from sklearn.ensemble import GradientBoostingClassifier from sklearn.ensemble import GradientBoostingRegressor +from sklearn.ensemble import GradientBoostingClassifierCV +from sklearn.ensemble import GradientBoostingRegressorCV from sklearn import datasets @@ -416,3 +418,52 @@ def test_mem_layout(): clf.fit(X, y_) assert_array_equal(clf.predict(T), true_result) assert_equal(100, len(clf.estimators_)) + + +def test_cv_attr(): + """Test attribute access for CV classes. """ + clf = GradientBoostingClassifierCV(max_estimators=100, + min_samples_leaf=3) + clf.fit(X, y) + + assert clf.min_samples_leaf == 3 == clf._model.min_samples_leaf + clf.min_samples_leaf = 2 + assert clf.min_samples_leaf == 2 == clf._model.min_samples_leaf + + +def test_cv_clf(): + """Test GradientBoostingClassifierCV n_estimators selection. """ + X, y = datasets.make_hastie_10_2(n_samples=1000, random_state=1) + + max_estimators = 50 + + clf = GradientBoostingClassifierCV(max_estimators=max_estimators) + clf.fit(X, y) + # max_estimators very small so it chooses all of them + assert clf.n_estimators == max_estimators + + clf_prime = GradientBoostingClassifier( + n_estimators=max_estimators).fit(X, y) + + assert_array_equal(clf_prime.train_score_, clf.train_score_) + assert clf.cv_score_.shape[0] == 5 # default 5-fold CV + assert clf.cv_score_.shape[1] == max_estimators + + +def test_cv_reg(): + """Test GradientBoostingRegressorCV n_estimators selection. """ + X, y = datasets.make_friedman1(n_samples=1000, random_state=1) + + max_estimators = 50 + + clf = GradientBoostingRegressorCV(max_estimators=max_estimators) + clf.fit(X, y) + # max_estimators very small so it chooses all of them + assert clf.n_estimators == max_estimators + + clf_prime = GradientBoostingRegressor( + n_estimators=max_estimators).fit(X, y) + + assert_array_equal(clf_prime.train_score_, clf.train_score_) + assert clf.cv_score_.shape[0] == 5 # default 5-fold CV + assert clf.cv_score_.shape[1] == max_estimators From 600fc631c1e98a3dd9cd24ebc29a94b9c45d4f42 Mon Sep 17 00:00:00 2001 From: Peter Prettenhofer Date: Mon, 27 Aug 2012 10:26:58 +0200 Subject: [PATCH 05/11] wip: aggregate fold via groupby --- sklearn/ensemble/gradient_boosting.py | 93 +++++++++++++++++++++------ 1 file changed, 72 insertions(+), 21 deletions(-) diff --git a/sklearn/ensemble/gradient_boosting.py b/sklearn/ensemble/gradient_boosting.py index 6ed3a81dc81c9..b8874ede207f6 100644 --- a/sklearn/ensemble/gradient_boosting.py +++ b/sklearn/ensemble/gradient_boosting.py @@ -33,15 +33,18 @@ import numpy as np +from itertools import groupby, islice +from opterator import itemgetter from scipy import stats from .base import BaseEnsemble from ..base import BaseEstimator from ..base import ClassifierMixin from ..base import RegressorMixin -from ..base import clone from ..utils import check_random_state, array2d, check_arrays from ..cross_validation import KFold +from ..grid_search import IterGrid +from ..externals.joblib import Parallel, delayed from ..tree._tree import Tree from ..tree._tree import _random_sample_mask @@ -1046,13 +1049,32 @@ def staged_predict(self, X): yield y.ravel() +def fit_grid_point(grid_idx, cv_idx, X, y, estimator_class, params, + train, test): + """Fit a single grid point and return staged scores. """ + X_train, y_train = X[train], y[train] + X_test, y_test = X[test], y[test] + + estimator = estimator_class(**params) + estimator.fit(X_train, y_train) + + test_deviance = np.fromiter( + (estimator.loss_(y_test, score) + for score in estimator.staged_decision_function(X_test)), + dtype=np.float64, count=estimator.n_estimators) + + return (grid_idx, cv_idx, test_deviance) + + class BaseGradientBoostingCV(BaseEstimator): """Abstract base class for GB with built-in cross-validation. This class implements the Decorator design pattern; it wraps - a concrete ``_model_class`` object and delegates attribute - access to the object. Soley the arguments ``cv``, ``max_estimators``, - and ``_model`` are stored in the decorator object. + a concrete ``_estimator_class`` object and delegates attribute + access to the object. + + XXX Soley the arguments ``cv``, ``max_estimators``, + and ``_estimator`` are stored in the decorator object. """ __metaclass__ = ABCMeta @@ -1064,7 +1086,7 @@ def __init__(self, **kwargs): kwargs.pop('max_estimators', 1000)) kwargs['n_estimators'] = self.max_estimators - BaseEstimator.__setattr__(self, '_model', self._model_class(**kwargs)) + BaseEstimator.__setattr__(self, '_params', kwargs) def fit(self, X, y): """Pick best ``n_estimators`` based on cross-validation ``cv``. @@ -1083,28 +1105,57 @@ def fit(self, X, y): else: cv = self.cv - cv_score = np.zeros((cv.k, self.max_estimators), - dtype=np.float64) - for k, (train, test) in enumerate(cv): - model = clone(self._model) - model.fit(X[train], y[train]) - for i, score in enumerate(model.staged_decision_function(X[test])): - cv_score[k, i] = model.loss_(y[test], score) + params_grid = IterGrid(self._params) + + tasks = ((grid_idx, cv_idx, X, y, self._estimator_class, + params, train, test) + for grid_idx, params in enumerate(params_grid) + for cv_idx, (train, test) in enumerate(cv)) + + grid = Parallel(n_jobs=self.n_jobs, verbose=self.verbose, + pre_dispatch="2*n_jobs")( + delayed(fit_grid_point)(tasks)) + + out = [] - mean_score = cv_score.mean(axis=0) - best_estimators = mean_score.argmin() + 1 + for grid_idx, grid_points in groupby(grid, itemgetter(0)): + grid_points = list(grid_points) + assert len(grid_points) == cv.k - self._model.set_params(n_estimators=best_estimators) - self.cv_score_ = cv_score - BaseEstimator.__setattr__(self, 'cv_score_', cv_score) + A = np.row_stack([fold[2] for fold in grid_points]) + scores = A.mean(axis=0) + best_iter = np.argmin(scores) + best_score = scores[best_iter] + + out.append((best_score, grid_idx)) + + out = sorted(out) + best_score, best_grid_idx = out[0] + + # get best params setting + best_params = next(islice(params_grid, best_grid_idx, + best_grid_idx + 1)) + + print("best_score: %.4f; params: %s" % (best_score, best_params)) + + self._model = self._model_class(**best_params) self._model.fit(X, y) + + #BaseEstimator.__setattr__(self, 'cv_score_', cv_score) return self def __getattr__(self, name): - return getattr(self._model, name) + if self._model: + return getattr(self._model, name) + else: + raise AttributeError("type object '' has no attribute '%s'" % + (self.__class__.__name__, name)) def __setattr__(self, name, value): - setattr(self._model, name, value) + if self._model: + setattr(self._model, name, value) + else: + BaseEstimator.__setattr__(self, name, value) class GradientBoostingClassifierCV(BaseGradientBoostingCV): @@ -1181,7 +1232,7 @@ class GradientBoostingClassifierCV(BaseGradientBoostingCV): GradientBoostingClassifier """ - _model_class = GradientBoostingClassifier + _estimator_class = GradientBoostingClassifier def __init__(self, loss='deviance', learn_rate=0.1, max_estimators=1000, subsample=1.0, min_samples_split=1, min_samples_leaf=1, @@ -1275,7 +1326,7 @@ class GradientBoostingRegressorCV(BaseGradientBoostingCV): GradientBoostingRegressor """ - _model_class = GradientBoostingRegressor + _estimator_class = GradientBoostingRegressor def __init__(self, loss='ls', learn_rate=0.1, max_estimators=1000, subsample=1.0, min_samples_split=1, min_samples_leaf=1, From 1ec02b4ca40fa9f65db8cacd217665f3a4574596 Mon Sep 17 00:00:00 2001 From: Peter Prettenhofer Date: Mon, 27 Aug 2012 16:32:57 +0200 Subject: [PATCH 06/11] wip: fixing some set attr errors but still buggy if params not lists --- sklearn/ensemble/gradient_boosting.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/sklearn/ensemble/gradient_boosting.py b/sklearn/ensemble/gradient_boosting.py index b8874ede207f6..045eae51482e2 100644 --- a/sklearn/ensemble/gradient_boosting.py +++ b/sklearn/ensemble/gradient_boosting.py @@ -34,7 +34,7 @@ import numpy as np from itertools import groupby, islice -from opterator import itemgetter +from operator import itemgetter from scipy import stats from .base import BaseEnsemble @@ -1084,6 +1084,9 @@ def __init__(self, **kwargs): BaseEstimator.__setattr__(self, 'cv', kwargs.pop('cv', None)) BaseEstimator.__setattr__(self, 'max_estimators', kwargs.pop('max_estimators', 1000)) + BaseEstimator.__setattr__(self, 'n_jobs', kwargs.pop('n_jobs', 1)) + BaseEstimator.__setattr__(self, 'verbose', kwargs.pop('verbose', + 0)) kwargs['n_estimators'] = self.max_estimators BaseEstimator.__setattr__(self, '_params', kwargs) @@ -1114,7 +1117,7 @@ def fit(self, X, y): grid = Parallel(n_jobs=self.n_jobs, verbose=self.verbose, pre_dispatch="2*n_jobs")( - delayed(fit_grid_point)(tasks)) + delayed(fit_grid_point)(tasks) out = [] @@ -1138,21 +1141,21 @@ def fit(self, X, y): print("best_score: %.4f; params: %s" % (best_score, best_params)) - self._model = self._model_class(**best_params) - self._model.fit(X, y) + estimator = self._estimator_class(**best_params) + estimator.fit(X, y) - #BaseEstimator.__setattr__(self, 'cv_score_', cv_score) + BaseEstimator.__setattr__(self, '_estimator', estimator) return self def __getattr__(self, name): - if self._model: + if hasattr(self, '_estimator'): return getattr(self._model, name) else: - raise AttributeError("type object '' has no attribute '%s'" % + raise AttributeError("type object '%s' has no attribute '%s'" % (self.__class__.__name__, name)) def __setattr__(self, name, value): - if self._model: + if hasattr(self, '_estimator'): setattr(self._model, name, value) else: BaseEstimator.__setattr__(self, name, value) From 5749550e319b80ed0282bf7e7963206558a773e5 Mon Sep 17 00:00:00 2001 From: Peter Prettenhofer Date: Thu, 30 Aug 2012 19:56:29 +0200 Subject: [PATCH 07/11] remove *CV classes - only pick decision_function and staged predict refactoring --- sklearn/ensemble/gradient_boosting.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/sklearn/ensemble/gradient_boosting.py b/sklearn/ensemble/gradient_boosting.py index 045eae51482e2..136e14f11811a 100644 --- a/sklearn/ensemble/gradient_boosting.py +++ b/sklearn/ensemble/gradient_boosting.py @@ -14,15 +14,6 @@ - ``GradientBoostingRegressor`` implements gradient boosting for regression problems. - -- The ``BaseGradientBoostingCV`` base class implements a ``fit`` method - to choose the best ``n_estimators`` based on cross-validation. - -- ``GradientBoostingClassifierCV`` implements ``BaseGradientBoostingCV`` for - classification problems. - -- ``GradientBoostingRegressor`` implements ``BaseGradientBoostingCV`` for - regression problems. """ # Authors: Peter Prettenhofer, Scott White, Gilles Louppe From e002644213bbb67943854640638e9a3b0af365df Mon Sep 17 00:00:00 2001 From: Peter Prettenhofer Date: Thu, 30 Aug 2012 20:02:01 +0200 Subject: [PATCH 08/11] rm CV class tests --- .../ensemble/tests/test_gradient_boosting.py | 51 ------------------- 1 file changed, 51 deletions(-) diff --git a/sklearn/ensemble/tests/test_gradient_boosting.py b/sklearn/ensemble/tests/test_gradient_boosting.py index c5fd9d46e4377..85115ec0b0fe9 100644 --- a/sklearn/ensemble/tests/test_gradient_boosting.py +++ b/sklearn/ensemble/tests/test_gradient_boosting.py @@ -14,8 +14,6 @@ from sklearn.ensemble import GradientBoostingClassifier from sklearn.ensemble import GradientBoostingRegressor -from sklearn.ensemble import GradientBoostingClassifierCV -from sklearn.ensemble import GradientBoostingRegressorCV from sklearn import datasets @@ -418,52 +416,3 @@ def test_mem_layout(): clf.fit(X, y_) assert_array_equal(clf.predict(T), true_result) assert_equal(100, len(clf.estimators_)) - - -def test_cv_attr(): - """Test attribute access for CV classes. """ - clf = GradientBoostingClassifierCV(max_estimators=100, - min_samples_leaf=3) - clf.fit(X, y) - - assert clf.min_samples_leaf == 3 == clf._model.min_samples_leaf - clf.min_samples_leaf = 2 - assert clf.min_samples_leaf == 2 == clf._model.min_samples_leaf - - -def test_cv_clf(): - """Test GradientBoostingClassifierCV n_estimators selection. """ - X, y = datasets.make_hastie_10_2(n_samples=1000, random_state=1) - - max_estimators = 50 - - clf = GradientBoostingClassifierCV(max_estimators=max_estimators) - clf.fit(X, y) - # max_estimators very small so it chooses all of them - assert clf.n_estimators == max_estimators - - clf_prime = GradientBoostingClassifier( - n_estimators=max_estimators).fit(X, y) - - assert_array_equal(clf_prime.train_score_, clf.train_score_) - assert clf.cv_score_.shape[0] == 5 # default 5-fold CV - assert clf.cv_score_.shape[1] == max_estimators - - -def test_cv_reg(): - """Test GradientBoostingRegressorCV n_estimators selection. """ - X, y = datasets.make_friedman1(n_samples=1000, random_state=1) - - max_estimators = 50 - - clf = GradientBoostingRegressorCV(max_estimators=max_estimators) - clf.fit(X, y) - # max_estimators very small so it chooses all of them - assert clf.n_estimators == max_estimators - - clf_prime = GradientBoostingRegressor( - n_estimators=max_estimators).fit(X, y) - - assert_array_equal(clf_prime.train_score_, clf.train_score_) - assert clf.cv_score_.shape[0] == 5 # default 5-fold CV - assert clf.cv_score_.shape[1] == max_estimators From 2c404ed26906054070398ab90179221894e0592a Mon Sep 17 00:00:00 2001 From: Peter Prettenhofer Date: Thu, 30 Aug 2012 20:02:16 +0200 Subject: [PATCH 09/11] rm CV class legacy --- sklearn/ensemble/gradient_boosting.py | 332 ++------------------------ 1 file changed, 16 insertions(+), 316 deletions(-) diff --git a/sklearn/ensemble/gradient_boosting.py b/sklearn/ensemble/gradient_boosting.py index 136e14f11811a..3d93a0c98a821 100644 --- a/sklearn/ensemble/gradient_boosting.py +++ b/sklearn/ensemble/gradient_boosting.py @@ -24,8 +24,6 @@ import numpy as np -from itertools import groupby, islice -from operator import itemgetter from scipy import stats from .base import BaseEnsemble @@ -33,9 +31,6 @@ from ..base import ClassifierMixin from ..base import RegressorMixin from ..utils import check_random_state, array2d, check_arrays -from ..cross_validation import KFold -from ..grid_search import IterGrid -from ..externals.joblib import Parallel, delayed from ..tree._tree import Tree from ..tree._tree import _random_sample_mask @@ -716,6 +711,21 @@ class GradientBoostingClassifier(BaseGradientBoosting, ClassifierMixin): `max_features < n_features` leads to a reduction of variance and an increase in bias. + Attributes + ---------- + `feature_importances_` : array, shape = [n_features] + The feature importances (the higher, the more important the feature). + + `oob_score_` : array, shape = [n_estimators] + Score of the training dataset obtained using an out-of-bag estimate. + The i-th score ``oob_score_[i]`` is the deviance (= loss) of the + model at iteration ``i`` on the out-of-bag sample. + + `train_score_` : array, shape = [n_estimators] + The i-th score ``train_score_[i]`` is the deviance (= loss) of the + model at iteration ``i`` on the in-bag sample. + If ``subsample == 1`` this is the deviance on the training data. + Examples -------- >>> samples = [[0, 0, 2], [1, 0, 0]] @@ -932,21 +942,6 @@ class GradientBoostingRegressor(BaseGradientBoosting, RegressorMixin): model at iteration ``i`` on the in-bag sample. If ``subsample == 1`` this is the deviance on the training data. - Attributes - ---------- - `feature_importances_` : array, shape = [n_features] - The feature importances (the higher, the more important the feature). - - `oob_score_` : array, shape = [n_estimators] - Score of the training dataset obtained using an out-of-bag estimate. - The i-th score ``oob_score_[i]`` is the deviance (= loss) of the - model at iteration ``i`` on the out-of-bag sample. - - `train_score_` : array, shape = [n_estimators] - The i-th score ``train_score_[i]`` is the deviance (= loss) of the - model at iteration ``i`` on the in-bag sample. - If ``subsample == 1`` this is the deviance on the training data. - Examples -------- >>> samples = [[0, 0, 2], [1, 0, 0]] @@ -958,7 +953,7 @@ class GradientBoostingRegressor(BaseGradientBoosting, RegressorMixin): See also -------- - DecisionTreeRegressor, RandomForestRegressor, GradientBoostingRegressorCV + DecisionTreeRegressor, RandomForestRegressor References ---------- @@ -1038,298 +1033,3 @@ def staged_predict(self, X): """ for y in self.staged_decision_function(X): yield y.ravel() - - -def fit_grid_point(grid_idx, cv_idx, X, y, estimator_class, params, - train, test): - """Fit a single grid point and return staged scores. """ - X_train, y_train = X[train], y[train] - X_test, y_test = X[test], y[test] - - estimator = estimator_class(**params) - estimator.fit(X_train, y_train) - - test_deviance = np.fromiter( - (estimator.loss_(y_test, score) - for score in estimator.staged_decision_function(X_test)), - dtype=np.float64, count=estimator.n_estimators) - - return (grid_idx, cv_idx, test_deviance) - - -class BaseGradientBoostingCV(BaseEstimator): - """Abstract base class for GB with built-in cross-validation. - - This class implements the Decorator design pattern; it wraps - a concrete ``_estimator_class`` object and delegates attribute - access to the object. - - XXX Soley the arguments ``cv``, ``max_estimators``, - and ``_estimator`` are stored in the decorator object. - """ - __metaclass__ = ABCMeta - - @abstractmethod - def __init__(self, **kwargs): - # verbose syntax needed to avoid recursive __setattr__ invokation - BaseEstimator.__setattr__(self, 'cv', kwargs.pop('cv', None)) - BaseEstimator.__setattr__(self, 'max_estimators', - kwargs.pop('max_estimators', 1000)) - BaseEstimator.__setattr__(self, 'n_jobs', kwargs.pop('n_jobs', 1)) - BaseEstimator.__setattr__(self, 'verbose', kwargs.pop('verbose', - 0)) - - kwargs['n_estimators'] = self.max_estimators - BaseEstimator.__setattr__(self, '_params', kwargs) - - def fit(self, X, y): - """Pick best ``n_estimators`` based on cross-validation ``cv``. - - Finally, fits model on entire dataset using ``n_estimators``. - Cross-validation scores are stored in ``self.cv_deviance``. - """ - X, y = check_arrays(X, y, sparse_format='dense') - X = np.asfortranarray(X, dtype=DTYPE) - y = np.ravel(y, order='C') - - if self.cv is None: - cv = KFold(y.shape[0], k=5) - if isinstance(self.cv, int): - cv = KFold(y.shape[0], k=self.cv) - else: - cv = self.cv - - params_grid = IterGrid(self._params) - - tasks = ((grid_idx, cv_idx, X, y, self._estimator_class, - params, train, test) - for grid_idx, params in enumerate(params_grid) - for cv_idx, (train, test) in enumerate(cv)) - - grid = Parallel(n_jobs=self.n_jobs, verbose=self.verbose, - pre_dispatch="2*n_jobs")( - delayed(fit_grid_point)(tasks) - - out = [] - - for grid_idx, grid_points in groupby(grid, itemgetter(0)): - grid_points = list(grid_points) - assert len(grid_points) == cv.k - - A = np.row_stack([fold[2] for fold in grid_points]) - scores = A.mean(axis=0) - best_iter = np.argmin(scores) - best_score = scores[best_iter] - - out.append((best_score, grid_idx)) - - out = sorted(out) - best_score, best_grid_idx = out[0] - - # get best params setting - best_params = next(islice(params_grid, best_grid_idx, - best_grid_idx + 1)) - - print("best_score: %.4f; params: %s" % (best_score, best_params)) - - estimator = self._estimator_class(**best_params) - estimator.fit(X, y) - - BaseEstimator.__setattr__(self, '_estimator', estimator) - return self - - def __getattr__(self, name): - if hasattr(self, '_estimator'): - return getattr(self._model, name) - else: - raise AttributeError("type object '%s' has no attribute '%s'" % - (self.__class__.__name__, name)) - - def __setattr__(self, name, value): - if hasattr(self, '_estimator'): - setattr(self._model, name, value) - else: - BaseEstimator.__setattr__(self, name, value) - - -class GradientBoostingClassifierCV(BaseGradientBoostingCV): - """GB classifier with built-in cross-validation. - - A ``GradientBoostingClassifier`` that optimizes ``n_estimators`` - via cross-validation. - - Parameters - ---------- - loss : {'deviance'}, optional (default='deviance') - loss function to be optimized. 'deviance' refers to - deviance (= logistic regression) for classification - with probabilistic outputs. - - learn_rate : float, optional (default=0.1) - learning rate shrinks the contribution of each tree by `learn_rate`. - There is a trade-off between learn_rate and n_estimators. - - max_estimators : int (default=1000) - The maximum number of boosting stages to perform. The best number - of estimators ``n_estimators`` is picked based on deviance on - held-out data. - - max_depth : integer, optional (default=3) - maximum depth of the individual regression estimators. The maximum - depth limits the number of nodes in the tree. Tune this parameter - for best performance; the best value depends on the interaction - of the input variables. - - min_samples_split : integer, optional (default=1) - The minimum number of samples required to split an internal node. - - min_samples_leaf : integer, optional (default=1) - The minimum number of samples required to be at a leaf node. - - subsample : float, optional (default=1.0) - The fraction of samples to be used for fitting the individual base - learners. If smaller than 1.0 this results in Stochastic Gradient - Boosting. `subsample` interacts with the parameter `n_estimators`. - Choosing `subsample < 1.0` leads to a reduction of variance - and an increase in bias. - - max_features : int, None, optional (default=None) - The number of features to consider when looking for the best split. - Features are choosen randomly at each split point. - If None, then `max_features=n_features`. Choosing - `max_features < n_features` leads to a reduction of variance - and an increase in bias. - - cv : cross-validation generator or int (default=5) - If int, ``cv``-fold cross-valdiation will be used. - - Attributes - ---------- - `feature_importances_` : array, shape = [n_features] - The feature importances (the higher, the more important the feature). - - `oob_score_` : array, shape = [n_estimators] - Score of the training dataset obtained using an out-of-bag estimate. - The i-th score ``oob_score_[i]`` is the deviance (= loss) of the - model at iteration ``i`` on the out-of-bag sample. - - `train_score_` : array, shape = [n_estimators] - The i-th score ``train_score_[i]`` is the deviance (= loss) of the - model at iteration ``i`` on the in-bag sample. - If ``subsample == 1`` this is the deviance on the training data. - - `cv_score_` : array, shape = [cv.k, max_estimators] - The deviance scores for each fold and boosting iteration. - - See also - -------- - GradientBoostingClassifier - """ - - _estimator_class = GradientBoostingClassifier - - def __init__(self, loss='deviance', learn_rate=0.1, max_estimators=1000, - subsample=1.0, min_samples_split=1, min_samples_leaf=1, - max_depth=3, init=None, random_state=None, - max_features=None, cv=5): - super(GradientBoostingClassifierCV, self).__init__( - loss=loss, learn_rate=learn_rate, max_estimators=max_estimators, - min_samples_split=min_samples_split, max_depth=max_depth, - min_samples_leaf=min_samples_leaf, init=init, - subsample=subsample, max_features=max_features, - random_state=random_state, cv=cv) - - -class GradientBoostingRegressorCV(BaseGradientBoostingCV): - """GB regressor with built-in cross-validation. - - A ``GradientBoostingRegressor`` that optimizes ``n_estimators`` - via cross-validation. - - Parameters - ---------- - loss : {'ls', 'lad', 'huber', 'quantile'}, optional (default='ls') - loss function to be optimized. 'ls' refers to least squares - regression. 'lad' (least absolute deviation) is a highly robust - loss function soley based on order information of the input - variables. 'huber' is a combination of the two. 'quantile' - allows quantile regression (use `alpha` to specify the quantile). - - learn_rate : float, optional (default=0.1) - learning rate shrinks the contribution of each tree by `learn_rate`. - There is a trade-off between learn_rate and n_estimators. - - max_estimators : int (default=1000) - The maximum number of boosting stages to perform. The best number - of estimators ``n_estimators`` is picked based on deviance on - held-out data. - - max_depth : integer, optional (default=3) - maximum depth of the individual regression estimators. The maximum - depth limits the number of nodes in the tree. Tune this parameter - for best performance; the best value depends on the interaction - of the input variables. - - min_samples_split : integer, optional (default=1) - The minimum number of samples required to split an internal node. - - min_samples_leaf : integer, optional (default=1) - The minimum number of samples required to be at a leaf node. - - subsample : float, optional (default=1.0) - The fraction of samples to be used for fitting the individual base - learners. If smaller than 1.0 this results in Stochastic Gradient - Boosting. `subsample` interacts with the parameter `n_estimators`. - Choosing `subsample < 1.0` leads to a reduction of variance - and an increase in bias. - - max_features : int, None, optional (default=None) - The number of features to consider when looking for the best split. - Features are choosen randomly at each split point. - If None, then `max_features=n_features`. Choosing - `max_features < n_features` leads to a reduction of variance - and an increase in bias. - - alpha : float (default=0.9) - The alpha-quantile of the huber loss function and the quantile - loss function. Only if ``loss='huber'`` or ``loss='quantile'``. - - cv : cross-validation generator or int (default=5) - If int, ``cv``-fold cross-valdiation will be used. - - Attributes - ---------- - `feature_importances_` : array, shape = [n_features] - The feature importances (the higher, the more important the feature). - - `oob_score_` : array, shape = [n_estimators] - Score of the training dataset obtained using an out-of-bag estimate. - The i-th score ``oob_score_[i]`` is the deviance (= loss) of the - model at iteration ``i`` on the out-of-bag sample. - - `train_score_` : array, shape = [n_estimators] - The i-th score ``train_score_[i]`` is the deviance (= loss) of the - model at iteration ``i`` on the in-bag sample. - If ``subsample == 1`` this is the deviance on the training data. - - `cv_score_` : array, shape = [cv.k, max_estimators] - The deviance scores for each fold and boosting iteration. - - See also - -------- - GradientBoostingRegressor - """ - - _estimator_class = GradientBoostingRegressor - - def __init__(self, loss='ls', learn_rate=0.1, max_estimators=1000, - subsample=1.0, min_samples_split=1, min_samples_leaf=1, - max_depth=3, init=None, random_state=None, - max_features=None, alpha=0.9, cv=5): - super(GradientBoostingRegressorCV, self).__init__( - loss=loss, learn_rate=learn_rate, max_estimators=max_estimators, - min_samples_split=min_samples_split, - min_samples_leaf=min_samples_leaf, - max_depth=max_depth, init=init, subsample=subsample, - max_features=max_features, random_state=random_state, - alpha=alpha, cv=cv) From 5895acea8f748b417e79250ccb7a3fb8fc674467 Mon Sep 17 00:00:00 2001 From: Peter Prettenhofer Date: Thu, 30 Aug 2012 20:03:07 +0200 Subject: [PATCH 10/11] remove CV class legacy --- sklearn/ensemble/__init__.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/sklearn/ensemble/__init__.py b/sklearn/ensemble/__init__.py index 2134a54abb456..d76405cf0e827 100644 --- a/sklearn/ensemble/__init__.py +++ b/sklearn/ensemble/__init__.py @@ -10,5 +10,3 @@ from .forest import ExtraTreesRegressor from .gradient_boosting import GradientBoostingClassifier from .gradient_boosting import GradientBoostingRegressor -from .gradient_boosting import GradientBoostingClassifierCV -from .gradient_boosting import GradientBoostingRegressorCV From a8ff326acc174aa4553169411703d267e0b05d59 Mon Sep 17 00:00:00 2001 From: Peter Prettenhofer Date: Fri, 31 Aug 2012 12:27:37 +0200 Subject: [PATCH 11/11] add API changes and feature_importance fix to whatsnew --- doc/whats_new.rst | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/doc/whats_new.rst b/doc/whats_new.rst index 722ba5f6f8be3..cfe55da0f293a 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -67,6 +67,9 @@ Changelog - Fixed API inconsistency: :meth:`SGDClassifier.predict_proba` now returns 2d array. + - Fixed feature importance computation in + :ref:`ensemble.gradient_boosting`. + API changes summary ------------------- @@ -97,6 +100,10 @@ API changes summary ``min_n`` and ``max_n`` were joined to the parameter ``n_gram_range`` to enable grid-searching both at once. + - :class:`ensemble.GradientBoostingClassifier` now supports + :meth:`ensemble.GradientBoostingClassifier.staged_predict_proba`, and + :meth:`ensemble.GradientBoostingClassifier.staged_predict`. + .. _changes_0_11: 0.11