Skip to content

[RFC] Voting classifier flatten transform (#7230) #7794

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

olologin
Copy link
Contributor

Reference Issue

Fixes #7230

What does this implement/fix? Explain your changes.

It adds flatten_transform parameter to VotingClassifier, which changes shape of transform method's output to [n_samples, n_classifiers * n_classes] instead of [n_classifiers, n_samples, n_classes],
With this parameter turned on you can use VotingClassifier as a transformer, and feed its output to other estimators/transformers in Pipeline.

Any other comments?

None, make suggestions. I Summon @amueller into this PR :)

@olologin olologin changed the title Voting classifier flatten transform [RFC] Voting classifier flatten transform (#7230) Oct 30, 2016
@olologin
Copy link
Contributor Author

Also, I'm not sure whether regression test is needed. Because added functionality is pretty simple.

@jnothman
Copy link
Member

jnothman commented Oct 30, 2016

I'd be tempted to deprecate the current behaviour.

@amueller
Copy link
Member

amueller commented Nov 2, 2016

tests are always needed!
@jnothman change the default or deprecate entirely?

@jnothman
Copy link
Member

jnothman commented Nov 3, 2016

Do you see harm in deprecating entirely? I'd rather that. The current transform output has indeterminate semantics within the scikit-learn API.

@amueller
Copy link
Member

amueller commented Nov 17, 2016

I'm happy with deprecating entirely. That needs two steps, though: introducing the parameter and removing it again. Or what did you want to do?

@amueller
Copy link
Member

Also, it might be helpful to special case for binary classification and only retain probabilities of one of the classes, say the positive one? That would result in smaller and more interpretable downstream models.

@amueller
Copy link
Member

bonus points for implementing get_feature_names() ;)



def test_transform():
"""Check trqansform method of VotingClassifier on toy dataset."""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

trqansform -> transform

flatten_transform=True).fit(X, y)

assert_array_equal(eclf1.transform(X).shape, (3, 4, 2))
assert_array_equal(eclf2.transform(X).shape, (4, 6))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you check more than shapes with an assert_equal on the values after a proper reshape?

flatten_transform : bool, optional (default=False)
Affects shape of transform output only when voting='soft'
If voting='soft' and flatten_transform=True, transform method returns
matrix with shape [n_samples, n_classifiers * n_classes] instead of
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[n_samples, n_classifiers * n_classes]
->
(n_samples, n_classifiers * n_classes)

shapes are tuples

Affects shape of transform output only when voting='soft'
If voting='soft' and flatten_transform=True, transform method returns
matrix with shape [n_samples, n_classifiers * n_classes] instead of
[n_classifiers, n_samples, n_classes].
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

@@ -238,16 +247,25 @@ def transform(self, X):

Returns
-------
If `voting='soft'`:
If `voting='soft'` and `flatten_transform=False`:
array-like = [n_classifiers, n_samples, n_classes]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shape as tuple

- Added ``flatten_transform`` parameter to :class:`ensemble.VotingClassifier`
to change output shape of `transform` method to 2 dimensional.
(`#7794 <https://github.com/scikit-learn/scikit-learn/pull/7794>`_)
by `Ibraim Ganiev`_.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be written as

:issue:7794 by Ibraim Ganiev_.

@jnothman
Copy link
Member

jnothman commented Jun 8, 2017

How does this relate to our concurrent ambitions towards stacking (e.g. #8960)

@GaelVaroquaux
Copy link
Member

You'll also need to rebase on master, to get the tests running.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Output shape of VotingClassifier.transform is non-standard
6 participants