Skip to content

MRG: Multi-output decision trees #923

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 38 commits into from
Jul 9, 2012
Merged

Conversation

glouppe
Copy link
Contributor

@glouppe glouppe commented Jun 29, 2012

Hi folks!

Just to let you know, I am currently working on a multi-output extension of our decision trees.

Basically, this will make our implementation capable of handling classification or regression problems with several outputs. As I was discussing with @pprett, a very simple way to solve this kind of problems is to build n independent models, i.e. one for each output. However by doing that you lose the (likely) correlation between the outputs (classes or regression values). Hence an often better way is to build a single model to predict simultaneously all n outputs. With regard to decision trees, this amounts to store n output values in a leaf and to use splitting criteria that compute the average reduction among all outputs.

This PR includes a working prototype of multi-output decision trees. I tried as much as possible not to impair training time on single-output problems.

A lot of things still need to be done:

  1. Write multi-output unit tests
  2. Patch RandomForest* and ExtraTrees* to account for the API changes
  3. Patch GradientBoosting* to account for the API changes
  4. Update the documentation

glouppe added 6 commits June 25, 2012 16:37
ENH: MultiOuputTree (wip)

ENH: MultiOutputTree (wip)

ENH: MultiOutputTree (wip)

ENH: MultiOutputTree (wip)

ENH: MultiOutputTree (wip)

ENH: MultiOutputTree (wip)
@amueller
Copy link
Member

Hey Gilles. To bad we didn't talk about this. I have a working version of this for classification. But it is not so hard to do. I'll have a look at your implementation later and compare to mine. btw, your pull request can't be merged :-/

@glouppe
Copy link
Contributor Author

glouppe commented Jun 29, 2012

@amueller It can now :) But anyway this is nowhere ready. I still have to change the ensemble estimators.

@pprett
Copy link
Member

pprett commented Jun 29, 2012

Thanks Gilles - I'm looking forward to study it in more detail; weather this weekend should be really fine in the alps so I might not make it in the next couple of days.

@amueller
Copy link
Member

@glouppe this looks pretty cool! can I ask you what your motivation was? I used this for multilabel classification. Your implementation might actually be able to cope with structured class labels.

What I did used a list for y. It did not seem to impact the perfomance much in the single label case and was very flexible. Your method is probably more efficient and more amendable for optimization.

Having completely arbitrary objects as y would be pretty sweet, though (and would be no theoretical problem as long as one can define what node purity is).

@pprett
Copy link
Member

pprett commented Jun 29, 2012

wow.. straight from my alma mater (tu-graz) - neat - thanks for the ref.

@amueller
Copy link
Member

@pprett do you know hogh forests? they are pretty related and also from somewhere around there I think ;)

ps: sorry for OT

@glouppe
Copy link
Contributor Author

glouppe commented Jun 29, 2012

@amueller We are planning to apply that to classify windows of pixels in images.
http://www.montefiore.ulg.ac.be/services/stochastic/pubs/2009/DMWG09/

Regarding my implementation, I consider y as a 2d array (or convert it to that format). A few things may still need to be done though, to convert it appropriately. I'll check that when writing the unit tests.

@amueller
Copy link
Member

@glouppe that sounds a look like the paper I cited above :) (but is earlier I think).

@glouppe
Copy link
Contributor Author

glouppe commented Jun 29, 2012

@amueller Haha yeah! I am going to carefully read yours I guess ;)

@amueller
Copy link
Member

I haven't really read your code but does it do any "extra work" when y is 1d? I think multi-label would be a pretty standart application for this algorithm.

@glouppe
Copy link
Contributor Author

glouppe commented Jun 29, 2012

@amueller Well basically it shouldn't. All loops degenerate into single iterations but some (little) overhead is likely.

@amueller
Copy link
Member

Oh... I remember why I used lists... in a multi-label setup, each instance can have different numbers of labels. That might not play well with your approach of a fixed length y. We could use 1 of n encodings but that would make it unnecessary slow to work with many classes, I guess. we could also fill the remaining entries with -1 but that also does not really excite my.
If we could address both, the 2d patch and multi-label setting, that would be awesome!

@glouppe
Copy link
Contributor Author

glouppe commented Jun 29, 2012

@amueller I handle that :)

@amueller
Copy link
Member

@glouppe SWEET! ok I'll keep quite until I read your code ;)

@glouppe
Copy link
Contributor Author

glouppe commented Jun 29, 2012

Please not that multi-output is different from multi-labels. I don't know if we are exactly talking about the same thing? In a multi-output classification setting, each output column has its own set of classes and only one of them can be picked at prediction time. In a multi-label setting, as I understand it, you are allowed to predict several classes from the same unique set of classes, which is different. However, you can indeed transform a multi-label problem into a multi-ouput problem using binary encoding (i.e. use n binary outputs, one for each class).

@amueller
Copy link
Member

What I wanted to say was that it would be good if we could handle both, multi-output and multi-label. I was afraid the 1-of-n coding might be inefficient if there are many classes but only a few are active at any time.

@glouppe
Copy link
Contributor Author

glouppe commented Jun 29, 2012

Well that it cannot handle. :/

However it can handle multi-output classification problems with sets of classes of different sizes at each output. That is what I wanted to say earlier.

@amueller
Copy link
Member

Well, maybe doing the binary coding isn't so bad. And doing patch based learning in sklearn is mega-awesome (which I might have not said enough before ;).

@bdholt1
Copy link
Member

bdholt1 commented Jun 29, 2012

I just wanted to say thanks guys especially @glouppe! This has been on my list since I started and I've just not been getting around to it, so I support this 100%!

@ogrisel
Copy link
Member

ogrisel commented Jul 4, 2012

Very nice example BTW. Except for the few comments above, +1 for merging.

@pprett
Copy link
Member

pprett commented Jul 4, 2012

I'm also +1 for merge but if you want I can do another review of the cython code in the evening.

great example and PR!

PS: we should talk to @vene regarding a vbench script which checks for performance regressions in the tree module for future feature requests.

@glouppe
Copy link
Contributor Author

glouppe commented Jul 4, 2012

@ogrisel All your comments have been addressed.

@pprett Yes, I am not against a review of the Cython code. Actually, I found a serious bug this morning regarding regressors (segfault and crash). Could you re-run your benchmark on Boston? I guess it'll take longer this time :( (But at least results will be correct)

@pprett
Copy link
Member

pprett commented Jul 4, 2012

sure - I'll check

2012/7/4 Gilles Louppe
reply@reply.github.com:

@ogrisel All your comments have been addressed.

@pprett Yes, I am not against a review of the Cython code. Actually, I found a serious bug this morning regarding regressors (segfault and crash). Could you re-run your benchmark on Boston? I guess it'll take longer this time :( (But at least results will be correct)


Reply to this email directly or view it on GitHub:
#923 (comment)

Peter Prettenhofer

@glouppe
Copy link
Contributor Author

glouppe commented Jul 7, 2012

Any more review? :)

@ogrisel
Copy link
Member

ogrisel commented Jul 8, 2012

Looks good for me but as I am not a tree expert / user I would rather have @amueller , @bdholt1 or @pprett (or someone else interested in multiple output trees) give it another round of review.

@bdholt1
Copy link
Member

bdholt1 commented Jul 8, 2012

Sorry for the delay - I've been away on holiday. I'd like to give it a
final round if thats possible - perhaps post reviews later this evening?

On 8 July 2012 11:20, Olivier Grisel <
reply@reply.github.com

wrote:

Looks good for me but as I am not a tree expert / user I would rather have
@amueller or @pprett (or someone else interested in multiple output trees)
give it another round of review.


Reply to this email directly or view it on GitHub:
#923 (comment)

@glouppe
Copy link
Contributor Author

glouppe commented Jul 8, 2012

@bdholt1 Sure!

@glouppe
Copy link
Contributor Author

glouppe commented Jul 9, 2012

Thanks for this additional example Brian!

@pprett Waiting for your approval to hit the green button :)

@bdholt1
Copy link
Member

bdholt1 commented Jul 9, 2012

@glouppe Thanks very much for undertaking to implement this, its a very welcome addition! This functionality achieves what mvpart does for R, taking is one step closer to being a complete suite.

@pprett
Copy link
Member

pprett commented Jul 9, 2012

@glouppe I re-run the benchmarks on boston - looks very good.

Fit
+--------+----------+-------+-------+
|        |Master    |  MO   |  MOv2 |
+--------+----------+-------+-------+
|Tree(20)|   41.1   |  43.7 | 44.6  |
+--------+----------+-------+-------+
|Tree(1) |     0.6  |  0.712|    .72|
+--------+----------+-------+-------+
|RF      |      338 |  296  |  321  |
+--------+----------+-------+-------+
|GBRT    |      90  |  109  |  112  | 
+--------+----------+-------+-------+

Predict
+--------+-------------+-------+------+
|        |Master       |MO     |MO    |
+--------+-------------+-------+------+
|Tree(20)|      0.09   |  0.1  | 0.1  |
+--------+-------------+-------+------+
|Tree(1) |     0.031   |  0.034|0.037 |
+--------+-------------+-------+------+
|RF      |      30     |  1.1  | 1.2  |
+--------+-------------+-------+------+
|GBRT    |      0.775  |0.768  | 0.82 |
+--------+-------------+-------+------+

@pprett
Copy link
Member

pprett commented Jul 9, 2012

@glouppe there are two formatting errors in the doctests (tree.rst); apart from that I'm +1

great work - thx!

@glouppe
Copy link
Contributor Author

glouppe commented Jul 9, 2012

Thank you all for the reviews! I merge :)

glouppe added a commit that referenced this pull request Jul 9, 2012
MRG: Multi-output decision trees
@glouppe glouppe merged commit aad531f into scikit-learn:master Jul 9, 2012
@amueller
Copy link
Member

amueller commented Jul 9, 2012

Great Work. Thanks a lot!

@glouppe glouppe mentioned this pull request Jul 11, 2012
@@ -165,9 +175,10 @@ class Tree(object):
LEAF = -1
UNDEFINED = -2

def __init__(self, n_classes, n_features, capacity=3):
def __init__(self, n_classes, n_features, n_outputs, capacity=3):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am working on #941, and as I was thinking of making n_output an optional argument (with default 1). Would you mind?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't mind, I am okay with that.

While I am at it, I must warn you though, I am currently making huge changes on the tree structure (see #946). I don't know how we should resolve our future conflicts :/

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants