diff --git a/doc/whats_new.rst b/doc/whats_new.rst index f647806e023ff..6074acd2b55b0 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -477,8 +477,21 @@ Bug fixes - Fix bug in :func:`metrics.silhouette_samples` so that it now works with arbitrary labels, not just those ranging from 0 to n_clusters - 1. + + - Fix bug where :class:`ensemble.AdaBoostClassifier` and + :class:`ensemble.AdaBoostRegressor` would perform poorly if the + ``random_state`` was fixed + (`#7411 `_). By `Joel Nothman`_. + - Fix bug in ensembles with randomization where the ensemble would not + set ``random_state`` on base estimators in a pipeline or similar nesting. + (`#7411 `_). + Note, results for :class:`ensemble.BaggingClassifier` + :class:`ensemble.BaggingRegressor`, :class:`ensemble.AdaBoostClassifier` + and :class:`ensemble.AdaBoostRegressor` will now differ from previous + versions. By `Joel Nothman`_. + API changes summary ------------------- diff --git a/sklearn/ensemble/bagging.py b/sklearn/ensemble/bagging.py index c73b4d50d6c22..06712e79ddf5b 100644 --- a/sklearn/ensemble/bagging.py +++ b/sklearn/ensemble/bagging.py @@ -86,12 +86,8 @@ def _parallel_build_estimators(n_estimators, ensemble, X, y, sample_weight, (i + 1, n_estimators, total_n_estimators)) random_state = np.random.RandomState(seeds[i]) - estimator = ensemble._make_estimator(append=False) - - try: # Not all estimators accept a random_state - estimator.set_params(random_state=seeds[i]) - except ValueError: - pass + estimator = ensemble._make_estimator(append=False, + random_state=random_state) # Draw random feature, sample indices features, indices = _generate_bagging_indices(random_state, diff --git a/sklearn/ensemble/base.py b/sklearn/ensemble/base.py index 9610f3402c8fb..2add9d730062e 100644 --- a/sklearn/ensemble/base.py +++ b/sklearn/ensemble/base.py @@ -10,7 +10,45 @@ from ..base import clone from ..base import BaseEstimator from ..base import MetaEstimatorMixin -from ..utils import _get_n_jobs +from ..utils import _get_n_jobs, check_random_state + +MAX_RAND_SEED = np.iinfo(np.int32).max + + +def _set_random_states(estimator, random_state=None): + """Sets fixed random_state parameters for an estimator + + Finds all parameters ending ``random_state`` and sets them to integers + derived from ``random_state``. + + Parameters + ---------- + + estimator : estimator supporting get/set_params + Estimator with potential randomness managed by random_state + parameters. + + random_state : numpy.RandomState or int, optional + Random state used to generate integer values. + + Notes + ----- + This does not necessarily set *all* ``random_state`` attributes that + control an estimator's randomness, only those accessible through + ``estimator.get_params()``. ``random_state``s not controlled include + those belonging to: + + * cross-validation splitters + * ``scipy.stats`` rvs + """ + random_state = check_random_state(random_state) + to_set = {} + for key in sorted(estimator.get_params(deep=True)): + if key == 'random_state' or key.endswith('__random_state'): + to_set[key] = random_state.randint(MAX_RAND_SEED) + + if to_set: + estimator.set_params(**to_set) class BaseEnsemble(BaseEstimator, MetaEstimatorMixin): @@ -67,7 +105,7 @@ def _validate_estimator(self, default=None): if self.base_estimator_ is None: raise ValueError("base_estimator cannot be None") - def _make_estimator(self, append=True): + def _make_estimator(self, append=True, random_state=None): """Make and configure a copy of the `base_estimator_` attribute. Warning: This method should be used to properly instantiate new @@ -77,6 +115,9 @@ def _make_estimator(self, append=True): estimator.set_params(**dict((p, getattr(self, p)) for p in self.estimator_params)) + if random_state is not None: + _set_random_states(estimator, random_state) + if append: self.estimators_.append(estimator) diff --git a/sklearn/ensemble/forest.py b/sklearn/ensemble/forest.py index 4d4f04bc12408..424b8266fe376 100644 --- a/sklearn/ensemble/forest.py +++ b/sklearn/ensemble/forest.py @@ -304,8 +304,8 @@ def fit(self, X, y, sample_weight=None): trees = [] for i in range(n_more_estimators): - tree = self._make_estimator(append=False) - tree.set_params(random_state=random_state.randint(MAX_INT)) + tree = self._make_estimator(append=False, + random_state=random_state) trees.append(tree) # Parallel loop: we use the threading backend as the Cython code diff --git a/sklearn/ensemble/tests/test_bagging.py b/sklearn/ensemble/tests/test_bagging.py index 70e44eb2824a6..0170a3fa2262f 100644 --- a/sklearn/ensemble/tests/test_bagging.py +++ b/sklearn/ensemble/tests/test_bagging.py @@ -553,6 +553,8 @@ def test_bagging_with_pipeline(): DecisionTreeClassifier()), max_features=2) estimator.fit(iris.data, iris.target) + assert_true(isinstance(estimator[0].steps[-1][1].random_state, + int)) class DummyZeroEstimator(BaseEstimator): diff --git a/sklearn/ensemble/tests/test_base.py b/sklearn/ensemble/tests/test_base.py index 0268715cde9ef..948f94c76f764 100644 --- a/sklearn/ensemble/tests/test_base.py +++ b/sklearn/ensemble/tests/test_base.py @@ -5,32 +5,45 @@ # Authors: Gilles Louppe # License: BSD 3 clause +import numpy as np from numpy.testing import assert_equal from nose.tools import assert_true from sklearn.utils.testing import assert_raise_message +from sklearn.utils.testing import assert_not_equal from sklearn.datasets import load_iris from sklearn.ensemble import BaggingClassifier +from sklearn.ensemble.base import _set_random_states from sklearn.linear_model import Perceptron +from sklearn.externals.odict import OrderedDict +from sklearn.discriminant_analysis import LinearDiscriminantAnalysis +from sklearn.pipeline import Pipeline +from sklearn.feature_selection import SelectFromModel def test_base(): # Check BaseEnsemble methods. - ensemble = BaggingClassifier(base_estimator=Perceptron(), n_estimators=3) + ensemble = BaggingClassifier(base_estimator=Perceptron(random_state=None), + n_estimators=3) iris = load_iris() ensemble.fit(iris.data, iris.target) ensemble.estimators_ = [] # empty the list and create estimators manually ensemble._make_estimator() - ensemble._make_estimator() - ensemble._make_estimator() + random_state = np.random.RandomState(3) + ensemble._make_estimator(random_state=random_state) + ensemble._make_estimator(random_state=random_state) ensemble._make_estimator(append=False) assert_equal(3, len(ensemble)) assert_equal(3, len(ensemble.estimators_)) assert_true(isinstance(ensemble[0], Perceptron)) + assert_equal(ensemble[0].random_state, None) + assert_true(isinstance(ensemble[1].random_state, int)) + assert_true(isinstance(ensemble[2].random_state, int)) + assert_not_equal(ensemble[1].random_state, ensemble[2].random_state) def test_base_zero_n_estimators(): @@ -41,3 +54,55 @@ def test_base_zero_n_estimators(): assert_raise_message(ValueError, "n_estimators must be greater than zero, got 0.", ensemble.fit, iris.data, iris.target) + + +def test_set_random_states(): + # Linear Discriminant Analysis doesn't have random state: smoke test + _set_random_states(LinearDiscriminantAnalysis(), random_state=17) + + clf1 = Perceptron(random_state=None) + assert_equal(clf1.random_state, None) + # check random_state is None still sets + _set_random_states(clf1, None) + assert_true(isinstance(clf1.random_state, int)) + + # check random_state fixes results in consistent initialisation + _set_random_states(clf1, 3) + assert_true(isinstance(clf1.random_state, int)) + clf2 = Perceptron(random_state=None) + _set_random_states(clf2, 3) + assert_equal(clf1.random_state, clf2.random_state) + + # nested random_state + + def make_steps(): + return [('sel', SelectFromModel(Perceptron(random_state=None))), + ('clf', Perceptron(random_state=None))] + + est1 = Pipeline(make_steps()) + _set_random_states(est1, 3) + assert_true(isinstance(est1.steps[0][1].estimator.random_state, int)) + assert_true(isinstance(est1.steps[1][1].random_state, int)) + assert_not_equal(est1.get_params()['sel__estimator__random_state'], + est1.get_params()['clf__random_state']) + + # ensure multiple random_state paramaters are invariant to get_params() + # iteration order + + class AlphaParamPipeline(Pipeline): + def get_params(self, *args, **kwargs): + params = Pipeline.get_params(self, *args, **kwargs).items() + return OrderedDict(sorted(params)) + + class RevParamPipeline(Pipeline): + def get_params(self, *args, **kwargs): + params = Pipeline.get_params(self, *args, **kwargs).items() + return OrderedDict(sorted(params, reverse=True)) + + for cls in [AlphaParamPipeline, RevParamPipeline]: + est2 = cls(make_steps()) + _set_random_states(est2, 3) + assert_equal(est1.get_params()['sel__estimator__random_state'], + est2.get_params()['sel__estimator__random_state']) + assert_equal(est1.get_params()['clf__random_state'], + est2.get_params()['clf__random_state']) diff --git a/sklearn/ensemble/tests/test_weight_boosting.py b/sklearn/ensemble/tests/test_weight_boosting.py index 83a5e819c13f7..6064a89f10e46 100755 --- a/sklearn/ensemble/tests/test_weight_boosting.py +++ b/sklearn/ensemble/tests/test_weight_boosting.py @@ -3,7 +3,7 @@ import numpy as np from sklearn.utils.testing import assert_array_equal, assert_array_less from sklearn.utils.testing import assert_array_almost_equal -from sklearn.utils.testing import assert_equal, assert_true +from sklearn.utils.testing import assert_equal, assert_true, assert_greater from sklearn.utils.testing import assert_raises, assert_raises_regexp from sklearn.base import BaseEstimator @@ -113,6 +113,12 @@ def test_iris(): assert score > 0.9, "Failed with algorithm %s and score = %f" % \ (alg, score) + # Check we used multiple estimators + assert_greater(len(clf.estimators_), 1) + # Check for distinct random states (see issue #7408) + assert_equal(len(set(est.random_state for est in clf.estimators_)), + len(clf.estimators_)) + # Somewhat hacky regression test: prior to # ae7adc880d624615a34bafdb1d75ef67051b8200, # predict_proba returned SAMME.R values for SAMME. @@ -123,11 +129,17 @@ def test_iris(): def test_boston(): # Check consistency on dataset boston house prices. - clf = AdaBoostRegressor(random_state=0) - clf.fit(boston.data, boston.target) - score = clf.score(boston.data, boston.target) + reg = AdaBoostRegressor(random_state=0) + reg.fit(boston.data, boston.target) + score = reg.score(boston.data, boston.target) assert score > 0.85 + # Check we used multiple estimators + assert_true(len(reg.estimators_) > 1) + # Check for distinct random states (see issue #7408) + assert_equal(len(set(est.random_state for est in reg.estimators_)), + len(reg.estimators_)) + def test_staged_predict(): # Check staged predictions. diff --git a/sklearn/ensemble/weight_boosting.py b/sklearn/ensemble/weight_boosting.py index 22814cfdd2359..56d7d6ff80a2f 100644 --- a/sklearn/ensemble/weight_boosting.py +++ b/sklearn/ensemble/weight_boosting.py @@ -132,12 +132,15 @@ def fit(self, X, y, sample_weight=None): self.estimator_weights_ = np.zeros(self.n_estimators, dtype=np.float64) self.estimator_errors_ = np.ones(self.n_estimators, dtype=np.float64) + random_state = check_random_state(self.random_state) + for iboost in range(self.n_estimators): # Boosting step sample_weight, estimator_weight, estimator_error = self._boost( iboost, X, y, - sample_weight) + sample_weight, + random_state) # Early termination if sample_weight is None: @@ -163,7 +166,7 @@ def fit(self, X, y, sample_weight=None): return self @abstractmethod - def _boost(self, iboost, X, y, sample_weight): + def _boost(self, iboost, X, y, sample_weight, random_state): """Implement a single boost. Warning: This method needs to be overridden by subclasses. @@ -183,6 +186,9 @@ def _boost(self, iboost, X, y, sample_weight): sample_weight : array-like of shape = [n_samples] The current sample weights. + random_state : numpy.RandomState + The current random number generator + Returns ------- sample_weight : array-like of shape = [n_samples] or None @@ -422,7 +428,7 @@ def _validate_estimator(self): raise ValueError("%s doesn't support sample_weight." % self.base_estimator_.__class__.__name__) - def _boost(self, iboost, X, y, sample_weight): + def _boost(self, iboost, X, y, sample_weight, random_state): """Implement a single boost. Perform a single boost according to the real multi-class SAMME.R @@ -444,6 +450,9 @@ def _boost(self, iboost, X, y, sample_weight): sample_weight : array-like of shape = [n_samples] The current sample weights. + random_state : numpy.RandomState + The current random number generator + Returns ------- sample_weight : array-like of shape = [n_samples] or None @@ -459,19 +468,15 @@ def _boost(self, iboost, X, y, sample_weight): If None then boosting has terminated early. """ if self.algorithm == 'SAMME.R': - return self._boost_real(iboost, X, y, sample_weight) + return self._boost_real(iboost, X, y, sample_weight, random_state) else: # elif self.algorithm == "SAMME": - return self._boost_discrete(iboost, X, y, sample_weight) + return self._boost_discrete(iboost, X, y, sample_weight, + random_state) - def _boost_real(self, iboost, X, y, sample_weight): + def _boost_real(self, iboost, X, y, sample_weight, random_state): """Implement a single boost using the SAMME.R real algorithm.""" - estimator = self._make_estimator() - - try: - estimator.set_params(random_state=self.random_state) - except ValueError: - pass + estimator = self._make_estimator(random_state=random_state) estimator.fit(X, y, sample_weight=sample_weight) @@ -527,14 +532,9 @@ def _boost_real(self, iboost, X, y, sample_weight): return sample_weight, 1., estimator_error - def _boost_discrete(self, iboost, X, y, sample_weight): + def _boost_discrete(self, iboost, X, y, sample_weight, random_state): """Implement a single boost using the SAMME discrete algorithm.""" - estimator = self._make_estimator() - - try: - estimator.set_params(random_state=self.random_state) - except ValueError: - pass + estimator = self._make_estimator(random_state=random_state) estimator.fit(X, y, sample_weight=sample_weight) @@ -959,7 +959,7 @@ def _validate_estimator(self): super(AdaBoostRegressor, self)._validate_estimator( default=DecisionTreeRegressor(max_depth=3)) - def _boost(self, iboost, X, y, sample_weight): + def _boost(self, iboost, X, y, sample_weight, random_state): """Implement a single boost for regression Perform a single boost according to the AdaBoost.R2 algorithm and @@ -981,6 +981,9 @@ def _boost(self, iboost, X, y, sample_weight): sample_weight : array-like of shape = [n_samples] The current sample weights. + random_state : numpy.RandomState + The current random number generator + Returns ------- sample_weight : array-like of shape = [n_samples] or None @@ -995,20 +998,13 @@ def _boost(self, iboost, X, y, sample_weight): The regression error for the current boost. If None then boosting has terminated early. """ - estimator = self._make_estimator() - - try: - estimator.set_params(random_state=self.random_state) - except ValueError: - pass - - generator = check_random_state(self.random_state) + estimator = self._make_estimator(random_state=random_state) # Weighted sampling of the training set with replacement # For NumPy >= 1.7.0 use np.random.choice cdf = sample_weight.cumsum() cdf /= cdf[-1] - uniform_samples = generator.random_sample(X.shape[0]) + uniform_samples = random_state.random_sample(X.shape[0]) bootstrap_idx = cdf.searchsorted(uniform_samples, side='right') # searchsorted returns a scalar bootstrap_idx = np.array(bootstrap_idx, copy=False)