Skip to content

[MRG] ENH Add support for missing values to Tree based Classifiers #5974

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 86 additions & 32 deletions sklearn/ensemble/forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ class calls the ``fit`` method of each sub-estimator on random samples
Single and multi-output problems are both handled.

"""

# Authors: Gilles Louppe <g.louppe@gmail.com>
# Brian Holt <bdholt1@gmail.com>
# Joly Arnaud <arnaud.v.joly@gmail.com>
# Fares Hedayati <fares.hedayati@gmail.com>
# Raghav R V <rvraghav93@gmail.com>
#
# License: BSD 3 clause

Expand Down Expand Up @@ -94,7 +94,7 @@ def _generate_unsampled_indices(random_state, n_samples):


def _parallel_build_trees(tree, forest, X, y, sample_weight, tree_idx, n_trees,
verbose=0, class_weight=None):
verbose=0, class_weight=None, missing_mask=None):
"""Private function used to fit a single tree in parallel."""
if verbose > 1:
print("building tree %d of %d" % (tree_idx + 1, n_trees))
Expand All @@ -117,9 +117,11 @@ def _parallel_build_trees(tree, forest, X, y, sample_weight, tree_idx, n_trees,
elif class_weight == 'balanced_subsample':
curr_sample_weight *= compute_sample_weight('balanced', y, indices)

tree.fit(X, y, sample_weight=curr_sample_weight, check_input=False)
tree.fit(X, y, sample_weight=curr_sample_weight, check_input=False,
missing_mask=missing_mask)
else:
tree.fit(X, y, sample_weight=sample_weight, check_input=False)
tree.fit(X, y, sample_weight=sample_weight, check_input=False,
missing_mask=missing_mask)

return tree

Expand All @@ -143,7 +145,8 @@ def __init__(self,
random_state=None,
verbose=0,
warm_start=False,
class_weight=None):
class_weight=None,
missing_values=None):
super(BaseForest, self).__init__(
base_estimator=base_estimator,
n_estimators=n_estimators,
Expand All @@ -156,6 +159,11 @@ def __init__(self,
self.verbose = verbose
self.warm_start = warm_start
self.class_weight = class_weight
self.allow_missing = missing_values is not None
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't work with grid search and pipeline because GridSearchCV find allowed parameters with get_params method rather than passing these as inputs. Indeed, GridSearchCV instantiates the class with all default parameters and after that, it changes the values. Then allow_missing will be always equals to false. I don't know the best solution to resolve this issue. Maybe add allow_missing value as a constructor argument ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah! Very good point. Thanks. We could set the allow_missing as a fit param...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, this should be determined at fit time, same goes for setting self._allow_nan below

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay!!


# If missing values is int/None
self._allow_nan = False
self.missing_values = missing_values

def apply(self, X):
"""Apply trees in the forest to X, return leaf indices.
Expand All @@ -174,9 +182,12 @@ def apply(self, X):
return the index of the leaf x ends up in.
"""
X = self._validate_X_predict(X)
missing_mask = self._validate_missing_mask(X)
results = Parallel(n_jobs=self.n_jobs, verbose=self.verbose,
backend="threading")(
delayed(parallel_helper)(tree, 'apply', X, check_input=False)
delayed(parallel_helper)(tree, 'apply', X,
check_input=False,
missing_mask=missing_mask)
for tree in self.estimators_)

return np.array(results).T
Expand Down Expand Up @@ -205,10 +216,12 @@ def decision_path(self, X):

"""
X = self._validate_X_predict(X)
missing_mask = self._validate_missing_mask(X)
indicators = Parallel(n_jobs=self.n_jobs, verbose=self.verbose,
backend="threading")(
delayed(parallel_helper)(tree, 'decision_path', X,
check_input=False)
check_input=False,
missing_mask=missing_mask)
for tree in self.estimators_)

n_nodes = [0]
Expand Down Expand Up @@ -243,8 +256,22 @@ def fit(self, X, y, sample_weight=None):
self : object
Returns self.
"""
if self.allow_missing:
if ((isinstance(self.missing_values, str) and
self.missing_values.strip().lower() == "nan") or
(isinstance(self.missing_values, np.float) and
np.isnan(self.missing_values))):

self._allow_nan = True
self.missing_values = np.nan
elif not isinstance(self.missing_values, int):
raise ValueError("missing_values should be 'NaN' or int. "
"Got %s" % self.missing_values)

# Validate or convert input data
X = check_array(X, accept_sparse="csc", dtype=DTYPE)
X = check_array(X, accept_sparse="csc", dtype=DTYPE,
allow_nan=self._allow_nan)
missing_mask = self._validate_missing_mask(X)
y = check_array(y, accept_sparse='csc', ensure_2d=False, dtype=None)
if issparse(X):
# Pre-sort indices to avoid that each individual tree of the
Expand Down Expand Up @@ -310,8 +337,10 @@ def fit(self, X, y, sample_weight=None):

trees = []
for i in range(n_more_estimators):
tree = self._make_estimator(append=False,
random_state=random_state)
tree = self._make_estimator(append=False)
tree.set_params(random_state=random_state.randint(MAX_INT))
tree.allow_missing = self.allow_missing
tree._allow_nan = self._allow_nan
trees.append(tree)

# Parallel loop: we use the threading backend as the Cython code
Expand All @@ -322,14 +351,15 @@ def fit(self, X, y, sample_weight=None):
backend="threading")(
delayed(_parallel_build_trees)(
t, self, X, y, sample_weight, i, len(trees),
verbose=self.verbose, class_weight=self.class_weight)
verbose=self.verbose, class_weight=self.class_weight,
missing_mask=missing_mask)
for i, t in enumerate(trees))

# Collect newly grown trees
self.estimators_.extend(trees)

if self.oob_score:
self._set_oob_score(X, y)
self._set_oob_score(X, y, missing_mask)

# Decapsulate classes_ attributes
if hasattr(self, "classes_") and self.n_outputs_ == 1:
Expand All @@ -339,7 +369,7 @@ def fit(self, X, y, sample_weight=None):
return self

@abstractmethod
def _set_oob_score(self, X, y):
def _set_oob_score(self, X, y, missing_mask=None):
"""Calculate out of bag predictions and score."""

def _validate_y_class_weight(self, y):
Expand All @@ -354,6 +384,17 @@ def _validate_X_predict(self, X):

return self.estimators_[0]._validate_X_predict(X, check_input=True)

def _validate_missing_mask(self, X, missing_mask=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am wondering whether adding this missing_mask really is bringing anything? Wouldnt it have been the same to compare directly the feature values when splitting and predicting? This would have saved a lot of boilerplate code to transport this array from one place to another. What do you think? Maybe I am missing something,

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(At the moment, you dont really save computations, since you still need to performs boolean comparisons in this array, instead of doing float comparisons directly in X)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without missing mask this would have to be isnan. And the reason why we decided to go with missing_mask was that nan representations differ and hence making the comparison of float to nan costly. (Ref: #5870 (comment))

"""Generate a new missing_mask or validate a given one"""
if self.allow_missing and missing_mask is None:
# Fortran ordered 8 bit boolean mask
if self._allow_nan: # Missing value is a NaN
missing_mask = np.asfortranarray(np.isnan(X), dtype=np.bool8)
else:
missing_mask = np.zeros(X.shape, dtype=np.bool8, order='F')
missing_mask[X == self.missing_values] = True
return missing_mask

@property
def feature_importances_(self):
"""Return the feature importances (the higher, the more important the
Expand Down Expand Up @@ -394,7 +435,8 @@ def __init__(self,
random_state=None,
verbose=0,
warm_start=False,
class_weight=None):
class_weight=None,
missing_values=None):

super(ForestClassifier, self).__init__(
base_estimator,
Expand All @@ -406,11 +448,12 @@ def __init__(self,
random_state=random_state,
verbose=verbose,
warm_start=warm_start,
class_weight=class_weight)
class_weight=class_weight,
missing_values=missing_values)

def _set_oob_score(self, X, y):
def _set_oob_score(self, X, y, missing_mask=None):
"""Compute out-of-bag score"""
X = check_array(X, dtype=DTYPE, accept_sparse='csr')
X = self._validate_X_predict(X)

n_classes_ = self.n_classes_
n_samples = y.shape[0]
Expand All @@ -426,7 +469,8 @@ def _set_oob_score(self, X, y):
unsampled_indices = _generate_unsampled_indices(
estimator.random_state, n_samples)
p_estimator = estimator.predict_proba(X[unsampled_indices, :],
check_input=False)
check_input=False,
missing_mask=missing_mask)

if self.n_outputs_ == 1:
p_estimator = [p_estimator]
Expand Down Expand Up @@ -571,6 +615,7 @@ class in a leaf.
"""
# Check data
X = self._validate_X_predict(X)
missing_mask = self._validate_missing_mask(X)

# Assign chunk of trees to jobs
n_jobs, _, _ = _partition_estimators(self.n_estimators, self.n_jobs)
Expand All @@ -579,7 +624,8 @@ class in a leaf.
all_proba = Parallel(n_jobs=n_jobs, verbose=self.verbose,
backend="threading")(
delayed(parallel_helper)(e, 'predict_proba', X,
check_input=False)
check_input=False,
missing_mask=missing_mask)
for e in self.estimators_)

# Reduce
Expand Down Expand Up @@ -651,7 +697,8 @@ def __init__(self,
n_jobs=1,
random_state=None,
verbose=0,
warm_start=False):
warm_start=False,
missing_values=None):
super(ForestRegressor, self).__init__(
base_estimator,
n_estimators=n_estimators,
Expand All @@ -661,7 +708,8 @@ def __init__(self,
n_jobs=n_jobs,
random_state=random_state,
verbose=verbose,
warm_start=warm_start)
warm_start=warm_start,
missing_values=missing_values)

def predict(self, X):
"""Predict regression target for X.
Expand Down Expand Up @@ -698,9 +746,9 @@ def predict(self, X):

return y_hat

def _set_oob_score(self, X, y):
def _set_oob_score(self, X, y, missing_mask=None):
"""Compute out-of-bag scores"""
X = check_array(X, dtype=DTYPE, accept_sparse='csr')
X = self._validate_X_predict(X)

n_samples = y.shape[0]

Expand All @@ -711,7 +759,9 @@ def _set_oob_score(self, X, y):
unsampled_indices = _generate_unsampled_indices(
estimator.random_state, n_samples)
p_estimator = estimator.predict(
X[unsampled_indices, :], check_input=False)
X[unsampled_indices, :],
missing_mask=None,
check_input=False)

if self.n_outputs_ == 1:
p_estimator = p_estimator[:, np.newaxis]
Expand Down Expand Up @@ -923,21 +973,23 @@ def __init__(self,
random_state=None,
verbose=0,
warm_start=False,
class_weight=None):
class_weight=None,
missing_values=None):
super(RandomForestClassifier, self).__init__(
base_estimator=DecisionTreeClassifier(),
n_estimators=n_estimators,
estimator_params=("criterion", "max_depth", "min_samples_split",
"min_samples_leaf", "min_weight_fraction_leaf",
"max_features", "max_leaf_nodes", "min_impurity_split",
"random_state"),
"random_state", "missing_values"),
bootstrap=bootstrap,
oob_score=oob_score,
n_jobs=n_jobs,
random_state=random_state,
verbose=verbose,
warm_start=warm_start,
class_weight=class_weight)
class_weight=class_weight,
missing_values=missing_values)

self.criterion = criterion
self.max_depth = max_depth
Expand All @@ -948,7 +1000,6 @@ def __init__(self,
self.max_leaf_nodes = max_leaf_nodes
self.min_impurity_split = min_impurity_split


class RandomForestRegressor(ForestRegressor):
"""A random forest regressor.

Expand Down Expand Up @@ -1308,21 +1359,23 @@ def __init__(self,
random_state=None,
verbose=0,
warm_start=False,
class_weight=None):
class_weight=None,
missing_values=None):
super(ExtraTreesClassifier, self).__init__(
base_estimator=ExtraTreeClassifier(),
n_estimators=n_estimators,
estimator_params=("criterion", "max_depth", "min_samples_split",
"min_samples_leaf", "min_weight_fraction_leaf",
"max_features", "max_leaf_nodes", "min_impurity_split",
"random_state"),
"random_state", "missing_values"),
bootstrap=bootstrap,
oob_score=oob_score,
n_jobs=n_jobs,
random_state=random_state,
verbose=verbose,
warm_start=warm_start,
class_weight=class_weight)
class_weight=class_weight,
missing_values=missing_values)

self.criterion = criterion
self.max_depth = max_depth
Expand Down Expand Up @@ -1687,7 +1740,8 @@ def fit_transform(self, X, y=None, sample_weight=None):
"""
# ensure_2d=False because there are actually unit test checking we fail
# for 1d.
X = check_array(X, accept_sparse=['csc'], ensure_2d=False)
X = check_array(X, accept_sparse=['csc'], ensure_2d=False,
allow_nan=self._allow_nan)
if issparse(X):
# Pre-sort indices to avoid that each individual tree of the
# ensemble sorts the indices.
Expand Down
Loading