diff --git a/sklearn/ensemble/forest.py b/sklearn/ensemble/forest.py index 3d252dfa4ff3c..e61c556770042 100644 --- a/sklearn/ensemble/forest.py +++ b/sklearn/ensemble/forest.py @@ -31,11 +31,11 @@ class calls the ``fit`` method of each sub-estimator on random samples Single and multi-output problems are both handled. """ - # Authors: Gilles Louppe # Brian Holt # Joly Arnaud # Fares Hedayati +# Raghav R V # # License: BSD 3 clause @@ -94,7 +94,7 @@ def _generate_unsampled_indices(random_state, n_samples): def _parallel_build_trees(tree, forest, X, y, sample_weight, tree_idx, n_trees, - verbose=0, class_weight=None): + verbose=0, class_weight=None, missing_mask=None): """Private function used to fit a single tree in parallel.""" if verbose > 1: print("building tree %d of %d" % (tree_idx + 1, n_trees)) @@ -117,9 +117,11 @@ def _parallel_build_trees(tree, forest, X, y, sample_weight, tree_idx, n_trees, elif class_weight == 'balanced_subsample': curr_sample_weight *= compute_sample_weight('balanced', y, indices) - tree.fit(X, y, sample_weight=curr_sample_weight, check_input=False) + tree.fit(X, y, sample_weight=curr_sample_weight, check_input=False, + missing_mask=missing_mask) else: - tree.fit(X, y, sample_weight=sample_weight, check_input=False) + tree.fit(X, y, sample_weight=sample_weight, check_input=False, + missing_mask=missing_mask) return tree @@ -143,7 +145,8 @@ def __init__(self, random_state=None, verbose=0, warm_start=False, - class_weight=None): + class_weight=None, + missing_values=None): super(BaseForest, self).__init__( base_estimator=base_estimator, n_estimators=n_estimators, @@ -156,6 +159,11 @@ def __init__(self, self.verbose = verbose self.warm_start = warm_start self.class_weight = class_weight + self.allow_missing = missing_values is not None + + # If missing values is int/None + self._allow_nan = False + self.missing_values = missing_values def apply(self, X): """Apply trees in the forest to X, return leaf indices. @@ -174,9 +182,12 @@ def apply(self, X): return the index of the leaf x ends up in. """ X = self._validate_X_predict(X) + missing_mask = self._validate_missing_mask(X) results = Parallel(n_jobs=self.n_jobs, verbose=self.verbose, backend="threading")( - delayed(parallel_helper)(tree, 'apply', X, check_input=False) + delayed(parallel_helper)(tree, 'apply', X, + check_input=False, + missing_mask=missing_mask) for tree in self.estimators_) return np.array(results).T @@ -205,10 +216,12 @@ def decision_path(self, X): """ X = self._validate_X_predict(X) + missing_mask = self._validate_missing_mask(X) indicators = Parallel(n_jobs=self.n_jobs, verbose=self.verbose, backend="threading")( delayed(parallel_helper)(tree, 'decision_path', X, - check_input=False) + check_input=False, + missing_mask=missing_mask) for tree in self.estimators_) n_nodes = [0] @@ -243,8 +256,22 @@ def fit(self, X, y, sample_weight=None): self : object Returns self. """ + if self.allow_missing: + if ((isinstance(self.missing_values, str) and + self.missing_values.strip().lower() == "nan") or + (isinstance(self.missing_values, np.float) and + np.isnan(self.missing_values))): + + self._allow_nan = True + self.missing_values = np.nan + elif not isinstance(self.missing_values, int): + raise ValueError("missing_values should be 'NaN' or int. " + "Got %s" % self.missing_values) + # Validate or convert input data - X = check_array(X, accept_sparse="csc", dtype=DTYPE) + X = check_array(X, accept_sparse="csc", dtype=DTYPE, + allow_nan=self._allow_nan) + missing_mask = self._validate_missing_mask(X) y = check_array(y, accept_sparse='csc', ensure_2d=False, dtype=None) if issparse(X): # Pre-sort indices to avoid that each individual tree of the @@ -310,8 +337,10 @@ def fit(self, X, y, sample_weight=None): trees = [] for i in range(n_more_estimators): - tree = self._make_estimator(append=False, - random_state=random_state) + tree = self._make_estimator(append=False) + tree.set_params(random_state=random_state.randint(MAX_INT)) + tree.allow_missing = self.allow_missing + tree._allow_nan = self._allow_nan trees.append(tree) # Parallel loop: we use the threading backend as the Cython code @@ -322,14 +351,15 @@ def fit(self, X, y, sample_weight=None): backend="threading")( delayed(_parallel_build_trees)( t, self, X, y, sample_weight, i, len(trees), - verbose=self.verbose, class_weight=self.class_weight) + verbose=self.verbose, class_weight=self.class_weight, + missing_mask=missing_mask) for i, t in enumerate(trees)) # Collect newly grown trees self.estimators_.extend(trees) if self.oob_score: - self._set_oob_score(X, y) + self._set_oob_score(X, y, missing_mask) # Decapsulate classes_ attributes if hasattr(self, "classes_") and self.n_outputs_ == 1: @@ -339,7 +369,7 @@ def fit(self, X, y, sample_weight=None): return self @abstractmethod - def _set_oob_score(self, X, y): + def _set_oob_score(self, X, y, missing_mask=None): """Calculate out of bag predictions and score.""" def _validate_y_class_weight(self, y): @@ -354,6 +384,17 @@ def _validate_X_predict(self, X): return self.estimators_[0]._validate_X_predict(X, check_input=True) + def _validate_missing_mask(self, X, missing_mask=None): + """Generate a new missing_mask or validate a given one""" + if self.allow_missing and missing_mask is None: + # Fortran ordered 8 bit boolean mask + if self._allow_nan: # Missing value is a NaN + missing_mask = np.asfortranarray(np.isnan(X), dtype=np.bool8) + else: + missing_mask = np.zeros(X.shape, dtype=np.bool8, order='F') + missing_mask[X == self.missing_values] = True + return missing_mask + @property def feature_importances_(self): """Return the feature importances (the higher, the more important the @@ -394,7 +435,8 @@ def __init__(self, random_state=None, verbose=0, warm_start=False, - class_weight=None): + class_weight=None, + missing_values=None): super(ForestClassifier, self).__init__( base_estimator, @@ -406,11 +448,12 @@ def __init__(self, random_state=random_state, verbose=verbose, warm_start=warm_start, - class_weight=class_weight) + class_weight=class_weight, + missing_values=missing_values) - def _set_oob_score(self, X, y): + def _set_oob_score(self, X, y, missing_mask=None): """Compute out-of-bag score""" - X = check_array(X, dtype=DTYPE, accept_sparse='csr') + X = self._validate_X_predict(X) n_classes_ = self.n_classes_ n_samples = y.shape[0] @@ -426,7 +469,8 @@ def _set_oob_score(self, X, y): unsampled_indices = _generate_unsampled_indices( estimator.random_state, n_samples) p_estimator = estimator.predict_proba(X[unsampled_indices, :], - check_input=False) + check_input=False, + missing_mask=missing_mask) if self.n_outputs_ == 1: p_estimator = [p_estimator] @@ -571,6 +615,7 @@ class in a leaf. """ # Check data X = self._validate_X_predict(X) + missing_mask = self._validate_missing_mask(X) # Assign chunk of trees to jobs n_jobs, _, _ = _partition_estimators(self.n_estimators, self.n_jobs) @@ -579,7 +624,8 @@ class in a leaf. all_proba = Parallel(n_jobs=n_jobs, verbose=self.verbose, backend="threading")( delayed(parallel_helper)(e, 'predict_proba', X, - check_input=False) + check_input=False, + missing_mask=missing_mask) for e in self.estimators_) # Reduce @@ -651,7 +697,8 @@ def __init__(self, n_jobs=1, random_state=None, verbose=0, - warm_start=False): + warm_start=False, + missing_values=None): super(ForestRegressor, self).__init__( base_estimator, n_estimators=n_estimators, @@ -661,7 +708,8 @@ def __init__(self, n_jobs=n_jobs, random_state=random_state, verbose=verbose, - warm_start=warm_start) + warm_start=warm_start, + missing_values=missing_values) def predict(self, X): """Predict regression target for X. @@ -698,9 +746,9 @@ def predict(self, X): return y_hat - def _set_oob_score(self, X, y): + def _set_oob_score(self, X, y, missing_mask=None): """Compute out-of-bag scores""" - X = check_array(X, dtype=DTYPE, accept_sparse='csr') + X = self._validate_X_predict(X) n_samples = y.shape[0] @@ -711,7 +759,9 @@ def _set_oob_score(self, X, y): unsampled_indices = _generate_unsampled_indices( estimator.random_state, n_samples) p_estimator = estimator.predict( - X[unsampled_indices, :], check_input=False) + X[unsampled_indices, :], + missing_mask=None, + check_input=False) if self.n_outputs_ == 1: p_estimator = p_estimator[:, np.newaxis] @@ -923,21 +973,23 @@ def __init__(self, random_state=None, verbose=0, warm_start=False, - class_weight=None): + class_weight=None, + missing_values=None): super(RandomForestClassifier, self).__init__( base_estimator=DecisionTreeClassifier(), n_estimators=n_estimators, estimator_params=("criterion", "max_depth", "min_samples_split", "min_samples_leaf", "min_weight_fraction_leaf", "max_features", "max_leaf_nodes", "min_impurity_split", - "random_state"), + "random_state", "missing_values"), bootstrap=bootstrap, oob_score=oob_score, n_jobs=n_jobs, random_state=random_state, verbose=verbose, warm_start=warm_start, - class_weight=class_weight) + class_weight=class_weight, + missing_values=missing_values) self.criterion = criterion self.max_depth = max_depth @@ -948,7 +1000,6 @@ def __init__(self, self.max_leaf_nodes = max_leaf_nodes self.min_impurity_split = min_impurity_split - class RandomForestRegressor(ForestRegressor): """A random forest regressor. @@ -1308,21 +1359,23 @@ def __init__(self, random_state=None, verbose=0, warm_start=False, - class_weight=None): + class_weight=None, + missing_values=None): super(ExtraTreesClassifier, self).__init__( base_estimator=ExtraTreeClassifier(), n_estimators=n_estimators, estimator_params=("criterion", "max_depth", "min_samples_split", "min_samples_leaf", "min_weight_fraction_leaf", "max_features", "max_leaf_nodes", "min_impurity_split", - "random_state"), + "random_state", "missing_values"), bootstrap=bootstrap, oob_score=oob_score, n_jobs=n_jobs, random_state=random_state, verbose=verbose, warm_start=warm_start, - class_weight=class_weight) + class_weight=class_weight, + missing_values=missing_values) self.criterion = criterion self.max_depth = max_depth @@ -1687,7 +1740,8 @@ def fit_transform(self, X, y=None, sample_weight=None): """ # ensure_2d=False because there are actually unit test checking we fail # for 1d. - X = check_array(X, accept_sparse=['csc'], ensure_2d=False) + X = check_array(X, accept_sparse=['csc'], ensure_2d=False, + allow_nan=self._allow_nan) if issparse(X): # Pre-sort indices to avoid that each individual tree of the # ensemble sorts the indices. diff --git a/sklearn/ensemble/tests/test_forest.py b/sklearn/ensemble/tests/test_forest.py index 5ff4cf851f2d2..061a5d752a9dc 100644 --- a/sklearn/ensemble/tests/test_forest.py +++ b/sklearn/ensemble/tests/test_forest.py @@ -5,7 +5,8 @@ # Authors: Gilles Louppe, # Brian Holt, # Andreas Mueller, -# Arnaud Joly +# Arnaud Joly, +# Raghav RV # License: BSD 3 clause import pickle @@ -49,9 +50,18 @@ # toy sample X = [[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1]] y = [-1, -1, -1, 1, 1, 1] + +# toy sample with missing data +X_missing = [[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1], + [-2, np.nan], [-1, np.nan], [2, np.nan], [1, np.nan]] +y_missing = [-1, -1, -1, 1, 1, 1, 1, 1, -1, -1] + T = [[-1, -1], [2, 2], [3, 2]] true_result = [-1, 1, 1] +T_missing = [[-1, -1], [2, 2], [3, 2], [-3, np.nan], [3, np.nan]] +true_result_missing = [-1, 1, 1, 1, -1] + # also load the iris dataset # and randomly permute it iris = datasets.load_iris() @@ -109,10 +119,48 @@ def check_classification_toy(name): leaf_indices = clf.apply(X) assert_equal(leaf_indices.shape, (len(X), clf.n_estimators)) + # Check if missing_values are correctly set + assert_true(np.all(list(clf.estimators_[i].missing_values is None + for i in range(10)))) + + +def check_classification_toy_missing_data(name): + # Check classification on a toy dataset with missing values + ForestClassifier = FOREST_CLASSIFIERS[name] + + clf = ForestClassifier(n_estimators=10, missing_values="NaN", + random_state=1) + clf.fit(X_missing, y_missing) + assert_array_equal(clf.predict(T_missing), true_result_missing) + assert_equal(10, len(clf)) + + # Check if missing_values are correctly set + assert_true(np.all(list(np.isnan(clf.estimators_[i].missing_values) + for i in range(10)))) + + clf = ForestClassifier(n_estimators=10, max_features=1, + missing_values="NaN", random_state=1) + clf.fit(X_missing, y_missing) + assert_array_equal(clf.predict(T_missing), true_result_missing) + assert_equal(10, len(clf)) + + # also test apply + leaf_indices = clf.apply(X_missing) + assert_equal(leaf_indices.shape, (len(X_missing), clf.n_estimators)) + + # Enabling missing value support should not mess with the classification + # for data without missing values + clf = ForestClassifier(n_estimators=10, missing_values="NaN", + random_state=1) + clf.fit(X, y) + assert_array_equal(clf.predict(T), true_result) + assert_equal(10, len(clf)) + def test_classification_toy(): for name in FOREST_CLASSIFIERS: yield check_classification_toy, name + yield check_classification_toy_missing_data, name def check_iris_criterion(name, criterion): @@ -126,6 +174,13 @@ def check_iris_criterion(name, criterion): assert_greater(score, 0.9, "Failed with criterion %s and score = %f" % (criterion, score)) + clf = ForestClassifier(n_estimators=10, criterion=criterion, + missing_values="NaN", random_state=1) + clf.fit(iris.data, iris.target) + score = clf.score(iris.data, iris.target) + assert_greater(score, 0.9, "Failed with criterion %s and score = %f" + % (criterion, score)) + clf = ForestClassifier(n_estimators=10, criterion=criterion, max_features=2, random_state=1) clf.fit(iris.data, iris.target) @@ -227,6 +282,7 @@ def check_importances(name, criterion, X, y): importances = est.feature_importances_ assert_true(np.all(importances >= 0.0)) + # Check scale invariance for scale in [0.5, 10, 100]: est = ForestEstimator(n_estimators=20, random_state=0, criterion=criterion) est.fit(X, y, sample_weight=scale * sample_weight) @@ -505,9 +561,44 @@ def check_multioutput(name): assert_equal(log_proba[1].shape, (4, 4)) +def check_multioutput_missing_value(name): + # Check estimators on multi-output problems with missing data. + + X_train = [[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1], [-2, 1], + [-1, 1], [-1, 2], [2, -1], [1, -1], [1, -2], + [np.nan, -3], [np.nan, -2], [np.nan, -1], + [1, np.nan], [2, np.nan], [3, np.nan]] + y_train = [[-1, 0], [-1, 0], [-1, 0], [1, 1], [1, 1], [1, 1], [-1, 2], + [-1, 2], [-1, 2], [1, 3], [1, 3], [1, 3], + [1, 3], [1, 3], [1, 3], + # If the 2nd feature is nan 1/3 chance that its class 3 + # for 2nd output and 2/3 chance that its class 1 for 2nd output + [1, 3], [1, 1], [1, 1]] + + X_test = [[-1, -1], [1, 1], [-1, 1], [1, -1], [3, np.nan], [np.nan, -5]] + y_test = [[-1, 0], [1, 1], [-1, 2], [1, 3], [1, 1], [1, 3]] + + est = FOREST_CLASSIFIERS[name](random_state=0, missing_values="NaN", + bootstrap=False) + y_pred = est.fit(X_train, y_train).predict(X_test) + assert_array_almost_equal(y_pred, y_test) + + with np.errstate(divide="ignore"): + proba = est.predict_proba(X_test) + assert_equal(len(proba), 2) + assert_equal(proba[0].shape, (6, 2)) + assert_equal(proba[1].shape, (6, 4)) + + log_proba = est.predict_log_proba(X_test) + assert_equal(len(log_proba), 2) + assert_equal(log_proba[0].shape, (6, 2)) + assert_equal(log_proba[1].shape, (6, 4)) + + def test_multioutput(): for name in FOREST_CLASSIFIERS: yield check_multioutput, name + yield check_multioutput_missing_value, name for name in FOREST_REGRESSORS: yield check_multioutput, name @@ -1177,6 +1268,28 @@ def check_decision_path(name): assert_array_almost_equal(leave_indicator, np.ones(shape=n_samples)) +def check_decision_path_missing_values(name): + X, y = X_missing, y_missing + n_samples = X.shape[0] + ForestEstimator = FOREST_ESTIMATORS[name] + est = ForestEstimator(n_estimators=5, max_depth=1, warm_start=False, + missing_values="NaN", random_state=1) + est.fit(X, y) + indicator, n_nodes_ptr = est.decision_path(X) + + assert_equal(indicator.shape[1], n_nodes_ptr[-1]) + assert_equal(indicator.shape[0], n_samples) + assert_array_equal(np.diff(n_nodes_ptr), + [e.tree_.node_count for e in est.estimators_]) + + # Assert that leaves index are correct + leaves = est.apply(X) + for est_id in range(leaves.shape[1]): + leave_indicator = [indicator[i, n_nodes_ptr[est_id] + j] + for i, j in enumerate(leaves[:, est_id])] + assert_array_almost_equal(leave_indicator, np.ones(shape=n_samples)) + + def test_decision_path(): for name in FOREST_CLASSIFIERS: yield check_decision_path, name diff --git a/sklearn/tree/_criterion.pxd b/sklearn/tree/_criterion.pxd index cf6d32d1b7fe1..646964cd541dc 100644 --- a/sklearn/tree/_criterion.pxd +++ b/sklearn/tree/_criterion.pxd @@ -4,6 +4,7 @@ # Joel Nothman # Arnaud Joly # Jacob Schreiber +# Raghav R V # # License: BSD 3 clause @@ -37,7 +38,7 @@ cdef class Criterion: cdef SIZE_t n_samples # Number of samples cdef SIZE_t n_node_samples # Number of samples in the node (end-start) cdef double weighted_n_samples # Weighted number of samples (in total) - cdef double weighted_n_node_samples # Weighted number of samples in the node + cdef double weighted_n_node_samples # Weighted number of total samples in the node cdef double weighted_n_left # Weighted number of samples in the left node cdef double weighted_n_right # Weighted number of samples in the right node @@ -49,6 +50,19 @@ cdef class Criterion: cdef double* sum_left # Same as above, but for the left side of the split cdef double* sum_right # same as above, but for the right side of the split + # To handle missing values + cdef bint allow_missing # Whether to permit missing values + cdef double* sum_available # The sum of weighted count of each label, + # but only for the values that are not missing + cdef double* sum_missing # same as above, but only for the missing values + cdef SIZE_t n_missing # The number of samples with missing values, + # in the current node, for the current feature + # (samples[end-n_missing-1:end]) correspond to the + # indices of samples with missing values + cdef SIZE_t missing_direction # The partition to which the missing samples must be sent to + cdef double weighted_n_node_available # Weighted number of available samples in the node + cdef double weighted_n_node_missing # Weighted number of missing samples in the node + # The criterion object is maintained such that left and right collected # statistics correspond to samples[start:pos] and samples[pos:end]. @@ -56,9 +70,11 @@ cdef class Criterion: cdef void init(self, DOUBLE_t* y, SIZE_t y_stride, DOUBLE_t* sample_weight, double weighted_n_samples, SIZE_t* samples, SIZE_t start, SIZE_t end) nogil + cdef void init_missing(self, SIZE_t n_missing) nogil cdef void reset(self) nogil cdef void reverse_reset(self) nogil cdef void update(self, SIZE_t new_pos) nogil + cdef void move_missing(self, SIZE_t direction=*) nogil cdef double node_impurity(self) nogil cdef void children_impurity(self, double* impurity_left, double* impurity_right) nogil diff --git a/sklearn/tree/_criterion.pyx b/sklearn/tree/_criterion.pyx index e7ad82f6dcd49..5888d5d571b4c 100644 --- a/sklearn/tree/_criterion.pyx +++ b/sklearn/tree/_criterion.pyx @@ -13,6 +13,7 @@ # Fares Hedayati # Jacob Schreiber # Nelson Liu +# Raghav R V # # License: BSD 3 clause @@ -31,6 +32,11 @@ from ._utils cimport safe_realloc from ._utils cimport sizet_ptr_to_ndarray from ._utils cimport WeightedMedianCalculator +# Constants to handle missing values +cdef SIZE_t MISSING_DIR_LEFT = 0 +cdef SIZE_t MISSING_DIR_RIGHT = 1 +cdef SIZE_t MISSING_DIR_UNDEF = 2 + cdef class Criterion: """Interface for impurity criteria. @@ -44,6 +50,9 @@ cdef class Criterion: free(self.sum_total) free(self.sum_left) free(self.sum_right) + # free(NULL) is no op, hence not a problem if sum_missing is NULL + free(self.sum_missing) + free(self.sum_available) def __getstate__(self): return {} @@ -79,6 +88,17 @@ cdef class Criterion: pass + cdef void init_missing(self, SIZE_t n_missing) nogil: + """Separate the total statistics computed at the init, into missing + and available statistics. + + indices of missing samples are stored from samples[end-missing-1:end] + + If n_missing is 0, the missing_direction is set to MISSING_DIR_UNDEF + """ + + pass + cdef void reset(self) nogil: """Reset the criterion at pos=start. @@ -88,7 +108,7 @@ cdef class Criterion: pass cdef void reverse_reset(self) nogil: - """Reset the criterion at pos=end. + """Reset the criterion at pos=end(-n_missing). This method must be implemented by the subclass. """ @@ -97,6 +117,8 @@ cdef class Criterion: cdef void update(self, SIZE_t new_pos) nogil: """Updated statistics by moving samples[pos:new_pos] to the left child. + Leave the missing samples (if any) unaffected. + This updates the collected statistics by moving samples[pos:new_pos] from the right child to the left child. It must be implemented by the subclass. @@ -109,6 +131,15 @@ cdef class Criterion: pass + cdef void move_missing(self, SIZE_t direction=MISSING_DIR_LEFT) nogil: + """Updated statistics by changing the missing_direction to the given dir. + + The missing values (samples[end-missing-1:n_missing]) are moved to + specified direction. It must be implemented by the subclass. + """ + + pass + cdef double node_impurity(self) nogil: """Placeholder for calculating the impurity of the node. @@ -214,7 +245,8 @@ cdef class ClassificationCriterion(Criterion): cdef SIZE_t sum_stride def __cinit__(self, SIZE_t n_outputs, - np.ndarray[SIZE_t, ndim=1] n_classes): + np.ndarray[SIZE_t, ndim=1] n_classes, + bint allow_missing=False): """Initialize attributes for this criterion. Parameters @@ -247,6 +279,15 @@ cdef class ClassificationCriterion(Criterion): self.sum_right = NULL self.n_classes = NULL + # To handle missing values + self.n_missing = 0 + self.allow_missing = allow_missing + self.missing_direction = MISSING_DIR_UNDEF + self.sum_missing = NULL + self.sum_available = NULL + self.weighted_n_node_missing = 0.0 + self.weighted_n_node_available = 0.0 + safe_realloc(&self.n_classes, n_outputs) cdef SIZE_t k = 0 @@ -263,15 +304,23 @@ cdef class ClassificationCriterion(Criterion): self.sum_stride = sum_stride cdef SIZE_t n_elements = n_outputs * sum_stride - self.sum_total = calloc(n_elements, sizeof(double)) - self.sum_left = calloc(n_elements, sizeof(double)) - self.sum_right = calloc(n_elements, sizeof(double)) + + safe_realloc(&self.sum_total, n_elements) + safe_realloc(&self.sum_left, n_elements) + safe_realloc(&self.sum_right, n_elements) if (self.sum_total == NULL or self.sum_left == NULL or self.sum_right == NULL): raise MemoryError() + if allow_missing: + safe_realloc(&self.sum_missing, n_elements) + safe_realloc(&self.sum_available, n_elements) + + if self.sum_missing == NULL or self.sum_available == NULL: + raise MemoryError() + def __dealloc__(self): """Destructor.""" free(self.n_classes) @@ -279,7 +328,8 @@ cdef class ClassificationCriterion(Criterion): def __reduce__(self): return (type(self), (self.n_outputs, - sizet_ptr_to_ndarray(self.n_classes, self.n_outputs)), + sizet_ptr_to_ndarray(self.n_classes, self.n_outputs), + self.allow_missing), self.__getstate__()) cdef void init(self, DOUBLE_t* y, SIZE_t y_stride, @@ -349,22 +399,97 @@ cdef class ClassificationCriterion(Criterion): # Reset to pos=start self.reset() - cdef void reset(self) nogil: - """Reset the criterion at pos=start.""" + cdef void init_missing(self, SIZE_t n_missing) nogil: + """Separate the total statistics computed at the init, into missing and available statistics - self.pos = self.start + missing sample indices are stored from samples[end-missing-1:end] + """ + # NOTE init must have been called before + self.n_missing = n_missing + if n_missing == 0: + self.missing_direction = MISSING_DIR_UNDEF + return - self.weighted_n_left = 0.0 - self.weighted_n_right = self.weighted_n_node_samples + cdef SIZE_t end = self.end + cdef SIZE_t* samples = self.samples + cdef DOUBLE_t* sample_weight = self.sample_weight + cdef SIZE_t y_stride = self.y_stride + cdef DOUBLE_t* y = self.y + + cdef SIZE_t i + cdef SIZE_t p + cdef SIZE_t k + cdef SIZE_t c + cdef DOUBLE_t w = 1.0 + cdef SIZE_t offset = 0 + + # Keep the missing direction as right at the start + self.missing_direction = MISSING_DIR_RIGHT + self.weighted_n_node_missing = 0.0 + self.weighted_n_node_available = 0.0 + + cdef double* sum_missing = self.sum_missing + cdef double* sum_available = self.sum_available + cdef SIZE_t* n_classes = self.n_classes + cdef double* sum_total = self.sum_total + + # SELFNOTE + # Why can't we - memset(sum_missing, 0, self.n_outputs * self.sum_stride * sizeof(double)) + for k in range(self.n_outputs): + memset(sum_missing + offset, 0, n_classes[k] * sizeof(double)) + memset(sum_available + offset, 0, n_classes[k] * sizeof(double)) + offset += self.sum_stride + + # Compute sum_missing + for p in range(end-n_missing, end): + i = samples[p] + + # w is originally set to be 1.0, meaning that if no sample weights + # are given, the default weight of each sample is 1.0 + if sample_weight != NULL: + w = sample_weight[i] + + # Count weighted class frequency for each target + offset = i * y_stride + + for k in range(self.n_outputs): + c = y[offset + k] + sum_missing[k * self.sum_stride + c] += w + self.weighted_n_node_missing += w + + # Using the computed sum_missing, separate the sum_available + # from the sum_total + # {sum_total} --> {sum_available} + {sum_missing} + for k in range(self.n_outputs): + offset = k * self.sum_stride + for c in range(n_classes[k]): + i = c + offset + sum_available[i] = sum_total[i] - sum_missing[i] + self.weighted_n_node_available = self.weighted_n_node_samples - self.weighted_n_node_missing + + cdef void reset(self) nogil: + """Reset the criterion at pos=start. + + Move the missing values also to the right, if include_missing is True. + """ cdef double* sum_total = self.sum_total cdef double* sum_left = self.sum_left cdef double* sum_right = self.sum_right + cdef double* sum_missing = self.sum_missing + cdef double* sum_available = self.sum_available cdef SIZE_t* n_classes = self.n_classes cdef SIZE_t k + self.pos = self.start + + self.missing_direction = MISSING_DIR_RIGHT + self.weighted_n_left = 0.0 + self.weighted_n_right = self.weighted_n_node_samples + for k in range(self.n_outputs): + # Move everything from left to the right partition memset(sum_left, 0, n_classes[k] * sizeof(double)) memcpy(sum_right, sum_total, n_classes[k] * sizeof(double)) @@ -373,19 +498,25 @@ cdef class ClassificationCriterion(Criterion): sum_right += self.sum_stride cdef void reverse_reset(self) nogil: - """Reset the criterion at pos=end.""" - self.pos = self.end - - self.weighted_n_left = self.weighted_n_node_samples - self.weighted_n_right = 0.0 + """Reset the criterion at pos=end. Move the missing values also to the left""" cdef double* sum_total = self.sum_total cdef double* sum_left = self.sum_left cdef double* sum_right = self.sum_right + cdef double* sum_missing = self.sum_missing + cdef double* sum_available = self.sum_available cdef SIZE_t* n_classes = self.n_classes cdef SIZE_t k + # The missing values are stored at the end irrespective of the + # direction of criterion. + self.pos = self.end - self.n_missing + + self.missing_direction = MISSING_DIR_LEFT + self.weighted_n_left = self.weighted_n_node_samples + self.weighted_n_right = 0.0 + for k in range(self.n_outputs): memset(sum_right, 0, n_classes[k] * sizeof(double)) memcpy(sum_left, sum_total, n_classes[k] * sizeof(double)) @@ -405,7 +536,8 @@ cdef class ClassificationCriterion(Criterion): """ cdef DOUBLE_t* y = self.y cdef SIZE_t pos = self.pos - cdef SIZE_t end = self.end + # Ignore the missing sample indices stored at the end, when updating the pos + cdef SIZE_t end_available = self.end - self.n_missing cdef double* sum_left = self.sum_left cdef double* sum_right = self.sum_right @@ -421,6 +553,7 @@ cdef class ClassificationCriterion(Criterion): cdef SIZE_t c cdef SIZE_t label_index cdef DOUBLE_t w = 1.0 + cdef SIZE_t direction # Update statistics up to new_pos # @@ -428,9 +561,10 @@ cdef class ClassificationCriterion(Criterion): # sum_left[x] + sum_right[x] = sum_total[x] # and that sum_total is known, we are going to update # sum_left from the direction that require the least amount - # of computations, i.e. from pos to new_pos or from end to new_po. + # of computations, i.e. from pos to new_pos or from end - n_missing to + # new_pos. - if (new_pos - pos) <= (end - new_pos): + if (new_pos - pos) <= (end_available - new_pos): for p in range(pos, new_pos): i = samples[p] @@ -445,9 +579,12 @@ cdef class ClassificationCriterion(Criterion): self.weighted_n_left += w else: + # Reverse reset the available values alone. + direction = self.missing_direction self.reverse_reset() + self.move_missing(direction) - for p in range(end - 1, new_pos - 1, -1): + for p in range(end_available - 1, new_pos - 1, -1): i = samples[p] if sample_weight != NULL: @@ -472,6 +609,49 @@ cdef class ClassificationCriterion(Criterion): self.pos = new_pos + cdef void move_missing(self, SIZE_t direction=MISSING_DIR_LEFT) nogil: + """Update statistics by moving missing samples to the left partition""" + cdef SIZE_t old_dir = self.missing_direction + self.missing_direction = direction + # TODO prevent updating the missing_direction if n_missing <= 0 ? + # Or if already at left + if (old_dir == direction) or self.n_missing <= 0: + return + + cdef double* sum_missing = self.sum_missing + cdef double* sum_left = self.sum_left + cdef double* sum_right = self.sum_right + cdef SIZE_t* n_classes = self.n_classes + + cdef SIZE_t k + cdef SIZE_t c + + # If missing must be moved from right to left + if direction == MISSING_DIR_LEFT: + self.weighted_n_left += self.weighted_n_node_missing + self.weighted_n_right -= self.weighted_n_node_missing + + for k in range(self.n_outputs): + for c in range(n_classes[k]): + sum_left[c] += sum_missing[c] + sum_right[c] -= sum_missing[c] + sum_missing += self.sum_stride + sum_left += self.sum_stride + sum_right += self.sum_stride + + # If missing must be moved from left to right + elif direction == MISSING_DIR_RIGHT: + self.weighted_n_left -= self.weighted_n_node_missing + self.weighted_n_right += self.weighted_n_node_missing + + for k in range(self.n_outputs): + for c in range(n_classes[k]): + sum_left[c] -= sum_missing[c] + sum_right[c] += sum_missing[c] + sum_missing += self.sum_stride + sum_left += self.sum_stride + sum_right += self.sum_stride + cdef double node_impurity(self) nogil: pass @@ -687,7 +867,8 @@ cdef class RegressionCriterion(Criterion): cdef double sq_sum_total - def __cinit__(self, SIZE_t n_outputs, SIZE_t n_samples): + def __cinit__(self, SIZE_t n_outputs, SIZE_t n_samples, + bint allow_missing=False): """Initialize parameters for this criterion. Parameters @@ -734,8 +915,13 @@ cdef class RegressionCriterion(Criterion): self.sum_right == NULL): raise MemoryError() + # To handle missing values + self.allow_missing = allow_missing + def __reduce__(self): - return (type(self), (self.n_outputs, self.n_samples), self.__getstate__()) + return (type(self), + (self.n_outputs, self.n_samples, self.allow_missing), + self.__getstate__()) cdef void init(self, DOUBLE_t* y, SIZE_t y_stride, DOUBLE_t* sample_weight, double weighted_n_samples, SIZE_t* samples, SIZE_t start, @@ -818,6 +1004,7 @@ cdef class RegressionCriterion(Criterion): cdef SIZE_t k cdef DOUBLE_t w = 1.0 cdef DOUBLE_t y_ik + cdef SIZE_t direction # Update statistics up to new_pos # @@ -840,7 +1027,10 @@ cdef class RegressionCriterion(Criterion): self.weighted_n_left += w else: + # Reverse reset but retain the missing at the existing partition + direction = self.missing_direction self.reverse_reset() + self.move_missing(direction) for p in range(end - 1, new_pos - 1, -1): i = samples[p] @@ -983,7 +1173,8 @@ cdef class MAE(RegressionCriterion): cdef np.ndarray right_child cdef DOUBLE_t* node_medians - def __cinit__(self, SIZE_t n_outputs, SIZE_t n_samples): + def __cinit__(self, SIZE_t n_outputs, SIZE_t n_samples, + bint allow_missing=False): """Initialize parameters for this criterion. Parameters @@ -1321,5 +1512,5 @@ cdef class FriedmanMSE(MSE): diff = (self.weighted_n_right * total_sum_left - self.weighted_n_left * total_sum_right) / self.n_outputs - return (diff * diff / (self.weighted_n_left * self.weighted_n_right * + return (diff * diff / (self.weighted_n_left * self.weighted_n_right * self.weighted_n_node_samples)) diff --git a/sklearn/tree/_splitter.pxd b/sklearn/tree/_splitter.pxd index 4b97d01614b9f..ca78da2e42e2a 100644 --- a/sklearn/tree/_splitter.pxd +++ b/sklearn/tree/_splitter.pxd @@ -4,6 +4,7 @@ # Joel Nothman # Arnaud Joly # Jacob Schreiber +# Raghav R V # # License: BSD 3 clause @@ -19,6 +20,7 @@ ctypedef np.npy_float64 DOUBLE_t # Type of y, sample_weight ctypedef np.npy_intp SIZE_t # Type for indices and counters ctypedef np.npy_int32 INT32_t # Signed 32 bit integer ctypedef np.npy_uint32 UINT32_t # Unsigned 32 bit integer +ctypedef np.npy_uint8 BOOL_t # 8 bit boolean (np.bool8) cdef struct SplitRecord: # Data to track sample split @@ -30,6 +32,9 @@ cdef struct SplitRecord: double improvement # Impurity improvement given parent node. double impurity_left # Impurity of the left split. double impurity_right # Impurity of the right split. + # The child/partition to which the missing values should be sent to + SIZE_t missing_direction + SIZE_t n_missing # The number of missing samples at this split cdef class Splitter: # The splitter searches in the input space for a feature and a threshold @@ -64,6 +69,11 @@ cdef class Splitter: cdef SIZE_t y_stride cdef DOUBLE_t* sample_weight + cdef BOOL_t* missing_mask + cdef SIZE_t missing_mask_stride + cdef bint allow_missing + cdef SIZE_t n_missing + # The samples vector `samples` is maintained by the Splitter object such # that the samples contained in a node are contiguous. With this setting, # `node_split` reorganizes the node samples `samples[start:end]` in two @@ -83,7 +93,8 @@ cdef class Splitter: # Methods cdef void init(self, object X, np.ndarray y, DOUBLE_t* sample_weight, - np.ndarray X_idx_sorted=*) except * + np.ndarray X_idx_sorted=*, + np.ndarray missing_mask=*) except * cdef void node_reset(self, SIZE_t start, SIZE_t end, double* weighted_n_node_samples) nogil @@ -95,4 +106,4 @@ cdef class Splitter: cdef void node_value(self, double* dest) nogil - cdef double node_impurity(self) nogil \ No newline at end of file + cdef double node_impurity(self) nogil diff --git a/sklearn/tree/_splitter.pyx b/sklearn/tree/_splitter.pyx index 0617508aab236..13d39cb2266a1 100644 --- a/sklearn/tree/_splitter.pyx +++ b/sklearn/tree/_splitter.pyx @@ -43,13 +43,21 @@ cdef DTYPE_t FEATURE_THRESHOLD = 1e-7 # in SparseSplitter cdef DTYPE_t EXTRACT_NNZ_SWITCH = 0.1 -cdef inline void _init_split(SplitRecord* self, SIZE_t start_pos) nogil: +# Constants to handle missing values +cdef SIZE_t MISSING_DIR_LEFT = 0 +cdef SIZE_t MISSING_DIR_RIGHT = 1 + +cdef inline void _init_split(SplitRecord* self, + SIZE_t start_pos) nogil: self.impurity_left = INFINITY self.impurity_right = INFINITY self.pos = start_pos self.feature = 0 self.threshold = 0. self.improvement = -INFINITY + # SELFNOTE changed def. dir UNDEF --> RIGHT + self.missing_direction = MISSING_DIR_RIGHT + self.n_missing = 0 cdef class Splitter: """Abstract splitter class. @@ -60,7 +68,7 @@ cdef class Splitter: def __cinit__(self, Criterion criterion, SIZE_t max_features, SIZE_t min_samples_leaf, double min_weight_leaf, - object random_state, bint presort): + object random_state, bint presort, bint allow_missing): """ Parameters ---------- @@ -102,6 +110,11 @@ cdef class Splitter: self.random_state = random_state self.presort = presort + self.allow_missing = allow_missing + self.missing_mask = NULL + self.missing_mask_stride = 0 + self.n_missing = 0 + def __dealloc__(self): """Destructor.""" @@ -120,7 +133,8 @@ cdef class Splitter: object X, np.ndarray[DOUBLE_t, ndim=2, mode="c"] y, DOUBLE_t* sample_weight, - np.ndarray X_idx_sorted=None) except *: + np.ndarray X_idx_sorted=None, + np.ndarray missing_mask=None) except*: """Initialize the splitter. Take in the input data X, the target Y, and optional sample weights. @@ -181,6 +195,11 @@ cdef class Splitter: self.sample_weight = sample_weight + if self.allow_missing: + self.missing_mask = missing_mask.data + self.missing_mask_stride = missing_mask.strides[1] / missing_mask.itemsize + self.n_missing = 0 + cdef void node_reset(self, SIZE_t start, SIZE_t end, double* weighted_n_node_samples) nogil: """Reset splitter on node samples[start:end]. @@ -242,7 +261,7 @@ cdef class BaseDenseSplitter(Splitter): def __cinit__(self, Criterion criterion, SIZE_t max_features, SIZE_t min_samples_leaf, double min_weight_leaf, - object random_state, bint presort): + object random_state, bint presort, bint allow_missing): self.X = NULL self.X_sample_stride = 0 @@ -250,6 +269,7 @@ cdef class BaseDenseSplitter(Splitter): self.X_idx_sorted_ptr = NULL self.X_idx_sorted_stride = 0 self.sample_mask = NULL + # SELFNOTE Why doesn't the parent cinit take care of this? self.presort = presort def __dealloc__(self): @@ -261,11 +281,13 @@ cdef class BaseDenseSplitter(Splitter): object X, np.ndarray[DOUBLE_t, ndim=2, mode="c"] y, DOUBLE_t* sample_weight, - np.ndarray X_idx_sorted=None) except *: + np.ndarray X_idx_sorted=None, + np.ndarray missing_mask=None) except *: """Initialize the splitter.""" # Call parent init - Splitter.init(self, X, y, sample_weight) + Splitter.init(self, X=X, y=y, sample_weight=sample_weight, + X_idx_sorted=X_idx_sorted, missing_mask=missing_mask) # Initialize X cdef np.ndarray X_ndarray = X @@ -293,7 +315,8 @@ cdef class BestSplitter(BaseDenseSplitter): self.min_samples_leaf, self.min_weight_leaf, self.random_state, - self.presort), self.__getstate__()) + self.presort, + self.allow_missing), self.__getstate__()) cdef void node_split(self, double impurity, SplitRecord* split, SIZE_t* n_constant_features) nogil: @@ -324,13 +347,7 @@ cdef class BestSplitter(BaseDenseSplitter): cdef double best_proxy_improvement = -INFINITY cdef SIZE_t f_i = n_features - cdef SIZE_t f_j - cdef SIZE_t tmp - cdef SIZE_t p - cdef SIZE_t feature_idx_offset - cdef SIZE_t feature_offset - cdef SIZE_t i - cdef SIZE_t j + cdef SIZE_t f_j, tmp, p, feature_idx_offset, feature_offset, i, j cdef SIZE_t n_visited_features = 0 # Number of features discovered to be constant during the split search @@ -343,10 +360,19 @@ cdef class BestSplitter(BaseDenseSplitter): cdef DTYPE_t current_feature_value cdef SIZE_t partition_end + # To handle missing + cdef SIZE_t missing_direction, end_available = end, start_available + cdef SIZE_t q, miss_mask_feat_offset, directions_to_check + cdef SIZE_t n_missing = 0 + cdef BOOL_t* missing_mask = self.missing_mask + cdef SIZE_t pos_missing_offset = 0 # missing offset based on missing direction to help compute the stopping criteria correctly + _init_split(&best, end) if self.presort == 1: for p in range(start, end): + # To denote which ones are the samples from the current split + # that must be extracted from the sorted_idx_mask sample_mask[samples[p]] = 1 # Sample up to max_features without replacement using a @@ -400,23 +426,46 @@ cdef class BestSplitter(BaseDenseSplitter): # presorting, or by copying the values into an array and # sorting the array in a manner which utilizes the cache more # effectively. + n_missing = 0 + end_available = end + miss_mask_feat_offset = self.missing_mask_stride * current.feature + if self.presort == 1: p = start feature_idx_offset = self.X_idx_sorted_stride * current.feature - for i in range(self.n_total_samples): + # Filter the samples based on the samples_mask and move the + # missing, if any to the end + for i in range(self.n_total_samples): j = X_idx_sorted[i + feature_idx_offset] if sample_mask[j] == 1: samples[p] = j Xf[p] = X[self.X_sample_stride * j + feature_offset] p += 1 + + if self.allow_missing: + end_available = move_missing_to_end( + samples, missing_mask, start, end, miss_mask_feat_offset) + n_missing = end - end_available + else: - for i in range(start, end): + # Move the missing to the end, extract available values + # and sort them. + if self.allow_missing: + end_available = move_missing_to_end( + samples, missing_mask, start, end, miss_mask_feat_offset) + n_missing = end - end_available + + for i in range(start, end_available): Xf[i] = X[self.X_sample_stride * samples[i] + feature_offset] - sort(Xf + start, samples + start, end - start) + sort(Xf + start, samples + start, end_available - start) - if Xf[end - 1] <= Xf[start] + FEATURE_THRESHOLD: + # If there are no missing values and all the available values are constant, + # in the current feature + if (((Xf[end_available - 1] <= Xf[start] + FEATURE_THRESHOLD) and (n_missing == 0)) or + # or if all the values are missing + (n_missing == (end - start))): features[f_j] = features[n_total_constants] features[n_total_constants] = current.feature @@ -427,64 +476,156 @@ cdef class BestSplitter(BaseDenseSplitter): f_i -= 1 features[f_i], features[f_j] = features[f_j], features[f_i] - # Evaluate all splits + + self.criterion.init_missing(n_missing) + current.n_missing = n_missing + + # SELFNOTE changed for n_missing == 0 : UNDEF --> RIGHT + current.missing_direction = MISSING_DIR_RIGHT + if n_missing > 0: + # Search for the best split by sending the missing values + # both sides + directions_to_check = 2 + else: + directions_to_check = 1 + + pos_missing_offset = 0 self.criterion.reset() - p = start - while p < end: - while (p + 1 < end and - Xf[p + 1] <= Xf[p] + FEATURE_THRESHOLD): + while directions_to_check > 0: + p = start + while p < end_available: + while (p + 1 < end_available and + Xf[p + 1] <= Xf[p] + FEATURE_THRESHOLD): + p += 1 + + # (p + 1 >= end_available) or (X[samples[p + 1], current.feature] > + # X[samples[p], current.feature]) p += 1 + # (p >= end_available) or (X[samples[p], current.feature] > + # X[samples[p - 1], current.feature]) + + if p < end_available: + current.pos = p + + if ((((p + pos_missing_offset) - start) < min_samples_leaf) or + ((end - (p + pos_missing_offset)) < min_samples_leaf)): + continue + + # The criterion is updated with the split position w.r.t + # start always as the missing sample indices + # remain at the end of the `samples` array, + # irrespective of the missing_direction + self.criterion.update(current.pos) + + # Reject if min_weight_leaf is not satisfied + if ((self.criterion.weighted_n_left < min_weight_leaf) or + (self.criterion.weighted_n_right < min_weight_leaf)): + continue + + current_proxy_improvement = self.criterion.proxy_impurity_improvement() + + if current_proxy_improvement > best_proxy_improvement: + best_proxy_improvement = current_proxy_improvement + current.threshold = (Xf[p - 1] + Xf[p]) / 2.0 + if current.threshold == Xf[p]: + current.threshold = Xf[p - 1] + + best = current # copy + + directions_to_check -= 1 + if directions_to_check == 1: + # Check if sending missing values alone to the right gives the best split + # (where all the available values remain at the left partition) + # (We need to reverse reset to get this partition + # as criterion does not update beyond the last sample) + self.criterion.reverse_reset() + self.criterion.move_missing(MISSING_DIR_RIGHT) + current_proxy_improvement = self.criterion.proxy_impurity_improvement() - # (p + 1 >= end) or (X[samples[p + 1], current.feature] > - # X[samples[p], current.feature]) - p += 1 - # (p >= end) or (X[samples[p], current.feature] > - # X[samples[p - 1], current.feature]) + # Now check if it satisfies the stopping criterion + # and is better than current best + if ((n_missing >= min_samples_leaf) and + ((end_available - start) >= min_samples_leaf) and + (self.criterion.weighted_n_left >= min_weight_leaf) and + (self.criterion.weighted_n_right >= min_weight_leaf) and + (current_proxy_improvement > best_proxy_improvement)): + best_proxy_improvement = current_proxy_improvement - if p < end: - current.pos = p + # If yes, retain the missing at right, and the available at the left + # and set the threshold to infinity, so all values, which are not missing, + # are sent to the left + current.threshold = INFINITY + # Don't move the missing samples, just + # set pos to end_available, to denote that all + # the available samples must be moved left + current.pos = end_available + best = current - # Reject if min_samples_leaf is not guaranteed - if (((current.pos - start) < min_samples_leaf) or - ((end - current.pos) < min_samples_leaf)): - continue + # Now compute all the other splits with the missing kept at the left + # NOTE do not offset the criterion's position, as the missing sample + # indices are stored at the end in the `samples`, irrespective of the + # missing_direction + pos_missing_offset = n_missing + self.criterion.reset() + self.criterion.move_missing(MISSING_DIR_LEFT) + current.missing_direction = MISSING_DIR_LEFT - self.criterion.update(current.pos) - # Reject if min_weight_leaf is not satisfied - if ((self.criterion.weighted_n_left < min_weight_leaf) or - (self.criterion.weighted_n_right < min_weight_leaf)): - continue + # Update n_missing/end_available for the best's feature + # XXX Not doing this is what screwed up the implementation :@ + n_missing = best.n_missing + start_available = start + end_available = end - n_missing - current_proxy_improvement = self.criterion.proxy_impurity_improvement() + # The samples array is now organized based on the last feature searched for split + # Reorganize this samples array based on the best split's feature. - if current_proxy_improvement > best_proxy_improvement: - best_proxy_improvement = current_proxy_improvement - current.threshold = (Xf[p - 1] + Xf[p]) / 2.0 + # If missing_direction is UNDEF (n_missing is 0) + # {Right partition available vals} | {Left partition available vals} - if current.threshold == Xf[p]: - current.threshold = Xf[p - 1] + # If missing_direction is LEFT + # {Missing vals + left partition available vals} | {Right partition available vals} - best = current # copy + # If missing_direction is RIGHT + # {left partition available vals} | {Right partition available vals + Missing vals} - # Reorganize into samples[start:best.pos] + samples[best.pos:end] - if best.pos < end: + # If best pos is invalid (at the end) don't allow + # except when the right split contains all the missing values + if (best.pos < end_available or + (best.pos == end_available and + best.missing_direction == MISSING_DIR_RIGHT)): feature_offset = X_feature_stride * best.feature - partition_end = end - p = start - - while p < partition_end: + miss_mask_feat_offset = self.missing_mask_stride * best.feature + + if n_missing > 0: + if best.missing_direction == MISSING_DIR_RIGHT: + move_missing_to_end(samples, missing_mask, start, end, + miss_mask_feat_offset) + + elif best.missing_direction == MISSING_DIR_LEFT: + move_missing_to_start(samples, missing_mask, start, end, + miss_mask_feat_offset) + + # Missing vals were assumed to be in right parition, correct that assumption + best.pos += n_missing + start_available += n_missing + end_available += n_missing + + p = start_available + q = end_available + # Now repartition the samples based on the available values. + while p < q: if X[X_sample_stride * samples[p] + feature_offset] <= best.threshold: p += 1 - else: - partition_end -= 1 - - tmp = samples[partition_end] - samples[partition_end] = samples[p] - samples[p] = tmp + q -= 1 + samples[q], samples[p] = samples[p], samples[q] + # Now that we have sent the missing values to start/end of `samples`, + # forget that they exist and update the criterion based on the + # adjusted best.pos + self.criterion.init_missing(0) self.criterion.reset() self.criterion.update(best.pos) best.improvement = self.criterion.impurity_improvement(impurity) @@ -621,6 +762,110 @@ cdef void heapsort(DTYPE_t* Xf, SIZE_t* samples, SIZE_t n) nogil: end = end - 1 +cdef inline SIZE_t move_missing_to_end(SIZE_t* samples, BOOL_t* missing_mask, + SIZE_t start, SIZE_t end, + SIZE_t miss_mask_feat_offset) nogil: + """Move the indices with missing values to the end of the current split. + + Returns + ------- + end_available : SIZE_t + The position next to the last available value. + """ + cdef SIZE_t p = start, q = end, j + + while p < q: + j = samples[p] + if missing_mask[j + miss_mask_feat_offset] == 1: + q -= 1 + samples[p] = samples[q] + samples[q] = j + else: + p += 1 + + # end_available is q (as q - 1 was the posn. of the last avail. sample.) + return q + + +cdef inline SIZE_t move_missing_to_start(SIZE_t* samples, BOOL_t* missing_mask, + SIZE_t start, SIZE_t end, + SIZE_t miss_mask_feat_offset) nogil: + """Move the indices with missing values to the start of the current split. + + Returns + ------- + end_available : SIZE_t + The position next to the last available value. + """ + cdef SIZE_t p = start, q = end - 1, j + + while q > p: + j = samples[q] + if missing_mask[j + miss_mask_feat_offset] == 1: + samples[q] = samples[p] + samples[p] = j + p += 1 + else: + q -= 1 + + # end_available is q + 1 (as q is the posn. of the last available sample) + return q + 1 + + +cdef inline SIZE_t move_missing_to_end_sparse(SIZE_t* samples, + INT32_t* X_indices, + SIZE_t* index_to_samples, + BOOL_t* missing_mask, + SIZE_t samples_start, + SIZE_t samples_end, + SIZE_t indptr_start, + SIZE_t indptr_end) nogil: + """Move missing to the end in a sparse CSC matrix and return n_missing""" + cdef SIZE_t i, samples_i + + cdef SIZE_t n_missing = 0 + + for i in range(indptr_start, indptr_end): + samples_i = index_to_samples[X_indices[i]] + if samples_start <= samples_i <= samples_end: + if missing_mask[i] == 1: + n_missing += 1 + # If X_data[X_i] is missing, move the corresponding sample, + # samples[samples_i] (or samples[index_to_samples[X_i]]) + # to the end + sparse_swap(index_to_samples, samples, + samples_i, samples_end - n_missing) + + return n_missing + + +cdef inline SIZE_t move_missing_to_start_sparse(SIZE_t* samples, + INT32_t* X_indices, + SIZE_t* index_to_samples, + BOOL_t* missing_mask, + SIZE_t samples_start, + SIZE_t samples_end, + SIZE_t indptr_start, + SIZE_t indptr_end) nogil: + """Move missing to the start in a sparse CSC matrix and return n_missing""" + cdef SIZE_t i, samples_i + + cdef SIZE_t n_missing = 0 + + for i in range(indptr_start, indptr_end): + samples_i = index_to_samples[X_indices[i]] + if samples_start <= samples_i <= samples_end: + if missing_mask[i] == 1: + # If X_data[X_i] is missing, move the corresponding sample, + # samples[samples_i] (or samples[index_to_samples[X_i]]) + # to the start + sparse_swap(index_to_samples, samples, + samples_i, samples_start + n_missing) + n_missing += 1 + + return n_missing + + cdef class RandomSplitter(BaseDenseSplitter): """Splitter for finding the best random split.""" def __reduce__(self): @@ -629,7 +874,8 @@ cdef class RandomSplitter(BaseDenseSplitter): self.min_samples_leaf, self.min_weight_leaf, self.random_state, - self.presort), self.__getstate__()) + self.presort, + self.allow_missing), self.__getstate__()) cdef void node_split(self, double impurity, SplitRecord* split, SIZE_t* n_constant_features) nogil: @@ -660,7 +906,7 @@ cdef class RandomSplitter(BaseDenseSplitter): cdef SIZE_t f_j cdef SIZE_t p cdef SIZE_t tmp - cdef SIZE_t feature_stride + cdef SIZE_t feature_offset # Number of features discovered to be constant during the split search cdef SIZE_t n_found_constants = 0 # Number of features known to be constant and drawn without replacement @@ -674,6 +920,13 @@ cdef class RandomSplitter(BaseDenseSplitter): cdef DTYPE_t current_feature_value cdef SIZE_t partition_end + # To handle missing + cdef SIZE_t missing_direction, end_available = end, start_available + cdef SIZE_t q, miss_mask_feat_offset, directions_to_check + cdef SIZE_t n_missing = 0 + cdef BOOL_t* missing_mask = self.missing_mask + cdef SIZE_t pos_missing_offset = 0 # missing offset based on missing direction to help compute the stopping criteria correctly + _init_split(&best, end) # Sample up to max_features without replacement using a @@ -721,15 +974,27 @@ cdef class RandomSplitter(BaseDenseSplitter): # f_j in the interval [n_total_constants, f_i[ current.feature = features[f_j] - feature_stride = X_feature_stride * current.feature + feature_offset = X_feature_stride * current.feature + + p = start + n_missing = 0 + end_available = end + miss_mask_feat_offset = self.missing_mask_stride * current.feature + + # First move any missing values to the end + if self.allow_missing: + end_available = move_missing_to_end( + samples, missing_mask, start, end, miss_mask_feat_offset) + n_missing = end - end_available # Find min, max - min_feature_value = X[X_sample_stride * samples[start] + feature_stride] + # NOTE Do this after moving the missing values to the end + min_feature_value = X[X_sample_stride * samples[start] + feature_offset] max_feature_value = min_feature_value Xf[start] = min_feature_value - for p in range(start + 1, end): - current_feature_value = X[X_sample_stride * samples[p] + feature_stride] + for p in range(start + 1, end_available): + current_feature_value = X[X_sample_stride * samples[p] + feature_offset] Xf[p] = current_feature_value if current_feature_value < min_feature_value: @@ -737,7 +1002,14 @@ cdef class RandomSplitter(BaseDenseSplitter): elif current_feature_value > max_feature_value: max_feature_value = current_feature_value - if max_feature_value <= min_feature_value + FEATURE_THRESHOLD: + # If all the feature values are missing, max_feature_value and + # min_feature_value could be nan, so we first check that + if ((n_missing == end - start) or + # Now check for the case when all the features values + # are constant and there is no missing + ((max_feature_value <= min_feature_value + FEATURE_THRESHOLD) and + (n_missing == 0))): + features[f_j] = features[n_total_constants] features[n_total_constants] = current.feature @@ -752,12 +1024,12 @@ cdef class RandomSplitter(BaseDenseSplitter): current.threshold = rand_uniform(min_feature_value, max_feature_value, random_state) - if current.threshold == max_feature_value: current.threshold = min_feature_value - # Partition - partition_end = end + # Partition the available values based on the randomly + # chosen threshold + partition_end = end_available p = start while p < partition_end: current_feature_value = Xf[p] @@ -775,17 +1047,38 @@ cdef class RandomSplitter(BaseDenseSplitter): current.pos = partition_end - # Reject if min_samples_leaf is not guaranteed - if (((current.pos - start) < min_samples_leaf) or - ((end - current.pos) < min_samples_leaf)): - continue + self.criterion.init_missing(n_missing) + current.n_missing = n_missing + pos_missing_offset = 0 + if n_missing > 0: + # When the randomly chosen threshold (before or after + # correction for the case of max_feature_value) is + # min_feature_value, set the threshold to INFINITY + # and send all the missing values to the right. + if current.threshold == min_feature_value: + current.threshold = INFINITY + current.missing_direction = MISSING_DIR_RIGHT + else: + # If not choose the missing direction randomly + current.missing_direction = rand_int(0, 2, random_state) + else: + # SELFNOTE undef --> right + current.missing_direction = MISSING_DIR_RIGHT # Evaluate split - self.criterion.reset() + if current.missing_direction == MISSING_DIR_LEFT: + self.criterion.reset() + self.criterion.move_missing(MISSING_DIR_LEFT) + pos_missing_offset = n_missing + else: + self.criterion.reset() self.criterion.update(current.pos) - # Reject if min_weight_leaf is not satisfied - if ((self.criterion.weighted_n_left < min_weight_leaf) or + # min_samples_leaf should be guaranteed + if (((current.pos + pos_missing_offset - start) < min_samples_leaf) or + ((end - (current.pos + pos_missing_offset)) < min_samples_leaf) or + # min_weight_leaf should be satisfied + (self.criterion.weighted_n_left < min_weight_leaf) or (self.criterion.weighted_n_right < min_weight_leaf)): continue @@ -795,25 +1088,46 @@ cdef class RandomSplitter(BaseDenseSplitter): best_proxy_improvement = current_proxy_improvement best = current # copy - # Reorganize into samples[start:best.pos] + samples[best.pos:end] - feature_stride = X_feature_stride * best.feature - if best.pos < end: - if current.feature != best.feature: - partition_end = end - p = start - - while p < partition_end: - if X[X_sample_stride * samples[p] + feature_stride] <= best.threshold: - p += 1 - - else: - partition_end -= 1 - - tmp = samples[partition_end] - samples[partition_end] = samples[p] - samples[p] = tmp + # Update n_missing for the best's feature + n_missing = best.n_missing + start_available = start + end_available = end - n_missing + # Reorganize into samples[start:best.pos] + samples[best.pos:end] + if (best.pos < end_available or + (best.pos == end_available and + best.missing_direction == MISSING_DIR_RIGHT)): + feature_offset = X_feature_stride * best.feature + miss_mask_feat_offset = self.missing_mask_stride * best.feature + + # The missing values should be moved to the right partition + if n_missing > 0: + if best.missing_direction == MISSING_DIR_RIGHT: + move_missing_to_end(samples, missing_mask, start, end, + miss_mask_feat_offset) + elif best.missing_direction == MISSING_DIR_LEFT: + move_missing_to_start(samples, missing_mask, start, end, + miss_mask_feat_offset) + + # Adjust best.pos to include the missing values which is at start now + best.pos += n_missing + start_available += n_missing + end_available += n_missing + + p = start_available + q = end_available + + while p < q: + if X[X_sample_stride * samples[p] + feature_offset] <= best.threshold: + p += 1 + else: + q -= 1 + samples[q], samples[p] = samples[p], samples[q] + # Now that we have sent the missing values to start/end of `samples`, + # forget that they exist and update the criterion based on the + # adjusted best.pos + self.criterion.init_missing(0) self.criterion.reset() self.criterion.update(best.pos) best.improvement = self.criterion.impurity_improvement(impurity) @@ -848,7 +1162,7 @@ cdef class BaseSparseSplitter(Splitter): def __cinit__(self, Criterion criterion, SIZE_t max_features, SIZE_t min_samples_leaf, double min_weight_leaf, - object random_state, bint presort): + object random_state, bint presort, bint allow_missing): # Parent __cinit__ is automatically called self.X_data = NULL @@ -869,11 +1183,13 @@ cdef class BaseSparseSplitter(Splitter): object X, np.ndarray[DOUBLE_t, ndim=2, mode="c"] y, DOUBLE_t* sample_weight, - np.ndarray X_idx_sorted=None) except *: + np.ndarray X_idx_sorted=None, + np.ndarray missing_mask=None) except *: """Initialize the splitter.""" # Call parent init - Splitter.init(self, X, y, sample_weight) + Splitter.init(self, X=X, y=y, sample_weight=sample_weight, + X_idx_sorted=None, missing_mask=missing_mask) if not isinstance(X, csc_matrix): raise ValueError("X should be in csc format") @@ -905,9 +1221,10 @@ cdef class BaseSparseSplitter(Splitter): index_to_samples[samples[p]] = p cdef inline SIZE_t _partition(self, double threshold, + SIZE_t start_available, SIZE_t end_available, SIZE_t end_negative, SIZE_t start_positive, SIZE_t zero_pos) nogil: - """Partition samples[start:end] based on threshold.""" + """Partition samples[start_available:end_available] based on threshold.""" cdef double value cdef SIZE_t partition_end @@ -918,11 +1235,11 @@ cdef class BaseSparseSplitter(Splitter): cdef SIZE_t* index_to_samples = self.index_to_samples if threshold < 0.: - p = self.start + p = start_available partition_end = end_negative elif threshold > 0.: p = start_positive - partition_end = self.end + partition_end = end_available else: # Data are already split return zero_pos @@ -942,18 +1259,20 @@ cdef class BaseSparseSplitter(Splitter): return partition_end - cdef inline void extract_nnz(self, SIZE_t feature, + cdef inline void extract_nnz(self, SIZE_t feature, SIZE_t missing_direction, + SIZE_t* start_available, SIZE_t* end_available, SIZE_t* end_negative, SIZE_t* start_positive, bint* is_samples_sorted) nogil: """Extract and partition values for a given feature. The extracted values are partitioned between negative values - Xf[start:end_negative[0]] and positive values Xf[start_positive[0]:end]. - The samples and index_to_samples are modified according to this - partition. + Xf[start:end_negative[0]] and positive values + Xf[start_positive[0]:end-n_missing]. Missing values if any are stored + at the end. The samples and index_to_samples are modified according to + this partition. The extraction corresponds to the intersection between the arrays - X_indices[indptr_start:indptr_end] and samples[start:end]. + X_indices[indptr_start:indptr_end] and samples[start:end-n_missing]. This is done efficiently using either an index_to_samples based approach or binary search based approach. @@ -962,6 +1281,8 @@ cdef class BaseSparseSplitter(Splitter): feature : SIZE_t, Index of the feature we want to extract non zero value. + missing_direction : SIZE_t + The direction for sending the missing values. end_negative, start_positive : SIZE_t*, SIZE_t*, Return extracted non zero values in self.samples[start:end] where @@ -974,11 +1295,41 @@ cdef class BaseSparseSplitter(Splitter): the sorted version of self.samples[start:end]. """ - cdef SIZE_t indptr_start = self.X_indptr[feature], - cdef SIZE_t indptr_end = self.X_indptr[feature + 1] - cdef SIZE_t n_indices = (indptr_end - indptr_start) - cdef SIZE_t n_samples = self.end - self.start - + cdef: + SIZE_t indptr_start = self.X_indptr[feature], + SIZE_t indptr_end = self.X_indptr[feature + 1] + SIZE_t n_indices = (indptr_end - indptr_start) + SIZE_t start, end, start_available_, end_available_, i + SIZE_t n_samples = self.end - self.start, n_missing = 0 + BOOL_t* missing_mask = self.missing_mask + SIZE_t* index_to_samples = self.index_to_samples + + start = start_available_ = self.start + end = end_available_ = self.end + + if self.allow_missing: + # Move those missing values to the end + if missing_direction == MISSING_DIR_RIGHT: + n_missing = move_missing_to_end_sparse(self.samples, + self.X_indices, + index_to_samples, + missing_mask, + self.start, + self.end, + indptr_start, + indptr_end) + end_available_ -= n_missing + + elif missing_direction == MISSING_DIR_LEFT: + n_missing = move_missing_to_start_sparse(self.samples, + self.X_indices, + index_to_samples, + missing_mask, + self.start, + self.end, + indptr_start, + indptr_end) + start_available_ += n_missing # Use binary search if n_samples * log(n_indices) < # n_indices and index_to_samples approach otherwise. # O(n_samples * log(n_indices)) is the running time of binary @@ -988,7 +1339,8 @@ cdef class BaseSparseSplitter(Splitter): n_samples * log(n_indices) < EXTRACT_NNZ_SWITCH * n_indices): extract_nnz_binary_search(self.X_indices, self.X_data, indptr_start, indptr_end, - self.samples, self.start, self.end, + self.samples, + start_available_, end_available_, self.index_to_samples, self.feature_values, end_negative, start_positive, @@ -999,11 +1351,14 @@ cdef class BaseSparseSplitter(Splitter): else: extract_nnz_index_to_samples(self.X_indices, self.X_data, indptr_start, indptr_end, - self.samples, self.start, self.end, + self.samples, + start_available_, end_available_, self.index_to_samples, self.feature_values, end_negative, start_positive) + start_available[0] = start_available_ + end_available[0] = end_available_ cdef int compare_SIZE_t(const void* a, const void* b) nogil: """Comparison function for sort.""" @@ -1040,8 +1395,8 @@ cdef inline void extract_nnz_index_to_samples(INT32_t* X_indices, INT32_t indptr_start, INT32_t indptr_end, SIZE_t* samples, - SIZE_t start, - SIZE_t end, + SIZE_t start_available, + SIZE_t end_available, SIZE_t* index_to_samples, DTYPE_t* Xf, SIZE_t* end_negative, @@ -1052,11 +1407,11 @@ cdef inline void extract_nnz_index_to_samples(INT32_t* X_indices, """ cdef INT32_t k cdef SIZE_t index - cdef SIZE_t end_negative_ = start - cdef SIZE_t start_positive_ = end + cdef SIZE_t end_negative_ = start_available + cdef SIZE_t start_positive_ = end_available for k in range(indptr_start, indptr_end): - if start <= index_to_samples[X_indices[k]] < end: + if start_available <= index_to_samples[X_indices[k]] < end_available: if X_data[k] > 0: start_positive_ -= 1 Xf[start_positive_] = X_data[k] @@ -1080,8 +1435,8 @@ cdef inline void extract_nnz_binary_search(INT32_t* X_indices, INT32_t indptr_start, INT32_t indptr_end, SIZE_t* samples, - SIZE_t start, - SIZE_t end, + SIZE_t start_available, + SIZE_t end_available, SIZE_t* index_to_samples, DTYPE_t* Xf, SIZE_t* end_negative, @@ -1090,8 +1445,8 @@ cdef inline void extract_nnz_binary_search(INT32_t* X_indices, bint* is_samples_sorted) nogil: """Extract and partition values for a given feature using binary search. - If n_samples = end - start and n_indices = indptr_end - indptr_start, - the complexity is + If n_samples = end_available - start_available and + n_indices = indptr_end - indptr_start, the complexity is O((1 - is_samples_sorted[0]) * n_samples * log(n_samples) + n_samples * log(n_indices)). @@ -1099,28 +1454,28 @@ cdef inline void extract_nnz_binary_search(INT32_t* X_indices, cdef SIZE_t n_samples if not is_samples_sorted[0]: - n_samples = end - start - memcpy(sorted_samples + start, samples + start, + n_samples = end_available - start_available + memcpy(sorted_samples + start_available, samples + start_available, n_samples * sizeof(SIZE_t)) - qsort(sorted_samples + start, n_samples, sizeof(SIZE_t), + qsort(sorted_samples + start_available, n_samples, sizeof(SIZE_t), compare_SIZE_t) is_samples_sorted[0] = 1 while (indptr_start < indptr_end and - sorted_samples[start] > X_indices[indptr_start]): + sorted_samples[start_available] > X_indices[indptr_start]): indptr_start += 1 while (indptr_start < indptr_end and - sorted_samples[end - 1] < X_indices[indptr_end - 1]): + sorted_samples[end_available - 1] < X_indices[indptr_end - 1]): indptr_end -= 1 - cdef SIZE_t p = start + cdef SIZE_t p = start_available cdef SIZE_t index cdef SIZE_t k - cdef SIZE_t end_negative_ = start - cdef SIZE_t start_positive_ = end + cdef SIZE_t end_negative_ = start_available + cdef SIZE_t start_positive_ = end_available - while (p < end and indptr_start < indptr_end): + while (p < end_available and indptr_start < indptr_end): # Find index of sorted_samples[p] in X_indices binary_search(X_indices, indptr_start, indptr_end, sorted_samples[p], &k, &indptr_start) @@ -1164,7 +1519,8 @@ cdef class BestSparseSplitter(BaseSparseSplitter): self.min_samples_leaf, self.min_weight_leaf, self.random_state, - self.presort), self.__getstate__()) + self.presort, + self.allow_missing), self.__getstate__()) cdef void node_split(self, double impurity, SplitRecord* split, SIZE_t* n_constant_features) nogil: @@ -1213,8 +1569,14 @@ cdef class BestSparseSplitter(BaseSparseSplitter): cdef SIZE_t p_prev cdef bint is_samples_sorted = 0 # indicate is sorted_samples is # inititialized - - # We assume implicitely that end_positive = end and + # To handle missing + cdef SIZE_t missing_direction, end_available = end, start_available + cdef SIZE_t q, directions_to_check + cdef SIZE_t n_missing = 0 + cdef BOOL_t* missing_mask = self.missing_mask + cdef SIZE_t pos_missing_offset = 0 # missing offset based on missing direction to help compute the stopping criteria correctly + + # We assume implicitely that end_positive = end_available and # start_negative = start cdef SIZE_t start_positive cdef SIZE_t end_negative @@ -1265,19 +1627,29 @@ cdef class BestSparseSplitter(BaseSparseSplitter): # f_j in the interval [n_total_constants, f_i[ current.feature = features[f_j] - self.extract_nnz(current.feature, + p = start + n_missing = 0 + missing_direction = MISSING_DIR_RIGHT + start_available = start + end_available = end + + # Now partition the available samples into + # [+ves, zeros, -ves, MVs] or [MVs, -ves, zeros, +ves] + self.extract_nnz(current.feature, missing_direction, + &start_available, &end_available, &end_negative, &start_positive, &is_samples_sorted) + n_missing = end - end_available # Sort the positive and negative parts of `Xf` sort(Xf + start, samples + start, end_negative - start) sort(Xf + start_positive, samples + start_positive, - end - start_positive) + end_available - start_positive) # Update index_to_samples to take into account the sort for p in range(start, end_negative): index_to_samples[samples[p]] = p - for p in range(start_positive, end): + for p in range(start_positive, end_available): index_to_samples[samples[p]] = p # Add one or two zeros in Xf, if there is any @@ -1289,7 +1661,7 @@ cdef class BestSparseSplitter(BaseSparseSplitter): Xf[end_negative] = 0. end_negative += 1 - if Xf[end - 1] <= Xf[start] + FEATURE_THRESHOLD: + if Xf[end_available - 1] <= Xf[start] + FEATURE_THRESHOLD: features[f_j] = features[n_total_constants] features[n_total_constants] = current.feature @@ -1300,69 +1672,123 @@ cdef class BestSparseSplitter(BaseSparseSplitter): f_i -= 1 features[f_i], features[f_j] = features[f_j], features[f_i] - # Evaluate all splits - self.criterion.reset() - p = start + self.criterion.init_missing(n_missing) + current.n_missing = n_missing - while p < end: - if p + 1 != end_negative: - p_next = p + 1 - else: - p_next = start_positive + if n_missing > 0: + directions_to_check = 2 + current.missing_direction = MISSING_DIR_RIGHT + else: + directions_to_check = 1 + # SELFNOTE: changed default direction from UNDEF to RIGHT + current.missing_direction = MISSING_DIR_RIGHT - while (p_next < end and - Xf[p_next] <= Xf[p] + FEATURE_THRESHOLD): - p = p_next + pos_missing_offset = 0 + self.criterion.reset() + + while directions_to_check > 0: + p = start + while p < end_available: if p + 1 != end_negative: p_next = p + 1 else: p_next = start_positive + while (p_next < end_available and + Xf[p_next] <= Xf[p] + FEATURE_THRESHOLD): + p = p_next + if p + 1 != end_negative: + p_next = p + 1 + else: + p_next = start_positive - # (p_next >= end) or (X[samples[p_next], current.feature] > - # X[samples[p], current.feature]) - p_prev = p - p = p_next - # (p >= end) or (X[samples[p], current.feature] > - # X[samples[p_prev], current.feature]) + # (p_next >= end) or (X[samples[p_next], current.feature] > + # X[samples[p], current.feature]) + p_prev =p + p = p_next + # (p >= end) or (X[samples[p], current.feature] > + # X[samples[p_prev], current.feature]) + + + if p < end_available: + current.pos = p + + # Reject if min_samples_leaf is not guaranteed + if ((((current.pos + pos_missing_offset) - start) < min_samples_leaf) or + ((end_available - (current.pos + pos_missing_offset)) < min_samples_leaf)): + continue - if p < end: - current.pos = p + self.criterion.update(current.pos) - # Reject if min_samples_leaf is not guaranteed - if (((current.pos - start) < min_samples_leaf) or - ((end - current.pos) < min_samples_leaf)): - continue + # Reject if min_weight_leaf is not satisfied + if ((self.criterion.weighted_n_left < min_weight_leaf) or + (self.criterion.weighted_n_right < min_weight_leaf)): + continue - self.criterion.update(current.pos) + current_proxy_improvement = self.criterion.proxy_impurity_improvement() - # Reject if min_weight_leaf is not satisfied - if ((self.criterion.weighted_n_left < min_weight_leaf) or - (self.criterion.weighted_n_right < min_weight_leaf)): - continue + if current_proxy_improvement > best_proxy_improvement: + best_proxy_improvement = current_proxy_improvement + current.threshold = (Xf[p_prev] + Xf[p]) / 2.0 + if current.threshold == Xf[p]: + current.threshold = Xf[p_prev] + best = current + directions_to_check -= 1 + if directions_to_check == 1: + # Send all missing alone to the right + self.criterion.reverse_reset() + self.criterion.move_missing(direction=MISSING_DIR_RIGHT) current_proxy_improvement = self.criterion.proxy_impurity_improvement() - if current_proxy_improvement > best_proxy_improvement: + # Check stopping criterion and if sending all missing alone right is the best split + if ((n_missing >= min_samples_leaf) and + ((end_available - start) >= min_samples_leaf) and + (self.criterion.weighted_n_left >= min_weight_leaf) and + (self.criterion.weighted_n_right >= min_weight_leaf) and + (current_proxy_improvement > best_proxy_improvement)): best_proxy_improvement = current_proxy_improvement - current.threshold = (Xf[p_prev] + Xf[p]) / 2.0 - if current.threshold == Xf[p]: - current.threshold = Xf[p_prev] - + # If yes, retain the missing at right, and the available at the left + # and set the threshold to infinity, so all values, which are not missing, + # are sent to the left + current.threshold = INFINITY + current.pos = end_available best = current + # Now compute all the other splits with the missing kept at the left + pos_missing_offset = n_missing + self.criterion.reset() + self.criterion.move_missing(direction=MISSING_DIR_LEFT) + current.missing_direction = MISSING_DIR_LEFT + # Reorganize into samples[start:best.pos] + samples[best.pos:end] - if best.pos < end: - self.extract_nnz(best.feature, &end_negative, &start_positive, + n_missing = best.n_missing + if (best.pos < end_available or + (best.pos == end_available and + best.missing_direction == MISSING_DIR_RIGHT)): + # This will also move the MVs to the start/end + self.extract_nnz(best.feature, best.missing_direction, + &start_available, &end_available, + &end_negative, &start_positive, &is_samples_sorted) - self._partition(best.threshold, end_negative, start_positive, + self._partition(best.threshold, + start_available, end_available, + end_negative, start_positive, best.pos) + # as the MVs are moved to the right partition, forget they exist + # and split at (best.pos + n_missing) if MISSING_DIR_LEFT or + # at best.pos if MISSING_DIR_RIGHT + self.criterion.init_missing(0) self.criterion.reset() - self.criterion.update(best.pos) + if best.missing_direction == MISSING_DIR_LEFT: + pos_missing_offset = n_missing + else: + pos_missing_offset = 0 + self.criterion.update(best.pos + pos_missing_offset) best.improvement = self.criterion.impurity_improvement(impurity) self.criterion.children_impurity(&best.impurity_left, &best.impurity_right) @@ -1391,7 +1817,8 @@ cdef class RandomSparseSplitter(BaseSparseSplitter): self.min_samples_leaf, self.min_weight_leaf, self.random_state, - self.presort), self.__getstate__()) + self.presort, + self.allow_missing), self.__getstate__()) cdef void node_split(self, double impurity, SplitRecord* split, SIZE_t* n_constant_features) nogil: @@ -1438,6 +1865,13 @@ cdef class RandomSparseSplitter(BaseSparseSplitter): cdef SIZE_t n_total_constants = n_known_constants cdef SIZE_t partition_end + # To handle missing + cdef SIZE_t missing_direction, end_available = end, start_available + cdef SIZE_t q, miss_mask_feat_offset, directions_to_check + cdef SIZE_t n_missing = 0 + cdef BOOL_t* missing_mask = self.missing_mask + + cdef SIZE_t pos_missing_offset = 0 # missing offset based on missing direction to help compute the stopping criteria correctly cdef DTYPE_t min_feature_value cdef DTYPE_t max_feature_value @@ -1495,10 +1929,19 @@ cdef class RandomSparseSplitter(BaseSparseSplitter): # f_j in the interval [n_total_constants, f_i[ current.feature = features[f_j] - - self.extract_nnz(current.feature, + p = start + n_missing = 0 + start_available = start + end_available = end + missing_direction = MISSING_DIR_RIGHT + + # Move the missing value to the end and repartition the + # samples[start:end_available] into Negative...Zeros...Positive + self.extract_nnz(current.feature, missing_direction, + &start_available, &end_available, &end_negative, &start_positive, &is_samples_sorted) + n_missing = end - end_available # Add one or two zeros in Xf, if there is any if end_negative < start_positive: @@ -1521,8 +1964,8 @@ cdef class RandomSparseSplitter(BaseSparseSplitter): elif current_feature_value > max_feature_value: max_feature_value = current_feature_value - # Update min, max given Xf[start_positive:end] - for p in range(start_positive, end): + # Update min, max given Xf[start_positive:end_available] + for p in range(start_positive, end_available): current_feature_value = Xf[p] if current_feature_value < min_feature_value: @@ -1530,7 +1973,10 @@ cdef class RandomSparseSplitter(BaseSparseSplitter): elif current_feature_value > max_feature_value: max_feature_value = current_feature_value - if max_feature_value <= min_feature_value + FEATURE_THRESHOLD: + if ((n_missing == end - start) or + ((max_feature_value <= min_feature_value + FEATURE_THRESHOLD) and + (n_missing == 0))): + features[f_j] = features[n_total_constants] features[n_total_constants] = current.feature @@ -1549,22 +1995,47 @@ cdef class RandomSparseSplitter(BaseSparseSplitter): if current.threshold == max_feature_value: current.threshold = min_feature_value - # Partition - current.pos = self._partition(current.threshold, - end_negative, - start_positive, - start_positive + - (Xf[start_positive] == 0.)) - - # Reject if min_samples_leaf is not guaranteed - if (((current.pos - start) < min_samples_leaf) or - ((end - current.pos) < min_samples_leaf)): - continue + current.n_missing = n_missing + if n_missing > 0: + if current.threshold == min_feature_value: + # Move missing alone to the right + current.threshold = INFINITY + current.missing_direction = MISSING_DIR_RIGHT + current.pos = end_available + else: + current.missing_direction = rand_int(0, 2, random_state) + else: + # SELFNOTE default dir. changed from UNDEF --> RIGHT + current.missing_direction = MISSING_DIR_RIGHT + + # We need to move the missing to the correct partition + self.extract_nnz(current.feature, current.missing_direction, + &start_available, &end_available, + &end_negative, &start_positive, + &is_samples_sorted) + + # Partition the available values + p = self._partition(current.threshold, start_available, + end_available, end_negative, + start_positive, start_positive + + (Xf[start_positive] == 0.)) + + if current.threshold != INFINITY: + current.pos = p + + # We have already sent all the missing to the correct partition + self.criterion.init_missing(0) # Evaluate split self.criterion.reset() + # This is corrected by _partition to include missing values self.criterion.update(current.pos) + # Reject if min_samples_leaf is not guaranteed + if ((((current.pos) - start) < min_samples_leaf) or + ((end - (current.pos)) < min_samples_leaf)): + continue + # Reject if min_weight_leaf is not satisfied if ((self.criterion.weighted_n_left < min_weight_leaf) or (self.criterion.weighted_n_right < min_weight_leaf)): @@ -1581,20 +2052,27 @@ cdef class RandomSparseSplitter(BaseSparseSplitter): best = current # Reorganize into samples[start:best.pos] + samples[best.pos:end] - if best.pos < end: - if current.feature != best.feature: - self.extract_nnz(best.feature, &end_negative, &start_positive, - &is_samples_sorted) - - self._partition(best.threshold, end_negative, start_positive, - best.pos) + if (best.pos < end_available or + (best.pos == end_available and + best.missing_direction == MISSING_DIR_RIGHT)): + + # This will recount the n_missing and move the missing to correct + # partition + extract the non zeros into -ves, 0s, +ves + self.extract_nnz(best.feature, best.missing_direction, + &start_available, &end_available, + &end_negative, &start_positive, + &is_samples_sorted) + self._partition(best.threshold, + start_available, end_available, + end_negative, start_positive, + best.pos) + self.criterion.init_missing(0) self.criterion.reset() self.criterion.update(best.pos) best.improvement = self.criterion.impurity_improvement(impurity) self.criterion.children_impurity(&best.impurity_left, &best.impurity_right) - # Respect invariant for constant features: the original order of # element in features[:n_known_constants] must be preserved for sibling # and child nodes diff --git a/sklearn/tree/_tree.pxd b/sklearn/tree/_tree.pxd index dbf0545b1e1d5..83b472cdeeb9a 100644 --- a/sklearn/tree/_tree.pxd +++ b/sklearn/tree/_tree.pxd @@ -5,6 +5,7 @@ # Arnaud Joly # Jacob Schreiber # Nelson Liu +# Raghav R V # # License: BSD 3 clause @@ -18,6 +19,7 @@ ctypedef np.npy_float64 DOUBLE_t # Type of y, sample_weight ctypedef np.npy_intp SIZE_t # Type for indices and counters ctypedef np.npy_int32 INT32_t # Signed 32 bit integer ctypedef np.npy_uint32 UINT32_t # Unsigned 32 bit integer +ctypedef np.npy_uint8 BOOL_t # 8 bit boolean from ._splitter cimport Splitter from ._splitter cimport SplitRecord @@ -32,6 +34,7 @@ cdef struct Node: DOUBLE_t impurity # Impurity of the node (i.e., the value of the criterion) SIZE_t n_node_samples # Number of samples at the node DOUBLE_t weighted_n_node_samples # Weighted number of samples at the node + SIZE_t missing_direction # To specify the path of missing samples in a node cdef class Tree: @@ -44,6 +47,8 @@ cdef class Tree: cdef SIZE_t* n_classes # Number of classes in y[:, k] cdef public SIZE_t n_outputs # Number of outputs in y cdef public SIZE_t max_n_classes # max(n_classes) + cdef bint allow_missing # Whether to allow missing values or not + cdef object missing_values # The missing value placeholder # Inner structures: values are stored separately from node structure, # since size is determined at runtime. @@ -58,22 +63,23 @@ cdef class Tree: cdef SIZE_t _add_node(self, SIZE_t parent, bint is_left, bint is_leaf, SIZE_t feature, double threshold, double impurity, SIZE_t n_node_samples, - double weighted_n_samples) nogil + double weighted_n_samples, + SIZE_t missing_direction) nogil cdef void _resize(self, SIZE_t capacity) except * cdef int _resize_c(self, SIZE_t capacity=*) nogil cdef np.ndarray _get_value_ndarray(self) cdef np.ndarray _get_node_ndarray(self) - cpdef np.ndarray predict(self, object X) + cpdef np.ndarray predict(self, object X, np.ndarray missing_mask=*) - cpdef np.ndarray apply(self, object X) - cdef np.ndarray _apply_dense(self, object X) - cdef np.ndarray _apply_sparse_csr(self, object X) + cpdef np.ndarray apply(self, object X, np.ndarray missing_mask=*) + cdef np.ndarray _apply_dense(self, object X, np.ndarray missing_mask) + cdef np.ndarray _apply_sparse_csr(self, object X, np.ndarray missing_mask) - cpdef object decision_path(self, object X) - cdef object _decision_path_dense(self, object X) - cdef object _decision_path_sparse_csr(self, object X) + cpdef object decision_path(self, object X, np.ndarray missing_mask=*) + cdef object _decision_path_dense(self, object X, np.ndarray missing_mask) + cdef object _decision_path_sparse_csr(self, object X, np.ndarray missing_mask) cpdef compute_feature_importances(self, normalize=*) @@ -97,8 +103,10 @@ cdef class TreeBuilder: cdef double min_weight_leaf # Minimum weight in a leaf cdef SIZE_t max_depth # Maximal tree depth cdef double min_impurity_split # Impurity threshold for early stopping + cdef SIZE_t allow_missing # Whether to allow missing values or not cpdef build(self, Tree tree, object X, np.ndarray y, np.ndarray sample_weight=*, - np.ndarray X_idx_sorted=*) + np.ndarray X_idx_sorted=*, + np.ndarray missing_mask=*) cdef _check_input(self, object X, np.ndarray y, np.ndarray sample_weight) diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index f8632ab1640d8..7bf5d7568494b 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -13,6 +13,7 @@ # Fares Hedayati # Jacob Schreiber # Nelson Liu +# Raghav R V # # License: BSD 3 clause @@ -37,6 +38,8 @@ from ._utils cimport PriorityHeap from ._utils cimport PriorityHeapRecord from ._utils cimport safe_realloc from ._utils cimport sizet_ptr_to_ndarray +from ._utils cimport rand_int +from ._utils cimport RAND_R_MAX cdef extern from "numpy/arrayobject.h": object PyArray_NewFromDescr(object subtype, np.dtype descr, @@ -65,12 +68,17 @@ cdef SIZE_t _TREE_LEAF = TREE_LEAF cdef SIZE_t _TREE_UNDEFINED = TREE_UNDEFINED cdef SIZE_t INITIAL_STACK_SIZE = 10 +# Constants to decide the direction of missing values in a node +cdef SIZE_t MISSING_DIR_LEFT = 0 +cdef SIZE_t MISSING_DIR_RIGHT = 1 + # Repeat struct definition for numpy NODE_DTYPE = np.dtype({ 'names': ['left_child', 'right_child', 'feature', 'threshold', 'impurity', - 'n_node_samples', 'weighted_n_node_samples'], + 'n_node_samples', 'weighted_n_node_samples', + 'missing_direction'], 'formats': [np.intp, np.intp, np.intp, np.float64, np.float64, np.intp, - np.float64], + np.float64, np.intp], 'offsets': [ &( NULL).left_child, &( NULL).right_child, @@ -78,7 +86,8 @@ NODE_DTYPE = np.dtype({ &( NULL).threshold, &( NULL).impurity, &( NULL).n_node_samples, - &( NULL).weighted_n_node_samples + &( NULL).weighted_n_node_samples, + &( NULL).missing_direction, ] }) @@ -91,7 +100,8 @@ cdef class TreeBuilder: cpdef build(self, Tree tree, object X, np.ndarray y, np.ndarray sample_weight=None, - np.ndarray X_idx_sorted=None): + np.ndarray X_idx_sorted=None, + np.ndarray missing_mask=None): """Build a decision tree from the training set (X, y).""" pass @@ -141,7 +151,8 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): cpdef build(self, Tree tree, object X, np.ndarray y, np.ndarray sample_weight=None, - np.ndarray X_idx_sorted=None): + np.ndarray X_idx_sorted=None, + np.ndarray missing_mask=None): """Build a decision tree from the training set (X, y).""" # check input @@ -170,7 +181,7 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): cdef double min_impurity_split = self.min_impurity_split # Recursive partition (without actual recursion) - splitter.init(X, y, sample_weight_ptr, X_idx_sorted) + splitter.init(X, y, sample_weight_ptr, X_idx_sorted, missing_mask) cdef SIZE_t start cdef SIZE_t end @@ -234,7 +245,8 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): node_id = tree._add_node(parent, is_left, is_leaf, split.feature, split.threshold, impurity, n_node_samples, - weighted_n_node_samples) + weighted_n_node_samples, + split.missing_direction) if node_id == (-1): rc = -1 @@ -302,7 +314,8 @@ cdef class BestFirstTreeBuilder(TreeBuilder): cpdef build(self, Tree tree, object X, np.ndarray y, np.ndarray sample_weight=None, - np.ndarray X_idx_sorted=None): + np.ndarray X_idx_sorted=None, + np.ndarray missing_mask=None): """Build a decision tree from the training set (X, y).""" # check input @@ -320,7 +333,7 @@ cdef class BestFirstTreeBuilder(TreeBuilder): cdef SIZE_t min_samples_split = self.min_samples_split # Recursive partition (without actual recursion) - splitter.init(X, y, sample_weight_ptr, X_idx_sorted) + splitter.init(X, y, sample_weight_ptr, X_idx_sorted, missing_mask) cdef PriorityHeap frontier = PriorityHeap(INITIAL_STACK_SIZE) cdef PriorityHeapRecord record @@ -362,6 +375,7 @@ cdef class BestFirstTreeBuilder(TreeBuilder): node.right_child = _TREE_LEAF node.feature = _TREE_UNDEFINED node.threshold = _TREE_UNDEFINED + node.missing_direction = _TREE_UNDEFINED else: # Node is expandable @@ -451,7 +465,7 @@ cdef class BestFirstTreeBuilder(TreeBuilder): else _TREE_UNDEFINED, is_left, is_leaf, split.feature, split.threshold, impurity, n_node_samples, - weighted_n_node_samples) + weighted_n_node_samples, split.missing_direction) if node_id == (-1): return -1 @@ -577,17 +591,23 @@ cdef class Tree: def __get__(self): return self._get_node_ndarray()['weighted_n_node_samples'][:self.node_count] + property missing_direction: + def __get__(self): + return self._get_node_ndarray()['missing_direction'][:self.node_count] + property value: def __get__(self): return self._get_value_ndarray()[:self.node_count] def __cinit__(self, int n_features, np.ndarray[SIZE_t, ndim=1] n_classes, - int n_outputs): + int n_outputs, bint allow_missing, object missing_values): """Constructor.""" # Input/Output layout self.n_features = n_features self.n_outputs = n_outputs self.n_classes = NULL + self.allow_missing = allow_missing + self.missing_values = missing_values safe_realloc(&self.n_classes, n_outputs) self.max_n_classes = np.max(n_classes) @@ -613,9 +633,13 @@ cdef class Tree: def __reduce__(self): """Reduce re-implementation, for pickling.""" - return (Tree, (self.n_features, - sizet_ptr_to_ndarray(self.n_classes, self.n_outputs), - self.n_outputs), self.__getstate__()) + return (Tree, + (self.n_features, + sizet_ptr_to_ndarray(self.n_classes, self.n_outputs), + self.n_outputs, + self.allow_missing, + self.missing_values), + self.__getstate__()) def __getstate__(self): """Getstate re-implementation, for pickling.""" @@ -702,7 +726,8 @@ cdef class Tree: cdef SIZE_t _add_node(self, SIZE_t parent, bint is_left, bint is_leaf, SIZE_t feature, double threshold, double impurity, - SIZE_t n_node_samples, double weighted_n_node_samples) nogil: + SIZE_t n_node_samples, double weighted_n_node_samples, + SIZE_t missing_direction) nogil: """Add a node to the tree. The new node registers itself as the child of its parent. @@ -731,35 +756,40 @@ cdef class Tree: node.right_child = _TREE_LEAF node.feature = _TREE_UNDEFINED node.threshold = _TREE_UNDEFINED + node.missing_direction = _TREE_UNDEFINED else: # left_child and right_child will be set later node.feature = feature node.threshold = threshold + node.missing_direction = missing_direction self.node_count += 1 return node_id - cpdef np.ndarray predict(self, object X): + cpdef np.ndarray predict(self, object X, np.ndarray missing_mask=None): """Predict target for X.""" - out = self._get_value_ndarray().take(self.apply(X), axis=0, - mode='clip') + out = self._get_value_ndarray().take(self.apply( + X, missing_mask=missing_mask), + axis=0, mode='clip') if self.n_outputs == 1: out = out.reshape(X.shape[0], self.max_n_classes) return out - cpdef np.ndarray apply(self, object X): + cpdef np.ndarray apply(self, object X, np.ndarray missing_mask=None): """Finds the terminal region (=leaf node) for each sample in X.""" if issparse(X): - return self._apply_sparse_csr(X) + return self._apply_sparse_csr(X, missing_mask=missing_mask) else: - return self._apply_dense(X) + return self._apply_dense(X, missing_mask=missing_mask) - cdef inline np.ndarray _apply_dense(self, object X): + cdef inline np.ndarray _apply_dense(self, object X, + np.ndarray missing_mask): """Finds the terminal region (=leaf node) for each sample in X.""" # Check input + # SELFNOTE Why?? Don't we already check at tree.py if not isinstance(X, np.ndarray): raise ValueError("X should be in np.ndarray format, got %s" % type(X)) @@ -782,13 +812,30 @@ cdef class Tree: cdef Node* node = NULL cdef SIZE_t i = 0 + # To handle missing values + cdef BOOL_t* missing_mask_ptr = NULL + cdef SIZE_t missing_mask_fx_stride + cdef bint allow_missing = self.allow_missing + + # If allow_missing is True, the missing_mask should be specified. + # Either at the higher level + if allow_missing: + # Extract the missing mask + missing_mask_ptr = missing_mask.data + missing_mask_fx_stride = missing_mask.strides[1] / missing_mask.itemsize + with nogil: for i in range(n_samples): node = self.nodes # While node not a leaf while node.left_child != _TREE_LEAF: # ... and node.right_child != _TREE_LEAF: - if X_ptr[X_sample_stride * i + + if allow_missing and missing_mask_ptr[i + missing_mask_fx_stride * node.feature] == 1: + if node.missing_direction == MISSING_DIR_RIGHT: + node = &self.nodes[node.right_child] + else: + node = &self.nodes[node.left_child] + elif X_ptr[X_sample_stride * i + X_fx_stride * node.feature] <= node.threshold: node = &self.nodes[node.left_child] else: @@ -798,7 +845,8 @@ cdef class Tree: return out - cdef inline np.ndarray _apply_sparse_csr(self, object X): + cdef inline np.ndarray _apply_sparse_csr(self, object X, + np.ndarray missing_mask): """Finds the terminal region (=leaf node) for each sample in sparse X. """ # Check input @@ -841,26 +889,58 @@ cdef class Tree: safe_realloc(&X_sample, n_features) safe_realloc(&feature_to_sample, n_features) + # To handle missing values + cdef BOOL_t* missing_mask_ptr = NULL + cdef SIZE_t missing_mask_fx_stride + cdef bint allow_missing = self.allow_missing + cdef bint is_missing = False + cdef bint is_missing_value_zero = False + + cdef BOOL_t* X_missing = NULL + safe_realloc(&X_missing, n_features * sizeof(BOOL_t)) + for i in range(n_features): + X_missing[i] = False + + if allow_missing: + missing_mask_ptr = missing_mask.data + is_missing_value_zero = (isinstance(self.missing_values, int) and + (self.missing_values == 0)) + with nogil: memset(feature_to_sample, -1, n_features * sizeof(SIZE_t)) for i in range(n_samples): node = self.nodes - for k in range(X_indptr[i], X_indptr[i + 1]): - feature_to_sample[X_indices[k]] = i - X_sample[X_indices[k]] = X_data[k] + if allow_missing and not is_missing_value_zero: + for k in range(X_indptr[i], X_indptr[i + 1]): + feature_to_sample[X_indices[k]] = i + X_sample[X_indices[k]] = X_data[k] + X_missing[X_indices[k]] = missing_mask_ptr[k] + else: + for k in range(X_indptr[i], X_indptr[i + 1]): + feature_to_sample[X_indices[k]] = i + X_sample[X_indices[k]] = X_data[k] # While node not a leaf while node.left_child != _TREE_LEAF: # ... and node.right_child != _TREE_LEAF: if feature_to_sample[node.feature] == i: feature_value = X_sample[node.feature] - + is_missing = X_missing[node.feature] + elif is_missing_value_zero: + is_missing = True else: feature_value = 0. + is_missing = False - if feature_value <= node.threshold: + if is_missing: + if node.missing_direction == MISSING_DIR_RIGHT: + node = &self.nodes[node.right_child] + else: + node = &self.nodes[node.left_child] + + elif feature_value <= node.threshold: node = &self.nodes[node.left_child] else: node = &self.nodes[node.right_child] @@ -869,18 +949,20 @@ cdef class Tree: # Free auxiliary arrays free(X_sample) + free(X_missing) free(feature_to_sample) return out - cpdef object decision_path(self, object X): + cpdef object decision_path(self, object X, np.ndarray missing_mask=None): """Finds the decision path (=node) for each sample in X.""" if issparse(X): - return self._decision_path_sparse_csr(X) + return self._decision_path_sparse_csr(X, missing_mask) else: - return self._decision_path_dense(X) + return self._decision_path_dense(X, missing_mask) - cdef inline object _decision_path_dense(self, object X): + cdef inline object _decision_path_dense(self, object X, + np.ndarray missing_mask): """Finds the decision path (=node) for each sample in X.""" # Check input @@ -911,6 +993,17 @@ cdef class Tree: cdef Node* node = NULL cdef SIZE_t i = 0 + # To handle missing values + cdef BOOL_t* missing_mask_ptr = NULL + cdef SIZE_t missing_mask_fx_stride + + cdef bint allow_missing = self.allow_missing + + if allow_missing: + # Extract the missing mask + missing_mask_ptr = missing_mask.data + missing_mask_fx_stride = missing_mask.strides[1] / missing_mask.itemsize + with nogil: for i in range(n_samples): node = self.nodes @@ -922,7 +1015,13 @@ cdef class Tree: indices_ptr[indptr_ptr[i + 1]] = (node - self.nodes) indptr_ptr[i + 1] += 1 - if X_ptr[X_sample_stride * i + + if allow_missing and missing_mask_ptr[i + missing_mask_fx_stride * node.feature]: + if node.missing_direction == MISSING_DIR_RIGHT: + node = &self.nodes[node.right_child] + else: + node = &self.nodes[node.left_child] + + elif X_ptr[X_sample_stride * i + X_fx_stride * node.feature] <= node.threshold: node = &self.nodes[node.left_child] else: @@ -940,7 +1039,8 @@ cdef class Tree: return out - cdef inline object _decision_path_sparse_csr(self, object X): + cdef inline object _decision_path_sparse_csr(self, object X, + np.ndarray missing_mask): """Finds the decision path (=node) for each sample in X.""" # Check input @@ -987,6 +1087,25 @@ cdef class Tree: safe_realloc(&X_sample, n_features) safe_realloc(&feature_to_sample, n_features) + # To handle missing values + cdef BOOL_t* missing_mask_ptr = NULL + cdef SIZE_t missing_mask_fx_stride + cdef bint allow_missing = self.allow_missing + cdef bint is_missing = False + cdef bint is_missing_value_zero = False + + # Auxiliary data-structure to record if the X is missing or not + cdef BOOL_t* X_missing = NULL + safe_realloc(&X_missing, n_features * sizeof(BOOL_t)) + for i in range(n_features): + X_missing[i] = False + + if allow_missing: + # NOTE the missing mask is only for the X.data not the whole X + missing_mask_ptr = missing_mask.data + is_missing_value_zero = (isinstance(self.missing_values, int) and + self.missing_values == 0) + with nogil: memset(feature_to_sample, -1, n_features * sizeof(SIZE_t)) @@ -994,9 +1113,15 @@ cdef class Tree: node = self.nodes indptr_ptr[i + 1] = indptr_ptr[i] - for k in range(X_indptr[i], X_indptr[i + 1]): - feature_to_sample[X_indices[k]] = i - X_sample[X_indices[k]] = X_data[k] + if allow_missing and not is_missing_value_zero: + for k in range(X_indptr[i], X_indptr[i + 1]): + feature_to_sample[X_indices[k]] = i + X_sample[X_indices[k]] = X_data[k] + X_missing[X_indices[k]] = missing_mask_ptr[k] + else: + for k in range(X_indptr[i], X_indptr[i + 1]): + feature_to_sample[X_indices[k]] = i + X_sample[X_indices[k]] = X_data[k] # While node not a leaf while node.left_child != _TREE_LEAF: @@ -1007,11 +1132,20 @@ cdef class Tree: if feature_to_sample[node.feature] == i: feature_value = X_sample[node.feature] - + is_missing = X_missing[node.feature] + elif is_missing_value_zero: + is_missing = True else: feature_value = 0. + is_missing = False + + if is_missing: + if node.missing_direction == MISSING_DIR_RIGHT: + node = &self.nodes[node.right_child] + else: + node = &self.nodes[node.left_child] - if feature_value <= node.threshold: + elif feature_value <= node.threshold: node = &self.nodes[node.left_child] else: node = &self.nodes[node.right_child] @@ -1022,6 +1156,7 @@ cdef class Tree: # Free auxiliary arrays free(X_sample) + free(X_missing) free(feature_to_sample) indices = indices[:indptr[n_samples]] diff --git a/sklearn/tree/export.py b/sklearn/tree/export.py index 43e8aa11b9611..ff12ff45935a2 100644 --- a/sklearn/tree/export.py +++ b/sklearn/tree/export.py @@ -71,7 +71,8 @@ def export_graphviz(decision_tree, out_file=SENTINEL, max_depth=None, feature_names=None, class_names=None, label='all', filled=False, leaves_parallel=False, impurity=True, node_ids=False, proportion=False, rotate=False, - rounded=False, special_characters=False): + rounded=False, special_characters=False, + show_missing_dir=False): """Export a decision tree in DOT format. This function generates a GraphViz representation of the decision tree, @@ -226,6 +227,18 @@ def node_to_str(tree, node_id, criterion): round(tree.threshold[node_id], 4), characters[4]) + if show_missing_dir: + missing_direction = tree.missing_direction[node_id] + if missing_direction == 0: + missing_direction = "Left" + elif missing_direction == 1: + missing_direction = "Right" + else: + missing_direction = "Undefined" + + node_string += ('Missing dir. - %s%s' + % (missing_direction, characters[4])) + # Write impurity if impurity: if isinstance(criterion, _criterion.FriedmanMSE): diff --git a/sklearn/tree/tests/test_tree.py b/sklearn/tree/tests/test_tree.py index c3e8e795b32f0..bf20bb0273d9a 100644 --- a/sklearn/tree/tests/test_tree.py +++ b/sklearn/tree/tests/test_tree.py @@ -28,11 +28,13 @@ from sklearn.utils.testing import assert_less from sklearn.utils.testing import assert_less_equal from sklearn.utils.testing import assert_true +from sklearn.utils.testing import assert_false from sklearn.utils.testing import assert_warns from sklearn.utils.testing import raises from sklearn.utils.testing import ignore_warnings from sklearn.utils.validation import check_random_state +from sklearn.utils.validation import assert_all_finite from sklearn.exceptions import NotFittedError @@ -49,6 +51,11 @@ from sklearn.utils import compute_sample_weight + +MISSING_DIR_LEFT = 0 +MISSING_DIR_RIGHT = 1 +TREE_UNDEFINED = -2 + CLF_CRITERIONS = ("gini", "entropy") REG_CRITERIONS = ("mse", "mae", "friedman_mse") @@ -56,6 +63,11 @@ "DecisionTreeClassifier": DecisionTreeClassifier, "Presort-DecisionTreeClassifier": partial(DecisionTreeClassifier, presort=True), + "MV-DecisionTreeClassifier": partial(DecisionTreeClassifier, + missing_values="NaN"), + "MV-Presort-DecisionTreeClassifier": partial(DecisionTreeClassifier, + missing_values="NaN", + presort=True), "ExtraTreeClassifier": ExtraTreeClassifier, } @@ -602,7 +614,6 @@ def test_min_samples_split(): "Failed with {0}".format(name)) - def test_min_samples_leaf(): # Test if leaves contain more than leaf_count training examples X = np.asfortranarray(iris.data.astype(tree._tree.DTYPE)) @@ -617,8 +628,10 @@ def test_min_samples_leaf(): est = TreeEstimator(min_samples_leaf=5, max_leaf_nodes=max_leaf_nodes, random_state=0) + # Get the missing mask + missing_mask = est._validate_missing_mask(X) est.fit(X, y) - out = est.tree_.apply(X) + out = est.tree_.apply(X, missing_mask=missing_mask) node_counts = np.bincount(out) # drop inner nodes leaf_count = node_counts[node_counts != 0] @@ -630,7 +643,7 @@ def test_min_samples_leaf(): max_leaf_nodes=max_leaf_nodes, random_state=0) est.fit(X, y) - out = est.tree_.apply(X) + out = est.tree_.apply(X, missing_mask=missing_mask) node_counts = np.bincount(out) # drop inner nodes leaf_count = node_counts[node_counts != 0] @@ -661,10 +674,13 @@ def check_min_weight_fraction_leaf(name, datasets, sparse=False): est.fit(X, y, sample_weight=weights) if sparse: - out = est.tree_.apply(X.tocsr()) + # Get the missing mask + missing_mask = est._validate_missing_mask(X.tocsr()) + out = est.tree_.apply(X.tocsr(), missing_mask=missing_mask) else: - out = est.tree_.apply(X) + missing_mask = est._validate_missing_mask(X) + out = est.tree_.apply(X, missing_mask=missing_mask) node_weights = np.bincount(out, weights=weights) # drop inner nodes @@ -686,9 +702,13 @@ def check_min_weight_fraction_leaf(name, datasets, sparse=False): est.fit(X, y) if sparse: - out = est.tree_.apply(X.tocsr()) + # Get the missing mask + missing_mask = est._validate_missing_mask(X.tocsr()) + out = est.tree_.apply(X.tocsr(), missing_mask=missing_mask) + else: - out = est.tree_.apply(X) + missing_mask = est._validate_missing_mask(X) + out = est.tree_.apply(X, missing_mask=missing_mask) node_weights = np.bincount(out) # drop inner nodes @@ -732,9 +752,13 @@ def check_min_weight_fraction_leaf_with_min_samples_leaf(name, datasets, est.fit(X, y) if sparse: - out = est.tree_.apply(X.tocsr()) + # Get the missing mask + missing_mask = est._validate_missing_mask(X.tocsr()) + out = est.tree_.apply(X.tocsr(), missing_mask=missing_mask) + else: - out = est.tree_.apply(X) + missing_mask = est._validate_missing_mask(X) + out = est.tree_.apply(X, missing_mask=missing_mask) node_weights = np.bincount(out) # drop inner nodes @@ -757,9 +781,13 @@ def check_min_weight_fraction_leaf_with_min_samples_leaf(name, datasets, est.fit(X, y) if sparse: - out = est.tree_.apply(X.tocsr()) + # Get the missing mask + missing_mask = est._validate_missing_mask(X.tocsr()) + out = est.tree_.apply(X.tocsr(), missing_mask=missing_mask) + else: - out = est.tree_.apply(X) + missing_mask = est._validate_missing_mask(X) + out = est.tree_.apply(X, missing_mask=missing_mask) node_weights = np.bincount(out) # drop inner nodes @@ -1430,16 +1458,29 @@ def check_explicit_sparse_zeros(tree, max_depth=3, Xs = (X_test, X_sparse_test) for X1, X2 in product(Xs, Xs): - assert_array_almost_equal(s.tree_.apply(X1), d.tree_.apply(X2)) + # Get the missing masks + missing_mask1 = s._validate_missing_mask(X1) + missing_mask2 = d._validate_missing_mask(X2) + assert_array_almost_equal( + s.tree_.apply(X1, missing_mask=missing_mask1), + d.tree_.apply(X2, missing_mask=missing_mask2)) + assert_array_almost_equal(s.apply(X1), d.apply(X2)) - assert_array_almost_equal(s.apply(X1), s.tree_.apply(X1)) - assert_array_almost_equal(s.tree_.decision_path(X1).toarray(), - d.tree_.decision_path(X2).toarray()) + assert_array_almost_equal( + s.apply(X1), + s.tree_.apply(X1, missing_mask=missing_mask1)) + + assert_array_almost_equal( + s.tree_.decision_path(X1, missing_mask=missing_mask1).toarray(), + d.tree_.decision_path(X2, missing_mask=missing_mask2).toarray()) + assert_array_almost_equal(s.decision_path(X1).toarray(), d.decision_path(X2).toarray()) - assert_array_almost_equal(s.decision_path(X1).toarray(), - s.tree_.decision_path(X1).toarray()) + + assert_array_almost_equal( + s.decision_path(X1).toarray(), + s.tree_.decision_path(X1, missing_mask=missing_mask1).toarray()) assert_array_almost_equal(s.predict(X1), d.predict(X2)) @@ -1507,18 +1548,24 @@ def check_public_apply(name): X_small32 = X_small.astype(tree._tree.DTYPE) est = ALL_TREES[name]() + + # Get the missing mask + missing_mask = est._validate_missing_mask(X_small32) est.fit(X_small, y_small) assert_array_equal(est.apply(X_small), - est.tree_.apply(X_small32)) + est.tree_.apply(X_small32, missing_mask=missing_mask)) def check_public_apply_sparse(name): X_small32 = csr_matrix(X_small.astype(tree._tree.DTYPE)) est = ALL_TREES[name]() + + # Get the missing mask + missing_mask = est._validate_missing_mask(X_small32) est.fit(X_small, y_small) assert_array_equal(est.apply(X_small), - est.tree_.apply(X_small32)) + est.tree_.apply(X_small32, missing_mask=missing_mask)) def test_public_apply(): @@ -1624,19 +1671,260 @@ def test_criterion_copy(): def _pickle_copy(obj): return pickle.loads(pickle.dumps(obj)) + for copy_func in [copy.copy, copy.deepcopy, _pickle_copy]: - for _, typename in CRITERIA_CLF.items(): - criteria = typename(n_outputs, n_classes) - result = copy_func(criteria).__reduce__() - typename_, (n_outputs_, n_classes_), _ = result - assert_equal(typename, typename_) - assert_equal(n_outputs, n_outputs_) - assert_array_equal(n_classes, n_classes_) - - for _, typename in CRITERIA_REG.items(): - criteria = typename(n_outputs, n_samples) - result = copy_func(criteria).__reduce__() - typename_, (n_outputs_, n_samples_), _ = result - assert_equal(typename, typename_) - assert_equal(n_outputs, n_outputs_) - assert_equal(n_samples, n_samples_) + for allow_missing in (True, False): + for _, typename in CRITERIA_CLF.items(): + criteria = typename(n_outputs, n_classes, allow_missing) + result = copy_func(criteria).__reduce__() + typename_, (n_outputs_, n_classes_, allow_missing_), _ = result + assert_equal(typename, typename_) + assert_equal(n_outputs, n_outputs_) + assert_array_equal(n_classes, n_classes_) + assert_equal(allow_missing, allow_missing_) + + for _, typename in CRITERIA_REG.items(): + criteria = typename(n_outputs, n_samples, allow_missing) + result = copy_func(criteria).__reduce__() + typename_, (n_outputs_, n_samples_, allow_missing_), _ = result + assert_equal(typename, typename_) + assert_equal(n_outputs, n_outputs_) + assert_equal(n_samples, n_samples_) + assert_equal(allow_missing, allow_missing_) + + +def test_tree_missing_value_handling_corner_cases_best_splitter(): + # All the missing values should be sent to a separate child in one of the + # nodes + X_de = np.array([[np.nan], [np.nan], [np.nan], [np.nan], + [0], [1], [2], [3], [4], [5], + [10], [11], [12], [13], [15]]) + X_sp = coo_matrix(X_de) + y = np.array([1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3]) + + # Test both the dense and sparse splitters + for X in (X_de, X_sp): + dtc = DecisionTreeClassifier(missing_values="NaN", + random_state=42).fit(X, y) + + # The 2nd partition of this tree splits the missing values to one side, + # for the given random_state + assert_equal(dtc.tree_.threshold[2], np.inf) + assert_equal(dtc.tree_.missing_direction[2], + MISSING_DIR_RIGHT) + assert_equal(dtc.tree_.missing_direction[0], + MISSING_DIR_RIGHT) + # The leaf should have missing direction undefined + assert_equal(dtc.tree_.missing_direction[1], TREE_UNDEFINED) + # assert_all_finite cannot be used as one node will have a inf thres. + assert_false(np.any(np.isnan(dtc.tree_.threshold))) + assert_array_equal(dtc.predict(X), y) + + # The missing should be sent along with available to left child + X_de = np.array([[np.nan], [np.nan], [np.nan], [np.nan], + [0], [1], [2], [3], [4], [5]]) + X_sp = coo_matrix(X_de) + y = np.array([1, 1, 1, 1, 1, 2, 2, 2, 2, 2]) + + for X in (X_de, X_sp): + dtc = DecisionTreeClassifier(missing_values="NaN", + random_state=42).fit(X, y) + + assert_equal(dtc.tree_.threshold[0], 0.5) + # Send the missing along with the sample [0,] to the left + assert_equal(dtc.tree_.missing_direction[0], + MISSING_DIR_LEFT) + # The leaf should have missing direction undefined + assert_equal(dtc.tree_.missing_direction[1], TREE_UNDEFINED) + assert_equal(dtc.tree_.missing_direction[2], TREE_UNDEFINED) + assert_all_finite(dtc.tree_.threshold) + assert_array_equal(dtc.predict(X), y) + + # The missing should be sent along with available to right child + y = np.array([2, 2, 2, 2, 1, 1, 1, 1, 2, 2]) + + for X in (X_de, X_sp): + dtc = DecisionTreeClassifier(missing_values="NaN", + random_state=42).fit(X, y) + + assert_equal(dtc.tree_.threshold[0], 3.5) + # Send the missing along with the samples [4,] and [5,] to the right + assert_equal(dtc.tree_.missing_direction[0], + MISSING_DIR_RIGHT) + # The leaf should have missing direction undefined + assert_equal(dtc.tree_.missing_direction[1], TREE_UNDEFINED) + assert_equal(dtc.tree_.missing_direction[2], TREE_UNDEFINED) + assert_all_finite(dtc.tree_.threshold) + assert_array_equal(dtc.predict(X), y) + + # When no missing exist, none of the thresholds should be NaN/Inf + # And missing direction should either be the default RIGHT / TREE_UNDEFINED + X_de = rng.random_sample((10, 2)) + X_sp = coo_matrix(X_de) + y = rng.randint(0, 10, (10,)) + + for X in (X_de, X_sp): + dtc = DecisionTreeClassifier(missing_values="NaN", + random_state=42).fit(X, y) + # All the missing_direction for non-leaf nodes should be the default of + # MISSING_DIR_RIGHT and for leaves it should be TREE_UNDEFINED + assert_true( + np.all((dtc.tree_.missing_direction == MISSING_DIR_RIGHT) ^ + (dtc.tree_.children_left == TREE_LEAF))) + assert_true( + np.all((dtc.tree_.missing_direction == TREE_UNDEFINED) ^ + (dtc.tree_.children_left != TREE_LEAF))) + assert_all_finite(dtc.tree_.threshold) + + # When the missing values are equally from all the classes + # the tree building should split it into separate node + X_de = np.array([[110], [100], [1], [2], [0], [np.nan], [500], + [600], [np.nan], [5]]) + X_sp = coo_matrix(X_de) + y = np.array([1, 1, 0, 0, 0, 0, 1, 1, 1, 0]) + + for X in (X_de, X_sp): + dtc = DecisionTreeClassifier(missing_values="NaN", + random_state=42).fit(X, y) + + assert_equal(dtc.tree_.threshold[0], 52.5) + assert_equal(dtc.tree_.missing_direction[0], MISSING_DIR_RIGHT) + # The right child should partition the missing samples to the right + assert_equal(dtc.tree_.missing_direction[2], + MISSING_DIR_RIGHT) + assert_equal(dtc.tree_.threshold[2], np.inf) + # All other non-leaf nodes should have the default missing_direction as + # RIGHT + assert_true( + np.all((dtc.tree_.missing_direction == MISSING_DIR_RIGHT) ^ + (dtc.tree_.children_left == TREE_LEAF))) + # No nan thresholds + assert_false(np.any(np.isnan(dtc.tree_.threshold))) + try: + # This should not pass as both the missing values are grouped to + # single class + assert_array_equal(dtc.predict(X), y) + except AssertionError: + pass + + +def test_tree_missing_value_handling_corner_cases_random_splitter(): + # Using test cases as in the best splitter tests. + # Testing for actual thresholds or missing directions at each node does not + # Make sense. However in all of the corner cases the tree should capture + # the structure of the data exactly. + + # All the missing values belong to a single class + X_de = np.array([[np.nan], [np.nan], [np.nan], [np.nan], + [0], [1], [2], [3], [4], [5], + [10], [11], [12], [13], [15]]) + X_sp = coo_matrix(X_de) + y = np.array([1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3]) + + # Test both the dense and sparse splitters + for X in (X_de, X_sp): + dtc = DecisionTreeClassifier(missing_values="NaN", + splitter="random", + random_state=42).fit(X, y) + # No nan thresholds (INF can be a threshold however) + assert_false(np.any(np.isnan(dtc.tree_.threshold))) + assert_array_equal(dtc.predict(X), y) + + # The missing values and some available values in one class + X_de = np.array([[np.nan], [np.nan], [np.nan], [np.nan], + [0], [1], [2], [3], [4], [5]]) + X_sp = coo_matrix(X_de) + y = np.array([1, 1, 1, 1, 1, 2, 2, 2, 2, 2]) + + for X in (X_de, X_sp): + dtc = DecisionTreeClassifier(missing_values="NaN", + splitter="random", + random_state=42).fit(X, y) + # No nan thresholds (INF can be a threshold however) + assert_false(np.any(np.isnan(dtc.tree_.threshold))) + assert_array_equal(dtc.predict(X), y) + + # The missing should be sent along with available to right child + y = np.array([2, 2, 2, 2, 1, 1, 1, 1, 2, 2]) + + for X in (X_de, X_sp): + dtc = DecisionTreeClassifier(missing_values="NaN", + splitter="random", + random_state=42).fit(X, y) + # No nan thresholds (INF can be a threshold however) + assert_false(np.any(np.isnan(dtc.tree_.threshold))) + assert_array_equal(dtc.predict(X), y) + + # When no missing exist, none of the thresholds should be NaN/Inf + # And missing direction should either be the default RIGHT / TREE_UNDEFINED + X_de = rng.random_sample((10, 2)) + X_sp = coo_matrix(X_de) + y = rng.randint(0, 10, (10,)) + + for X in (X_de, X_sp): + dtc = DecisionTreeClassifier(missing_values="NaN", + splitter="random", + random_state=42).fit(X, y) + # All the missing_direction for non-leaf nodes should be the default of + # MISSING_DIR_RIGHT and for leaves it should be TREE_UNDEFINED + assert_true( + np.all((dtc.tree_.missing_direction == MISSING_DIR_RIGHT) ^ + (dtc.tree_.children_left == TREE_LEAF))) + assert_true( + np.all((dtc.tree_.missing_direction == TREE_UNDEFINED) ^ + (dtc.tree_.children_left != TREE_LEAF))) + assert_all_finite(dtc.tree_.threshold) + assert_all_finite(dtc.tree_.threshold) + assert_array_equal(dtc.predict(X), y) + + # When the missing values are equally from all the classes + # the tree building should split it into separate node + X_de = np.array([[110], [100], [1], [2], [0], [np.nan], [500], + [600], [np.nan], [5]]) + X_sp = coo_matrix(X_de) + y = np.array([1, 1, 0, 0, 0, 0, 1, 1, 1, 0]) + + for X in (X_de, X_sp): + dtc = DecisionTreeClassifier(missing_values="NaN", + splitter="random", + random_state=42).fit(X, y) + + # No nan thresholds (INF can be a threshold however) + assert_false(np.any(np.isnan(dtc.tree_.threshold))) + try: + # This should not pass as both the missing values are grouped to + # single class + assert_array_equal(dtc.predict(X), y) + except AssertionError: + pass + + +def test_tree_explicit_missing_mask(): + # All the missing values belong to a single class + X_de = np.array([[np.nan], [np.nan], [np.nan], [np.nan], + [0], [1], [2], [3], [4], [5], + [10], [11], [12], [13], [15]]) + X_sp = coo_matrix(X_de) + y = np.array([1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3]) + + # Check if validation of missing mask is correct + est = DecisionTreeClassifier(missing_values="NaN") + missing_mask_sp = np.isnan(X_sp.data) + assert_array_equal(missing_mask_sp, est._validate_missing_mask(X_sp)) + missing_mask_de = np.isnan(X_de) + assert_array_equal(missing_mask_de, est._validate_missing_mask(X_de)) + + # Test both the dense and sparse splitters to see if passing + # missing mask produces the same effect as autogenerated missing mask + for X, missing_mask in ((X_de, missing_mask_de), (X_sp, missing_mask_sp)): + dtc = DecisionTreeClassifier(missing_values="NaN", + splitter="random", + random_state=42).fit( + X, y) + dtc2 = DecisionTreeClassifier(missing_values="NaN", + splitter="random", + random_state=42).fit( + X, y, missing_mask=missing_mask) + # No nan thresholds (INF can be a threshold however) + assert_false(np.any(np.isnan(dtc.tree_.threshold))) + assert_array_equal(dtc.predict(X), dtc2.predict(X)) diff --git a/sklearn/tree/tree.py b/sklearn/tree/tree.py index c3567e864c10b..f537931b8539f 100644 --- a/sklearn/tree/tree.py +++ b/sklearn/tree/tree.py @@ -11,12 +11,12 @@ # Joly Arnaud # Fares Hedayati # Nelson Liu +# Raghav R V # # License: BSD 3 clause from __future__ import division - import numbers from abc import ABCMeta from abc import abstractmethod @@ -92,7 +92,8 @@ def __init__(self, random_state, min_impurity_split, class_weight=None, - presort=False): + presort=False, + missing_values=None): self.criterion = criterion self.splitter = splitter self.max_depth = max_depth @@ -114,12 +115,29 @@ def __init__(self, self.tree_ = None self.max_features_ = None - def fit(self, X, y, sample_weight=None, check_input=True, - X_idx_sorted=None): + self.missing_values = missing_values + self.allow_missing = missing_values is not None + + # If missing values is int/None + self._allow_nan = False + self.missing_values = missing_values + if self.allow_missing: + if (isinstance(missing_values, str) and + missing_values.strip().lower() == "nan"): + self._allow_nan = True + self.missing_values = np.nan + elif not isinstance(missing_values, int): + raise ValueError("missing_values should be 'NaN' or int. " + "Got %s", missing_values) + + def fit(self, X, y, sample_weight=None, check_input=True, + X_idx_sorted=None, missing_mask=None): random_state = check_random_state(self.random_state) + if check_input: - X = check_array(X, dtype=DTYPE, accept_sparse="csc") + X = check_array(X, dtype=DTYPE, accept_sparse="csc", + allow_nan=self._allow_nan) y = check_array(y, ensure_2d=False, dtype=None) if issparse(X): X.sort_indices() @@ -300,6 +318,7 @@ def fit(self, X, y, sample_weight=None, check_input=True, # but do not handle any presorting themselves. Ensemble algorithms # which desire presorting must do presorting themselves and pass that # matrix into each tree. + # Same goes for missing_mask if X_idx_sorted is None and presort: X_idx_sorted = np.asfortranarray(np.argsort(X, axis=0), dtype=np.int32) @@ -310,15 +329,28 @@ def fit(self, X, y, sample_weight=None, check_input=True, ".shape = {})".format(X.shape, X_idx_sorted.shape)) + missing_mask = self._validate_missing_mask(X, missing_mask) + if self.allow_missing: + if issparse(X) and missing_mask.shape != X.data.shape: + raise ValueError("The shape of X.data (X.data.shape = {}) " + "doesn't match the shape of missing_mask " + "(missing_mask.shape = {})" + .format(X.data.shape, missing_mask.shape)) + elif not issparse(X) and missing_mask.shape != X.shape: + raise ValueError("The shape of X (X.shape = {}) " + "doesn't match the shape of missing_mask " + "(missing_mask.shape = {})" + .format(X.shape, missing_mask.shape)) + # Build tree criterion = self.criterion if not isinstance(criterion, Criterion): if is_classification: - criterion = CRITERIA_CLF[self.criterion](self.n_outputs_, - self.n_classes_) + criterion = CRITERIA_CLF[self.criterion]( + self.n_outputs_, self.n_classes_, self.allow_missing) else: - criterion = CRITERIA_REG[self.criterion](self.n_outputs_, - n_samples) + criterion = CRITERIA_REG[self.criterion]( + self.n_outputs_, n_samples, self.allow_missing) SPLITTERS = SPARSE_SPLITTERS if issparse(X) else DENSE_SPLITTERS @@ -329,9 +361,14 @@ def fit(self, X, y, sample_weight=None, check_input=True, min_samples_leaf, min_weight_leaf, random_state, - self.presort) + self.presort, + self.allow_missing) - self.tree_ = Tree(self.n_features_, self.n_classes_, self.n_outputs_) + self.tree_ = Tree(n_features=self.n_features_, + n_classes=self.n_classes_, + n_outputs=self.n_outputs_, + allow_missing=self.allow_missing, + missing_values=self.missing_values) # Use BestFirst if max_leaf_nodes given; use DepthFirst otherwise if max_leaf_nodes < 0: @@ -347,7 +384,8 @@ def fit(self, X, y, sample_weight=None, check_input=True, max_leaf_nodes, self.min_impurity_split) - builder.build(self.tree_, X, y, sample_weight, X_idx_sorted) + builder.build(self.tree_, X=X, y=y, sample_weight=sample_weight, + X_idx_sorted=X_idx_sorted, missing_mask=missing_mask) if self.n_outputs_ == 1: self.n_classes_ = self.n_classes_[0] @@ -357,12 +395,14 @@ def fit(self, X, y, sample_weight=None, check_input=True, def _validate_X_predict(self, X, check_input): """Validate X whenever one tries to predict, apply, predict_proba""" + # Generate missing mask if self.allow_missing is True if self.tree_ is None: raise NotFittedError("Estimator not fitted, " "call `fit` before exploiting the model.") if check_input: - X = check_array(X, dtype=DTYPE, accept_sparse="csr") + X = check_array(X, dtype=DTYPE, accept_sparse="csr", + allow_nan=self._allow_nan) if issparse(X) and (X.indices.dtype != np.intc or X.indptr.dtype != np.intc): raise ValueError("No support for np.int64 index based " @@ -377,7 +417,25 @@ def _validate_X_predict(self, X, check_input): return X - def predict(self, X, check_input=True): + def _validate_missing_mask(self, X, missing_mask=None): + """Generate a new missing_mask or validate a given one""" + if self.allow_missing and missing_mask is None: + # Fortran ordered 8 bit boolean mask + if issparse(X): + to_mask = X.data + else: + to_mask = X + + if self._allow_nan: # Missing value is a NaN + missing_mask = np.asfortranarray(np.isnan(to_mask), + dtype=np.bool8) + else: + missing_mask = np.zeros(to_mask.shape, + dtype=np.bool8, order='F') + missing_mask[to_mask == self.missing_values] = True + return missing_mask + + def predict(self, X, check_input=True, missing_mask=None): """Predict class or regression value for X. For a classification model, the predicted class for each sample in X is @@ -402,7 +460,8 @@ def predict(self, X, check_input=True): """ X = self._validate_X_predict(X, check_input) - proba = self.tree_.predict(X) + missing_mask = self._validate_missing_mask(X, missing_mask) + proba = self.tree_.predict(X, missing_mask=missing_mask) n_samples = X.shape[0] # Classification @@ -428,7 +487,7 @@ def predict(self, X, check_input=True): else: return proba[:, :, 0] - def apply(self, X, check_input=True): + def apply(self, X, check_input=True, missing_mask=None): """ Returns the index of the leaf that each sample is predicted as. @@ -454,9 +513,10 @@ def apply(self, X, check_input=True): numbering. """ X = self._validate_X_predict(X, check_input) - return self.tree_.apply(X) + missing_mask = self._validate_missing_mask(X, missing_mask) + return self.tree_.apply(X, missing_mask=missing_mask) - def decision_path(self, X, check_input=True): + def decision_path(self, X, check_input=True, missing_mask=None): """Return the decision path in the tree .. versionadded:: 0.18 @@ -480,7 +540,8 @@ def decision_path(self, X, check_input=True): """ X = self._validate_X_predict(X, check_input) - return self.tree_.decision_path(X) + missing_mask = self._validate_missing_mask(X, missing_mask) + return self.tree_.decision_path(X, missing_mask=missing_mask) @property def feature_importances_(self): @@ -533,9 +594,9 @@ class DecisionTreeClassifier(BaseDecisionTree, ClassifierMixin): - If "log2", then `max_features=log2(n_features)`. - If None, then `max_features=n_features`. - Note: the search for a split does not stop until at least one - valid partition of the node samples is found, even if it requires to - effectively inspect more than ``max_features`` features. + Note: the scikit-learn/search for a split does not stop until at least + one valid partition of the node samples is found, even if it requires + to effectively inspect more than ``max_features`` features. max_depth : int or None, optional (default=None) The maximum depth of the tree. If None, then nodes are expanded until @@ -608,6 +669,12 @@ class DecisionTreeClassifier(BaseDecisionTree, ClassifierMixin): When using either a smaller dataset or a restricted depth, this may speed up the training. + missing_values : integer or "NaN", optional (default=None) + The placeholder for the missing values. If not None, all missing + values will be imputed. + + For missing values encoded as np.nan, use the string value "NaN". + Attributes ---------- classes_ : array of shape = [n_classes] or a list of such arrays @@ -680,7 +747,8 @@ def __init__(self, max_leaf_nodes=None, min_impurity_split=1e-7, class_weight=None, - presort=False): + presort=False, + missing_values=None): super(DecisionTreeClassifier, self).__init__( criterion=criterion, splitter=splitter, @@ -693,10 +761,11 @@ def __init__(self, class_weight=class_weight, random_state=random_state, min_impurity_split=min_impurity_split, - presort=presort) + presort=presort, + missing_values=missing_values) def fit(self, X, y, sample_weight=None, check_input=True, - X_idx_sorted=None): + X_idx_sorted=None, missing_mask=None): """Build a decision tree classifier from the training set (X, y). Parameters @@ -736,11 +805,11 @@ def fit(self, X, y, sample_weight=None, check_input=True, X, y, sample_weight=sample_weight, check_input=check_input, - X_idx_sorted=X_idx_sorted) + X_idx_sorted=X_idx_sorted, + missing_mask=missing_mask) return self - - def predict_proba(self, X, check_input=True): + def predict_proba(self, X, check_input=True, missing_mask=None): """Predict class probabilities of the input samples X. The predicted class probability is the fraction of samples of the same @@ -765,7 +834,8 @@ class in a leaf. classes corresponds to that in the attribute `classes_`. """ X = self._validate_X_predict(X, check_input) - proba = self.tree_.predict(X) + missing_mask = self._validate_missing_mask(X, missing_mask) + proba = self.tree_.predict(X, missing_mask) if self.n_outputs_ == 1: proba = proba[:, :self.n_classes_] @@ -787,7 +857,7 @@ class in a leaf. return all_proba - def predict_log_proba(self, X): + def predict_log_proba(self, X, missing_mask=None): """Predict class log-probabilities of the input samples X. Parameters @@ -804,7 +874,7 @@ def predict_log_proba(self, X): The class log-probabilities of the input samples. The order of the classes corresponds to that in the attribute `classes_`. """ - proba = self.predict_proba(X) + proba = self.predict_proba(X, missing_mask=missing_mask) if self.n_outputs_ == 1: return np.log(proba) @@ -909,6 +979,12 @@ class DecisionTreeRegressor(BaseDecisionTree, RegressorMixin): When using either a smaller dataset or a restricted depth, this may speed up the training. + missing_values : integer or "NaN", optional (default=None) + The placeholder for the missing values. If not None, all missing + values will be imputed. + + For missing values encoded as np.nan, use the string value "NaN". + Attributes ---------- feature_importances_ : array of shape = [n_features] @@ -972,7 +1048,8 @@ def __init__(self, random_state=None, max_leaf_nodes=None, min_impurity_split=1e-7, - presort=False): + presort=False, + missing_values=None): super(DecisionTreeRegressor, self).__init__( criterion=criterion, splitter=splitter, @@ -984,10 +1061,11 @@ def __init__(self, max_leaf_nodes=max_leaf_nodes, random_state=random_state, min_impurity_split=min_impurity_split, - presort=presort) + presort=presort, + missing_values=missing_values) def fit(self, X, y, sample_weight=None, check_input=True, - X_idx_sorted=None): + X_idx_sorted=None, missing_mask=None): """Build a decision tree regressor from the training set (X, y). Parameters @@ -1026,7 +1104,8 @@ def fit(self, X, y, sample_weight=None, check_input=True, X, y, sample_weight=sample_weight, check_input=check_input, - X_idx_sorted=X_idx_sorted) + X_idx_sorted=X_idx_sorted, + missing_mask=missing_mask) return self @@ -1065,7 +1144,8 @@ def __init__(self, random_state=None, max_leaf_nodes=None, min_impurity_split=1e-7, - class_weight=None): + class_weight=None, + missing_values=None): super(ExtraTreeClassifier, self).__init__( criterion=criterion, splitter=splitter, @@ -1077,7 +1157,8 @@ def __init__(self, max_leaf_nodes=max_leaf_nodes, class_weight=class_weight, min_impurity_split=min_impurity_split, - random_state=random_state) + random_state=random_state, + missing_values=missing_values) class ExtraTreeRegressor(DecisionTreeRegressor): @@ -1114,7 +1195,8 @@ def __init__(self, max_features="auto", random_state=None, min_impurity_split=1e-7, - max_leaf_nodes=None): + max_leaf_nodes=None, + missing_values=None): super(ExtraTreeRegressor, self).__init__( criterion=criterion, splitter=splitter, @@ -1125,4 +1207,5 @@ def __init__(self, max_features=max_features, max_leaf_nodes=max_leaf_nodes, min_impurity_split=min_impurity_split, - random_state=random_state) + random_state=random_state, + missing_values=missing_values) diff --git a/sklearn/utils/tests/test_validation.py b/sklearn/utils/tests/test_validation.py index 9c57ab4e96de8..d8d97556af382 100644 --- a/sklearn/utils/tests/test_validation.py +++ b/sklearn/utils/tests/test_validation.py @@ -141,6 +141,11 @@ def test_check_array(): X_nan[0, 0] = np.nan assert_raises(ValueError, check_array, X_nan) check_array(X_inf, force_all_finite=False) # no raise + # allow_nan check + check_array(X_nan, force_all_finite=True, allow_nan=True) # no raise + # allow_nan check should not hinder check for inf + assert_raises(ValueError, check_array, X_inf, force_all_finite=True, + allow_nan=False) # dtype and order enforcement. X_C = np.arange(4).reshape(2, 2).copy("C") diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index deb98eef85039..cefe18048af1b 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -46,26 +46,33 @@ class NotFittedError(_NotFittedError): warnings.simplefilter('ignore', _NonBLASDotWarning) -def _assert_all_finite(X): +def _assert_all_finite(X, allow_nan=False): """Like assert_all_finite, but only for ndarray.""" X = np.asanyarray(X) # First try an O(n) time, O(1) space solution for the common case that # everything is finite; fall back to O(n) space np.isfinite to prevent # false positives from overflow in sum method. - if (X.dtype.char in np.typecodes['AllFloat'] and not np.isfinite(X.sum()) - and not np.isfinite(X).all()): - raise ValueError("Input contains NaN, infinity" - " or a value too large for %r." % X.dtype) + if allow_nan: + def any_not_isfinite(X): return np.isinf(X).any() + np_sum = np.nansum + else: + def any_not_isfinite(X): return not np.isfinite(X).all() + np_sum = np.sum + + if (X.dtype.char in np.typecodes['AllFloat'] and + not np.isfinite(np_sum(X)) and any_not_isfinite(X)): + raise ValueError("Input contains %sinfinity or a value too large for " + "%r." % ("" if allow_nan else "NaN, ", X.dtype)) -def assert_all_finite(X): - """Throw a ValueError if X contains NaN or infinity. +def assert_all_finite(X, allow_nan=False): + """Throw a ValueError if X contains infinity or NaN (if allow_nan is False) Input MUST be an np.ndarray instance or a scipy.sparse matrix.""" - _assert_all_finite(X.data if sp.issparse(X) else X) + _assert_all_finite(X.data if sp.issparse(X) else X, allow_nan) -def as_float_array(X, copy=True, force_all_finite=True): +def as_float_array(X, copy=True, force_all_finite=True, allow_nan=False): """Converts an array-like to an array of floats The new dtype will be np.float32 or np.float64, depending on the original @@ -83,6 +90,9 @@ def as_float_array(X, copy=True, force_all_finite=True): force_all_finite : boolean (default=True) Whether to raise an error on np.inf and np.nan in X. + allow_nan : boolean (default=False) + Whether to allow nan values in X. + Returns ------- XT : {array, sparse matrix} @@ -92,7 +102,7 @@ def as_float_array(X, copy=True, force_all_finite=True): and not sp.issparse(X)): return check_array(X, ['csr', 'csc', 'coo'], dtype=np.float64, copy=copy, force_all_finite=force_all_finite, - ensure_2d=False) + allow_nan=allow_nan, ensure_2d=False) elif sp.issparse(X) and X.dtype in [np.float32, np.float64]: return X.copy() if copy else X elif X.dtype in [np.float32, np.float64]: # is numpy array @@ -208,7 +218,7 @@ def indexable(*iterables): def _ensure_sparse_format(spmatrix, accept_sparse, dtype, copy, - force_all_finite): + force_all_finite, allow_nan): """Convert a sparse matrix to a given format. Checks the sparse format of spmatrix and converts if necessary. @@ -232,7 +242,11 @@ def _ensure_sparse_format(spmatrix, accept_sparse, dtype, copy, be triggered by a conversion. force_all_finite : boolean (default=True) - Whether to raise an error on np.inf and np.nan in X. + Whether to raise an error on np.inf and np.nan (if allow_nan is False) + in X. + + allow_nan : boolean (default=True) + Whether to allow nan. Returns ------- @@ -265,14 +279,14 @@ def _ensure_sparse_format(spmatrix, accept_sparse, dtype, copy, warnings.warn("Can't check %s sparse matrix for nan or inf." % spmatrix.format) else: - _assert_all_finite(spmatrix.data) + _assert_all_finite(spmatrix.data, allow_nan) return spmatrix def check_array(array, accept_sparse=None, dtype="numeric", order=None, - copy=False, force_all_finite=True, ensure_2d=True, - allow_nd=False, ensure_min_samples=1, ensure_min_features=1, - warn_on_dtype=False, estimator=None): + copy=False, force_all_finite=True, allow_nan=False, + ensure_2d=True, allow_nd=False, ensure_min_samples=1, + ensure_min_features=1, warn_on_dtype=False, estimator=None): """Input validation on an array, list, sparse matrix or similar. By default, the input is converted to an at least 2D numpy array. @@ -308,7 +322,10 @@ def check_array(array, accept_sparse=None, dtype="numeric", order=None, be triggered by a conversion. force_all_finite : boolean (default=True) - Whether to raise an error on np.inf and np.nan in X. + Whether to raise an error on np.inf in X. + + allow_nan : boolean (default=False) + Whether to allow nan values in X. ensure_2d : boolean (default=True) Whether to make X at least 2d. @@ -377,7 +394,7 @@ def check_array(array, accept_sparse=None, dtype="numeric", order=None, if sp.issparse(array): array = _ensure_sparse_format(array, accept_sparse, dtype, copy, - force_all_finite) + force_all_finite, allow_nan) else: array = np.array(array, dtype=dtype, order=order, copy=copy) @@ -404,7 +421,7 @@ def check_array(array, accept_sparse=None, dtype="numeric", order=None, raise ValueError("Found array with dim %d. %s expected <= 2." % (array.ndim, estimator_name)) if force_all_finite: - _assert_all_finite(array) + _assert_all_finite(array, allow_nan) shape_repr = _shape_repr(array.shape) if ensure_min_samples > 0: @@ -431,9 +448,9 @@ def check_array(array, accept_sparse=None, dtype="numeric", order=None, def check_X_y(X, y, accept_sparse=None, dtype="numeric", order=None, - copy=False, force_all_finite=True, ensure_2d=True, - allow_nd=False, multi_output=False, ensure_min_samples=1, - ensure_min_features=1, y_numeric=False, + copy=False, force_all_finite=True, allow_nan=False, + ensure_2d=True, allow_nd=False, multi_output=False, + ensure_min_samples=1, ensure_min_features=1, y_numeric=False, warn_on_dtype=False, estimator=None): """Input validation for standard estimators. @@ -474,6 +491,9 @@ def check_X_y(X, y, accept_sparse=None, dtype="numeric", order=None, Whether to raise an error on np.inf and np.nan in X. This parameter does not influence whether y can have np.inf or np.nan values. + allow_nan : boolean (default=False) + Whether to allow nan values in X. + ensure_2d : boolean (default=True) Whether to make X at least 2d. @@ -517,14 +537,14 @@ def check_X_y(X, y, accept_sparse=None, dtype="numeric", order=None, The converted and validated y. """ X = check_array(X, accept_sparse, dtype, order, copy, force_all_finite, - ensure_2d, allow_nd, ensure_min_samples, + allow_nan, ensure_2d, allow_nd, ensure_min_samples, ensure_min_features, warn_on_dtype, estimator) if multi_output: - y = check_array(y, 'csr', force_all_finite=True, ensure_2d=False, - dtype=None) + y = check_array(y, 'csr', force_all_finite=True, allow_nan=allow_nan, + ensure_2d=False, dtype=None) else: y = column_or_1d(y, warn=True) - _assert_all_finite(y) + _assert_all_finite(y, allow_nan) if y_numeric and y.dtype.kind == 'O': y = y.astype(np.float64)