Skip to content

[WIP] Issue #11993 - add node_bootstrap param to RandomForest #17504

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 10 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 33 additions & 56 deletions sklearn/ensemble/_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ class calls the ``fit`` method of each sub-estimator on random samples
# License: BSD 3 clause


import numbers
from warnings import catch_warnings, simplefilter, warn
import threading

Expand All @@ -56,7 +55,9 @@ class calls the ``fit`` method of each sub-estimator on random samples
from ..tree import (DecisionTreeClassifier, DecisionTreeRegressor,
ExtraTreeClassifier, ExtraTreeRegressor)
from ..tree._tree import DTYPE, DOUBLE
from ..utils import check_random_state, check_array, compute_sample_weight
from ..utils import (check_random_state, check_array,
compute_sample_weight, _generate_sample_indices,
_get_n_samples_bootstrap)
from ..exceptions import DataConversionWarning
from ._base import BaseEnsemble, _partition_estimators
from ..utils.fixes import _joblib_parallel_args
Expand All @@ -74,55 +75,6 @@ class calls the ``fit`` method of each sub-estimator on random samples
MAX_INT = np.iinfo(np.int32).max


def _get_n_samples_bootstrap(n_samples, max_samples):
"""
Get the number of samples in a bootstrap sample.

Parameters
----------
n_samples : int
Number of samples in the dataset.
max_samples : int or float
The maximum number of samples to draw from the total available:
- if float, this indicates a fraction of the total and should be
the interval `(0, 1)`;
- if int, this indicates the exact number of samples;
- if None, this indicates the total number of samples.

Returns
-------
n_samples_bootstrap : int
The total number of samples to draw for the bootstrap sample.
"""
if max_samples is None:
return n_samples

if isinstance(max_samples, numbers.Integral):
if not (1 <= max_samples <= n_samples):
msg = "`max_samples` must be in range 1 to {} but got value {}"
raise ValueError(msg.format(n_samples, max_samples))
return max_samples

if isinstance(max_samples, numbers.Real):
if not (0 < max_samples < 1):
msg = "`max_samples` must be in range (0, 1) but got value {}"
raise ValueError(msg.format(max_samples))
return round(n_samples * max_samples)

msg = "`max_samples` should be int or float, but got type '{}'"
raise TypeError(msg.format(type(max_samples)))


def _generate_sample_indices(random_state, n_samples, n_samples_bootstrap):
"""
Private function used to _parallel_build_trees function."""

random_instance = check_random_state(random_state)
sample_indices = random_instance.randint(0, n_samples, n_samples_bootstrap)

return sample_indices


def _generate_unsampled_indices(random_state, n_samples, n_samples_bootstrap):
"""
Private function used to forest._set_oob_score function."""
Expand Down Expand Up @@ -186,26 +138,30 @@ def __init__(self,
n_estimators=100, *,
estimator_params=tuple(),
bootstrap=False,
node_bootstrap=False,
oob_score=False,
n_jobs=None,
random_state=None,
verbose=0,
warm_start=False,
class_weight=None,
max_samples=None):
max_samples=None,
node_max_samples=None):
super().__init__(
base_estimator=base_estimator,
n_estimators=n_estimators,
estimator_params=estimator_params)

self.bootstrap = bootstrap
self.node_bootstrap = node_bootstrap
self.oob_score = oob_score
self.n_jobs = n_jobs
self.random_state = random_state
self.verbose = verbose
self.warm_start = warm_start
self.class_weight = class_weight
self.max_samples = max_samples
self.node_max_samples = node_max_samples

def apply(self, X):
"""
Expand Down Expand Up @@ -486,25 +442,29 @@ def __init__(self,
n_estimators=100, *,
estimator_params=tuple(),
bootstrap=False,
node_bootstrap=False,
oob_score=False,
n_jobs=None,
random_state=None,
verbose=0,
warm_start=False,
class_weight=None,
max_samples=None):
max_samples=None,
node_max_samples=None):
super().__init__(
base_estimator,
n_estimators=n_estimators,
estimator_params=estimator_params,
bootstrap=bootstrap,
node_bootstrap=node_bootstrap,
oob_score=oob_score,
n_jobs=n_jobs,
random_state=random_state,
verbose=verbose,
warm_start=warm_start,
class_weight=class_weight,
max_samples=max_samples)
max_samples=max_samples,
node_max_samples=node_max_samples)

def _set_oob_score(self, X, y):
"""
Expand Down Expand Up @@ -880,6 +840,7 @@ def _compute_partial_dependence_recursion(self, grid, target_features):

return averaged_predictions


class RandomForestClassifier(ForestClassifier):
"""
A random forest classifier.
Expand Down Expand Up @@ -997,6 +958,9 @@ class RandomForestClassifier(ForestClassifier):
Whether bootstrap samples are used when building trees. If False, the
whole dataset is used to build each tree.

node_bootstrap : bool, default=False
Whether bootstrap samples are selected at each node.

oob_score : bool, default=False
Whether to use out-of-bag samples to estimate
the generalization accuracy.
Expand Down Expand Up @@ -1066,6 +1030,16 @@ class RandomForestClassifier(ForestClassifier):
- If float, then draw `max_samples * X.shape[0]` samples. Thus,
`max_samples` should be in the interval `(0, 1)`.

node_max_samples : int or float, default=None
If node_bootstrap is True, the number of samples to draw at each node
to select the best split criterion.

- If None (default), then draw `node size` samples.
- If int, then draw `node_max_samples` samples.
- If float, then draw `node_max_samples * node size` samples. Thus,
`node_max_samples` should be in the interval `(0, 1)`.


.. versionadded:: 0.22

Attributes
Expand Down Expand Up @@ -1162,22 +1136,25 @@ def __init__(self,
min_impurity_decrease=0.,
min_impurity_split=None,
bootstrap=True,
node_bootstrap=False,
oob_score=False,
n_jobs=None,
random_state=None,
verbose=0,
warm_start=False,
class_weight=None,
ccp_alpha=0.0,
max_samples=None):
max_samples=None,
node_max_samples=None):
super().__init__(
base_estimator=DecisionTreeClassifier(),
n_estimators=n_estimators,
estimator_params=("criterion", "max_depth", "min_samples_split",
"min_samples_leaf", "min_weight_fraction_leaf",
"max_features", "max_leaf_nodes",
"min_impurity_decrease", "min_impurity_split",
"random_state", "ccp_alpha"),
"random_state", "ccp_alpha", "node_bootstrap",
"node_max_samples"),
bootstrap=bootstrap,
oob_score=oob_score,
n_jobs=n_jobs,
Expand Down
44 changes: 36 additions & 8 deletions sklearn/tree/_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@
from ..base import is_classifier
from ..base import MultiOutputMixin
from ..utils import Bunch
from ..utils import check_array
from ..utils import check_random_state
from ..utils import (check_array, check_random_state,
_get_n_samples_bootstrap, _generate_sample_indices)
from ..utils.validation import _check_sample_weight
from ..utils import compute_sample_weight
from ..utils.multiclass import check_classification_targets
Expand Down Expand Up @@ -97,7 +97,9 @@ def __init__(self, *,
min_impurity_decrease,
min_impurity_split,
class_weight=None,
ccp_alpha=0.0):
ccp_alpha=0.0,
node_bootstrap=False,
node_max_samples=None):
self.criterion = criterion
self.splitter = splitter
self.max_depth = max_depth
Expand All @@ -111,6 +113,8 @@ def __init__(self, *,
self.min_impurity_split = min_impurity_split
self.class_weight = class_weight
self.ccp_alpha = ccp_alpha
self.node_bootstrap = node_bootstrap
self.node_max_samples = node_max_samples

def get_depth(self):
"""Return the depth of the decision tree.
Expand Down Expand Up @@ -370,7 +374,14 @@ def fit(self, X, y, sample_weight=None, check_input=True,
self.min_impurity_decrease,
min_impurity_split)

builder.build(self.tree_, X, y, sample_weight)
if self.node_bootstrap:
boot_size = _get_n_samples_bootstrap(X.shape[0],
self.node_max_samples)
ind = _generate_sample_indices(self.random_state, X.shape[0],
boot_size)
builder.build(self.tree_, X[ind], y[ind], sample_weight)
else:
builder.build(self.tree_, X, y, sample_weight)

if self.n_outputs_ == 1 and is_classifier(self):
self.n_classes_ = self.n_classes_[0]
Expand Down Expand Up @@ -732,6 +743,11 @@ class DecisionTreeClassifier(ClassifierMixin, BaseDecisionTree):
``ccp_alpha`` will be chosen. By default, no pruning is performed. See
:ref:`minimal_cost_complexity_pruning` for details.

node_bootstrap : FIXME (do we want to show this param ? or only use it
with ensemble methods ?)

node_max_samples : FIXME

.. versionadded:: 0.22

Attributes
Expand Down Expand Up @@ -824,7 +840,9 @@ def __init__(self, *,
min_impurity_decrease=0.,
min_impurity_split=None,
class_weight=None,
ccp_alpha=0.0):
ccp_alpha=0.0,
node_bootstrap=False,
node_max_samples=None):
super().__init__(
criterion=criterion,
splitter=splitter,
Expand All @@ -838,7 +856,9 @@ def __init__(self, *,
random_state=random_state,
min_impurity_decrease=min_impurity_decrease,
min_impurity_split=min_impurity_split,
ccp_alpha=ccp_alpha)
ccp_alpha=ccp_alpha,
node_bootstrap=node_bootstrap,
node_max_samples=node_max_samples)

def fit(self, X, y, sample_weight=None, check_input=True,
X_idx_sorted="deprecated"):
Expand Down Expand Up @@ -1088,6 +1108,10 @@ class DecisionTreeRegressor(RegressorMixin, BaseDecisionTree):
``ccp_alpha`` will be chosen. By default, no pruning is performed. See
:ref:`minimal_cost_complexity_pruning` for details.

node_bootstrap : FIXME

node_max_samples : FIXME

.. versionadded:: 0.22

Attributes
Expand Down Expand Up @@ -1170,7 +1194,9 @@ def __init__(self, *,
max_leaf_nodes=None,
min_impurity_decrease=0.,
min_impurity_split=None,
ccp_alpha=0.0):
ccp_alpha=0.0,
node_bootstrap=False,
node_max_samples=None):
super().__init__(
criterion=criterion,
splitter=splitter,
Expand All @@ -1183,7 +1209,9 @@ def __init__(self, *,
random_state=random_state,
min_impurity_decrease=min_impurity_decrease,
min_impurity_split=min_impurity_split,
ccp_alpha=ccp_alpha)
ccp_alpha=ccp_alpha,
node_bootstrap=node_bootstrap,
node_max_samples=node_max_samples)

def fit(self, X, y, sample_weight=None, check_input=True,
X_idx_sorted="deprecated"):
Expand Down
28 changes: 28 additions & 0 deletions sklearn/tree/tests/test_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -1961,3 +1961,31 @@ def test_X_idx_sorted_deprecated(TreeEstimator):
with pytest.warns(FutureWarning,
match="The parameter 'X_idx_sorted' is deprecated"):
tree.fit(X, y, X_idx_sorted=X_idx_sorted)


@pytest.mark.parametrize("tree_type,dataset",
[(DecisionTreeRegressor, diabetes),
(DecisionTreeClassifier, iris)])
@skip_if_32bit
def test_node_bootstrap_accuracy(tree_type, dataset):
error_threshold = 0.1
rng = np.random.RandomState(1)

est = tree_type(
random_state=rng,
node_bootstrap=False,
node_max_samples=None)
est_bootstrap = tree_type(
random_state=rng,
node_bootstrap=True,
node_max_samples=0.75)

est.fit(dataset.data, dataset.target)
est_bootstrap.fit(dataset.data, dataset.target)

error = mean_squared_error(
est.feature_importances_,
est_bootstrap.feature_importances_)

assert error < error_threshold, (f'''Difference between bootstrap
and full_sample is {error}, exceeding {error_threshold}''')
Loading