Skip to content

WIP GBRT with built-in cross-validation #1036

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from

Conversation

pprett
Copy link
Member

@pprett pprett commented Aug 18, 2012

Two new classes GradientBoostingClassifierCV and GradientBoostingRegressorCV which pick n_estimators based on cross-validation.

GradientBoostingClassifierCV fits a GradientBoostingClassifier with max_estimators for each fold; it picks n_estimators based on the min deviance averaged over all test sets. Finally, it trains the model on the whole training set using the found n_estimators.

GradientBoostingClassifierCV is implemented as a GradientBoostingClassifier decorator. It soley implements fit, otherwise it delegates to GradientBoostingClassifier (see __getattr__ and __setattr__).
The current implementation might pose some problems if the client uses isinstance rather than duck typing: a GradientBoostingClassifierCV instance is not an instance of GradientBoostingClassifier. I would really appreciate any remarks/feedback to this issue.

I tried to adhere the interface of RidgeCV.

Additionally, I refactored the prediction routines in order to remove code duplication. staged_predict and staged_predict_proba has been added to GradientBoostingClassifier.

Limitations:

  • It is currently hard-wired to pick n_estimtors based on deviance (no support for custom loss function yet) - is this needed?
  • No joblib support yet
  • Only cross-validation support - should we add held-out and OOB estimation too?

kwargs.pop('max_estimators', 1000))

kwargs['n_estimators'] = self.max_estimators
BaseEstimator.__setattr__(self, '_model', self._model_class(**kwargs))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if the sub model instantiation should not be deferred to the fit call and the sub-estimator storing attribute be renamed to self.model_. That will require to store the model init kwargs as an attribute though.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually that would break the __setattr__ and __getattr__ override so please ignore my previous comment.

@ogrisel
Copy link
Member

ogrisel commented Aug 18, 2012

Nice, this sounds like a nice feature for kaggle competitors :)

  • About the isinstance issue, I don't think this is an issue. AFAIK there is no assumption that a ModelCV class should be a subclass of the vanilla Model class.
  • Custom loss for model selection would be nice but can come in a later PR IMO
  • Leveraging OOB samples for model selection and / or to return test-set accuracy estimate would be a boon too IMHO but can also probably be done in a later PR.
  • Could you add some smoke tests for pickling and use in a pipeline / wrapping GridSearchCV (e.g. on the learning rate) so as to check that the base class g/setattr override does not break those meta tools? Maybe this can be done as part of the common testing framework by @amueller (I am not yet familiar with it).

@amueller
Copy link
Member

@ogrisel I am not entirely sure what you want to check. Just that GridSearchCV works?
We could do that in the common test framework very easily for all classifiers. I am not sure if it is actually true, though. Well for multi-class it should be fine, binary might not work yet, see the inconsistent shape discussion.

@ogrisel
Copy link
Member

ogrisel commented Aug 18, 2012

indeed, maye a custom test for GBRT*CV models then. The params for grid is model specific anyway.

@mblondel
Copy link
Member

It may be nice to support tuning other parameters than n_estimators in GradientBoostingClassifierCV directly. Using GridSearchCV on top of GradientBoostingClassifierCV is possible but it would be semantically a bit different. It would result in a greedy approximation: tune n_estimators with the other parameters fixed first then tune the remaining of the parameters. On the other hand, if tuning other parameters were supported directly in GradientBoostingClassifierCV, it would be an exhaustive search. In term of implementation, the idea would be to generate parameter combinations with IterGrid but to handle n_estimators specifically for efficiency.

@pprett
Copy link
Member Author

pprett commented Aug 20, 2012

thanks for the feedback!

@mblondel I'm not a friend of supporting tuning other parameters; n_estimators has strong interactions with other tuning parameters (esp. learn_rate). E.g. changing learn_rate=0.001 to learn_rate=0.0001 likely requires 10x more boosting iterations for the same training deviance. In fact, I'd actually use GridSearchCV on top of GradientBoostingClassifierCV if I have sufficient computational resources.

If computational resources are an issue, I'd use GridSearchCV for paramter tuning choosing n_estimators as large as possible (depending on computational resources). Then, I'd use GradientBoostingClassifierCV to determine n_estimators.

@ogrisel
Copy link
Member

ogrisel commented Aug 20, 2012

@pprett Would be great to wrap up this in a "Parameters selection tips" section in the narrative doc of the GBRT models.

@pprett
Copy link
Member Author

pprett commented Aug 20, 2012

I agree but I would certainly end up copying Greg Ridgeways definitive guide to parameter selection in GBM (it is linked in the docs but it deserves more promotion) http://cran.r-project.org/web/packages/gbm/gbm.pdf .

@mblondel
Copy link
Member

@mblondel I'm not a friend of supporting tuning other parameters; n_estimators has strong interactions with other tuning parameters (esp. learn_rate). E.g. changing learn_rate=0.001 to learn_rate=0.0001 likely requires 10x more boosting iterations for the same training deviance. In fact, I'd actually use GridSearchCV on top of GradientBoostingClassifierCV if I have sufficient computational resources.

So, to summarize, you would use a sane default learning_rate and optimize n_estimators only in most cases?

If computational resources are an issue, I'd use GridSearchCV for paramter tuning choosing n_estimators as large as possible (depending on computational resources). Then, I'd use GradientBoostingClassifierCV to determine n_estimators.

I don't understand that. For me GradientBoostingClassifierCV should always be more efficient to choose n_estimators (the process is incremental and the prediction scores are readily available)

@pprett
Copy link
Member Author

pprett commented Aug 20, 2012

If computational resources are an issue, I'd use GridSearchCV for paramter tuning choosing n_estimators as large as possible (depending on computational resources). Then, I'd use GradientBoostingClassifierCV to determine n_estimators.

I don't understand that. For me GradientBoostingClassifierCV should always be more efficient to choose n_estimators (the process is incremental and the prediction scores are readily available)

Sorry, I meant I'd use GridSearchCV for tuning max_depth, min_samples_split and learn_rate using a fixed n_estimators (as large as possible, e.g. 3000). Then, I'd tune n_estimators via GradientBoostingClassifierCV fixing the other parameters with the values found by GridSearchCV. Does this makes sense now?

@mblondel
Copy link
Member

@pprett yes :)

@pprett
Copy link
Member Author

pprett commented Aug 22, 2012

@mblondel please ignore my first response to your comment - I was thinking about this issue yesterday and it does make sense to wrap GridSearchCV within GradientBoostingClassifierCV** - the way I described it here would be rather wasteful in terms of computational resources (2x CV)... Apart from that I'm not totally sure about the difference between the two approaches... I need to spend more time on this issue. Anyway, sorry for the noise.

** also described in this thread on the ML http://www.mail-archive.com/scikit-learn-general@lists.sourceforge.net/msg03395.html

PS: I promise next time I'll think before I write

@mblondel
Copy link
Member

No worries.

To tune other parameters than n_estimators directly in GradientBoostingClassifierCV, the idea that I had in mind was to use IterGrid to generate all parameter combinations (n_estimators excepted). For example, given {'learn_rate':[0.05, 0.01, 0.001], 'subsample':[0.25, 0.5, 0.75]}, one parameter combination will be {'learn_rate'=0.05, 'subsample'=0.25}. Given these two values fixed, it is possible to choose n_estimators efficiently. Then this process must be repeated for each fold (to compute the average score of a parameter combination).

One thing that worries me about using GridSearchCV on top of GradientBoostingClassifierCV is that the train / validation split will have to be made twice (once inside GridSearchCV, and once inside GradientBoostingClassifierCV). Not good if you don't have so much data.

Supporting other parameters in GradientBoostingClassifierCV will however increase the implementation complexity...

@pprett
Copy link
Member Author

pprett commented Aug 22, 2012

2012/8/22 Mathieu Blondel notifications@github.com

No worries.

To tune other parameters than n_estimators directly in
GradientBoostingClassifierCV, the idea that I had in mind was to use
IterGrid to generate all parameter combinations (n_estimators excepted).
For example, given {'learn_rate':[0.05, 0.01, 0.001], 'subsample':[0.25,
0.5, 0.75]}, one parameter combination will be {'learn_rate'=0.05,
'subsample'=0.25}. Given these two values fixed, it is possible to choose
n_estimators efficiently. Then this process must be repeated for each
fold (to compute the average score of a parameter combination).

Exactly, for each grid point and fold you get an array of deviance scores
(shape=n_estimators); for each grid point you then need to compute the mean
deviance scores across all folds and pick n_estimators with the lowest
(mean) deviance.

One thing that worries me about using GridSearchCV on top of
GradientBoostingClassifierCV is that the train / validation split will
have to be made twice (once inside GridSearchCV, and once inside
GradientBoostingClassifierCV). Not good if you don't have so much data.

I agree

Supporting other parameters in GradientBoostingClassifierCV will however
increase the implementation complexity...

IterGrid and KFold do most of the heavy lifting - I simply pass the
results of each grid point - fold combination into itertools.groupby in
order to groupby grid point id and than compute mean deviance for each grid
point. I'll push an update in the evening.

thanks!


Reply to this email directly or view it on GitHubhttps://github.com//pull/1036#issuecomment-7927499.

Peter Prettenhofer

@GaelVaroquaux
Copy link
Member

On Wed, Aug 22, 2012 at 01:18:28AM -0700, Mathieu Blondel wrote:

One thing that worries me about using GridSearchCV on top of
GradientBoostingClassifierCV is that the train / validation split will
have to be made twice (once inside GridSearchCV, and once inside
GradientBoostingClassifierCV). Not good if you don't have so much data.

That's one reason why we need to be able to have cross-validation like
objects use a validation set. There is (quite a bit) of design work to do
here...

G

@amueller
Copy link
Member

Did someone say api design?

If we where able to pass the CV-like object the test-split of the grid-search, we'd be fine, right?
So we add another function fit_with_validation(X_train, y, X_test) and let the GridSearchCV check if the object has that,
the splitting-twice problem would be averted.

@raghavrv
Copy link
Member

raghavrv commented Dec 2, 2015

Revived by @vighneshbirodkar in #5689

@amueller
Copy link
Member

Fixed in #7071

@amueller amueller closed this Sep 27, 2018
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants