Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,11 @@ Enhancements
(`#7723 <https://github.com/scikit-learn/scikit-learn/pull/7723>`_)
by `Mikhail Korobov`_.

- Added ``flatten_transform`` parameter to :class:`ensemble.VotingClassifier`
to change output shape of `transform` method to 2 dimensional.
(`#7794 <https://github.com/scikit-learn/scikit-learn/pull/7794>`_)
by `Ibraim Ganiev`_.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be written as

:issue:7794 by Ibraim Ganiev_.


Bug fixes
.........

Expand Down
22 changes: 21 additions & 1 deletion sklearn/ensemble/tests/test_voting_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def test_gridsearch():
grid.fit(iris.data, iris.target)


def test_parallel_predict():
def test_parallel_fit():
"""Check parallel backend of VotingClassifier on toy dataset."""
clf1 = LogisticRegression(random_state=123)
clf2 = RandomForestClassifier(random_state=123)
Expand Down Expand Up @@ -258,3 +258,23 @@ def test_sample_weight():
voting='soft')
msg = ('Underlying estimator \'knn\' does not support sample weights.')
assert_raise_message(ValueError, msg, eclf3.fit, X, y, sample_weight)


def test_transform():
"""Check trqansform method of VotingClassifier on toy dataset."""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

trqansform -> transform

clf1 = LogisticRegression(random_state=123)
clf2 = RandomForestClassifier(random_state=123)
clf3 = GaussianNB()
X = np.array([[-1.1, -1.5], [-1.2, -1.4], [-3.4, -2.2], [1.1, 1.2]])
y = np.array([1, 1, 2, 2])

eclf1 = VotingClassifier(estimators=[
('lr', clf1), ('rf', clf2), ('gnb', clf3)],
voting='soft').fit(X, y)
eclf2 = VotingClassifier(estimators=[
('lr', clf1), ('rf', clf2), ('gnb', clf3)],
voting='soft',
flatten_transform=True).fit(X, y)

assert_array_equal(eclf1.transform(X).shape, (3, 4, 2))
assert_array_equal(eclf2.transform(X).shape, (4, 6))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you check more than shapes with an assert_equal on the values after a proper reshape?

24 changes: 21 additions & 3 deletions sklearn/ensemble/voting_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# License: BSD 3 clause

import numpy as np
import warnings

from ..base import BaseEstimator
from ..base import ClassifierMixin
Expand Down Expand Up @@ -61,6 +62,12 @@ class VotingClassifier(BaseEstimator, ClassifierMixin, TransformerMixin):
The number of jobs to run in parallel for ``fit``.
If -1, then the number of jobs is set to the number of cores.

flatten_transform : bool, optional (default=False)
Affects shape of transform output only when voting='soft'
If voting='soft' and flatten_transform=True, transform method returns
matrix with shape [n_samples, n_classifiers * n_classes] instead of
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[n_samples, n_classifiers * n_classes]
->
(n_samples, n_classifiers * n_classes)

shapes are tuples

[n_classifiers, n_samples, n_classes].
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here


Attributes
----------
estimators_ : list of classifiers
Expand Down Expand Up @@ -100,12 +107,14 @@ class VotingClassifier(BaseEstimator, ClassifierMixin, TransformerMixin):
>>>
"""

def __init__(self, estimators, voting='hard', weights=None, n_jobs=1):
def __init__(self, estimators, voting='hard', weights=None, n_jobs=1,
flatten_transform=False):
self.estimators = estimators
self.named_estimators = dict(estimators)
self.voting = voting
self.weights = weights
self.n_jobs = n_jobs
self.flatten_transform = flatten_transform

def fit(self, X, y, sample_weight=None):
""" Fit the estimators.
Expand Down Expand Up @@ -238,16 +247,25 @@ def transform(self, X):

Returns
-------
If `voting='soft'`:
If `voting='soft'` and `flatten_transform=False`:
array-like = [n_classifiers, n_samples, n_classes]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shape as tuple

Class probabilities calculated by each classifier.
If `voting='soft'` and `flatten_transform=True`:
array-like = [n_samples, n_classifiers * n_classes]
Class probabilities calculated by each classifier.
If `voting='hard'`:
array-like = [n_samples, n_classifiers]
Class labels predicted by each classifier.
"""
check_is_fitted(self, 'estimators_')
if self.voting == 'soft':
return self._collect_probas(X)
probas = self._collect_probas(X)
if not self.flatten_transform:
warnings.warn("'flatten_transform' default value will be"
" changed to True in 0.21.", DeprecationWarning)
return probas
else:
return np.hstack(probas)
else:
return self._predict(X)

Expand Down