Skip to content

[MRG+2] FIX adaboost estimators not randomising correctly #7411

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Sep 23, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -477,8 +477,21 @@ Bug fixes

- Fix bug in :func:`metrics.silhouette_samples` so that it now works with
arbitrary labels, not just those ranging from 0 to n_clusters - 1.

- Fix bug where :class:`ensemble.AdaBoostClassifier` and
:class:`ensemble.AdaBoostRegressor` would perform poorly if the
``random_state`` was fixed
(`#7411 <https://github.com/scikit-learn/scikit-learn/pull/7411>`_).
By `Joel Nothman`_.

- Fix bug in ensembles with randomization where the ensemble would not
set ``random_state`` on base estimators in a pipeline or similar nesting.
(`#7411 <https://github.com/scikit-learn/scikit-learn/pull/7411>`_).
Note, results for :class:`ensemble.BaggingClassifier`
:class:`ensemble.BaggingRegressor`, :class:`ensemble.AdaBoostClassifier`
and :class:`ensemble.AdaBoostRegressor` will now differ from previous
versions. By `Joel Nothman`_.


API changes summary
-------------------
Expand Down
8 changes: 2 additions & 6 deletions sklearn/ensemble/bagging.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,8 @@ def _parallel_build_estimators(n_estimators, ensemble, X, y, sample_weight,
(i + 1, n_estimators, total_n_estimators))

random_state = np.random.RandomState(seeds[i])
estimator = ensemble._make_estimator(append=False)

try: # Not all estimators accept a random_state
estimator.set_params(random_state=seeds[i])
except ValueError:
pass
estimator = ensemble._make_estimator(append=False,
random_state=random_state)

# Draw random feature, sample indices
features, indices = _generate_bagging_indices(random_state,
Expand Down
45 changes: 43 additions & 2 deletions sklearn/ensemble/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,45 @@
from ..base import clone
from ..base import BaseEstimator
from ..base import MetaEstimatorMixin
from ..utils import _get_n_jobs
from ..utils import _get_n_jobs, check_random_state

MAX_RAND_SEED = np.iinfo(np.int32).max


def _set_random_states(estimator, random_state=None):
"""Sets fixed random_state parameters for an estimator

Finds all parameters ending ``random_state`` and sets them to integers
derived from ``random_state``.

Parameters
----------

estimator : estimator supporting get/set_params
Estimator with potential randomness managed by random_state
parameters.

random_state : numpy.RandomState or int, optional
Random state used to generate integer values.

Notes
-----
This does not necessarily set *all* ``random_state`` attributes that
control an estimator's randomness, only those accessible through
``estimator.get_params()``. ``random_state``s not controlled include
those belonging to:

* cross-validation splitters
* ``scipy.stats`` rvs
"""
random_state = check_random_state(random_state)
to_set = {}
for key in sorted(estimator.get_params(deep=True)):
if key == 'random_state' or key.endswith('__random_state'):
to_set[key] = random_state.randint(MAX_RAND_SEED)

if to_set:
estimator.set_params(**to_set)


class BaseEnsemble(BaseEstimator, MetaEstimatorMixin):
Expand Down Expand Up @@ -67,7 +105,7 @@ def _validate_estimator(self, default=None):
if self.base_estimator_ is None:
raise ValueError("base_estimator cannot be None")

def _make_estimator(self, append=True):
def _make_estimator(self, append=True, random_state=None):
"""Make and configure a copy of the `base_estimator_` attribute.

Warning: This method should be used to properly instantiate new
Expand All @@ -77,6 +115,9 @@ def _make_estimator(self, append=True):
estimator.set_params(**dict((p, getattr(self, p))
for p in self.estimator_params))

if random_state is not None:
_set_random_states(estimator, random_state)

if append:
self.estimators_.append(estimator)

Expand Down
4 changes: 2 additions & 2 deletions sklearn/ensemble/forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,8 +304,8 @@ def fit(self, X, y, sample_weight=None):

trees = []
for i in range(n_more_estimators):
tree = self._make_estimator(append=False)
tree.set_params(random_state=random_state.randint(MAX_INT))
tree = self._make_estimator(append=False,
random_state=random_state)
trees.append(tree)

# Parallel loop: we use the threading backend as the Cython code
Expand Down
2 changes: 2 additions & 0 deletions sklearn/ensemble/tests/test_bagging.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,8 @@ def test_bagging_with_pipeline():
DecisionTreeClassifier()),
max_features=2)
estimator.fit(iris.data, iris.target)
assert_true(isinstance(estimator[0].steps[-1][1].random_state,
int))


class DummyZeroEstimator(BaseEstimator):
Expand Down
71 changes: 68 additions & 3 deletions sklearn/ensemble/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,32 +5,45 @@
# Authors: Gilles Louppe
# License: BSD 3 clause

import numpy as np
from numpy.testing import assert_equal
from nose.tools import assert_true

from sklearn.utils.testing import assert_raise_message
from sklearn.utils.testing import assert_not_equal
from sklearn.datasets import load_iris
from sklearn.ensemble import BaggingClassifier
from sklearn.ensemble.base import _set_random_states
from sklearn.linear_model import Perceptron
from sklearn.externals.odict import OrderedDict
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.pipeline import Pipeline
from sklearn.feature_selection import SelectFromModel


def test_base():
# Check BaseEnsemble methods.
ensemble = BaggingClassifier(base_estimator=Perceptron(), n_estimators=3)
ensemble = BaggingClassifier(base_estimator=Perceptron(random_state=None),
n_estimators=3)

iris = load_iris()
ensemble.fit(iris.data, iris.target)
ensemble.estimators_ = [] # empty the list and create estimators manually

ensemble._make_estimator()
ensemble._make_estimator()
ensemble._make_estimator()
random_state = np.random.RandomState(3)
ensemble._make_estimator(random_state=random_state)
ensemble._make_estimator(random_state=random_state)
ensemble._make_estimator(append=False)

assert_equal(3, len(ensemble))
assert_equal(3, len(ensemble.estimators_))

assert_true(isinstance(ensemble[0], Perceptron))
assert_equal(ensemble[0].random_state, None)
assert_true(isinstance(ensemble[1].random_state, int))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For completeness you could add:

    assert_true(isinstance(ensemble[2].random_state, int))

assert_true(isinstance(ensemble[2].random_state, int))
assert_not_equal(ensemble[1].random_state, ensemble[2].random_state)


def test_base_zero_n_estimators():
Expand All @@ -41,3 +54,55 @@ def test_base_zero_n_estimators():
assert_raise_message(ValueError,
"n_estimators must be greater than zero, got 0.",
ensemble.fit, iris.data, iris.target)


def test_set_random_states():
# Linear Discriminant Analysis doesn't have random state: smoke test
_set_random_states(LinearDiscriminantAnalysis(), random_state=17)

clf1 = Perceptron(random_state=None)
assert_equal(clf1.random_state, None)
# check random_state is None still sets
_set_random_states(clf1, None)
assert_true(isinstance(clf1.random_state, int))

# check random_state fixes results in consistent initialisation
_set_random_states(clf1, 3)
assert_true(isinstance(clf1.random_state, int))
clf2 = Perceptron(random_state=None)
_set_random_states(clf2, 3)
assert_equal(clf1.random_state, clf2.random_state)

# nested random_state

def make_steps():
return [('sel', SelectFromModel(Perceptron(random_state=None))),
('clf', Perceptron(random_state=None))]

est1 = Pipeline(make_steps())
_set_random_states(est1, 3)
assert_true(isinstance(est1.steps[0][1].estimator.random_state, int))
assert_true(isinstance(est1.steps[1][1].random_state, int))
assert_not_equal(est1.get_params()['sel__estimator__random_state'],
est1.get_params()['clf__random_state'])

# ensure multiple random_state paramaters are invariant to get_params()
# iteration order

class AlphaParamPipeline(Pipeline):
def get_params(self, *args, **kwargs):
params = Pipeline.get_params(self, *args, **kwargs).items()
return OrderedDict(sorted(params))

class RevParamPipeline(Pipeline):
def get_params(self, *args, **kwargs):
params = Pipeline.get_params(self, *args, **kwargs).items()
return OrderedDict(sorted(params, reverse=True))

for cls in [AlphaParamPipeline, RevParamPipeline]:
est2 = cls(make_steps())
_set_random_states(est2, 3)
assert_equal(est1.get_params()['sel__estimator__random_state'],
est2.get_params()['sel__estimator__random_state'])
assert_equal(est1.get_params()['clf__random_state'],
est2.get_params()['clf__random_state'])
20 changes: 16 additions & 4 deletions sklearn/ensemble/tests/test_weight_boosting.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
from sklearn.utils.testing import assert_array_equal, assert_array_less
from sklearn.utils.testing import assert_array_almost_equal
from sklearn.utils.testing import assert_equal, assert_true
from sklearn.utils.testing import assert_equal, assert_true, assert_greater
from sklearn.utils.testing import assert_raises, assert_raises_regexp

from sklearn.base import BaseEstimator
Expand Down Expand Up @@ -113,6 +113,12 @@ def test_iris():
assert score > 0.9, "Failed with algorithm %s and score = %f" % \
(alg, score)

# Check we used multiple estimators
assert_greater(len(clf.estimators_), 1)
# Check for distinct random states (see issue #7408)
assert_equal(len(set(est.random_state for est in clf.estimators_)),
len(clf.estimators_))

# Somewhat hacky regression test: prior to
# ae7adc880d624615a34bafdb1d75ef67051b8200,
# predict_proba returned SAMME.R values for SAMME.
Expand All @@ -123,11 +129,17 @@ def test_iris():

def test_boston():
# Check consistency on dataset boston house prices.
clf = AdaBoostRegressor(random_state=0)
clf.fit(boston.data, boston.target)
score = clf.score(boston.data, boston.target)
reg = AdaBoostRegressor(random_state=0)
reg.fit(boston.data, boston.target)
score = reg.score(boston.data, boston.target)
assert score > 0.85

# Check we used multiple estimators
assert_true(len(reg.estimators_) > 1)
# Check for distinct random states (see issue #7408)
assert_equal(len(set(est.random_state for est in reg.estimators_)),
len(reg.estimators_))


def test_staged_predict():
# Check staged predictions.
Expand Down
Loading