Skip to content

Integration and test cases for RandomForest subsampling #9645

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 84 additions & 23 deletions sklearn/ensemble/forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ class calls the ``fit`` method of each sub-estimator on random samples
from warnings import warn

from abc import ABCMeta, abstractmethod
from math import ceil
import numbers
import numpy as np
from scipy.sparse import issparse
from scipy.sparse import hstack as sparse_hstack
Expand Down Expand Up @@ -74,17 +76,17 @@ class calls the ``fit`` method of each sub-estimator on random samples
MAX_INT = np.iinfo(np.int32).max


def _generate_sample_indices(random_state, n_samples):
def _generate_sample_indices(random_state, n_samples, max_samples):
"""Private function used to _parallel_build_trees function."""
random_instance = check_random_state(random_state)
sample_indices = random_instance.randint(0, n_samples, n_samples)
sample_indices = random_instance.randint(0, n_samples, max_samples)

return sample_indices


def _generate_unsampled_indices(random_state, n_samples):
"""Private function used to forest._set_oob_score function."""
sample_indices = _generate_sample_indices(random_state, n_samples)
def _generate_unsampled_indices(random_state, n_samples, max_samples):
"""Private function used to forest._set_oob_score fuction."""
sample_indices = _generate_sample_indices(random_state, n_samples, max_samples)
sample_counts = np.bincount(sample_indices, minlength=n_samples)
unsampled_mask = sample_counts == 0
indices_range = np.arange(n_samples)
Expand All @@ -94,19 +96,27 @@ def _generate_unsampled_indices(random_state, n_samples):


def _parallel_build_trees(tree, forest, X, y, sample_weight, tree_idx, n_trees,
verbose=0, class_weight=None):
verbose=0, class_weight=None, max_samples=1.0):
"""Private function used to fit a single tree in parallel."""
if verbose > 1:
print("building tree %d of %d" % (tree_idx + 1, n_trees))

if forest.bootstrap:
n_samples = X.shape[0]

# if max_samples is float:
if not isinstance(max_samples, (numbers.Integral, np.integer)):
max_samples = int(ceil(max_samples * X.shape[0]))

if not (0 < max_samples <= X.shape[0]):
raise ValueError("max_samples = " + str(max_samples) + " and it must be in (0, " + str(n_samples) + ")" )

if sample_weight is None:
curr_sample_weight = np.ones((n_samples,), dtype=np.float64)
else:
curr_sample_weight = sample_weight.copy()

indices = _generate_sample_indices(tree.random_state, n_samples)
indices = _generate_sample_indices(tree.random_state, n_samples,max_samples)
sample_counts = np.bincount(indices, minlength=n_samples)
curr_sample_weight *= sample_counts

Expand Down Expand Up @@ -142,7 +152,8 @@ def __init__(self,
random_state=None,
verbose=0,
warm_start=False,
class_weight=None):
class_weight=None,
max_samples=1.0):
super(BaseForest, self).__init__(
base_estimator=base_estimator,
n_estimators=n_estimators,
Expand All @@ -155,6 +166,7 @@ def __init__(self,
self.verbose = verbose
self.warm_start = warm_start
self.class_weight = class_weight
self.max_samples = max_samples

def apply(self, X):
"""Apply trees in the forest to X, return leaf indices.
Expand Down Expand Up @@ -323,7 +335,8 @@ def fit(self, X, y, sample_weight=None):
backend="threading")(
delayed(_parallel_build_trees)(
t, self, X, y, sample_weight, i, len(trees),
verbose=self.verbose, class_weight=self.class_weight)
verbose=self.verbose, class_weight=self.class_weight,
max_samples=self.max_samples)
for i, t in enumerate(trees))

# Collect newly grown trees
Expand Down Expand Up @@ -406,7 +419,8 @@ def __init__(self,
random_state=None,
verbose=0,
warm_start=False,
class_weight=None):
class_weight=None,
max_samples=1.0):

super(ForestClassifier, self).__init__(
base_estimator,
Expand All @@ -418,14 +432,23 @@ def __init__(self,
random_state=random_state,
verbose=verbose,
warm_start=warm_start,
class_weight=class_weight)
class_weight=class_weight,
max_samples=max_samples)

def _set_oob_score(self, X, y):
"""Compute out-of-bag score"""
X = check_array(X, dtype=DTYPE, accept_sparse='csr')

n_classes_ = self.n_classes_
n_samples = y.shape[0]
max_samples = self.max_samples

# if max_samples is float:
if not isinstance(max_samples, (numbers.Integral, np.integer)):
max_samples = int(ceil(max_samples * X.shape[0]))

if not (0 < max_samples <= X.shape[0]):
raise ValueError("max_samples = " + str(max_samples) + " and it must be in (0, " + str(n_samples) + ")" )

oob_decision_function = []
oob_score = 0.0
Expand All @@ -436,7 +459,7 @@ def _set_oob_score(self, X, y):

for estimator in self.estimators_:
unsampled_indices = _generate_unsampled_indices(
estimator.random_state, n_samples)
estimator.random_state, n_samples, max_samples)
p_estimator = estimator.predict_proba(X[unsampled_indices, :],
check_input=False)

Expand Down Expand Up @@ -643,7 +666,8 @@ def __init__(self,
n_jobs=1,
random_state=None,
verbose=0,
warm_start=False):
warm_start=False,
max_samples=1.0):
super(ForestRegressor, self).__init__(
base_estimator,
n_estimators=n_estimators,
Expand All @@ -653,7 +677,8 @@ def __init__(self,
n_jobs=n_jobs,
random_state=random_state,
verbose=verbose,
warm_start=warm_start)
warm_start=warm_start,
max_samples=max_samples)

def predict(self, X):
"""Predict regression target for X.
Expand Down Expand Up @@ -700,13 +725,21 @@ def _set_oob_score(self, X, y):
X = check_array(X, dtype=DTYPE, accept_sparse='csr')

n_samples = y.shape[0]
max_samples = self.max_samples

# if max_samples is float:
if not isinstance(max_samples, (numbers.Integral, np.integer)):
max_samples = int(ceil(max_samples * X.shape[0]))

if not (0 < max_samples <= X.shape[0]):
raise ValueError("max_samples = " + str(max_samples) + " and it must be in (0, " + str(n_samples) + ")" )

predictions = np.zeros((n_samples, self.n_outputs_))
n_predictions = np.zeros((n_samples, self.n_outputs_))

for estimator in self.estimators_:
unsampled_indices = _generate_unsampled_indices(
estimator.random_state, n_samples)
estimator.random_state, n_samples, max_samples)
p_estimator = estimator.predict(
X[unsampled_indices, :], check_input=False)

Expand Down Expand Up @@ -755,6 +788,11 @@ class RandomForestClassifier(ForestClassifier):
n_estimators : integer, optional (default=10)
The number of trees in the forest.

max_samples : int or float, optional (default=1.0)
The number of samples to draw from X to train each base estimator.
- If int, then draw `max_samples` samples.
- If float, then draw `max_samples * X.shape[0]` samples.

criterion : string, optional (default="gini")
The function to measure the quality of a split. Supported criteria are
"gini" for the Gini impurity and "entropy" for the information gain.
Expand Down Expand Up @@ -984,7 +1022,8 @@ def __init__(self,
random_state=None,
verbose=0,
warm_start=False,
class_weight=None):
class_weight=None,
max_samples=1.0):
super(RandomForestClassifier, self).__init__(
base_estimator=DecisionTreeClassifier(),
n_estimators=n_estimators,
Expand All @@ -999,7 +1038,8 @@ def __init__(self,
random_state=random_state,
verbose=verbose,
warm_start=warm_start,
class_weight=class_weight)
class_weight=class_weight,
max_samples=max_samples)

self.criterion = criterion
self.max_depth = max_depth
Expand Down Expand Up @@ -1029,6 +1069,11 @@ class RandomForestRegressor(ForestRegressor):
n_estimators : integer, optional (default=10)
The number of trees in the forest.

max_samples : int or float, optional (default=1.0)
The number of samples to draw from X to train each base estimator.
- If int, then draw `max_samples` samples.
- If float, then draw `max_samples * X.shape[0]` samples.

criterion : string, optional (default="mse")
The function to measure the quality of a split. Supported criteria
are "mse" for the mean squared error, which is equal to variance
Expand Down Expand Up @@ -1223,7 +1268,8 @@ def __init__(self,
n_jobs=1,
random_state=None,
verbose=0,
warm_start=False):
warm_start=False,
max_samples=1.0):
super(RandomForestRegressor, self).__init__(
base_estimator=DecisionTreeRegressor(),
n_estimators=n_estimators,
Expand All @@ -1237,7 +1283,8 @@ def __init__(self,
n_jobs=n_jobs,
random_state=random_state,
verbose=verbose,
warm_start=warm_start)
warm_start=warm_start,
max_samples=max_samples)

self.criterion = criterion
self.max_depth = max_depth
Expand Down Expand Up @@ -1265,6 +1312,11 @@ class ExtraTreesClassifier(ForestClassifier):
n_estimators : integer, optional (default=10)
The number of trees in the forest.

max_samples : int or float, optional (default=1.0)
The number of samples to draw from X to train each base estimator.
- If int, then draw `max_samples` samples.
- If float, then draw `max_samples * X.shape[0]` samples.

criterion : string, optional (default="gini")
The function to measure the quality of a split. Supported criteria are
"gini" for the Gini impurity and "entropy" for the information gain.
Expand Down Expand Up @@ -1466,7 +1518,8 @@ def __init__(self,
random_state=None,
verbose=0,
warm_start=False,
class_weight=None):
class_weight=None,
max_samples=1.0):
super(ExtraTreesClassifier, self).__init__(
base_estimator=ExtraTreeClassifier(),
n_estimators=n_estimators,
Expand All @@ -1481,7 +1534,8 @@ def __init__(self,
random_state=random_state,
verbose=verbose,
warm_start=warm_start,
class_weight=class_weight)
class_weight=class_weight,
max_samples=max_samples)

self.criterion = criterion
self.max_depth = max_depth
Expand Down Expand Up @@ -1509,6 +1563,11 @@ class ExtraTreesRegressor(ForestRegressor):
n_estimators : integer, optional (default=10)
The number of trees in the forest.

max_samples : int or float, optional (default=1.0)
The number of samples to draw from X to train each base estimator.
- If int, then draw `max_samples` samples.
- If float, then draw `max_samples * X.shape[0]` samples.

criterion : string, optional (default="mse")
The function to measure the quality of a split. Supported criteria
are "mse" for the mean squared error, which is equal to variance
Expand Down Expand Up @@ -1677,7 +1736,8 @@ def __init__(self,
n_jobs=1,
random_state=None,
verbose=0,
warm_start=False):
warm_start=False,
max_samples=1.0):
super(ExtraTreesRegressor, self).__init__(
base_estimator=ExtraTreeRegressor(),
n_estimators=n_estimators,
Expand All @@ -1691,7 +1751,8 @@ def __init__(self,
n_jobs=n_jobs,
random_state=random_state,
verbose=verbose,
warm_start=warm_start)
warm_start=warm_start,
max_samples=max_samples)

self.criterion = criterion
self.max_depth = max_depth
Expand Down
Loading