Skip to content

[MRG+1] Apply method added to GradientBoosting #5228

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 10, 2015

Conversation

jmschrei
Copy link
Member

@jmschrei jmschrei commented Sep 8, 2015

Fixed the issues in #5222. cc @glouppe @ogrisel @arjoly @amueller

Apologies about convoluting the PR.

@jmschrei jmschrei changed the title ENH apply method added to GradientBoosting [MRG] Apply method added to GradientBoosting Sep 8, 2015
@jmschrei
Copy link
Member Author

jmschrei commented Sep 8, 2015

Example updated as requested by @glouppe


if self.estimators_ is None or len(self.estimators_) == 0:
raise NotFittedError("Estimator not fitted, "
"call `fit` before exploiting the model.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know this is a bit out of scope wrt this PR, but I think factoring out this check and what is done at https://github.com/jmschrei/scikit-learn/blob/gb_apply/sklearn/ensemble/gradient_boosting.py#L1068 would be a nice thing to do.

@jmschrei
Copy link
Member Author

jmschrei commented Sep 8, 2015

Comments have been incorporated, thanks!

@glouppe glouppe changed the title [MRG] Apply method added to GradientBoosting [MRG+1] Apply method added to GradientBoosting Sep 8, 2015
@glouppe
Copy link
Contributor

glouppe commented Sep 8, 2015

Great! Thanks for the quick changes. +1 on my side

@glouppe
Copy link
Contributor

glouppe commented Sep 9, 2015

Ping @arjoly @pprett @betatim

@betatim
Copy link
Member

betatim commented Sep 9, 2015

I am happy with the change to the example. Nice work.

Only nitpick is that the shape of the returned array is different from the one returned by RandomForestClassifier and friends, but I can't think of a way to fix that.

@arjoly
Copy link
Member

arjoly commented Sep 9, 2015

Can you add a test in regression and a test in multi-class classification?

grd_enc.fit(gradient_apply(grd, X_train))
grd_lm.fit(grd_enc.transform(gradient_apply(grd, X_train_lr)), y_train_lr)
grd_enc.fit(grd.apply(X_train)[:,:,0])
grd_lm.fit(grd_enc.transform(grd.apply(X_train_lr)[:,:,0]), y_train_lr)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick: Does flake8 comply with [:,:,0]?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know what flake8 is. @ogrisel ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a command line tool to check PEP8 and PEP257 style guidelines (which we try to follow).

In particular, here I suppose Arnaud would have expected [: ,: , 0] rather than [:,:,0].

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 for respecting pep8 and be consistent with the style of the code base. @jmschrei if you use atom editor you can:

1- pip install flake8 to install the flake8 command in your PATH
2- install the linter and linter-flake8 packages in your atom environment to run the checks in your editor

Alternatively you can run:

flake8 sklearn/path/to/module.py

from the command line.

@jmschrei
Copy link
Member Author

jmschrei commented Sep 9, 2015

Unit tests added. I would support keeping the same shape for all gradient boosting models. Gradient boosting shape has to be different than RF in multi-class cases, might as well not make it even more complicated.


for i in range(n_estimators):
for j in range(n_classes):
leaves[:, i, j] = self.estimators_[i, j].apply(X)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please call .apply(X, check_input=False) as the inputs have already been checked previously.

@jmschrei
Copy link
Member Author

jmschrei commented Sep 9, 2015

Comments have been incorporated, GradientBoostingRegressor has its own method which wraps the base call with its own documentation, returning [n_samples, n_estimators]. flake8 has been run on the example and produces no warnings. Commits have been squashed, has been rebased on master (to include #5230), and unit tests all pass.

self._check_initialized()
X = self.estimators_[0, 0]._validate_X_predict(X, check_input=True)

n_estimators, n_classes = self.estimators_.shape
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add an inline comment here that says that n_classes is 1 both for binary classification and in a regression context.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

n_classes is 0; we changed the shape to [n_samples, n_estimators] in a regression context.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's got to be 1. Otherwise there would be no leaf data at all.

@ogrisel
Copy link
Member

ogrisel commented Sep 9, 2015

Apart from the nitpick, this LGTM as well.


leaves = super(GradientBoostingRegressor, self).apply(X)
leaves = leaves.reshape(X.shape[0], self.estimators_.shape[0])
return leaves
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A empty line is missing after this one.

@arjoly
Copy link
Member

arjoly commented Sep 9, 2015

+1 also whenever the remaining comments are addressed

@jmschrei
Copy link
Member Author

Changes incorporated, all tests pass.

plt.plot(fpr_grd_lm, tpr_grd_lm, label='GBT + LR')
plt.xlabel('False positive rate')
plt.ylabel('True positive rate')
plt.title('ROC curve')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe change the title to make it explicit that this is a zoom on on the top left corner of the ROC curves.

@jmschrei
Copy link
Member Author

@ogrisel fixed

@ogrisel
Copy link
Member

ogrisel commented Sep 10, 2015

Thanks, LGTM as well. merging!

ogrisel added a commit that referenced this pull request Sep 10, 2015
[MRG+1] Apply method added to GradientBoosting
@ogrisel ogrisel merged commit 470b9a4 into scikit-learn:master Sep 10, 2015
@jmschrei
Copy link
Member Author

🎺

@arjoly
Copy link
Member

arjoly commented Sep 10, 2015

Great ! :-)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants