Skip to content

[MRG+1] Change VotingClassifier estimators by set_params #7674

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Apr 10, 2017
7 changes: 7 additions & 0 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,13 @@ Enhancements

- In :class:`gaussian_process.GaussianProcessRegressor`, method ``predict``
is a lot faster with ``return_std=True`` by :user:`Hadrien Bertrand <hbertrand>`.
- Added ability to use sparse matrices in :func:`feature_selection.f_regression`
with ``center=True``. :issue:`8065` by :user:`Daniel LeJeune <acadiansith>`.

- :class:`ensemble.VotingClassifier` now allow changing estimators by using
:meth:`ensemble.VotingClassifier.set_params`. Estimators can also be
removed by setting it to `None`.
:issue:`7674` by:user:`Yichuan Liu <yl565>`.

Bug fixes
.........
Expand Down
91 changes: 90 additions & 1 deletion sklearn/ensemble/tests/test_voting_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import numpy as np
from sklearn.utils.testing import assert_almost_equal, assert_array_equal
from sklearn.utils.testing import assert_equal
from sklearn.utils.testing import assert_equal, assert_true, assert_false
from sklearn.utils.testing import assert_raise_message
from sklearn.exceptions import NotFittedError
from sklearn.linear_model import LogisticRegression
Expand Down Expand Up @@ -40,6 +40,19 @@ def test_estimator_init():
'; got 2 weights, 1 estimators')
assert_raise_message(ValueError, msg, eclf.fit, X, y)

eclf = VotingClassifier(estimators=[('lr', clf), ('lr', clf)],
weights=[1, 2])
msg = "Names provided are not unique: ['lr', 'lr']"
assert_raise_message(ValueError, msg, eclf.fit, X, y)

eclf = VotingClassifier(estimators=[('lr__', clf)])
msg = "Estimator names must not contain __: got ['lr__']"
assert_raise_message(ValueError, msg, eclf.fit, X, y)

eclf = VotingClassifier(estimators=[('estimators', clf)])
msg = "Estimator names conflict with constructor arguments: ['estimators']"
assert_raise_message(ValueError, msg, eclf.fit, X, y)


def test_predictproba_hardvoting():
eclf = VotingClassifier(estimators=[('lr1', LogisticRegression()),
Expand Down Expand Up @@ -260,6 +273,82 @@ def test_sample_weight():
assert_raise_message(ValueError, msg, eclf3.fit, X, y, sample_weight)


def test_set_params():
"""set_params should be able to set estimators"""
clf1 = LogisticRegression(random_state=123, C=1.0)
clf2 = RandomForestClassifier(random_state=123, max_depth=None)
clf3 = GaussianNB()
eclf1 = VotingClassifier([('lr', clf1), ('rf', clf2)], voting='soft',
weights=[1, 2])
eclf1.fit(X, y)
eclf2 = VotingClassifier([('lr', clf1), ('nb', clf3)], voting='soft',
weights=[1, 2])
eclf2.set_params(nb=clf2).fit(X, y)
assert_false(hasattr(eclf2, 'nb'))

assert_array_equal(eclf1.predict(X), eclf2.predict(X))
assert_array_equal(eclf1.predict_proba(X), eclf2.predict_proba(X))
assert_equal(eclf2.estimators[0][1].get_params(), clf1.get_params())
assert_equal(eclf2.estimators[1][1].get_params(), clf2.get_params())

eclf1.set_params(lr__C=10.0)
eclf2.set_params(nb__max_depth=5)

assert_true(eclf1.estimators[0][1].get_params()['C'] == 10.0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also test the get_params() interface of the VotingClassifier directly?. More specifically, eclf1.get_params()["lr__C"], eclf1.get_params()["lr"].get_params("C")? The get_params() interface seems untested.

assert_true(eclf2.estimators[1][1].get_params()['max_depth'] == 5)
assert_equal(eclf1.get_params()["lr__C"],
eclf1.get_params()["lr"].get_params()['C'])


def test_set_estimator_none():
"""VotingClassifier set_params should be able to set estimators as None"""
# Test predict
clf1 = LogisticRegression(random_state=123)
clf2 = RandomForestClassifier(random_state=123)
clf3 = GaussianNB()
eclf1 = VotingClassifier(estimators=[('lr', clf1), ('rf', clf2),
('nb', clf3)],
voting='hard', weights=[1, 0, 0.5]).fit(X, y)

eclf2 = VotingClassifier(estimators=[('lr', clf1), ('rf', clf2),
('nb', clf3)],
voting='hard', weights=[1, 1, 0.5])
eclf2.set_params(rf=None).fit(X, y)
assert_array_equal(eclf1.predict(X), eclf2.predict(X))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also check eclf2.estimators_, eclf2.estimators and eclf2.get_params()?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was suggesting to test the behaviour of eclf2.estimators, eclf2.estimators_ and ``eclf2.get_params()`.

assert_true(dict(eclf2.estimators)["rf"] is None)
assert_true(len(eclf2.estimators_) == 2)
assert_true(all([not isinstance(est, RandomForestClassifier)  for est in eclf2.estimators_])
assert_true(eclf2.get_params()["rf"] is None)


assert_true(dict(eclf2.estimators)["rf"] is None)
assert_true(len(eclf2.estimators_) == 2)
assert_true(all([not isinstance(est, RandomForestClassifier) for est in
eclf2.estimators_]))
assert_true(eclf2.get_params()["rf"] is None)

eclf1.set_params(voting='soft').fit(X, y)
eclf2.set_params(voting='soft').fit(X, y)
assert_array_equal(eclf1.predict(X), eclf2.predict(X))
assert_array_equal(eclf1.predict_proba(X), eclf2.predict_proba(X))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please test soft transform. The outputs should differ between the 0-weight and None variants, though...

msg = ('All estimators are None. At least one is required'
' to be a classifier!')
assert_raise_message(
ValueError, msg, eclf2.set_params(lr=None, rf=None, nb=None).fit, X, y)

# Test soft voting transform
X1 = np.array([[1], [2]])
y1 = np.array([1, 2])
eclf1 = VotingClassifier(estimators=[('rf', clf2), ('nb', clf3)],
voting='soft', weights=[0, 0.5]).fit(X1, y1)

eclf2 = VotingClassifier(estimators=[('rf', clf2), ('nb', clf3)],
voting='soft', weights=[1, 0.5])
eclf2.set_params(rf=None).fit(X1, y1)
assert_array_equal(eclf1.transform(X1), np.array([[[0.7, 0.3], [0.3, 0.7]],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm. Looking at this makes me wonder whether we should be multiplying the outputs by the weight. Not an issue for this PR.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we test the output of "hard-voting" and transform as well

[[1., 0.], [0., 1.]]]))
assert_array_equal(eclf2.transform(X1), np.array([[[1., 0.], [0., 1.]]]))
eclf1.set_params(voting='hard')
eclf2.set_params(voting='hard')
assert_array_equal(eclf1.transform(X1), np.array([[0, 0], [1, 1]]))
assert_array_equal(eclf2.transform(X1), np.array([[0], [1]]))


def test_estimator_weights_format():
# Test estimator weights inputs as list and array
clf1 = LogisticRegression(random_state=123)
Expand Down
100 changes: 71 additions & 29 deletions sklearn/ensemble/voting_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,13 @@

import numpy as np

from ..base import BaseEstimator
from ..base import ClassifierMixin
from ..base import TransformerMixin
from ..base import clone
from ..preprocessing import LabelEncoder
from ..externals import six
from ..externals.joblib import Parallel, delayed
from ..utils.validation import has_fit_parameter, check_is_fitted
from ..utils.metaestimators import _BaseComposition


def _parallel_fit_estimator(estimator, X, y, sample_weight):
Expand All @@ -32,7 +31,7 @@ def _parallel_fit_estimator(estimator, X, y, sample_weight):
return estimator


class VotingClassifier(BaseEstimator, ClassifierMixin, TransformerMixin):
class VotingClassifier(_BaseComposition, ClassifierMixin, TransformerMixin):
"""Soft Voting/Majority Rule classifier for unfitted estimators.

.. versionadded:: 0.17
Expand All @@ -44,7 +43,8 @@ class VotingClassifier(BaseEstimator, ClassifierMixin, TransformerMixin):
estimators : list of (string, estimator) tuples
Invoking the ``fit`` method on the ``VotingClassifier`` will fit clones
of those original estimators that will be stored in the class attribute
`self.estimators_`.
``self.estimators_``. An estimator can be set to `None` using
``set_params``.

voting : str, {'hard', 'soft'} (default='hard')
If 'hard', uses predicted class labels for majority rule voting.
Expand All @@ -64,7 +64,8 @@ class VotingClassifier(BaseEstimator, ClassifierMixin, TransformerMixin):
Attributes
----------
estimators_ : list of classifiers
The collection of fitted sub-estimators.
The collection of fitted sub-estimators as defined in ``estimators``
that are not `None`.

classes_ : array-like, shape = [n_predictions]
The classes labels.
Expand Down Expand Up @@ -102,11 +103,14 @@ class VotingClassifier(BaseEstimator, ClassifierMixin, TransformerMixin):

def __init__(self, estimators, voting='hard', weights=None, n_jobs=1):
self.estimators = estimators
self.named_estimators = dict(estimators)
self.voting = voting
self.weights = weights
self.n_jobs = n_jobs

@property
def named_estimators(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wish this didn't exist, but I know it's not your problem.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You wish the property didn't exist? Or that wasn't a property but a function?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wish that we had not copied this bad design feature from pipeline!

return dict(self.estimators)

def fit(self, X, y, sample_weight=None):
""" Fit the estimators.

Expand Down Expand Up @@ -150,23 +154,36 @@ def fit(self, X, y, sample_weight=None):
if sample_weight is not None:
for name, step in self.estimators:
if not has_fit_parameter(step, 'sample_weight'):
raise ValueError('Underlying estimator \'%s\' does not support'
' sample weights.' % name)

self.le_ = LabelEncoder()
self.le_.fit(y)
raise ValueError('Underlying estimator \'%s\' does not'
' support sample weights.' % name)
names, clfs = zip(*self.estimators)
self._validate_names(names)

n_isnone = np.sum([clf is None for _, clf in self.estimators])
if n_isnone == len(self.estimators):
raise ValueError('All estimators are None. At least one is '
'required to be a classifier!')
self.le_ = LabelEncoder().fit(y)
self.classes_ = self.le_.classes_
self.estimators_ = []

transformed_y = self.le_.transform(y)

self.estimators_ = Parallel(n_jobs=self.n_jobs)(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change to estimators_ needs to be documented under Attributes

delayed(_parallel_fit_estimator)(clone(clf), X, transformed_y,
sample_weight)
for _, clf in self.estimators)
sample_weight)
for clf in clfs if clf is not None)

return self

@property
def _weights_not_none(self):
"""Get the weights of not `None` estimators"""
if self.weights is None:
return None
return [w for est, w in zip(self.estimators,
self.weights) if est[1] is not None]

def predict(self, X):
""" Predict class labels for X.

Expand All @@ -188,11 +205,10 @@ def predict(self, X):

else: # 'hard' voting
predictions = self._predict(X)
maj = np.apply_along_axis(lambda x:
np.argmax(np.bincount(x,
weights=self.weights)),
axis=1,
arr=predictions.astype('int'))
maj = np.apply_along_axis(
lambda x: np.argmax(
np.bincount(x, weights=self._weights_not_none)),
axis=1, arr=predictions.astype('int'))

maj = self.le_.inverse_transform(maj)

Expand All @@ -208,7 +224,8 @@ def _predict_proba(self, X):
raise AttributeError("predict_proba is not available when"
" voting=%r" % self.voting)
check_is_fitted(self, 'estimators_')
avg = np.average(self._collect_probas(X), axis=0, weights=self.weights)
avg = np.average(self._collect_probas(X), axis=0,
weights=self._weights_not_none)
return avg

@property
Expand Down Expand Up @@ -252,17 +269,42 @@ def transform(self, X):
else:
return self._predict(X)

def set_params(self, **params):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you document this?

""" Setting the parameters for the voting classifier

Valid parameter keys can be listed with get_params().
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add under the get_params heading? I would just say "Get the parameters of the VotingClassifier". In addition, I would also document the parameter deep saying that setting it to True gets the various classifiers and the parameters of the classifiers as well.


Parameters
----------
params: keyword arguments
Specific parameters using e.g. set_params(parameter_name=new_value)
Copy link
Member

@MechCoder MechCoder Apr 7, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In addition, to setting the parameters of the VotingClassifier, (with doubleticks) the individual classifiers of the VotingClassifier can also be set or replaced by setting them to None.

In addition, to setting the parameters of the ``VotingClassifier``,
the individual classifiers of the ``VotingClassifier`` can also be
set or replaced by setting them to None.

Examples
--------
# In this example, the RandomForestClassifier is removed
clf1 = LogisticRegression()
clf2 = RandomForestClassifier()
eclf = VotingClassifier(estimators=[('lr', clf1), ('rf', clf2)]
eclf.set_params(rf=None)

"""
super(VotingClassifier, self)._set_params('estimators', **params)
return self

def get_params(self, deep=True):
"""Return estimator parameter names for GridSearch support"""
if not deep:
return super(VotingClassifier, self).get_params(deep=False)
else:
out = super(VotingClassifier, self).get_params(deep=False)
out.update(self.named_estimators.copy())
for name, step in six.iteritems(self.named_estimators):
for key, value in six.iteritems(step.get_params(deep=True)):
out['%s__%s' % (name, key)] = value
return out
""" Get the parameters of the VotingClassifier

Parameters
----------
deep: bool
Setting it to True gets the various classifiers and the parameters
of the classifiers as well
"""
return super(VotingClassifier,
self)._get_params('estimators', deep=deep)

def _predict(self, X):
"""Collect results from clf.predict calls. """
Expand Down
65 changes: 5 additions & 60 deletions sklearn/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# License: BSD

from collections import defaultdict

from abc import ABCMeta, abstractmethod

import numpy as np
Expand All @@ -22,68 +23,12 @@
from .utils.metaestimators import if_delegate_has_method
from .utils import Bunch

__all__ = ['Pipeline', 'FeatureUnion']

from .utils.metaestimators import _BaseComposition

class _BasePipeline(six.with_metaclass(ABCMeta, BaseEstimator)):
"""Handles parameter management for classifiers composed of named steps.
"""
__all__ = ['Pipeline', 'FeatureUnion']

@abstractmethod
def __init__(self):
pass

def _replace_step(self, steps_attr, name, new_val):
# assumes `name` is a valid step name
new_steps = getattr(self, steps_attr)[:]
for i, (step_name, _) in enumerate(new_steps):
if step_name == name:
new_steps[i] = (name, new_val)
break
setattr(self, steps_attr, new_steps)

def _get_params(self, steps_attr, deep=True):
out = super(_BasePipeline, self).get_params(deep=False)
if not deep:
return out
steps = getattr(self, steps_attr)
out.update(steps)
for name, estimator in steps:
if estimator is None:
continue
for key, value in six.iteritems(estimator.get_params(deep=True)):
out['%s__%s' % (name, key)] = value
return out

def _set_params(self, steps_attr, **params):
# Ensure strict ordering of parameter setting:
# 1. All steps
if steps_attr in params:
setattr(self, steps_attr, params.pop(steps_attr))
# 2. Step replacement
step_names, _ = zip(*getattr(self, steps_attr))
for name in list(six.iterkeys(params)):
if '__' not in name and name in step_names:
self._replace_step(steps_attr, name, params.pop(name))
# 3. Step parameters and other initilisation arguments
super(_BasePipeline, self).set_params(**params)
return self

def _validate_names(self, names):
if len(set(names)) != len(names):
raise ValueError('Names provided are not unique: '
'{0!r}'.format(list(names)))
invalid_names = set(names).intersection(self.get_params(deep=False))
if invalid_names:
raise ValueError('Step names conflict with constructor arguments: '
'{0!r}'.format(sorted(invalid_names)))
invalid_names = [name for name in names if '__' in name]
if invalid_names:
raise ValueError('Step names must not contain __: got '
'{0!r}'.format(invalid_names))


class Pipeline(_BasePipeline):
class Pipeline(_BaseComposition):
"""Pipeline of transforms with a final estimator.

Sequentially apply a list of transforms and a final estimator.
Expand Down Expand Up @@ -631,7 +576,7 @@ def _fit_transform_one(transformer, weight, X, y,
return res * weight, transformer


class FeatureUnion(_BasePipeline, TransformerMixin):
class FeatureUnion(_BaseComposition, TransformerMixin):
"""Concatenates results of multiple transformer objects.

This estimator applies a list of transformer objects in parallel to the
Expand Down
Loading