Skip to content

Multi-label and multi-output multi-class decision functions and predict proba aren't consistent #2451

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
arjoly opened this issue Sep 17, 2013 · 15 comments

Comments

@arjoly
Copy link
Member

arjoly commented Sep 17, 2013

The decision_function and predict_proba of a multi-label classifier (e.g. OneVsRestClassifier) is a 2d arrays where each column correspond to a label and each row correspond to a sample. (added in 0.14?)

The decision_function and predict_proba of multi-output multi-class classifier (e.g. RandomForestClassifier) is a list of length equal to the number of output with a multi-class decision_function or predict_proba output (a 2d array where each row corresponds to the samples and where each columns correspond to a class).

So this means that multi-output problem with only binary class output is a multi-label task, but isn't consistent with the multi-label format...

This is problematic if you want to code a roc_auc_score function to support multi-label output.

@arjoly
Copy link
Member Author

arjoly commented Sep 17, 2013

A small example to understand the issue

from sklearn.ensemble import RandomForestClassifier
from sklearn.multiclass import OneVsRestClassifier
from sklearn.datasets import make_multilabel_classification

X, Y = make_multilabel_classification(random_state=0, n_samples=5,
                                      return_indicator=True, n_classes=3)
print("rf")
rf = RandomForestClassifier(random_state=0).fit(X, Y)
print(rf.predict_proba(X))
# rf
# [array([[ 0.7,  0.3],
#        [ 0.2,  0.8],
#        [ 0.9,  0.1],
#        [ 0.8,  0.2],
#        [ 0.2,  0.8]]), array([[ 0.6,  0.4],
#        [ 0.2,  0.8],
#        [ 0.2,  0.8],
#        [ 0.9,  0.1],
#        [ 0.8,  0.2]]), array([[ 0.3,  0.7],
#        [ 0.8,  0.2],
#        [ 0.1,  0.9],
#        [ 1. ,  0. ],
#        [ 0.9,  0.1]])]

print("ovr rf")
ovr_rf = OneVsRestClassifier(RandomForestClassifier(random_state=0)).fit(X, Y)
print(ovr_rf.predict_proba(X))
# ovr rf
# [[ 0.2  0.4  0.7]
#  [ 0.8  0.8  0.1]
#  [ 0.1  0.9  0.9]
#  [ 0.2  0.1  0. ]
#  [ 0.8  0.1  0.2]]

@arjoly
Copy link
Member Author

arjoly commented Sep 17, 2013

There is 3 possibilities to solve this issue:

  1. either we accept to have two formats and each function must work gracefully with both,
  2. either we implement only one format and let the user format appropriately its data
  3. either we move to the most general format.

Option 1. means more format to support, option 2. won't work with a grid search estimator
and option 3. will induce non-backward compatible change.

What is your opinion on this issue? Do you have better ideas?

@jnothman
Copy link
Member

Part of the issue that you've not stated is that a multilabel label
indicator matrix may look identical and is semantically equivalent to
multiple binary classifications. Assuming there were some estimator that
learnt both multioutput and multilabel classification, given such input, it
would have to output one or the other format from decision_function.

Now, we already have the quirky case of binary classification resulting in
a 1d decision_function, where multiclass classification results in 2d
decision_function output. So it's tempting to similarly say "if it can be
interpreted as multilabel, format the output as such". But I have a feeling
that policy would be difficult to implement (and non-backwards compatible).

So I more-or-less think your option (1) is agreeable, but you haven't told
us which functions must work gracefully on both: do you mean metrics?
Presumably any metric that handles multioutput and multilabel will already
handle both gracefully. So I guess your concern is metrics that handle
multilabel, but not multioutput. In such cases, I don't think the extra
validation is cumbersome (certainly it doesn't double the effort as you
suggest; although I'm not sure whether it should happen in metric, or in
make_scorer which knows about 'decision_function').

~J

On Wed, Sep 18, 2013 at 1:18 AM, Arnaud Joly notifications@github.comwrote:

There is 3 possibilities to solve this issue:

  1. either we accept to have two formats and each function must work
    gracefully on both,
  2. either we implement only one format and let the user format
    appropriately its data
  3. either we move to the most general format.

Option 1. means two times more work and more validation, options 2. won't
work with a grid search estimator
and 3. will induce non-backward compatible change.

What is your opinion? Do you have better suggestions?


Reply to this email directly or view it on GitHubhttps://github.com//issues/2451#issuecomment-24596520
.

@arjoly
Copy link
Member Author

arjoly commented Sep 18, 2013

So I more-or-less think your option (1) is agreeable, but you haven't told
us which functions must work gracefully on both: do you mean metrics?

At the moment, I am thinking of metrics with a score or a probability. None
have been extended yet to support multilabel. It should also work with
grid search estimators and cross_val_score.

@arjoly
Copy link
Member Author

arjoly commented Sep 19, 2013

Related to #1781

@arjoly
Copy link
Member Author

arjoly commented Feb 15, 2015

As far as I know / remember, the only "multi-label" (and not multi-output multi-class) aware classifier is the OneVsRestClassifier. The issue could be handled by deprecating the multilabel support of ovr and implementing a separate class or module for a binary relevance / multi-output classifier.

@amueller
Copy link
Member

What is the motivation for deprecation? It seems to me OVR has the right interface. I'd rather remove multi-output multi-class support as it is a rather rare setting and messes with API contracts.

@arjoly
Copy link
Member Author

arjoly commented Feb 23, 2015

Working with both formats is a pain and you have to perform the format normalisation in your code.
Semantically, one-versus-rest is not one label versus the rest in multi-label classification, but one-estimator-for-each-label. Furthermore, it seems to that this is the ovr classifier which break the contract since ovr.predict_proba(X).sum(axis=1) doesn't sum to 1 with multi-label format and could be greater than 1.

I don't think that going the other way around by deprecating multi-output multi-class format is possible. There are more estimators supporting this format (e.g. dummy, k-nn, tree, forest). It will also break people code without any replacement.

@amueller
Copy link
Member

I didn't realize KNeighborsClassifier supported multi-output multi-class.
I just feel that multi-label is a much more common setting, and that the multi-output multi-class format is really inconvenient to handle, as it gives you lists of arrays.

tree.predict_proba(X).sum(axis=1) might give an attribute error. So what does that do to the contract?

@arjoly
Copy link
Member Author

arjoly commented May 28, 2015

tree.predict_proba(X).sum(axis=1) might give an attribute error. So what does that do to the contract?

Do you have a full example?

@amueller
Copy link
Member

from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import make_multilabel_classification

X, Y = make_multilabel_classification(random_state=0, n_samples=5,
                                      return_indicator=True, n_classes=3)
# works:
rf = RandomForestClassifier(random_state=0).fit(X, Y[:, 0])
Y_pred = rf.predict_proba(X).argmax(axis=1)

# attribute error:
rf = RandomForestClassifier(random_state=0).fit(X, Y)
Y_pred = rf.predict_proba(X).argmax(axis=1)

I don't like that.

@amueller
Copy link
Member

I am not sure how useful the multiclass multi-output is in general. Do you have any references?

@arjoly
Copy link
Member Author

arjoly commented May 29, 2015

The alternative could be to have 3D numpy array, but them some columns would be meaningless.

I am not sure how useful the multiclass multi-output is in general. Do you have any references?

There are applications, e.g. with pixel labelling, but I am not familiar with those. I know that some real problems are tackled using the multi-output code. @glouppe might know more about this.

@amueller
Copy link
Member

amueller commented Jun 1, 2015

I know one paper for image patches, but I don't think our implementation is very good for image data, because you want to sample the features on the fly.

@adrinjalali
Copy link
Member

Closing as a duplicate of a more recent / detailed issue: #19880

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

6 participants