Skip to content

[WIP] Generalized partial dependence plots #5653

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed

Conversation

trevorstephens
Copy link
Contributor

@trevorstephens trevorstephens commented Nov 1, 2015

Still a fair bit to do on this one, but the bones are there now for the partial_dependence function.

I have implemented an "exact" formulation of the partial dependence function, based on doing predictions on the original training data. I believe this can work on any regression model, and any classification model supporting predict_proba. It is based on https://github.com/cran/randomForest/blob/master/R/partialPlot.R

Additionally I added an "estimated" method that uses the column means of X, instead of looping over all the grid points and reevaluating every time. It is many times speedier than the other methods and gives reasonable results from the POV of seeing the "direction" a variable takes the model.

RandomForestRegressor is now also supported using the recursive method that is already working for GBMs. I am still toying with ways to recurse classification trees, but that might be another PR after this merges.

TODO:

  • Do something sensible with multioutput estimators
  • Support for Pipeline as well as BaseSearchCV estimators
  • Update the plot_partial_dependence function
  • Add some tests
  • Clean up exception logic
  • Deprecate sklearn.ensemble.partial_dependence
  • Move & enhance docs (add example for non-GBRT classification/regression, narrative docs on new methods, etc.)

High-level questions:

  • Should 'exact' or 'estimated' be the default for non regression tree-based estimators?
  • Is sklearn.partial_dependence the proper place for this utility?
  • Is demeaning the regression output a smart idea? This would need y to be passed to the function as well.

Timing:

The 'exact' method is extremely slow for ensembles. The 'estimated' method is very quick on just about anything and gives surprisingly similar results. The 'recursive' method is a fair bit slower on deep forests when compared to the GBMs which tend to be quite shallow trees. Here's the California housing example from the website run with both GBM and RF regressors, timers in the headings, both models with 100 trees.

download
download 1
download 2
download 3
download 4
download 5

Any comments on the methods being used? @pprett , since you wrote the original, would especially value your comments.

@trevorstephens
Copy link
Contributor Author

Reference issue #4405

pdp = np.subtract(pdp, np.mean(pdp, 1)[:, np.newaxis])
pdp = pdp.transpose()
else:
raise ValueError('est must be a fitted regressor or classifier model.')

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any reason you can't let est be an instance of sklearn.pipeline.Pipeline? So that you can look at the partial dependence as it applies to a predictor before some transformation is applied.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd have to look into it, at first thought it may get tricky to get a transformed X on which to construct the pdplot grid and calculate the function. Can you give an example of such a pipeline you would want to use, and where you'd want the pdplot to be calculated?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example--say you apply a spline to a predictor. Then you might want to look at the partial dependence of the target on the raw predictor, even though the regression uses a transformed predictor.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't a pipeline require the final step to be the predictor? Unless I misunderstand you... Can you give an actual example? ie. with the actual pipeline constructor code?

In general, I don't see any reason why I can't support a pipeline or gridsearch object as input.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant predictor as in column, for example:

pipeline = Pipeline(
    steps=[
        ("apply_spline_to_a_feature", ApplySplineToFeature(feature_num=1),
        ("run_a_linear_regression", LinearRegression())
    ]
)

where ApplySplineToColumn is an instance of TransformerMixin.

Then I would hope to be able to calculate the partial dependence on the entire pipeline:

partial_dependence(pipeline, X, target_features=(0))

I would be interested in seeing the partial dependence where the horizontal axis represents the "raw" data, even though some transformation is applied before scoring. Hope this makes a little more sense!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pauljacksonrodgers is this what you were going for? Same example as above... Pipeline for discussion purposes only... I don't generally PCA my forests :-) Note that this can't work with the "recursive" option as the X's within the tree splits are all modified.

estimators = [('scaling', StandardScaler()),
              ('reduce_dim', PCA(n_components=6)),
              ('rfr', RandomForestRegressor(n_estimators=100))]
clf = Pipeline(estimators)
clf.fit(X_train, y_train)

download
download 1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It also works for GridSearch...

parameters = {'learning_rate':[0.1, 0.5]}
gbr = GradientBoostingRegressor(n_estimators=100, max_depth=4,
                                loss='huber', random_state=1)
clf = GridSearchCV(gbr, parameters)
clf.fit(X_train, y_train)

download
download 1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI both the Pipeline and GridsearchCV examples above were generated on my local branch with

partial_dependence(clf, target_feature,
                   X=X_train, grid_resolution=50, method=method)

where X_train is the same as the second half of http://scikit-learn.org/stable/auto_examples/ensemble/plot_partial_dependence.html

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is exactly what I meant-thanks! I hadn't even considered using GridSearchCV here, also very cool. And yes, makes sense that you can't use the "recursive" method with either of these, since the "raw" data isn't even used in the computation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great. Sorry if I wasn't following you earlier. It's a good enhancement and easy to support. Will push changes as soon as I stop pulling my hair out about multi-output :-)

@vene
Copy link
Member

vene commented Nov 2, 2015

Here's a related reference that I ran into recently and seemed very useful at a skim, but I didn't get to thoroughly look into it or use it yet. It seems like some interaction effects can be masked on average. The paper has some nice visualisations. http://arxiv.org/pdf/1309.6392.pdf

else:
n_features = est.n_features_
elif X is None:
raise ValueError('X is required for method="exact" or "estimated".')

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this so? Seems like you should be able to do exact/estimated computation on the grid too, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Estimated and exact methods use X to calculate the function using predict or predict_proba. The grid only shows values for the target variables.

@trevorstephens
Copy link
Contributor Author

Thanks @vene , will have a flick through this when I have some time.

if isinstance(est, RegressorMixin):
try:
pdp = est.predict(X_eval)
except:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Never! Use except Exception. except: blocks KeyboardInterrupt

But besides, we now have common tests to ensure that calling predict before fit will raise NotFittedError

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A lot of this is just place holders @jnothman :-) I need to figure out what errors to let the estimator pass through and what to explicitly catch in this function. Not ready for code reviews yet. Just structural stuff for now please.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay np!

@jnothman
Copy link
Member

jnothman commented Nov 3, 2015

Add to your todos: example for GBRT, and for non-GBRT classification/regression; narrative docs if not already in your TODO

@trevorstephens
Copy link
Contributor Author

I had "enhance docs" in the todos. There is a fair bit of work left, this is just the bare bones. Appreciate your input though! Can you comment on the high-level questions in the original post?

@asjedh
Copy link

asjedh commented Jan 26, 2016

Do we have any updates here? Partial dependence plots for random forests would be very helpful :)

@trevorstephens
Copy link
Contributor Author

Hi @asjedh , been a bit busy since the holidays but will get back on the pull request soon. Might have a bit of time for it this weekend in fact!

@darribas
Copy link

+1 on PDP for random forests! :)

@chris60201
Copy link

Hey guys! Do we have an update here? Looks like checks have passed...

@amueller amueller changed the title [WIP] Generalized partial dependence plots [MRG] Generalized partial dependence plots Oct 11, 2016
@amueller amueller changed the title [MRG] Generalized partial dependence plots [WIP] Generalized partial dependence plots Oct 11, 2016
@amueller
Copy link
Member

@trevorstephens do you want to keep working on this or have someone else take it up?

@trevorstephens
Copy link
Contributor Author

Hey @amueller ... thanks for the ping. I know this is pretty stale, but I'd like to keep working on it actually.

Work being very busy and then recently moving back home to Australia made the year to date totally bonkers, but I have time now to come back to this. Will check in some code this weekend hopefully.

Looks like I will have to check the new model validation module to see how it plays with this feature after the recent update.

@amueller
Copy link
Member

@trevorstephens sure, no worries. Thanks for the update. I've been out for a bit myself, but trying to do some late spring cleaning here ;)

@jph00
Copy link

jph00 commented Jan 5, 2017

I'm hoping to use this in part 2 of http://course.fast.ai - otherwise I'll have to use R :( Is this still being worked on? Any chance of having it done in the next couple of weeks?
(Don't mean to hassle anybody! Just needing to finalize plans for the course)

@jnothman
Copy link
Member

jnothman commented Jan 6, 2017 via email

@trevorstephens
Copy link
Contributor Author

Hey @jph00 ... If I packaged the code up and pushed to PyPI (pip install-able) as a temporary measure would that help? Untested, unreviewed of course... You are more than welcome to just copy the function in the PR and use it directly as well. You would just need to change the relative imports to point to sklearn instead.

@Irene-GM
Copy link

Hi guys! Is there any update about the partial dependence plots for Random Forest? :-)
Can we have access to the source code of the generic function or have it as a standalone package in PyPi?
Thanks for the great effort!

@jnothman
Copy link
Member

@trevorstephens, would you welcome another contributor trying to finish off your work?

@lucianoviola
Copy link
Contributor

@trevorstephens @jnothman I could volunteer; I have had some previous experience with partial dependence on my job.

@jnothman
Copy link
Member

jnothman commented Jul 20, 2017 via email

@trevorstephens
Copy link
Contributor Author

Hi @lucianoviola , thanks for the offer! I had actually been working on this locally after @Irene-GM woke me up :-D should have some commits in this weekend and would welcome comments and suggestions from your experience soon! @jnothman sorry for dragging it out. Been very busy lately but should be ready for review in the coming days!

@amueller
Copy link
Member

amueller commented Nov 6, 2018

@amueller
Copy link
Member

amueller commented Nov 6, 2018

the plotting module seems unlikely to happen soon. I think we should move this forward, possibly without the approximation.

@amueller
Copy link
Member

amueller commented Nov 7, 2018

@trevorstephens would you mind if I put @NicolasHug on this to wrap up your work? I'd really like to see this get into 0.21 ;)

@trevorstephens
Copy link
Contributor Author

Sorry, been hard to find the time to work on this @amueller :-( The documentation part especially will require a fair bit of effort! I was recently thinking of spinning it out as a contrib library so I could finish it up bit-by-bit for later inclusion in the main repo rather than going straight to master, but if @NicolasHug has the capacity to finish it up, then go for it 👍

@NicolasHug
Copy link
Member

Thanks @trevorstephens, I'll finish this in another PR.

@NicolasHug
Copy link
Member

NicolasHug commented Nov 15, 2018

Hi @trevorstephens , just FWI I've opened #12599 (not finished yet but soon) and of course credited you.

Here is a list of the main changes. Some of them actually come from the original implementation, not necessarily form your PR:

  • removed the approx method after discussing it with @amueller
  • removed support for multiclass-multioutput classifiers.
  • merged label and output into target parameter
  • removed support for RandomForestRegressor with recursion because it actually doesn't
    work in multioutput settings
  • renamed exact into brute
  • renamed axes into values, which is arguably not much better... But I find 'axes' confusing since it's not directly related to a matplotlib axis instance.
  • renamed pdp into averaged_predictions
  • factorized tests and added a lot more of them
  • updated doc

@trevorstephens
Copy link
Contributor Author

Nice work @NicolasHug 👍

@trevorstephens trevorstephens deleted the general-pplot branch May 15, 2019 10:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.