Skip to content

[WIP] Balanced Random Forest #8728

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 17 commits into from
102 changes: 86 additions & 16 deletions sklearn/ensemble/forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,56 @@ def _generate_sample_indices(random_state, n_samples):
return sample_indices


def _get_class_balance_data(y):
"""Private function used to fit function."""
if len(y.shape) == 1:
classes, class_counts = np.unique(y, return_counts=True)
class_indices = [np.nonzero(y == cls)[0] for cls in classes]

else:
classes, class_counts, class_indices = [], [], []
for i in range(y.shape[1]):
y_i = y[:, i]
classes_i, class_counts_i = np.unique(y_i, return_counts=True)
class_indices_i = [np.nonzero(y == cls)[0] for cls in classes_i]
classes_i = [(i, cls) for cls in classes_i]

classes.extend(classes_i)
class_counts.extend(class_counts_i)
class_indices.extend(class_indices_i)

return classes, class_counts, class_indices


def _generate_balanced_sample_indices(random_state, balance_data):
"""Private function used to _parallel_build_trees function.

Generates samples according to the balanced random forest method [1],
adapted for multi-class, i.e. a bootstrap sample from the minority
class and a random sample with replacement of the same size from all
other classes.

References
----------
.. [1] Chen, C., Liaw, A., Breiman, L. (2004) "Using Random Forest to
Learn Imbalanced Data", Tech. Rep. 666, 2004
"""
classes, class_counts, class_indices = balance_data
min_count = np.min(class_counts)
n_class = len(classes)

random_instance = check_random_state(random_state)
sample_indices = np.empty(n_class*min_count, dtype=int)

for i, cls, count, indices in zip(range(n_class), classes, class_counts,
class_indices):
random_instances = random_instance.randint(0, count, min_count)
random_indices = indices[random_instances]
sample_indices[i*min_count:(i+1)*min_count] = random_indices

return sample_indices


def _generate_unsampled_indices(random_state, n_samples):
"""Private function used to forest._set_oob_score function."""
sample_indices = _generate_sample_indices(random_state, n_samples)
Expand All @@ -106,7 +156,12 @@ def _parallel_build_trees(tree, forest, X, y, sample_weight, tree_idx, n_trees,
else:
curr_sample_weight = sample_weight.copy()

indices = _generate_sample_indices(tree.random_state, n_samples)
if class_weight == 'balanced_bootstraping':
indices = _generate_balanced_sample_indices(tree.random_state,
balance_data)
else:
indices = _generate_sample_indices(tree.random_state, n_samples)

sample_counts = bincount(indices, minlength=n_samples)
curr_sample_weight *= sample_counts

Expand All @@ -117,6 +172,7 @@ def _parallel_build_trees(tree, forest, X, y, sample_weight, tree_idx, n_trees,
elif class_weight == 'balanced_subsample':
curr_sample_weight *= compute_sample_weight('balanced', y, indices)

tree.sample_weight = curr_sample_weight
tree.fit(X, y, sample_weight=curr_sample_weight, check_input=False)
else:
tree.fit(X, y, sample_weight=sample_weight, check_input=False)
Expand Down Expand Up @@ -207,7 +263,7 @@ def decision_path(self, X):
indicators = Parallel(n_jobs=self.n_jobs, verbose=self.verbose,
backend="threading")(
delayed(parallel_helper)(tree, 'decision_path', X,
check_input=False)
check_input=False)
for tree in self.estimators_)

n_nodes = [0]
Expand All @@ -222,8 +278,8 @@ def fit(self, X, y, sample_weight=None):
Parameters
----------
X : array-like or sparse matrix of shape = [n_samples, n_features]
The training input samples. Internally, its dtype will be converted to
``dtype=np.float32``. If a sparse matrix is provided, it will be
The training input samples. Internally, its dtype will be converted
to ``dtype=np.float32``. If a sparse matrix is provided, it will be
converted into a sparse ``csc_matrix``.

y : array-like, shape = [n_samples] or [n_samples, n_outputs]
Expand Down Expand Up @@ -315,6 +371,9 @@ def fit(self, X, y, sample_weight=None):
random_state=random_state)
trees.append(tree)

balance_data = _get_class_balance_data(y)\
if class_weight == 'balanced_bootstraping' else None

# Parallel loop: we use the threading backend as the Cython code
# for fitting the trees is internally releasing the Python GIL
# making threading always more efficient than multiprocessing in
Expand Down Expand Up @@ -479,7 +538,8 @@ def _validate_y_class_weight(self, y):

y_store_unique_indices = np.zeros(y.shape, dtype=np.int)
for k in range(self.n_outputs_):
classes_k, y_store_unique_indices[:, k] = np.unique(y[:, k], return_inverse=True)
classes_k, y_store_unique_indices[:, k] = np.unique(
y[:, k], return_inverse=True)
self.classes_.append(classes_k)
self.n_classes_.append(classes_k.shape[0])
y = y_store_unique_indices
Expand All @@ -489,16 +549,18 @@ def _validate_y_class_weight(self, y):
if isinstance(self.class_weight, six.string_types):
if self.class_weight not in valid_presets:
raise ValueError('Valid presets for class_weight include '
'"balanced" and "balanced_subsample". Given "%s".'
'"balanced" and "balanced_subsample". '
'Given "%s".'
% self.class_weight)
if self.warm_start:
warn('class_weight presets "balanced" or "balanced_subsample" are '
warn('class_weight presets "balanced" or '
'"balanced_subsample" are '
'not recommended for warm_start if the fitted data '
'differs from the full dataset. In order to use '
'"balanced" weights, use compute_class_weight("balanced", '
'classes, y). In place of y you can use a large '
'enough sample of the full training set target to '
'properly estimate the class frequency '
'"balanced" weights, use compute_class_weight('
'"balanced", classes, y). In place of y you can use a'
'large enough sample of the full training set target '
'to properly estimate the class frequency '
'distributions. Pass the resulting weights as the '
'class_weight parameter.')

Expand Down Expand Up @@ -554,8 +616,8 @@ def predict_proba(self, X):

The predicted class probabilities of an input sample are computed as
the mean predicted class probabilities of the trees in the forest. The
class probability of a single tree is the fraction of samples of the same
class in a leaf.
class probability of a single tree is the fraction of samples of the
same class in a leaf.

Parameters
----------
Expand Down Expand Up @@ -1300,7 +1362,8 @@ class ExtraTreesClassifier(ForestClassifier):
and add more estimators to the ensemble, otherwise, just fit a whole
new forest.

class_weight : dict, list of dicts, "balanced", "balanced_subsample" or None, optional (default=None)
class_weight : dict, list of dicts, "balanced", "balanced_subsample",
"balanced_bootstraping" or None, optional (default=None)
Weights associated with classes in the form ``{class_label: weight}``.
If not given, all classes are supposed to have weight one. For
multi-output problems, a list of dicts can be provided in the same
Expand All @@ -1310,14 +1373,18 @@ class ExtraTreesClassifier(ForestClassifier):
weights inversely proportional to class frequencies in the input data
as ``n_samples / (n_classes * np.bincount(y))``

The "balanced_subsample" mode is the same as "balanced" except that weights are
computed based on the bootstrap sample for every tree grown.
The "balanced_subsample" mode is the same as "balanced" except that
weights are computed based on the bootstrap sample for every tree
grown.

For multi-output, the weights of each column of y will be multiplied.

Note that these weights will be multiplied with sample_weight (passed
through the fit method) if sample_weight is specified.

The "balanced_bootstraping" mode bootstraps the samples to generate
each tree acording to the balanced random forest method [2].

Attributes
----------
estimators_ : list of DecisionTreeClassifier
Expand Down Expand Up @@ -1363,6 +1430,9 @@ class labels (multi-output problem).
.. [1] P. Geurts, D. Ernst., and L. Wehenkel, "Extremely randomized trees",
Machine Learning, 63(1), 3-42, 2006.

.. [2] Chen, C., Liaw, A., Breiman, L. (2004) "Using Random Forest to
Learn Imbalanced Data", Tech. Rep. 666, 2004

See also
--------
sklearn.tree.ExtraTreeClassifier : Base classifier for this ensemble.
Expand Down