Skip to content

[MRG+2] Classifier chain #7602

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 96 commits into from
Jun 29, 2017
Merged

Conversation

adamklec
Copy link

@adamklec adamklec commented Oct 7, 2016

This PR implements Classifier Chains for multi-label classification as described here:
http://www.cs.waikato.ac.nz/ml/publications/2009/chains.pdf
#3727 is a similar PR that seems to be stalled. This PR has one significant improvement over #3727. Since in general it is impossible to know the optimal ordering of classifiers for a particular problem a priori it is common to use an ensemble of randomly ordered classifier chains (as explained in the above paper). This implementation of classifier chains supports randomizing the chain order.

I included an example that demonstrates that on a multi-label dataset an ensemble of 10 randomly ordered classifier chains out-performs independently trained models.

@amueller
Copy link
Member

amueller commented Oct 8, 2016

The appveyor test failure seems related.

@@ -0,0 +1,39 @@
import numpy as np
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you create an example that has a file name starting with plot_ and plots things? Plots make for more engaging examples, and only examples that start with plot_ are automatically run. Also, if you want to add a new folder to the examples, you need to create an empty README.txt file in that folder.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will work on adding an example with plots. Also please see my comment below regarding the README.txt file.

Copy link
Author

@adamklec adamklec Oct 12, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I having a hard time coming up with something to plot here. I think the most compelling result is a comparison of the Jaccard similarity scores for independent models, a single classifier chain, and an ensemble of classifier chains. I suppose I could make a bar chart of those three numbers but that seems rather trivial. I experimented with plotting precision-recall curves (or ROC curves) for individual classes but the results are not that interesting to look at. I'm open to suggestions.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was about to suggest the confusion matrix, but I think we haven't implemented the multi-label confusion matrix yet.
It's actually not really important to add a plot for it to be a plot_ example. That just ensures it runs. But it requires that the example runs quickly enough. If it doesn't, maybe try yeast instead?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I switched to yeast. I like it. It's much faster (the example runs in a few seconds). And it makes for a more convincing example.

@amueller
Copy link
Member

amueller commented Oct 8, 2016

there shouldn't be an __init__.py file in the examples folder, that might be the error.

self.classifiers = []
self.chain_order = None

def fit(self, X, Y, chain_order=None, shuffle=True):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hm I'm not sure if we use capital Y for labels (though it would be logical).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd be happy to lower case it or change it to anything else. Just LMK what the correct convention is.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For what it's worth the convention of capitalizing the variable names of multi-label arrays already appears in the documentation here:
http://scikit-learn.org/stable/auto_examples/plot_multilabel.html#sphx-glr-auto-examples-plot-multilabel-py

"""

def __init__(self,
base_estimator=LogisticRegression(),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if having a mutable default is a good idea.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. I got rid of it.

@amueller
Copy link
Member

amueller commented Oct 8, 2016

Thanks, this looks interesting. Can you please also add an entry to the user guide? And maybe ping @arjoly ?

Adam Kleczewski added 2 commits October 10, 2016 11:13
removed logistic regression dependency
@adamklec
Copy link
Author

I removed the init.py file from the examples folder and added an README.txt file. The following exception is now getting thrown on CircleCI.

Exception occurred:
File "/home/ubuntu/scikit-learn/doc/sphinxext/sphinx_gallery/gen_rst.py", line 182, in get_docstring_and_rest
.format(filename))
ValueError: Could not find docstring in file "../examples/multi_label/classifier_chain.py". A docstring is required by sphinx-gallery
The full traceback has been saved in /tmp/sphinx-err-BTEGNb.log, if you want to report the issue to the developers.
Please also report this if it was a user error, so that a better error message can be provided next time.
A bug report can be filed in the tracker at https://github.com/sphinx-doc/sphinx/issues. Thanks!

I tried copying the format of the README.txt files in other /examples directories but that had no effect. Any ideas?

@GaelVaroquaux
Copy link
Member

GaelVaroquaux commented Oct 10, 2016 via email

@adamklec
Copy link
Author

Ah. Thanks @GaelVaroquaux !

@adamklec
Copy link
Author

How do I go about adding an entry to the user guide?

@jnothman
Copy link
Member

Edit doc/modules/multiclass.rst

============================

Demonstrate that a single classifier chain
out performs 10 independent classifiers
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

out-performs?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also format so that it's about 79 char each line (hopefully your editor can do that for you)

from sklearn.datasets import fetch_rcv1

# Get the Reuters Corpus Volume I dataset
rcv1 = fetch_rcv1()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Uh, can you actually check how long this thing runs? If it's too long, it doesn't qualify as a plot_ example because it would make the doc build too slow.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. I was worried about that. With my connection I believe it took a couple minutes to download. Of course that time is depends on connection speed. But it doesn't seem like something that should be run automatically.

Previously I was using make_multilabel_classification to generate a fake dataset which is obviously much faster. But I don't think this works because I believe the labels in the fake dataset are uncorrelated which means there is no advantage to using a classifier chain.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was actually not talking about the download (though that's also an issue) but about the computation. Have you tried yeast? That should have some correlations. And if make_multilabel_classification is not helpful for multi-label classification we should change/fix it.

----------
classifiers_ : array
List of classifiers, which will be used to chain prediction.
chain_order : list of ints
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Anything that's estimated during training should have a trailing underscore.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd be okay with dropping chain_ and just making this order_


"""

def __init__(self,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

on one line?


Parameters
----------
X : {array-like, sparse matrix}, shape (n_samples, n_features)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same nitpick here.

return Y_pred

@if_delegate_has_method('base_estimator')
def predict_proba(self, X):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that this could use a little bit more documentation. Specifically, you may want to clarify that predict and predict_proba don't change the input to the later models in the chain, it always uses predict.


Parameters
----------
base_estimator : estimator
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are your thoughts on allowing arbitrary lists of classifiers? For instance, one could use a logistic regression for one task, a Random Forest classifier for another, etc.. This is not something that would need to be implemented for this PR to get merged, I'm just asking out of curiosity.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possibly. We'd have to think about how this would work with random chain orders but I could imagine this being a useful feature.

Classifier Chain
============================
Example of using classifier chain on a multilabel dataset.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You may want to elaborate on the connections between classifier chains and model stacking. Might it be worth it to also mention that regression chains are possible? Maybe that's better saved for when someone implements it.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess my preference is to keep the text concise and focused on the example. If there is a specific point you think we should make here I'd consider adding a sentence or two. Otherwise I'm inclined to leave it as is.

I think we should stay away from dealing with regression models in this PR. But I agree that is a good direction to go.

voting ensemble of classifier chains by averaging the binary predictions of
the chains and apply a threshold of 0.5. The Jaccard similarity score of the
ensemble is greater than that of the independent models and tends to exceed
the score of each chain in the ensemble (although this is not guarenteed
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

guaranteed

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indeed

@jmschrei
Copy link
Member

Overall, like I've said before, this is strong work and I'm glad that we're close to merging it. At the risk of being reviewer 3, I'm wondering if it might be worthwhile to consider the idea of regression chains here, and abstract out a base chain object. @adamklec what are your thoughts? Are regression chains as popular as classification chains?

@jnothman
Copy link
Member

What do you mean by "at the risk of being reviewer 3"?

@jnothman
Copy link
Member

I'm curious, @jmschrei: Do you expect that there is any benefit in waiting until regression chains are implemented before we merge this? I.e. do you expect it will change the API or help find bugs in this implementation?

Here is what I hope to create issues for once this is merged:

  • support for multitarget classification
  • support for multitarget regression (with a separate RegressionChain)
  • support for chained features being outputs of predict_proba or decision_function not just predict

@jmschrei
Copy link
Member

When submitting a paper reviewer 3 tends to always be the annoying one that makes you do a bunch of work. I am OK with not having regression chains supported in this PR, I was just interested in discussing it. Otherwise this LGTM once my minor issues are taken care of.

@adamklec
Copy link
Author

@jmschrei I made the changes you requested. I did not do anything related to regression chains though. That seems like it could be a good idea but I think it should be addressed in another PR.

@jnothman to your list of issues to be addressed I would add

  • support for multilabel models in the ensemble module

In plot_classifier_chain_yeast.py I implement a very simple voting ensemble of 10 classifier chains. The main problem though is that the 10 chains are trained in series. If VotingClassifier supported multilabel targets the chains could be trained in parallel.

@jmschrei
Copy link
Member

Okay. I am in agreement with most of your comments, but will do a full review tonight (and hopefully merge!) I do think it might be valuable somewhere to mention the connection between stacking and classifier chains, since they are similar but not identical. Is there a place you think this might fit?

@jnothman
Copy link
Member

jnothman commented Jun 28, 2017 via email

@jmschrei
Copy link
Member

jmschrei commented Jun 28, 2017 via email

@jmschrei
Copy link
Member

LGTM!

@jmschrei jmschrei changed the title [MRG+1] Classifier chain [MRG+2] Classifier chain Jun 29, 2017
@jmschrei jmschrei merged commit b413299 into scikit-learn:master Jun 29, 2017
@jnothman
Copy link
Member

jnothman commented Jun 29, 2017 via email

@adamklec
Copy link
Author

Awesome! Thanks everyone.

midinas pushed a commit to midinas/scikit-learn that referenced this pull request Jun 29, 2017
midinas pushed a commit to midinas/scikit-learn that referenced this pull request Jun 29, 2017
@adamklec adamklec deleted the classifier_chain branch July 5, 2017 15:50
dmohns pushed a commit to dmohns/scikit-learn that referenced this pull request Aug 7, 2017
dmohns pushed a commit to dmohns/scikit-learn that referenced this pull request Aug 7, 2017
NelleV pushed a commit to NelleV/scikit-learn that referenced this pull request Aug 11, 2017
paulha pushed a commit to paulha/scikit-learn that referenced this pull request Aug 19, 2017
AishwaryaRK pushed a commit to AishwaryaRK/scikit-learn that referenced this pull request Aug 29, 2017
maskani-moh pushed a commit to maskani-moh/scikit-learn that referenced this pull request Nov 15, 2017
jwjohnson314 pushed a commit to jwjohnson314/scikit-learn that referenced this pull request Dec 18, 2017
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants