From 9b66811d486f783597b7a5767bd9152d486585d3 Mon Sep 17 00:00:00 2001 From: matt Date: Sat, 17 Aug 2019 22:16:38 -0500 Subject: [PATCH 01/27] Add max_samples bootstrap size kwarg --- sklearn/ensemble/forest.py | 90 +++++++++++++++++++++------ sklearn/ensemble/tests/test_forest.py | 89 ++++++++++++++++++++++++++ 2 files changed, 160 insertions(+), 19 deletions(-) diff --git a/sklearn/ensemble/forest.py b/sklearn/ensemble/forest.py index 6050fd2773a5f..bc1ad30a23c0e 100644 --- a/sklearn/ensemble/forest.py +++ b/sklearn/ensemble/forest.py @@ -40,6 +40,7 @@ class calls the ``fit`` method of each sub-estimator on random samples # License: BSD 3 clause +import numbers from warnings import catch_warnings, simplefilter, warn import threading @@ -72,17 +73,35 @@ class calls the ``fit`` method of each sub-estimator on random samples MAX_INT = np.iinfo(np.int32).max -def _generate_sample_indices(random_state, n_samples): +def _generate_sample_indices(random_state, n_samples, max_samples=None): """Private function used to _parallel_build_trees function.""" + + # Validate `max_samples` + if max_samples is None: + max_samples = n_samples + elif isinstance(max_samples, numbers.Integral): + if not (0 < max_samples <= n_samples): + msg = "`max_samples` must be in range 1 ... {} but got value {}" + raise ValueError(msg.format(n_samples, max_samples)) + elif isinstance(max_samples, numbers.Real): + if not (0 < max_samples <= 1.0): + msg = "`max_samples` must be in range (0, 1.0] but got value {}" + raise ValueError(msg.format(max_samples)) + max_samples = int(round(n_samples * max_samples)) + else: + msg = "`max_samples` should be int or float, but got type '{}'" + raise TypeError(msg.format(type(max_samples))) + random_instance = check_random_state(random_state) - sample_indices = random_instance.randint(0, n_samples, n_samples) + sample_indices = random_instance.randint(0, n_samples, max_samples) return sample_indices -def _generate_unsampled_indices(random_state, n_samples): +def _generate_unsampled_indices(random_state, n_samples, max_samples=None): """Private function used to forest._set_oob_score function.""" - sample_indices = _generate_sample_indices(random_state, n_samples) + sample_indices = _generate_sample_indices( + random_state, n_samples, max_samples) sample_counts = np.bincount(sample_indices, minlength=n_samples) unsampled_mask = sample_counts == 0 indices_range = np.arange(n_samples) @@ -92,7 +111,7 @@ def _generate_unsampled_indices(random_state, n_samples): def _parallel_build_trees(tree, forest, X, y, sample_weight, tree_idx, n_trees, - verbose=0, class_weight=None): + verbose=0, class_weight=None, max_samples=None): """Private function used to fit a single tree in parallel.""" if verbose > 1: print("building tree %d of %d" % (tree_idx + 1, n_trees)) @@ -104,7 +123,8 @@ def _parallel_build_trees(tree, forest, X, y, sample_weight, tree_idx, n_trees, else: curr_sample_weight = sample_weight.copy() - indices = _generate_sample_indices(tree.random_state, n_samples) + indices = _generate_sample_indices(tree.random_state, n_samples, + max_samples=max_samples) sample_counts = np.bincount(indices, minlength=n_samples) curr_sample_weight *= sample_counts @@ -140,7 +160,8 @@ def __init__(self, random_state=None, verbose=0, warm_start=False, - class_weight=None): + class_weight=None, + max_samples=None): super().__init__( base_estimator=base_estimator, n_estimators=n_estimators, @@ -153,6 +174,7 @@ def __init__(self, self.verbose = verbose self.warm_start = warm_start self.class_weight = class_weight + self.max_samples = max_samples def apply(self, X): """Apply trees in the forest to X, return leaf indices. @@ -320,7 +342,8 @@ def fit(self, X, y, sample_weight=None): **_joblib_parallel_args(prefer='threads'))( delayed(_parallel_build_trees)( t, self, X, y, sample_weight, i, len(trees), - verbose=self.verbose, class_weight=self.class_weight) + verbose=self.verbose, class_weight=self.class_weight, + max_samples=self.max_samples) for i, t in enumerate(trees)) # Collect newly grown trees @@ -410,7 +433,8 @@ def __init__(self, random_state=None, verbose=0, warm_start=False, - class_weight=None): + class_weight=None, + max_samples=None): super().__init__( base_estimator, n_estimators=n_estimators, @@ -421,7 +445,8 @@ def __init__(self, random_state=random_state, verbose=verbose, warm_start=warm_start, - class_weight=class_weight) + class_weight=class_weight, + max_samples=max_samples) def _set_oob_score(self, X, y): """Compute out-of-bag score""" @@ -650,7 +675,8 @@ def __init__(self, n_jobs=None, random_state=None, verbose=0, - warm_start=False): + warm_start=False, + max_samples=None): super().__init__( base_estimator, n_estimators=n_estimators, @@ -660,7 +686,8 @@ def __init__(self, n_jobs=n_jobs, random_state=random_state, verbose=verbose, - warm_start=warm_start) + warm_start=warm_start, + max_samples=max_samples) def predict(self, X): """Predict regression target for X. @@ -715,7 +742,8 @@ def _set_oob_score(self, X, y): for estimator in self.estimators_: unsampled_indices = _generate_unsampled_indices( - estimator.random_state, n_samples) + estimator.random_state, n_samples, + max_samples=self.max_samples) p_estimator = estimator.predict( X[unsampled_indices, :], check_input=False) @@ -913,6 +941,12 @@ class RandomForestClassifier(ForestClassifier): Note that these weights will be multiplied with sample_weight (passed through the fit method) if sample_weight is specified. + max_samples : int or float (default=None) + The number of samples to draw from X to train each base estimator. + - If None (default), then draw `X.shape[0]` samples. + - If int, then draw `max_samples` samples. + - If float, then draw `max_samples * X.shape[0]` samples. + Attributes ---------- estimators_ : list of DecisionTreeClassifier @@ -1001,7 +1035,8 @@ def __init__(self, random_state=None, verbose=0, warm_start=False, - class_weight=None): + class_weight=None, + max_samples=None): super().__init__( base_estimator=DecisionTreeClassifier(), n_estimators=n_estimators, @@ -1016,7 +1051,8 @@ def __init__(self, random_state=random_state, verbose=verbose, warm_start=warm_start, - class_weight=class_weight) + class_weight=class_weight, + max_samples=max_samples) self.criterion = criterion self.max_depth = max_depth @@ -1442,6 +1478,12 @@ class ExtraTreesClassifier(ForestClassifier): Note that these weights will be multiplied with sample_weight (passed through the fit method) if sample_weight is specified. + max_samples : int or float (default=None) + The number of samples to draw from X to train each base estimator. + - If None (default), then draw `X.shape[0]` samples. + - If int, then draw `max_samples` samples. + - If float, then draw `max_samples * X.shape[0]` samples. + Attributes ---------- estimators_ : list of DecisionTreeClassifier @@ -1510,7 +1552,8 @@ def __init__(self, random_state=None, verbose=0, warm_start=False, - class_weight=None): + class_weight=None, + max_samples=None): super().__init__( base_estimator=ExtraTreeClassifier(), n_estimators=n_estimators, @@ -1525,7 +1568,8 @@ def __init__(self, random_state=random_state, verbose=verbose, warm_start=warm_start, - class_weight=class_weight) + class_weight=class_weight, + max_samples=max_samples) self.criterion = criterion self.max_depth = max_depth @@ -1678,6 +1722,12 @@ class ExtraTreesRegressor(ForestRegressor): and add more estimators to the ensemble, otherwise, just fit a whole new forest. See :term:`the Glossary `. + max_samples : int or float (default=None) + The number of samples to draw from X to train each base estimator. + - If None (default), then draw `X.shape[0]` samples. + - If int, then draw `max_samples` samples. + - If float, then draw `max_samples * X.shape[0]` samples. + Attributes ---------- estimators_ : list of DecisionTreeRegressor @@ -1733,7 +1783,8 @@ def __init__(self, n_jobs=None, random_state=None, verbose=0, - warm_start=False): + warm_start=False, + max_samples=None): super().__init__( base_estimator=ExtraTreeRegressor(), n_estimators=n_estimators, @@ -1747,7 +1798,8 @@ def __init__(self, n_jobs=n_jobs, random_state=random_state, verbose=verbose, - warm_start=warm_start) + warm_start=warm_start, + max_samples=max_samples) self.criterion = criterion self.max_depth = max_depth diff --git a/sklearn/ensemble/tests/test_forest.py b/sklearn/ensemble/tests/test_forest.py index 01102c9679053..d8f868b8bf94b 100644 --- a/sklearn/ensemble/tests/test_forest.py +++ b/sklearn/ensemble/tests/test_forest.py @@ -1330,3 +1330,92 @@ def test_forest_degenerate_feature_importances(): gbr = RandomForestRegressor(n_estimators=10).fit(X, y) assert_array_equal(gbr.feature_importances_, np.zeros(10, dtype=np.float64)) + + +def test__generate_sample_indices(): + + from sklearn.ensemble.forest import _generate_sample_indices + + rng = np.random.RandomState(1234) + n_samples = 10 + + # Check that indices without max_samples kwarg defaults to full n_samples + indices = _generate_sample_indices(rng, n_samples) + assert len(indices) == n_samples + + # Check that indices with max_samples kwarg subsamples by that amount + indices = _generate_sample_indices(rng, n_samples, max_samples=5) + assert len(indices) == 5 + + # Check that ValueError is raised when `max_samples` is integral + # and greater than `n_samples` + with pytest.raises(ValueError): + _generate_sample_indices(rng, n_samples, max_samples=n_samples + 1) + + # Check that ValueError is raised when `max_samples` is float + # and not in range (0, 1] + with pytest.raises(ValueError): + _generate_sample_indices(rng, n_samples, max_samples=2.0) + with pytest.raises(ValueError): + _generate_sample_indices(rng, n_samples, max_samples=0.0) + with pytest.raises(ValueError): + _generate_sample_indices(rng, n_samples, max_samples=np.nan) + + # Check that TypeError is raised when `max_samples` is garbage + with pytest.raises(TypeError): + _generate_sample_indices(rng, n_samples, + max_samples='bad max sample type') + with pytest.raises(TypeError): + _generate_sample_indices(rng, n_samples, + max_samples=np.ones(n_samples)) + + +def check_max_samples_classification(name): + """ Checks that the `max_samples` option works as expected + for a simple two-class problem + """ + + rng = np.random.RandomState(1) + + # Make a two-sample, two-class dataset + X = np.array([ + [-1.], + [+1.], + ]) + y = np.array([0, 1]) + + # Initialize the classifier + rfc = FOREST_CLASSIFIERS[name]( + n_estimators=1, + random_state=rng, + bootstrap=True, + max_depth=1, + ) + + # Limiting bootstrap samples to 1 on the two-sample + # dataset with `n_estimators=1` and `max_depth=1` + # should yield an accuracy of 0.5 + rfc.max_samples = 1 + rfc.fit(X, y) + assert rfc.score(X, y) == 0.5 + + # Should be equivlaent to `max_samples=1` + rfc.max_samples = 0.5 + rfc.fit(X, y) + assert rfc.score(X, y) == 0.5 + + # Allowing bootstrap samples to 2 should allow choosing + # the optimal threshold between -1 and +1 yielding perfect accuracy + rfc.max_samples = 2 + rfc.fit(X, y) + assert rfc.score(X, y) == 1.0 + + # Should be equivalent to `max_samples=2` + rfc.max_samples = 1.0 + rfc.fit(X, y) + assert rfc.score(X, y) == 1.0 + + +@pytest.mark.parametrize('name', FOREST_CLASSIFIERS) +def test_max_samples_classification(name): + check_max_samples_classification(name) From b845709a1f5ab4081a5946f2cb133e8c8d8c2dfe Mon Sep 17 00:00:00 2001 From: matt Date: Sun, 18 Aug 2019 12:49:37 -0500 Subject: [PATCH 02/27] Refactor unit tests --- sklearn/ensemble/tests/test_forest.py | 62 +++++++++++---------------- 1 file changed, 25 insertions(+), 37 deletions(-) diff --git a/sklearn/ensemble/tests/test_forest.py b/sklearn/ensemble/tests/test_forest.py index d8f868b8bf94b..6124a4cd24af4 100644 --- a/sklearn/ensemble/tests/test_forest.py +++ b/sklearn/ensemble/tests/test_forest.py @@ -1333,6 +1333,8 @@ def test_forest_degenerate_feature_importances(): def test__generate_sample_indices(): + """ Test the behavior of the sample index generation helper + """ from sklearn.ensemble.forest import _generate_sample_indices @@ -1361,7 +1363,7 @@ def test__generate_sample_indices(): with pytest.raises(ValueError): _generate_sample_indices(rng, n_samples, max_samples=np.nan) - # Check that TypeError is raised when `max_samples` is garbage + # Check that TypeError is raised when `max_samples` is the wrong type with pytest.raises(TypeError): _generate_sample_indices(rng, n_samples, max_samples='bad max sample type') @@ -1370,52 +1372,38 @@ def test__generate_sample_indices(): max_samples=np.ones(n_samples)) -def check_max_samples_classification(name): - """ Checks that the `max_samples` option works as expected - for a simple two-class problem +def check_classification_toy_max_samples(name): + """ Test that the toy example is separable via a bootstrap size of only 2 """ rng = np.random.RandomState(1) + max_tries = 100 - # Make a two-sample, two-class dataset - X = np.array([ - [-1.], - [+1.], - ]) - y = np.array([0, 1]) - - # Initialize the classifier - rfc = FOREST_CLASSIFIERS[name]( + # The toy example is separable using just one + # decision stump, and choosing 2 examples from the full + # 6-example dataset *if* the 2 examples are chosen correctly. + est = FOREST_CLASSIFIERS[name]( n_estimators=1, - random_state=rng, bootstrap=True, + max_samples=2, max_depth=1, + random_state=rng, ) - # Limiting bootstrap samples to 1 on the two-sample - # dataset with `n_estimators=1` and `max_depth=1` - # should yield an accuracy of 0.5 - rfc.max_samples = 1 - rfc.fit(X, y) - assert rfc.score(X, y) == 0.5 - - # Should be equivlaent to `max_samples=1` - rfc.max_samples = 0.5 - rfc.fit(X, y) - assert rfc.score(X, y) == 0.5 - - # Allowing bootstrap samples to 2 should allow choosing - # the optimal threshold between -1 and +1 yielding perfect accuracy - rfc.max_samples = 2 - rfc.fit(X, y) - assert rfc.score(X, y) == 1.0 + # Each call to fit uses a different bootstrap sample of size two. If we + # fit multiple times, we expect that we eventually hit a case where + # the two examples chosen for the bootstrap sample are from the opposite + # class and yield a perfect score across the entire dataset. + perfect_score = False + for _ in range(max_tries): + est.fit(X, y) + if est.score(X, y) == 1.0: + perfect_score = True + break - # Should be equivalent to `max_samples=2` - rfc.max_samples = 1.0 - rfc.fit(X, y) - assert rfc.score(X, y) == 1.0 + assert perfect_score @pytest.mark.parametrize('name', FOREST_CLASSIFIERS) -def test_max_samples_classification(name): - check_max_samples_classification(name) +def test_classification_toy_max_samples(name): + check_classification_toy_max_samples(name) From 6db3384abc53d76116543a912ac5b812d4aa483d Mon Sep 17 00:00:00 2001 From: matt Date: Sun, 18 Aug 2019 13:01:17 -0500 Subject: [PATCH 03/27] Add one more assert check in index helper test --- sklearn/ensemble/tests/test_forest.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/sklearn/ensemble/tests/test_forest.py b/sklearn/ensemble/tests/test_forest.py index 6124a4cd24af4..abf009ffe5981 100644 --- a/sklearn/ensemble/tests/test_forest.py +++ b/sklearn/ensemble/tests/test_forest.py @@ -1349,6 +1349,10 @@ def test__generate_sample_indices(): indices = _generate_sample_indices(rng, n_samples, max_samples=5) assert len(indices) == 5 + # Check that indices with max_samples kwarg subsamples by that amount + indices = _generate_sample_indices(rng, n_samples, max_samples=0.5) + assert len(indices) == 5 + # Check that ValueError is raised when `max_samples` is integral # and greater than `n_samples` with pytest.raises(ValueError): From 2507d5b44f278047eae2cb1c0bea4705bb266622 Mon Sep 17 00:00:00 2001 From: matt Date: Sun, 18 Aug 2019 18:30:55 -0500 Subject: [PATCH 04/27] Move validation and bootstrap size get to helper --- sklearn/ensemble/forest.py | 45 ++++++++++++++--------- sklearn/ensemble/tests/test_forest.py | 51 ++++++++++++++++----------- 2 files changed, 58 insertions(+), 38 deletions(-) diff --git a/sklearn/ensemble/forest.py b/sklearn/ensemble/forest.py index bc1ad30a23c0e..a32a18f66a70f 100644 --- a/sklearn/ensemble/forest.py +++ b/sklearn/ensemble/forest.py @@ -73,35 +73,42 @@ class calls the ``fit`` method of each sub-estimator on random samples MAX_INT = np.iinfo(np.int32).max -def _generate_sample_indices(random_state, n_samples, max_samples=None): - """Private function used to _parallel_build_trees function.""" +def _get_n_bootstrap_samples(n_samples, max_samples=None): + """Generates the number of bootstrap samples given the total + available `n_samples` and the limit `max_samples`, which can be + either None, integral, or real valued. + """ - # Validate `max_samples` if max_samples is None: - max_samples = n_samples + return n_samples elif isinstance(max_samples, numbers.Integral): - if not (0 < max_samples <= n_samples): - msg = "`max_samples` must be in range 1 ... {} but got value {}" + if not (1 <= max_samples <= n_samples): + msg = "`max_samples` must be in range 1 to {} but got value {}" raise ValueError(msg.format(n_samples, max_samples)) + return max_samples elif isinstance(max_samples, numbers.Real): - if not (0 < max_samples <= 1.0): - msg = "`max_samples` must be in range (0, 1.0] but got value {}" + if not (0 < max_samples < 1.0): + msg = "`max_samples` must be in range (0, 1.0) but got value {}" raise ValueError(msg.format(max_samples)) - max_samples = int(round(n_samples * max_samples)) + return int(round(n_samples * max_samples)) else: msg = "`max_samples` should be int or float, but got type '{}'" raise TypeError(msg.format(type(max_samples))) + +def _generate_sample_indices(random_state, n_samples, n_bootstrap_samples): + """Private function used to _parallel_build_trees function.""" + random_instance = check_random_state(random_state) - sample_indices = random_instance.randint(0, n_samples, max_samples) + sample_indices = random_instance.randint(0, n_samples, n_bootstrap_samples) return sample_indices -def _generate_unsampled_indices(random_state, n_samples, max_samples=None): +def _generate_unsampled_indices(random_state, n_samples, n_bootstrap_samples): """Private function used to forest._set_oob_score function.""" - sample_indices = _generate_sample_indices( - random_state, n_samples, max_samples) + sample_indices = _generate_sample_indices(random_state, n_samples, + n_bootstrap_samples) sample_counts = np.bincount(sample_indices, minlength=n_samples) unsampled_mask = sample_counts == 0 indices_range = np.arange(n_samples) @@ -123,8 +130,9 @@ def _parallel_build_trees(tree, forest, X, y, sample_weight, tree_idx, n_trees, else: curr_sample_weight = sample_weight.copy() + n_bootstrap_samples = _get_n_bootstrap_samples(n_samples, max_samples) indices = _generate_sample_indices(tree.random_state, n_samples, - max_samples=max_samples) + n_bootstrap_samples) sample_counts = np.bincount(indices, minlength=n_samples) curr_sample_weight *= sample_counts @@ -461,8 +469,10 @@ def _set_oob_score(self, X, y): for k in range(self.n_outputs_)] for estimator in self.estimators_: + n_bootstrap_samples = _get_n_bootstrap_samples( + n_samples, self.max_samples) unsampled_indices = _generate_unsampled_indices( - estimator.random_state, n_samples) + estimator.random_state, n_samples, n_bootstrap_samples) p_estimator = estimator.predict_proba(X[unsampled_indices, :], check_input=False) @@ -741,9 +751,10 @@ def _set_oob_score(self, X, y): n_predictions = np.zeros((n_samples, self.n_outputs_)) for estimator in self.estimators_: + n_bootstrap_samples = _get_n_bootstrap_samples( + n_samples, self.max_samples) unsampled_indices = _generate_unsampled_indices( - estimator.random_state, n_samples, - max_samples=self.max_samples) + estimator.random_state, n_samples, n_bootstrap_samples) p_estimator = estimator.predict( X[unsampled_indices, :], check_input=False) diff --git a/sklearn/ensemble/tests/test_forest.py b/sklearn/ensemble/tests/test_forest.py index abf009ffe5981..64178b65b3ee4 100644 --- a/sklearn/ensemble/tests/test_forest.py +++ b/sklearn/ensemble/tests/test_forest.py @@ -1332,48 +1332,57 @@ def test_forest_degenerate_feature_importances(): np.zeros(10, dtype=np.float64)) -def test__generate_sample_indices(): - """ Test the behavior of the sample index generation helper +def test__get_n_bootstrap_samples(): + """ Test bootstrap sample number getter """ - from sklearn.ensemble.forest import _generate_sample_indices + from sklearn.ensemble.forest import _get_n_bootstrap_samples - rng = np.random.RandomState(1234) n_samples = 10 - # Check that indices without max_samples kwarg defaults to full n_samples - indices = _generate_sample_indices(rng, n_samples) - assert len(indices) == n_samples + # Check that max_samples=None defaults to full n_samples + n_bootstrap = _get_n_bootstrap_samples(n_samples, max_samples=None) + assert n_bootstrap == n_samples - # Check that indices with max_samples kwarg subsamples by that amount - indices = _generate_sample_indices(rng, n_samples, max_samples=5) - assert len(indices) == 5 + # Check that max less total yields max + n_bootstrap = _get_n_bootstrap_samples(n_samples, max_samples=5) + assert n_bootstrap == 5 # Check that indices with max_samples kwarg subsamples by that amount - indices = _generate_sample_indices(rng, n_samples, max_samples=0.5) - assert len(indices) == 5 + n_bootstrap = _get_n_bootstrap_samples(n_samples, max_samples=0.1) + assert n_bootstrap == 1 # Check that ValueError is raised when `max_samples` is integral # and greater than `n_samples` with pytest.raises(ValueError): - _generate_sample_indices(rng, n_samples, max_samples=n_samples + 1) + _get_n_bootstrap_samples(n_samples, max_samples=n_samples + 1) # Check that ValueError is raised when `max_samples` is float - # and not in range (0, 1] + # and not in range (0, 1) + with pytest.raises(ValueError): + # Edge case at 1.0 + _get_n_bootstrap_samples(n_samples, max_samples=1.0) + with pytest.raises(ValueError): + # Greater than 1.0 + _get_n_bootstrap_samples(n_samples, max_samples=2.0) + with pytest.raises(ValueError): + # Edge case at 0.0 + _get_n_bootstrap_samples(n_samples, max_samples=0.0) with pytest.raises(ValueError): - _generate_sample_indices(rng, n_samples, max_samples=2.0) + # Less than 0.0 + _get_n_bootstrap_samples(n_samples, max_samples=-1.0) with pytest.raises(ValueError): - _generate_sample_indices(rng, n_samples, max_samples=0.0) + # NaN: valid float, so should raise ValueError + _get_n_bootstrap_samples(n_samples, max_samples=np.nan) with pytest.raises(ValueError): - _generate_sample_indices(rng, n_samples, max_samples=np.nan) + # Inf: valid float, so should raise ValueError + _get_n_bootstrap_samples(n_samples, max_samples=np.inf) # Check that TypeError is raised when `max_samples` is the wrong type with pytest.raises(TypeError): - _generate_sample_indices(rng, n_samples, - max_samples='bad max sample type') + _get_n_bootstrap_samples(n_samples, max_samples='bad max sample type') with pytest.raises(TypeError): - _generate_sample_indices(rng, n_samples, - max_samples=np.ones(n_samples)) + _get_n_bootstrap_samples(n_samples, max_samples=np.ones(n_samples)) def check_classification_toy_max_samples(name): From 4fa10a9f6d2a41d014ce7b8451b814429070f933 Mon Sep 17 00:00:00 2001 From: matt Date: Sun, 18 Aug 2019 18:52:52 -0500 Subject: [PATCH 05/27] Compute bootstrap size just once --- sklearn/ensemble/forest.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/sklearn/ensemble/forest.py b/sklearn/ensemble/forest.py index a32a18f66a70f..8b341d5fa3def 100644 --- a/sklearn/ensemble/forest.py +++ b/sklearn/ensemble/forest.py @@ -118,7 +118,8 @@ def _generate_unsampled_indices(random_state, n_samples, n_bootstrap_samples): def _parallel_build_trees(tree, forest, X, y, sample_weight, tree_idx, n_trees, - verbose=0, class_weight=None, max_samples=None): + verbose=0, class_weight=None, + n_bootstrap_samples=None): """Private function used to fit a single tree in parallel.""" if verbose > 1: print("building tree %d of %d" % (tree_idx + 1, n_trees)) @@ -130,7 +131,6 @@ def _parallel_build_trees(tree, forest, X, y, sample_weight, tree_idx, n_trees, else: curr_sample_weight = sample_weight.copy() - n_bootstrap_samples = _get_n_bootstrap_samples(n_samples, max_samples) indices = _generate_sample_indices(tree.random_state, n_samples, n_bootstrap_samples) sample_counts = np.bincount(indices, minlength=n_samples) @@ -307,6 +307,10 @@ def fit(self, X, y, sample_weight=None): else: sample_weight = expanded_class_weight + # Get bootstrap sample size + n_bootstrap_samples = _get_n_bootstrap_samples( + n_samples=X.shape[0], max_samples=self.max_samples) + # Check parameters self._validate_estimator() @@ -351,7 +355,7 @@ def fit(self, X, y, sample_weight=None): delayed(_parallel_build_trees)( t, self, X, y, sample_weight, i, len(trees), verbose=self.verbose, class_weight=self.class_weight, - max_samples=self.max_samples) + n_bootstrap_samples=n_bootstrap_samples) for i, t in enumerate(trees)) # Collect newly grown trees From 3eb299cb5ec65ee65418bbefc5cc27bfab8b8b2e Mon Sep 17 00:00:00 2001 From: matt Date: Tue, 20 Aug 2019 18:17:14 -0500 Subject: [PATCH 06/27] n_bootstrap_samples -> n_samples_bootstrap; update docstring --- sklearn/ensemble/forest.py | 28 +++++++++++++++++++-------- sklearn/ensemble/tests/test_forest.py | 26 ++++++++++++------------- 2 files changed, 33 insertions(+), 21 deletions(-) diff --git a/sklearn/ensemble/forest.py b/sklearn/ensemble/forest.py index 8b341d5fa3def..5761b2599022f 100644 --- a/sklearn/ensemble/forest.py +++ b/sklearn/ensemble/forest.py @@ -73,12 +73,24 @@ class calls the ``fit`` method of each sub-estimator on random samples MAX_INT = np.iinfo(np.int32).max -def _get_n_bootstrap_samples(n_samples, max_samples=None): - """Generates the number of bootstrap samples given the total - available `n_samples` and the limit `max_samples`, which can be - either None, integral, or real valued. - """ +def _get_n_samples_bootstrap(n_samples, max_samples): + """Get the number of samples in a bootstrap sample. + Parameters + ---------- + n_samples : int + Number of samples in the dataset. + max_samples : int, float or None, default=None + The maximum number of samples to draw from the total available: + - float indicates a fraction of the total. + - int indicates the exact number of samples. + - None indicates the total number of samples + + Returns + ------- + n_samples_bootstrap : int + The total number of samples to draw for the bootstrap sample. + """ if max_samples is None: return n_samples elif isinstance(max_samples, numbers.Integral): @@ -308,7 +320,7 @@ def fit(self, X, y, sample_weight=None): sample_weight = expanded_class_weight # Get bootstrap sample size - n_bootstrap_samples = _get_n_bootstrap_samples( + n_bootstrap_samples = _get_n_samples_bootstrap( n_samples=X.shape[0], max_samples=self.max_samples) # Check parameters @@ -473,7 +485,7 @@ def _set_oob_score(self, X, y): for k in range(self.n_outputs_)] for estimator in self.estimators_: - n_bootstrap_samples = _get_n_bootstrap_samples( + n_bootstrap_samples = _get_n_samples_bootstrap( n_samples, self.max_samples) unsampled_indices = _generate_unsampled_indices( estimator.random_state, n_samples, n_bootstrap_samples) @@ -755,7 +767,7 @@ def _set_oob_score(self, X, y): n_predictions = np.zeros((n_samples, self.n_outputs_)) for estimator in self.estimators_: - n_bootstrap_samples = _get_n_bootstrap_samples( + n_bootstrap_samples = _get_n_samples_bootstrap( n_samples, self.max_samples) unsampled_indices = _generate_unsampled_indices( estimator.random_state, n_samples, n_bootstrap_samples) diff --git a/sklearn/ensemble/tests/test_forest.py b/sklearn/ensemble/tests/test_forest.py index 64178b65b3ee4..6de447dbf04a5 100644 --- a/sklearn/ensemble/tests/test_forest.py +++ b/sklearn/ensemble/tests/test_forest.py @@ -1336,53 +1336,53 @@ def test__get_n_bootstrap_samples(): """ Test bootstrap sample number getter """ - from sklearn.ensemble.forest import _get_n_bootstrap_samples + from sklearn.ensemble.forest import _get_n_samples_bootstrap n_samples = 10 # Check that max_samples=None defaults to full n_samples - n_bootstrap = _get_n_bootstrap_samples(n_samples, max_samples=None) + n_bootstrap = _get_n_samples_bootstrap(n_samples, max_samples=None) assert n_bootstrap == n_samples # Check that max less total yields max - n_bootstrap = _get_n_bootstrap_samples(n_samples, max_samples=5) + n_bootstrap = _get_n_samples_bootstrap(n_samples, max_samples=5) assert n_bootstrap == 5 # Check that indices with max_samples kwarg subsamples by that amount - n_bootstrap = _get_n_bootstrap_samples(n_samples, max_samples=0.1) + n_bootstrap = _get_n_samples_bootstrap(n_samples, max_samples=0.1) assert n_bootstrap == 1 # Check that ValueError is raised when `max_samples` is integral # and greater than `n_samples` with pytest.raises(ValueError): - _get_n_bootstrap_samples(n_samples, max_samples=n_samples + 1) + _get_n_samples_bootstrap(n_samples, max_samples=n_samples + 1) # Check that ValueError is raised when `max_samples` is float # and not in range (0, 1) with pytest.raises(ValueError): # Edge case at 1.0 - _get_n_bootstrap_samples(n_samples, max_samples=1.0) + _get_n_samples_bootstrap(n_samples, max_samples=1.0) with pytest.raises(ValueError): # Greater than 1.0 - _get_n_bootstrap_samples(n_samples, max_samples=2.0) + _get_n_samples_bootstrap(n_samples, max_samples=2.0) with pytest.raises(ValueError): # Edge case at 0.0 - _get_n_bootstrap_samples(n_samples, max_samples=0.0) + _get_n_samples_bootstrap(n_samples, max_samples=0.0) with pytest.raises(ValueError): # Less than 0.0 - _get_n_bootstrap_samples(n_samples, max_samples=-1.0) + _get_n_samples_bootstrap(n_samples, max_samples=-1.0) with pytest.raises(ValueError): # NaN: valid float, so should raise ValueError - _get_n_bootstrap_samples(n_samples, max_samples=np.nan) + _get_n_samples_bootstrap(n_samples, max_samples=np.nan) with pytest.raises(ValueError): # Inf: valid float, so should raise ValueError - _get_n_bootstrap_samples(n_samples, max_samples=np.inf) + _get_n_samples_bootstrap(n_samples, max_samples=np.inf) # Check that TypeError is raised when `max_samples` is the wrong type with pytest.raises(TypeError): - _get_n_bootstrap_samples(n_samples, max_samples='bad max sample type') + _get_n_samples_bootstrap(n_samples, max_samples='bad max sample type') with pytest.raises(TypeError): - _get_n_bootstrap_samples(n_samples, max_samples=np.ones(n_samples)) + _get_n_samples_bootstrap(n_samples, max_samples=np.ones(n_samples)) def check_classification_toy_max_samples(name): From 081b7b7bc33d0a2b69493e34e4e60de0244c1c3b Mon Sep 17 00:00:00 2001 From: matt Date: Tue, 20 Aug 2019 18:45:17 -0500 Subject: [PATCH 07/27] Refactor exception tests for max_samples --- sklearn/ensemble/tests/test_forest.py | 69 +++++++-------------------- 1 file changed, 18 insertions(+), 51 deletions(-) diff --git a/sklearn/ensemble/tests/test_forest.py b/sklearn/ensemble/tests/test_forest.py index 6de447dbf04a5..a3e8cb17120c8 100644 --- a/sklearn/ensemble/tests/test_forest.py +++ b/sklearn/ensemble/tests/test_forest.py @@ -1331,58 +1331,25 @@ def test_forest_degenerate_feature_importances(): assert_array_equal(gbr.feature_importances_, np.zeros(10, dtype=np.float64)) - -def test__get_n_bootstrap_samples(): - """ Test bootstrap sample number getter +@pytest.mark.parametrize('name', FOREST_CLASSIFIERS) +@pytest.mark.parametrize( + 'max_samples_exception', + [(int(1e9), ValueError), + (1.0, ValueError), + (2.0, ValueError), + (0.0, ValueError), + (np.nan, ValueError), + (np.inf, ValueError), + ('str max_samples?!', TypeError), + (np.ones(2), TypeError)] +) +def test_max_samples_exceptions(name, max_samples_exception): + """ Check invalid `max_samples` values """ - - from sklearn.ensemble.forest import _get_n_samples_bootstrap - - n_samples = 10 - - # Check that max_samples=None defaults to full n_samples - n_bootstrap = _get_n_samples_bootstrap(n_samples, max_samples=None) - assert n_bootstrap == n_samples - - # Check that max less total yields max - n_bootstrap = _get_n_samples_bootstrap(n_samples, max_samples=5) - assert n_bootstrap == 5 - - # Check that indices with max_samples kwarg subsamples by that amount - n_bootstrap = _get_n_samples_bootstrap(n_samples, max_samples=0.1) - assert n_bootstrap == 1 - - # Check that ValueError is raised when `max_samples` is integral - # and greater than `n_samples` - with pytest.raises(ValueError): - _get_n_samples_bootstrap(n_samples, max_samples=n_samples + 1) - - # Check that ValueError is raised when `max_samples` is float - # and not in range (0, 1) - with pytest.raises(ValueError): - # Edge case at 1.0 - _get_n_samples_bootstrap(n_samples, max_samples=1.0) - with pytest.raises(ValueError): - # Greater than 1.0 - _get_n_samples_bootstrap(n_samples, max_samples=2.0) - with pytest.raises(ValueError): - # Edge case at 0.0 - _get_n_samples_bootstrap(n_samples, max_samples=0.0) - with pytest.raises(ValueError): - # Less than 0.0 - _get_n_samples_bootstrap(n_samples, max_samples=-1.0) - with pytest.raises(ValueError): - # NaN: valid float, so should raise ValueError - _get_n_samples_bootstrap(n_samples, max_samples=np.nan) - with pytest.raises(ValueError): - # Inf: valid float, so should raise ValueError - _get_n_samples_bootstrap(n_samples, max_samples=np.inf) - - # Check that TypeError is raised when `max_samples` is the wrong type - with pytest.raises(TypeError): - _get_n_samples_bootstrap(n_samples, max_samples='bad max sample type') - with pytest.raises(TypeError): - _get_n_samples_bootstrap(n_samples, max_samples=np.ones(n_samples)) + max_samples, exception = max_samples_exception + est = FOREST_CLASSIFIERS_REGRESSORS[name](max_samples=max_samples) + with pytest.raises(exception): + est.fit(X, y) def check_classification_toy_max_samples(name): From cf8b0cb2b96bd08422a7963dc357aa49c47d2d37 Mon Sep 17 00:00:00 2001 From: matt Date: Tue, 20 Aug 2019 18:46:12 -0500 Subject: [PATCH 08/27] n_bootstrap_samples -> n_samples_bootstrap in signatures --- sklearn/ensemble/forest.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/sklearn/ensemble/forest.py b/sklearn/ensemble/forest.py index 5761b2599022f..145fec84d0153 100644 --- a/sklearn/ensemble/forest.py +++ b/sklearn/ensemble/forest.py @@ -108,19 +108,19 @@ def _get_n_samples_bootstrap(n_samples, max_samples): raise TypeError(msg.format(type(max_samples))) -def _generate_sample_indices(random_state, n_samples, n_bootstrap_samples): +def _generate_sample_indices(random_state, n_samples, n_samples_bootstrap): """Private function used to _parallel_build_trees function.""" random_instance = check_random_state(random_state) - sample_indices = random_instance.randint(0, n_samples, n_bootstrap_samples) + sample_indices = random_instance.randint(0, n_samples, n_samples_bootstrap) return sample_indices -def _generate_unsampled_indices(random_state, n_samples, n_bootstrap_samples): +def _generate_unsampled_indices(random_state, n_samples, n_samples_bootstrap): """Private function used to forest._set_oob_score function.""" sample_indices = _generate_sample_indices(random_state, n_samples, - n_bootstrap_samples) + n_samples_bootstrap) sample_counts = np.bincount(sample_indices, minlength=n_samples) unsampled_mask = sample_counts == 0 indices_range = np.arange(n_samples) @@ -131,7 +131,7 @@ def _generate_unsampled_indices(random_state, n_samples, n_bootstrap_samples): def _parallel_build_trees(tree, forest, X, y, sample_weight, tree_idx, n_trees, verbose=0, class_weight=None, - n_bootstrap_samples=None): + n_samples_bootstrap=None): """Private function used to fit a single tree in parallel.""" if verbose > 1: print("building tree %d of %d" % (tree_idx + 1, n_trees)) @@ -144,7 +144,7 @@ def _parallel_build_trees(tree, forest, X, y, sample_weight, tree_idx, n_trees, curr_sample_weight = sample_weight.copy() indices = _generate_sample_indices(tree.random_state, n_samples, - n_bootstrap_samples) + n_samples_bootstrap) sample_counts = np.bincount(indices, minlength=n_samples) curr_sample_weight *= sample_counts @@ -320,7 +320,7 @@ def fit(self, X, y, sample_weight=None): sample_weight = expanded_class_weight # Get bootstrap sample size - n_bootstrap_samples = _get_n_samples_bootstrap( + n_samples_bootstrap = _get_n_samples_bootstrap( n_samples=X.shape[0], max_samples=self.max_samples) # Check parameters @@ -367,7 +367,7 @@ def fit(self, X, y, sample_weight=None): delayed(_parallel_build_trees)( t, self, X, y, sample_weight, i, len(trees), verbose=self.verbose, class_weight=self.class_weight, - n_bootstrap_samples=n_bootstrap_samples) + n_samples_bootstrap=n_samples_bootstrap) for i, t in enumerate(trees)) # Collect newly grown trees @@ -485,10 +485,10 @@ def _set_oob_score(self, X, y): for k in range(self.n_outputs_)] for estimator in self.estimators_: - n_bootstrap_samples = _get_n_samples_bootstrap( + n_samples_bootstrap = _get_n_samples_bootstrap( n_samples, self.max_samples) unsampled_indices = _generate_unsampled_indices( - estimator.random_state, n_samples, n_bootstrap_samples) + estimator.random_state, n_samples, n_samples_bootstrap) p_estimator = estimator.predict_proba(X[unsampled_indices, :], check_input=False) @@ -767,10 +767,10 @@ def _set_oob_score(self, X, y): n_predictions = np.zeros((n_samples, self.n_outputs_)) for estimator in self.estimators_: - n_bootstrap_samples = _get_n_samples_bootstrap( + n_samples_bootstrap = _get_n_samples_bootstrap( n_samples, self.max_samples) unsampled_indices = _generate_unsampled_indices( - estimator.random_state, n_samples, n_bootstrap_samples) + estimator.random_state, n_samples, n_samples_bootstrap) p_estimator = estimator.predict( X[unsampled_indices, :], check_input=False) From 49a949a2ab3c2ca5c8c72ceb02f2c8d2119f1a9f Mon Sep 17 00:00:00 2001 From: matt Date: Tue, 20 Aug 2019 18:50:19 -0500 Subject: [PATCH 09/27] Add max_samples kwarg to RandomForestRegressor --- sklearn/ensemble/forest.py | 12 ++++++++++-- sklearn/ensemble/tests/test_forest.py | 3 ++- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/sklearn/ensemble/forest.py b/sklearn/ensemble/forest.py index 145fec84d0153..d5ae749635364 100644 --- a/sklearn/ensemble/forest.py +++ b/sklearn/ensemble/forest.py @@ -1235,6 +1235,12 @@ class RandomForestRegressor(ForestRegressor): and add more estimators to the ensemble, otherwise, just fit a whole new forest. See :term:`the Glossary `. + max_samples : int or float (default=None) + The number of samples to draw from X to train each base estimator. + - If None (default), then draw `X.shape[0]` samples. + - If int, then draw `max_samples` samples. + - If float, then draw `max_samples * X.shape[0]` samples. + Attributes ---------- estimators_ : list of DecisionTreeRegressor @@ -1317,7 +1323,8 @@ def __init__(self, n_jobs=None, random_state=None, verbose=0, - warm_start=False): + warm_start=False, + max_samples=None): super().__init__( base_estimator=DecisionTreeRegressor(), n_estimators=n_estimators, @@ -1331,7 +1338,8 @@ def __init__(self, n_jobs=n_jobs, random_state=random_state, verbose=verbose, - warm_start=warm_start) + warm_start=warm_start, + max_samples=max_samples) self.criterion = criterion self.max_depth = max_depth diff --git a/sklearn/ensemble/tests/test_forest.py b/sklearn/ensemble/tests/test_forest.py index a3e8cb17120c8..0701c5f7a7d97 100644 --- a/sklearn/ensemble/tests/test_forest.py +++ b/sklearn/ensemble/tests/test_forest.py @@ -1331,7 +1331,8 @@ def test_forest_degenerate_feature_importances(): assert_array_equal(gbr.feature_importances_, np.zeros(10, dtype=np.float64)) -@pytest.mark.parametrize('name', FOREST_CLASSIFIERS) + +@pytest.mark.parametrize('name', FOREST_CLASSIFIERS_REGRESSORS) @pytest.mark.parametrize( 'max_samples_exception', [(int(1e9), ValueError), From 0e2b38683b43749283fa8cf58e391d57ccf35e51 Mon Sep 17 00:00:00 2001 From: matt Date: Tue, 20 Aug 2019 18:54:56 -0500 Subject: [PATCH 10/27] Revert doc default --- sklearn/ensemble/forest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/ensemble/forest.py b/sklearn/ensemble/forest.py index d5ae749635364..cf87dcebf0db2 100644 --- a/sklearn/ensemble/forest.py +++ b/sklearn/ensemble/forest.py @@ -80,7 +80,7 @@ def _get_n_samples_bootstrap(n_samples, max_samples): ---------- n_samples : int Number of samples in the dataset. - max_samples : int, float or None, default=None + max_samples : int, float or None (default=None) The maximum number of samples to draw from the total available: - float indicates a fraction of the total. - int indicates the exact number of samples. From ef5efe1fe8bf1b16bfd71bc9b49b539213a021fa Mon Sep 17 00:00:00 2001 From: matt Date: Tue, 20 Aug 2019 19:52:35 -0500 Subject: [PATCH 11/27] Move n_samples_bootstrap out of loop --- sklearn/ensemble/forest.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/sklearn/ensemble/forest.py b/sklearn/ensemble/forest.py index cf87dcebf0db2..57af333d769e5 100644 --- a/sklearn/ensemble/forest.py +++ b/sklearn/ensemble/forest.py @@ -484,9 +484,10 @@ def _set_oob_score(self, X, y): predictions = [np.zeros((n_samples, n_classes_[k])) for k in range(self.n_outputs_)] + n_samples_bootstrap = _get_n_samples_bootstrap( + n_samples, self.max_samples) + for estimator in self.estimators_: - n_samples_bootstrap = _get_n_samples_bootstrap( - n_samples, self.max_samples) unsampled_indices = _generate_unsampled_indices( estimator.random_state, n_samples, n_samples_bootstrap) p_estimator = estimator.predict_proba(X[unsampled_indices, :], @@ -766,9 +767,10 @@ def _set_oob_score(self, X, y): predictions = np.zeros((n_samples, self.n_outputs_)) n_predictions = np.zeros((n_samples, self.n_outputs_)) + n_samples_bootstrap = _get_n_samples_bootstrap( + n_samples, self.max_samples) + for estimator in self.estimators_: - n_samples_bootstrap = _get_n_samples_bootstrap( - n_samples, self.max_samples) unsampled_indices = _generate_unsampled_indices( estimator.random_state, n_samples, n_samples_bootstrap) p_estimator = estimator.predict( From 5bc77960d4437d63d6b6fe8c3b0b4af94185a3ee Mon Sep 17 00:00:00 2001 From: matt Date: Tue, 20 Aug 2019 20:21:01 -0500 Subject: [PATCH 12/27] Add max_samples to RandomTreesEmbedding --- sklearn/ensemble/forest.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/sklearn/ensemble/forest.py b/sklearn/ensemble/forest.py index 4917fdc24bea6..ccef3ba9b67f4 100644 --- a/sklearn/ensemble/forest.py +++ b/sklearn/ensemble/forest.py @@ -2040,6 +2040,12 @@ class RandomTreesEmbedding(BaseForest): .. versionadded:: 0.22 + max_samples : int or float (default=None) + The number of samples to draw from X to train each base estimator. + - If None (default), then draw `X.shape[0]` samples. + - If int, then draw `max_samples` samples. + - If float, then draw `max_samples * X.shape[0]` samples. + Attributes ---------- estimators_ : list of DecisionTreeClassifier @@ -2072,7 +2078,8 @@ def __init__(self, random_state=None, verbose=0, warm_start=False, - ccp_alpha=0.0): + ccp_alpha=0.0, + max_samples=None): super().__init__( base_estimator=ExtraTreeRegressor(), n_estimators=n_estimators, @@ -2086,7 +2093,8 @@ def __init__(self, n_jobs=n_jobs, random_state=random_state, verbose=verbose, - warm_start=warm_start) + warm_start=warm_start, + max_samples=max_samples) self.max_depth = max_depth self.min_samples_split = min_samples_split From 75ce3384087ecfca6541c5fd78546a448b3183d7 Mon Sep 17 00:00:00 2001 From: matt Date: Mon, 26 Aug 2019 19:52:15 -0500 Subject: [PATCH 13/27] Change docstring style for default --- sklearn/ensemble/forest.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/sklearn/ensemble/forest.py b/sklearn/ensemble/forest.py index ccef3ba9b67f4..dca12c219d948 100644 --- a/sklearn/ensemble/forest.py +++ b/sklearn/ensemble/forest.py @@ -80,7 +80,7 @@ def _get_n_samples_bootstrap(n_samples, max_samples): ---------- n_samples : int Number of samples in the dataset. - max_samples : int, float or None (default=None) + max_samples : int, float or None, default=None The maximum number of samples to draw from the total available: - float indicates a fraction of the total. - int indicates the exact number of samples. @@ -980,7 +980,7 @@ class RandomForestClassifier(ForestClassifier): .. versionadded:: 0.22 - max_samples : int or float (default=None) + max_samples : int or float, default=None The number of samples to draw from X to train each base estimator. - If None (default), then draw `X.shape[0]` samples. - If int, then draw `max_samples` samples. @@ -1263,7 +1263,7 @@ class RandomForestRegressor(ForestRegressor): .. versionadded:: 0.22 - max_samples : int or float (default=None) + max_samples : int or float, default=None The number of samples to draw from X to train each base estimator. - If None (default), then draw `X.shape[0]` samples. - If int, then draw `max_samples` samples. @@ -1557,7 +1557,7 @@ class ExtraTreesClassifier(ForestClassifier): .. versionadded:: 0.22 - max_samples : int or float (default=None) + max_samples : int or float, default=None The number of samples to draw from X to train each base estimator. - If None (default), then draw `X.shape[0]` samples. - If int, then draw `max_samples` samples. @@ -1817,7 +1817,7 @@ class ExtraTreesRegressor(ForestRegressor): .. versionadded:: 0.22 - max_samples : int or float (default=None) + max_samples : int or float, default=None The number of samples to draw from X to train each base estimator. - If None (default), then draw `X.shape[0]` samples. - If int, then draw `max_samples` samples. @@ -2040,7 +2040,7 @@ class RandomTreesEmbedding(BaseForest): .. versionadded:: 0.22 - max_samples : int or float (default=None) + max_samples : int or float, default=None The number of samples to draw from X to train each base estimator. - If None (default), then draw `X.shape[0]` samples. - If int, then draw `max_samples` samples. From 37104d3528ac45dab5f3e57ed5f377b0e66a5651 Mon Sep 17 00:00:00 2001 From: matt Date: Mon, 26 Aug 2019 19:58:36 -0500 Subject: [PATCH 14/27] Update grammar in docstring --- sklearn/ensemble/forest.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/sklearn/ensemble/forest.py b/sklearn/ensemble/forest.py index dca12c219d948..af06873fd2954 100644 --- a/sklearn/ensemble/forest.py +++ b/sklearn/ensemble/forest.py @@ -80,11 +80,11 @@ def _get_n_samples_bootstrap(n_samples, max_samples): ---------- n_samples : int Number of samples in the dataset. - max_samples : int, float or None, default=None + max_samples : int or float The maximum number of samples to draw from the total available: - - float indicates a fraction of the total. - - int indicates the exact number of samples. - - None indicates the total number of samples + - if float, this indicates a fraction of the total; + - if int, this indicates the exact number of samples; + - if None, this indicates the total number of samples. Returns ------- @@ -100,7 +100,7 @@ def _get_n_samples_bootstrap(n_samples, max_samples): return max_samples elif isinstance(max_samples, numbers.Real): if not (0 < max_samples < 1.0): - msg = "`max_samples` must be in range (0, 1.0) but got value {}" + msg = "`max_samples` must be in range (0, 1) but got value {}" raise ValueError(msg.format(max_samples)) return int(round(n_samples * max_samples)) else: From 93e184093b6c417ef8fde899ff74a54e467b9642 Mon Sep 17 00:00:00 2001 From: matt Date: Mon, 26 Aug 2019 20:02:59 -0500 Subject: [PATCH 15/27] Refactor conditional structures --- sklearn/ensemble/forest.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/sklearn/ensemble/forest.py b/sklearn/ensemble/forest.py index af06873fd2954..64542ea54dfe7 100644 --- a/sklearn/ensemble/forest.py +++ b/sklearn/ensemble/forest.py @@ -93,19 +93,21 @@ def _get_n_samples_bootstrap(n_samples, max_samples): """ if max_samples is None: return n_samples - elif isinstance(max_samples, numbers.Integral): + + if isinstance(max_samples, numbers.Integral): if not (1 <= max_samples <= n_samples): msg = "`max_samples` must be in range 1 to {} but got value {}" raise ValueError(msg.format(n_samples, max_samples)) return max_samples - elif isinstance(max_samples, numbers.Real): - if not (0 < max_samples < 1.0): + + if isinstance(max_samples, numbers.Real): + if not (0 < max_samples < 1): msg = "`max_samples` must be in range (0, 1) but got value {}" raise ValueError(msg.format(max_samples)) return int(round(n_samples * max_samples)) - else: - msg = "`max_samples` should be int or float, but got type '{}'" - raise TypeError(msg.format(type(max_samples))) + + msg = "`max_samples` should be int or float, but got type '{}'" + raise TypeError(msg.format(type(max_samples))) def _generate_sample_indices(random_state, n_samples, n_samples_bootstrap): From 005d6ca039937f1af9848c26031e21bd7cb80e5d Mon Sep 17 00:00:00 2001 From: matt Date: Mon, 26 Aug 2019 20:11:03 -0500 Subject: [PATCH 16/27] Add version added tag; change call style --- sklearn/ensemble/forest.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/sklearn/ensemble/forest.py b/sklearn/ensemble/forest.py index 64542ea54dfe7..752668e288998 100644 --- a/sklearn/ensemble/forest.py +++ b/sklearn/ensemble/forest.py @@ -323,7 +323,9 @@ def fit(self, X, y, sample_weight=None): # Get bootstrap sample size n_samples_bootstrap = _get_n_samples_bootstrap( - n_samples=X.shape[0], max_samples=self.max_samples) + n_samples=X.shape[0], + max_samples=self.max_samples + ) # Check parameters self._validate_estimator() @@ -487,7 +489,8 @@ def _set_oob_score(self, X, y): for k in range(self.n_outputs_)] n_samples_bootstrap = _get_n_samples_bootstrap( - n_samples, self.max_samples) + n_samples, self.max_samples + ) for estimator in self.estimators_: unsampled_indices = _generate_unsampled_indices( @@ -770,7 +773,8 @@ def _set_oob_score(self, X, y): n_predictions = np.zeros((n_samples, self.n_outputs_)) n_samples_bootstrap = _get_n_samples_bootstrap( - n_samples, self.max_samples) + n_samples, self.max_samples + ) for estimator in self.estimators_: unsampled_indices = _generate_unsampled_indices( @@ -988,6 +992,8 @@ class RandomForestClassifier(ForestClassifier): - If int, then draw `max_samples` samples. - If float, then draw `max_samples * X.shape[0]` samples. + .. versionadded:: 0.22 + Attributes ---------- base_estimator_ : DecisionTreeClassifier @@ -1271,6 +1277,8 @@ class RandomForestRegressor(ForestRegressor): - If int, then draw `max_samples` samples. - If float, then draw `max_samples * X.shape[0]` samples. + .. versionadded:: 0.22 + Attributes ---------- base_estimator_ : DecisionTreeRegressor @@ -1565,6 +1573,8 @@ class ExtraTreesClassifier(ForestClassifier): - If int, then draw `max_samples` samples. - If float, then draw `max_samples * X.shape[0]` samples. + .. versionadded:: 0.22 + Attributes ---------- base_estimator_ : ExtraTreeClassifier @@ -1825,6 +1835,8 @@ class ExtraTreesRegressor(ForestRegressor): - If int, then draw `max_samples` samples. - If float, then draw `max_samples * X.shape[0]` samples. + .. versionadded:: 0.22 + Attributes ---------- base_estimator_ : ExtraTreeRegressor @@ -2048,6 +2060,8 @@ class RandomTreesEmbedding(BaseForest): - If int, then draw `max_samples` samples. - If float, then draw `max_samples * X.shape[0]` samples. + .. versionadded:: 0.22 + Attributes ---------- estimators_ : list of DecisionTreeClassifier From eac5ad632806e32ff409963d2ba82177c1eed622 Mon Sep 17 00:00:00 2001 From: matt Date: Mon, 26 Aug 2019 20:11:59 -0500 Subject: [PATCH 17/27] Remove docstring from unit test --- sklearn/ensemble/tests/test_forest.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sklearn/ensemble/tests/test_forest.py b/sklearn/ensemble/tests/test_forest.py index 0701c5f7a7d97..ab391d373e646 100644 --- a/sklearn/ensemble/tests/test_forest.py +++ b/sklearn/ensemble/tests/test_forest.py @@ -1345,8 +1345,7 @@ def test_forest_degenerate_feature_importances(): (np.ones(2), TypeError)] ) def test_max_samples_exceptions(name, max_samples_exception): - """ Check invalid `max_samples` values - """ + # Check invalid `max_samples` values max_samples, exception = max_samples_exception est = FOREST_CLASSIFIERS_REGRESSORS[name](max_samples=max_samples) with pytest.raises(exception): From d9de5416027cff3f8f0003836eaacf9dd6d98ed3 Mon Sep 17 00:00:00 2001 From: matt Date: Mon, 26 Aug 2019 20:26:34 -0500 Subject: [PATCH 18/27] Add exception message checks to unit test --- sklearn/ensemble/tests/test_forest.py | 34 +++++++++++++++++---------- 1 file changed, 22 insertions(+), 12 deletions(-) diff --git a/sklearn/ensemble/tests/test_forest.py b/sklearn/ensemble/tests/test_forest.py index ab391d373e646..e629d9170d1d6 100644 --- a/sklearn/ensemble/tests/test_forest.py +++ b/sklearn/ensemble/tests/test_forest.py @@ -1334,22 +1334,32 @@ def test_forest_degenerate_feature_importances(): @pytest.mark.parametrize('name', FOREST_CLASSIFIERS_REGRESSORS) @pytest.mark.parametrize( - 'max_samples_exception', - [(int(1e9), ValueError), - (1.0, ValueError), - (2.0, ValueError), - (0.0, ValueError), - (np.nan, ValueError), - (np.inf, ValueError), - ('str max_samples?!', TypeError), - (np.ones(2), TypeError)] + 'max_samples, exc_type, exc_msg', + [(int(1e9), ValueError, + "`max_samples` must be in range 1 to 6 but got value 1000000000"), + (1.0, ValueError, + "`max_samples` must be in range (0, 1) but got value 1.0"), + (2.0, ValueError, + "`max_samples` must be in range (0, 1) but got value 2.0"), + (0.0, ValueError, + "`max_samples` must be in range (0, 1) but got value 0.0"), + (np.nan, ValueError, + "`max_samples` must be in range (0, 1) but got value nan"), + (np.inf, ValueError, + "`max_samples` must be in range (0, 1) but got value inf"), + ('str max_samples?!', TypeError, + "`max_samples` should be int or float, but got type ''"), + (np.ones(2), TypeError, + "`max_samples` should be int or float, but got type " + "''")] ) -def test_max_samples_exceptions(name, max_samples_exception): +def test_max_samples_exceptions(name, max_samples, exc_type, exc_msg): # Check invalid `max_samples` values - max_samples, exception = max_samples_exception + print(exc_msg) est = FOREST_CLASSIFIERS_REGRESSORS[name](max_samples=max_samples) - with pytest.raises(exception): + with pytest.raises(exc_type) as exc_info: est.fit(X, y) + assert str(exc_info.value) == exc_msg def check_classification_toy_max_samples(name): From 93e579938588829a44f80d7bd4e8a9e8d39e0984 Mon Sep 17 00:00:00 2001 From: matt Date: Mon, 26 Aug 2019 20:29:37 -0500 Subject: [PATCH 19/27] Refactor toy data unit test --- sklearn/ensemble/tests/test_forest.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/sklearn/ensemble/tests/test_forest.py b/sklearn/ensemble/tests/test_forest.py index e629d9170d1d6..2d27a86410e66 100644 --- a/sklearn/ensemble/tests/test_forest.py +++ b/sklearn/ensemble/tests/test_forest.py @@ -1355,16 +1355,13 @@ def test_forest_degenerate_feature_importances(): ) def test_max_samples_exceptions(name, max_samples, exc_type, exc_msg): # Check invalid `max_samples` values - print(exc_msg) est = FOREST_CLASSIFIERS_REGRESSORS[name](max_samples=max_samples) - with pytest.raises(exc_type) as exc_info: + with pytest.raises(exc_type, match=exc_msg): est.fit(X, y) - assert str(exc_info.value) == exc_msg def check_classification_toy_max_samples(name): - """ Test that the toy example is separable via a bootstrap size of only 2 - """ + # Test that the toy example is separable via a bootstrap size of only 2 rng = np.random.RandomState(1) max_tries = 100 @@ -1391,7 +1388,8 @@ def check_classification_toy_max_samples(name): perfect_score = True break - assert perfect_score + msg = "Perfect accuracy is achievable with `max_samples=2` on toy data" + assert perfect_score, msg @pytest.mark.parametrize('name', FOREST_CLASSIFIERS) From 9cdd5c279e1c33cfa477daf85979f7175793620a Mon Sep 17 00:00:00 2001 From: matt Date: Mon, 26 Aug 2019 20:55:16 -0500 Subject: [PATCH 20/27] Add node count check unit test --- sklearn/ensemble/tests/test_forest.py | 38 +++++++++++++++++++++++---- 1 file changed, 33 insertions(+), 5 deletions(-) diff --git a/sklearn/ensemble/tests/test_forest.py b/sklearn/ensemble/tests/test_forest.py index 2d27a86410e66..29f93bb15d97f 100644 --- a/sklearn/ensemble/tests/test_forest.py +++ b/sklearn/ensemble/tests/test_forest.py @@ -1356,11 +1356,13 @@ def test_forest_degenerate_feature_importances(): def test_max_samples_exceptions(name, max_samples, exc_type, exc_msg): # Check invalid `max_samples` values est = FOREST_CLASSIFIERS_REGRESSORS[name](max_samples=max_samples) - with pytest.raises(exc_type, match=exc_msg): + with pytest.raises(exc_type) as exc_info: est.fit(X, y) + assert str(exc_info.value) == exc_msg, "Exception message does not match" -def check_classification_toy_max_samples(name): +@pytest.mark.parametrize('name', FOREST_CLASSIFIERS) +def test_classification_toy_max_samples(name): # Test that the toy example is separable via a bootstrap size of only 2 rng = np.random.RandomState(1) @@ -1392,6 +1394,32 @@ def check_classification_toy_max_samples(name): assert perfect_score, msg -@pytest.mark.parametrize('name', FOREST_CLASSIFIERS) -def test_classification_toy_max_samples(name): - check_classification_toy_max_samples(name) +@pytest.mark.parametrize('name', FOREST_CLASSIFIERS_REGRESSORS) +def test_little_tree_with_small_max_samples(name): + rng = np.random.RandomState(1) + + X = rng.randn(10000, 2) + y = rng.randn(10000) > 0 + + # First fit with no restriction on max samples + est1 = FOREST_CLASSIFIERS_REGRESSORS[name]( + n_estimators=1, + random_state=rng, + max_samples=None + ) + + # Second fit with max samples restricted to just 2 + est2 = FOREST_CLASSIFIERS_REGRESSORS[name]( + n_estimators=1, + random_state=rng, + max_samples=2, + ) + + est1.fit(X, y) + est2.fit(X, y) + + tree1 = est1.estimators_[0].tree_ + tree2 = est2.estimators_[0].tree_ + + msg = "Tree without `max_samples` restriction should have more nodes" + assert tree1.node_count > tree2.node_count, msg From 9d72d36c2fe9ebb96892308b17cd1ae72809c555 Mon Sep 17 00:00:00 2001 From: matt Date: Tue, 10 Sep 2019 18:59:44 -0500 Subject: [PATCH 21/27] Limit unit test to RandomForest* --- sklearn/ensemble/tests/test_forest.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/sklearn/ensemble/tests/test_forest.py b/sklearn/ensemble/tests/test_forest.py index 29f93bb15d97f..5fbe93e1d7155 100644 --- a/sklearn/ensemble/tests/test_forest.py +++ b/sklearn/ensemble/tests/test_forest.py @@ -1394,22 +1394,24 @@ def test_classification_toy_max_samples(name): assert perfect_score, msg -@pytest.mark.parametrize('name', FOREST_CLASSIFIERS_REGRESSORS) -def test_little_tree_with_small_max_samples(name): +@pytest.mark.parametrize( + 'forest_class', [RandomForestClassifier, RandomForestRegressor] +) +def test_little_tree_with_small_max_samples(forest_class): rng = np.random.RandomState(1) X = rng.randn(10000, 2) y = rng.randn(10000) > 0 # First fit with no restriction on max samples - est1 = FOREST_CLASSIFIERS_REGRESSORS[name]( + est1 = forest_class( n_estimators=1, random_state=rng, - max_samples=None + max_samples=None, ) # Second fit with max samples restricted to just 2 - est2 = FOREST_CLASSIFIERS_REGRESSORS[name]( + est2 = forest_class( n_estimators=1, random_state=rng, max_samples=2, From c95cecf96933fff8529aee146f58a46ed85b583e Mon Sep 17 00:00:00 2001 From: matt Date: Tue, 10 Sep 2019 19:29:24 -0500 Subject: [PATCH 22/27] Add whats new entry --- doc/whats_new/v0.22.rst | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/doc/whats_new/v0.22.rst b/doc/whats_new/v0.22.rst index b99c9b0e3f334..2d99e5b4f9481 100644 --- a/doc/whats_new/v0.22.rst +++ b/doc/whats_new/v0.22.rst @@ -167,6 +167,19 @@ Changelog `predict_proba` give consistent results. :pr:`14114` by :user:`Guillaume Lemaitre `. +- |Enhancement| Addition of ``max_samples`` argument allows limiting + size of bootstrap samples to be less than size of dataset. Added to + :class:`ensemble.forest.ForestClassifier`, + :class:`ensemble.forest.ForestRegressor`, + :class:`ensemble.forest.RandomForestClassifier`, + :class:`ensemble.forest.RandomForestRegressor`, + :class:`ensemble.forest.ExtraTreesClassifier`, + :class:`ensemble.forest.ExtraTreesRegressor`, + :class:`ensemble.forest.RandomTreesEmbedding`. :pr:`14682` by + :user:`Matt Hancock ` and + :pr:`5963` by :user:`Pablo Duboue `. + + :mod:`sklearn.feature_extraction` ................................. From 454819175355930a60865ca93fb86a1b26a8dee5 Mon Sep 17 00:00:00 2001 From: matt Date: Thu, 12 Sep 2019 17:47:41 -0500 Subject: [PATCH 23/27] Include bootstrap condition in docstring comments --- sklearn/ensemble/forest.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/sklearn/ensemble/forest.py b/sklearn/ensemble/forest.py index 3fc345a774eaa..129e6c28cd8a2 100644 --- a/sklearn/ensemble/forest.py +++ b/sklearn/ensemble/forest.py @@ -986,7 +986,8 @@ class RandomForestClassifier(ForestClassifier): .. versionadded:: 0.22 max_samples : int or float, default=None - The number of samples to draw from X to train each base estimator. + If bootstrap is True, the number of samples to draw from X + to train each base estimator. - If None (default), then draw `X.shape[0]` samples. - If int, then draw `max_samples` samples. - If float, then draw `max_samples * X.shape[0]` samples. @@ -1272,7 +1273,8 @@ class RandomForestRegressor(ForestRegressor): .. versionadded:: 0.22 max_samples : int or float, default=None - The number of samples to draw from X to train each base estimator. + If bootstrap is True, the number of samples to draw from X + to train each base estimator. - If None (default), then draw `X.shape[0]` samples. - If int, then draw `max_samples` samples. - If float, then draw `max_samples * X.shape[0]` samples. @@ -1569,7 +1571,8 @@ class ExtraTreesClassifier(ForestClassifier): .. versionadded:: 0.22 max_samples : int or float, default=None - The number of samples to draw from X to train each base estimator. + If bootstrap is True, the number of samples to draw from X + to train each base estimator. - If None (default), then draw `X.shape[0]` samples. - If int, then draw `max_samples` samples. - If float, then draw `max_samples * X.shape[0]` samples. @@ -1832,7 +1835,8 @@ class ExtraTreesRegressor(ForestRegressor): .. versionadded:: 0.22 max_samples : int or float, default=None - The number of samples to draw from X to train each base estimator. + If bootstrap is True, the number of samples to draw from X + to train each base estimator. - If None (default), then draw `X.shape[0]` samples. - If int, then draw `max_samples` samples. - If float, then draw `max_samples * X.shape[0]` samples. @@ -2058,7 +2062,8 @@ class RandomTreesEmbedding(BaseForest): .. versionadded:: 0.22 max_samples : int or float, default=None - The number of samples to draw from X to train each base estimator. + If bootstrap is True, the number of samples to draw from X + to train each base estimator. - If None (default), then draw `X.shape[0]` samples. - If int, then draw `max_samples` samples. - If float, then draw `max_samples * X.shape[0]` samples. From 58f005c8b5256efb62fe08e23d15e560bcc8561b Mon Sep 17 00:00:00 2001 From: matt Date: Thu, 12 Sep 2019 17:49:16 -0500 Subject: [PATCH 24/27] Rename forest_class -> ForestClass --- sklearn/ensemble/tests/test_forest.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sklearn/ensemble/tests/test_forest.py b/sklearn/ensemble/tests/test_forest.py index 5fbe93e1d7155..941cfe2005e1a 100644 --- a/sklearn/ensemble/tests/test_forest.py +++ b/sklearn/ensemble/tests/test_forest.py @@ -1395,23 +1395,23 @@ def test_classification_toy_max_samples(name): @pytest.mark.parametrize( - 'forest_class', [RandomForestClassifier, RandomForestRegressor] + 'ForestClass', [RandomForestClassifier, RandomForestRegressor] ) -def test_little_tree_with_small_max_samples(forest_class): +def test_little_tree_with_small_max_samples(ForestClass): rng = np.random.RandomState(1) X = rng.randn(10000, 2) y = rng.randn(10000) > 0 # First fit with no restriction on max samples - est1 = forest_class( + est1 = ForestClass( n_estimators=1, random_state=rng, max_samples=None, ) # Second fit with max samples restricted to just 2 - est2 = forest_class( + est2 = ForestClass( n_estimators=1, random_state=rng, max_samples=2, From 7869f87f10a6bfeef711d53753a94a0da7d911a7 Mon Sep 17 00:00:00 2001 From: matt Date: Fri, 13 Sep 2019 18:02:03 -0500 Subject: [PATCH 25/27] Escape all the special chars --- sklearn/ensemble/tests/test_forest.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/sklearn/ensemble/tests/test_forest.py b/sklearn/ensemble/tests/test_forest.py index 941cfe2005e1a..b92bd41f0a05f 100644 --- a/sklearn/ensemble/tests/test_forest.py +++ b/sklearn/ensemble/tests/test_forest.py @@ -1338,27 +1338,27 @@ def test_forest_degenerate_feature_importances(): [(int(1e9), ValueError, "`max_samples` must be in range 1 to 6 but got value 1000000000"), (1.0, ValueError, - "`max_samples` must be in range (0, 1) but got value 1.0"), + r"`max_samples` must be in range \(0, 1\) but got value 1.0"), (2.0, ValueError, - "`max_samples` must be in range (0, 1) but got value 2.0"), + r"`max_samples` must be in range \(0, 1\) but got value 2.0"), (0.0, ValueError, - "`max_samples` must be in range (0, 1) but got value 0.0"), + r"`max_samples` must be in range \(0, 1\) but got value 0.0"), (np.nan, ValueError, - "`max_samples` must be in range (0, 1) but got value nan"), + r"`max_samples` must be in range \(0, 1\) but got value nan"), (np.inf, ValueError, - "`max_samples` must be in range (0, 1) but got value inf"), + r"`max_samples` must be in range \(0, 1\) but got value inf"), ('str max_samples?!', TypeError, - "`max_samples` should be int or float, but got type ''"), + r"`max_samples` should be int or float, but got " + r"type '\'"), (np.ones(2), TypeError, - "`max_samples` should be int or float, but got type " - "''")] + r"`max_samples` should be int or float, but got type " + r"'\'")] ) def test_max_samples_exceptions(name, max_samples, exc_type, exc_msg): # Check invalid `max_samples` values est = FOREST_CLASSIFIERS_REGRESSORS[name](max_samples=max_samples) - with pytest.raises(exc_type) as exc_info: + with pytest.raises(exc_type, match=exc_msg): est.fit(X, y) - assert str(exc_info.value) == exc_msg, "Exception message does not match" @pytest.mark.parametrize('name', FOREST_CLASSIFIERS) From e8535d42a3f26e663f51786db4fdd30604d7fc5a Mon Sep 17 00:00:00 2001 From: matt Date: Sat, 14 Sep 2019 08:25:55 -0500 Subject: [PATCH 26/27] Remove extraneous test --- sklearn/ensemble/tests/test_forest.py | 33 --------------------------- 1 file changed, 33 deletions(-) diff --git a/sklearn/ensemble/tests/test_forest.py b/sklearn/ensemble/tests/test_forest.py index b92bd41f0a05f..b41d0e0e1ba13 100644 --- a/sklearn/ensemble/tests/test_forest.py +++ b/sklearn/ensemble/tests/test_forest.py @@ -1361,39 +1361,6 @@ def test_max_samples_exceptions(name, max_samples, exc_type, exc_msg): est.fit(X, y) -@pytest.mark.parametrize('name', FOREST_CLASSIFIERS) -def test_classification_toy_max_samples(name): - # Test that the toy example is separable via a bootstrap size of only 2 - - rng = np.random.RandomState(1) - max_tries = 100 - - # The toy example is separable using just one - # decision stump, and choosing 2 examples from the full - # 6-example dataset *if* the 2 examples are chosen correctly. - est = FOREST_CLASSIFIERS[name]( - n_estimators=1, - bootstrap=True, - max_samples=2, - max_depth=1, - random_state=rng, - ) - - # Each call to fit uses a different bootstrap sample of size two. If we - # fit multiple times, we expect that we eventually hit a case where - # the two examples chosen for the bootstrap sample are from the opposite - # class and yield a perfect score across the entire dataset. - perfect_score = False - for _ in range(max_tries): - est.fit(X, y) - if est.score(X, y) == 1.0: - perfect_score = True - break - - msg = "Perfect accuracy is achievable with `max_samples=2` on toy data" - assert perfect_score, msg - - @pytest.mark.parametrize( 'ForestClass', [RandomForestClassifier, RandomForestRegressor] ) From 5b72dccb1f17c630e6f97efffe0cb37a3f804383 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 20 Sep 2019 09:58:11 +0200 Subject: [PATCH 27/27] comments adrin --- sklearn/ensemble/forest.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/sklearn/ensemble/forest.py b/sklearn/ensemble/forest.py index 129e6c28cd8a2..a062c913aadcb 100644 --- a/sklearn/ensemble/forest.py +++ b/sklearn/ensemble/forest.py @@ -82,7 +82,8 @@ def _get_n_samples_bootstrap(n_samples, max_samples): Number of samples in the dataset. max_samples : int or float The maximum number of samples to draw from the total available: - - if float, this indicates a fraction of the total; + - if float, this indicates a fraction of the total and should be + the interval `(0, 1)`; - if int, this indicates the exact number of samples; - if None, this indicates the total number of samples. @@ -990,7 +991,8 @@ class RandomForestClassifier(ForestClassifier): to train each base estimator. - If None (default), then draw `X.shape[0]` samples. - If int, then draw `max_samples` samples. - - If float, then draw `max_samples * X.shape[0]` samples. + - If float, then draw `max_samples * X.shape[0]` samples. Thus, + `max_samples` should be in the interval `(0, 1)`. .. versionadded:: 0.22 @@ -1277,7 +1279,8 @@ class RandomForestRegressor(ForestRegressor): to train each base estimator. - If None (default), then draw `X.shape[0]` samples. - If int, then draw `max_samples` samples. - - If float, then draw `max_samples * X.shape[0]` samples. + - If float, then draw `max_samples * X.shape[0]` samples. Thus, + `max_samples` should be in the interval `(0, 1)`. .. versionadded:: 0.22 @@ -1575,7 +1578,8 @@ class ExtraTreesClassifier(ForestClassifier): to train each base estimator. - If None (default), then draw `X.shape[0]` samples. - If int, then draw `max_samples` samples. - - If float, then draw `max_samples * X.shape[0]` samples. + - If float, then draw `max_samples * X.shape[0]` samples. Thus, + `max_samples` should be in the interval `(0, 1)`. .. versionadded:: 0.22 @@ -1839,7 +1843,8 @@ class ExtraTreesRegressor(ForestRegressor): to train each base estimator. - If None (default), then draw `X.shape[0]` samples. - If int, then draw `max_samples` samples. - - If float, then draw `max_samples * X.shape[0]` samples. + - If float, then draw `max_samples * X.shape[0]` samples. Thus, + `max_samples` should be in the interval `(0, 1)`. .. versionadded:: 0.22 @@ -2066,7 +2071,8 @@ class RandomTreesEmbedding(BaseForest): to train each base estimator. - If None (default), then draw `X.shape[0]` samples. - If int, then draw `max_samples` samples. - - If float, then draw `max_samples * X.shape[0]` samples. + - If float, then draw `max_samples * X.shape[0]` samples. Thus, + `max_samples` should be in the interval `(0, 1)`. .. versionadded:: 0.22