diff --git a/sklearn/ensemble/_gradient_boosting.pyx b/sklearn/ensemble/_gradient_boosting.pyx index 00ad2c08a0ee3..49524cd697751 100644 --- a/sklearn/ensemble/_gradient_boosting.pyx +++ b/sklearn/ensemble/_gradient_boosting.pyx @@ -14,6 +14,7 @@ np.import_array() from sklearn.tree._tree cimport Node from sklearn.tree._tree cimport Tree +from sklearn.tree._utils cimport goes_left ctypedef np.int32_t int32 @@ -31,6 +32,7 @@ from numpy import float64 as np_float64 DTYPE = np.float32 ctypedef np.float32_t DTYPE_t ctypedef np.npy_intp SIZE_t +ctypedef np.npy_int32 INT32_t # constant to mark tree leafs @@ -44,6 +46,7 @@ cdef void _predict_regression_tree_inplace_fast(DTYPE_t *X, Py_ssize_t K, Py_ssize_t n_samples, Py_ssize_t n_features, + INT32_t* n_categories, float64 *out): """Predicts output for regression tree and stores it in ``out[i, k]``. @@ -78,6 +81,9 @@ cdef void _predict_regression_tree_inplace_fast(DTYPE_t *X, ``n_samples == X.shape[0]``. n_features : int The number of features; ``n_samples == X.shape[1]``. + n_categories : INT32_t pointer + Pointer to array of shape [n_features] containing the number of + categories for each feature, or -1 for non-categorical features. out : np.float64_t pointer The pointer to the data array where the predictions are stored. ``out`` is assumed to be a two-dimensional array of @@ -90,7 +96,8 @@ cdef void _predict_regression_tree_inplace_fast(DTYPE_t *X, node = root_node # While node not a leaf while node.left_child != -1 and node.right_child != -1: - if X[i * n_features + node.feature] <= node.threshold: + if goes_left(X[i * n_features + node.feature], node.split_value, + n_categories[node.feature], node._bit_cache): node = root_node + node.left_child else: node = root_node + node.right_child @@ -116,6 +123,8 @@ def predict_stages(np.ndarray[object, ndim=2] estimators, for k in range(K): tree = estimators[i, k].tree_ + tree.populate_bit_caches() + # avoid buffer validation by casting to ndarray # and get data pointer # need brackets because of casting operator priority @@ -123,9 +132,12 @@ def predict_stages(np.ndarray[object, ndim=2] estimators, X.data, tree.nodes, tree.value, scale, k, K, X.shape[0], X.shape[1], + tree.n_categories, ( out).data) ## out += scale * tree.predict(X).reshape((X.shape[0], 1)) + tree.delete_bit_caches() + @cython.nonecheck(False) def predict_stage(np.ndarray[object, ndim=2] estimators, @@ -204,65 +216,73 @@ cpdef _partial_dependence_tree(Tree tree, DTYPE_t[:, ::1] X, underlying_stack = np_zeros((stack_capacity,), dtype=np.intp) node_stack = ( underlying_stack).data - for i in range(X.shape[0]): - # init stacks for new example - stack_size = 1 - node_stack[0] = root_node - weight_stack[0] = 1.0 - total_weight = 0.0 - - while stack_size > 0: - # get top node on stack - stack_size -= 1 - current_node = node_stack[stack_size] - - if current_node.left_child == LEAF: - out[i] += weight_stack[stack_size] * value[current_node - root_node] * \ - learn_rate - total_weight += weight_stack[stack_size] - else: - # non-terminal node - feature_index = array_index(current_node.feature, target_feature) - if feature_index != -1: - # split feature in target set - # push left or right child on stack - if X[i, feature_index] <= current_node.threshold: - # left - node_stack[stack_size] = (root_node + - current_node.left_child) - else: - # right - node_stack[stack_size] = (root_node + - current_node.right_child) - stack_size += 1 + tree.populate_bit_caches() + + try: + for i in range(X.shape[0]): + # init stacks for new example + stack_size = 1 + node_stack[0] = root_node + weight_stack[0] = 1.0 + total_weight = 0.0 + + while stack_size > 0: + # get top node on stack + stack_size -= 1 + current_node = node_stack[stack_size] + + if current_node.left_child == LEAF: + out[i] += weight_stack[stack_size] * value[current_node - root_node] * \ + learn_rate + total_weight += weight_stack[stack_size] else: - # split feature in complement set - # push both children onto stack - - # push left child - node_stack[stack_size] = root_node + current_node.left_child - current_weight = weight_stack[stack_size] - left_sample_frac = root_node[current_node.left_child].n_node_samples / \ - current_node.n_node_samples - if left_sample_frac <= 0.0 or left_sample_frac >= 1.0: - raise ValueError("left_sample_frac:%f, " - "n_samples current: %d, " - "n_samples left: %d" - % (left_sample_frac, - current_node.n_node_samples, - root_node[current_node.left_child].n_node_samples)) - weight_stack[stack_size] = current_weight * left_sample_frac - stack_size +=1 - - # push right child - node_stack[stack_size] = root_node + current_node.right_child - weight_stack[stack_size] = current_weight * \ - (1.0 - left_sample_frac) - stack_size +=1 - - if not (0.999 < total_weight < 1.001): - raise ValueError("Total weight should be 1.0 but was %.9f" % - total_weight) + # non-terminal node + feature_index = array_index(current_node.feature, target_feature) + if feature_index != -1: + # split feature in target set + # push left or right child on stack + if goes_left(X[i, feature_index], current_node.split_value, + tree.n_categories[current_node.feature], + current_node._bit_cache): + # left + node_stack[stack_size] = (root_node + + current_node.left_child) + else: + # right + node_stack[stack_size] = (root_node + + current_node.right_child) + stack_size += 1 + else: + # split feature in complement set + # push both children onto stack + + # push left child + node_stack[stack_size] = root_node + current_node.left_child + current_weight = weight_stack[stack_size] + left_sample_frac = root_node[current_node.left_child].n_node_samples / \ + current_node.n_node_samples + if left_sample_frac <= 0.0 or left_sample_frac >= 1.0: + raise ValueError("left_sample_frac:%f, " + "n_samples current: %d, " + "n_samples left: %d" + % (left_sample_frac, + current_node.n_node_samples, + root_node[current_node.left_child].n_node_samples)) + weight_stack[stack_size] = current_weight * left_sample_frac + stack_size +=1 + + # push right child + node_stack[stack_size] = root_node + current_node.right_child + weight_stack[stack_size] = current_weight * \ + (1.0 - left_sample_frac) + stack_size +=1 + + if not (0.999 < total_weight < 1.001): + raise ValueError("Total weight should be 1.0 but was %.9f" % + total_weight) + + finally: + tree.delete_bit_caches() def _random_sample_mask(np.npy_intp n_total_samples, diff --git a/sklearn/ensemble/forest.py b/sklearn/ensemble/forest.py index e26323f65bfee..8afabe098bb27 100644 --- a/sklearn/ensemble/forest.py +++ b/sklearn/ensemble/forest.py @@ -57,7 +57,8 @@ class calls the ``fit`` method of each sub-estimator on random samples from ..metrics import r2_score from ..preprocessing import OneHotEncoder from ..tree import (DecisionTreeClassifier, DecisionTreeRegressor, - ExtraTreeClassifier, ExtraTreeRegressor) + ExtraTreeClassifier, ExtraTreeRegressor, + preproc_categorical, validate_categorical) from ..tree._tree import DTYPE, DOUBLE from ..utils import check_random_state, check_array, compute_sample_weight from ..exceptions import DataConversionWarning, NotFittedError @@ -90,8 +91,8 @@ def _generate_unsampled_indices(random_state, n_samples): return unsampled_indices -def _parallel_build_trees(tree, forest, X, y, sample_weight, tree_idx, n_trees, - verbose=0, class_weight=None): +def _parallel_build_trees(tree, forest, X, y, sample_weight, tree_idx, + n_trees, verbose=0, class_weight=None): """Private function used to fit a single tree in parallel.""" if verbose > 1: print("building tree %d of %d" % (tree_idx + 1, n_trees)) @@ -153,6 +154,7 @@ def __init__(self, self.verbose = verbose self.warm_start = warm_start self.class_weight = class_weight + self.category_map_ = None def apply(self, X): """Apply trees in the forest to X, return leaf indices. @@ -245,6 +247,10 @@ def fit(self, X, y, sample_weight=None): # ensemble sorts the indices. X.sort_indices() + # Preprocess categorical variables + X, _, self.category_map_ = preproc_categorical( + X, self.categorical, check_input=True) + # Remap output n_samples, self.n_features_ = X.shape @@ -346,7 +352,10 @@ def _validate_X_predict(self, X): raise NotFittedError("Estimator not fitted, " "call `fit` before exploiting the model.") - return self.estimators_[0]._validate_X_predict(X, check_input=True) + X = self.estimators_[0]._validate_X_predict(X, check_input=True) + X = validate_categorical(X, self.category_map_) + + return X @property def feature_importances_(self): @@ -805,6 +814,19 @@ class RandomForestClassifier(ForestClassifier): If None then unlimited number of leaf nodes. If not None then ``max_depth`` will be ignored. + categorical : array-like or str + Array of feature indices, boolean array of length + n_features, ``'all'``, or ``'none'``. Indicates which + features should be considered as categorical rather than + ordinal. For decision trees, the maximum number of + categories per feature is 64, though the real-world limit + will be much lower because evaluating splits has + :math:`O(2^N)` time complexity, for :math:`N` + categories. Extra-randomized trees do not have this + limitation because they do not try to find the best + split. For these trees, the maximum number of categories + per feature is :math:`2^{31}`. + bootstrap : boolean, optional (default=True) Whether bootstrap samples are used when building trees. @@ -898,6 +920,7 @@ def __init__(self, min_weight_fraction_leaf=0., max_features="auto", max_leaf_nodes=None, + categorical="none", bootstrap=True, oob_score=False, n_jobs=1, @@ -911,7 +934,7 @@ def __init__(self, estimator_params=("criterion", "max_depth", "min_samples_split", "min_samples_leaf", "min_weight_fraction_leaf", "max_features", "max_leaf_nodes", - "random_state"), + "random_state", "categorical"), bootstrap=bootstrap, oob_score=oob_score, n_jobs=n_jobs, @@ -927,6 +950,7 @@ def __init__(self, self.min_weight_fraction_leaf = min_weight_fraction_leaf self.max_features = max_features self.max_leaf_nodes = max_leaf_nodes + self.categorical = categorical class RandomForestRegressor(ForestRegressor): @@ -998,6 +1022,19 @@ class RandomForestRegressor(ForestRegressor): If None then unlimited number of leaf nodes. If not None then ``max_depth`` will be ignored. + categorical : array-like or str + Array of feature indices, boolean array of length + n_features, ``'all'``, or ``'none'``. Indicates which + features should be considered as categorical rather than + ordinal. For decision trees, the maximum number of + categories per feature is 64, though the real-world limit + will be much lower because evaluating splits has + :math:`O(2^N)` time complexity, for :math:`N` + categories. Extra-randomized trees do not have this + limitation because they do not try to find the best + split. For these trees, the maximum number of categories + per feature is :math:`2^{31}`. + bootstrap : boolean, optional (default=True) Whether bootstrap samples are used when building trees. @@ -1061,6 +1098,7 @@ def __init__(self, min_weight_fraction_leaf=0., max_features="auto", max_leaf_nodes=None, + categorical="none", bootstrap=True, oob_score=False, n_jobs=1, @@ -1073,7 +1111,7 @@ def __init__(self, estimator_params=("criterion", "max_depth", "min_samples_split", "min_samples_leaf", "min_weight_fraction_leaf", "max_features", "max_leaf_nodes", - "random_state"), + "random_state", "categorical"), bootstrap=bootstrap, oob_score=oob_score, n_jobs=n_jobs, @@ -1088,6 +1126,7 @@ def __init__(self, self.min_weight_fraction_leaf = min_weight_fraction_leaf self.max_features = max_features self.max_leaf_nodes = max_leaf_nodes + self.categorical = categorical class ExtraTreesClassifier(ForestClassifier): @@ -1157,6 +1196,19 @@ class ExtraTreesClassifier(ForestClassifier): If None then unlimited number of leaf nodes. If not None then ``max_depth`` will be ignored. + categorical : array-like or str + Array of feature indices, boolean array of length + n_features, ``'all'``, or ``'none'``. Indicates which + features should be considered as categorical rather than + ordinal. For decision trees, the maximum number of + categories per feature is 64, though the real-world limit + will be much lower because evaluating splits has + :math:`O(2^N)` time complexity, for :math:`N` + categories. Extra-randomized trees do not have this + limitation because they do not try to find the best + split. For these trees, the maximum number of categories + per feature is :math:`2^{31}`. + bootstrap : boolean, optional (default=False) Whether bootstrap samples are used when building trees. @@ -1252,6 +1304,7 @@ def __init__(self, min_weight_fraction_leaf=0., max_features="auto", max_leaf_nodes=None, + categorical="none", bootstrap=False, oob_score=False, n_jobs=1, @@ -1265,7 +1318,7 @@ def __init__(self, estimator_params=("criterion", "max_depth", "min_samples_split", "min_samples_leaf", "min_weight_fraction_leaf", "max_features", "max_leaf_nodes", - "random_state"), + "random_state", "categorical"), bootstrap=bootstrap, oob_score=oob_score, n_jobs=n_jobs, @@ -1281,6 +1334,7 @@ def __init__(self, self.min_weight_fraction_leaf = min_weight_fraction_leaf self.max_features = max_features self.max_leaf_nodes = max_leaf_nodes + self.categorical = categorical class ExtraTreesRegressor(ForestRegressor): @@ -1350,6 +1404,19 @@ class ExtraTreesRegressor(ForestRegressor): If None then unlimited number of leaf nodes. If not None then ``max_depth`` will be ignored. + categorical : array-like or str + Array of feature indices, boolean array of length + n_features, ``'all'``, or ``'none'``. Indicates which + features should be considered as categorical rather than + ordinal. For decision trees, the maximum number of + categories per feature is 64, though the real-world limit + will be much lower because evaluating splits has + :math:`O(2^N)` time complexity, for :math:`N` + categories. Extra-randomized trees do not have this + limitation because they do not try to find the best + split. For these trees, the maximum number of categories + per feature is :math:`2^{31}`. + bootstrap : boolean, optional (default=False) Whether bootstrap samples are used when building trees. @@ -1414,6 +1481,7 @@ def __init__(self, min_weight_fraction_leaf=0., max_features="auto", max_leaf_nodes=None, + categorical="none", bootstrap=False, oob_score=False, n_jobs=1, @@ -1426,7 +1494,7 @@ def __init__(self, estimator_params=("criterion", "max_depth", "min_samples_split", "min_samples_leaf", "min_weight_fraction_leaf", "max_features", "max_leaf_nodes", - "random_state"), + "random_state", "categorical"), bootstrap=bootstrap, oob_score=oob_score, n_jobs=n_jobs, @@ -1441,6 +1509,7 @@ def __init__(self, self.min_weight_fraction_leaf = min_weight_fraction_leaf self.max_features = max_features self.max_leaf_nodes = max_leaf_nodes + self.categorical = categorical class RandomTreesEmbedding(BaseForest): @@ -1517,6 +1586,19 @@ class RandomTreesEmbedding(BaseForest): and add more estimators to the ensemble, otherwise, just fit a whole new forest. + categorical : array-like or str + Array of feature indices, boolean array of length + n_features, ``'all'``, or ``'none'``. Indicates which + features should be considered as categorical rather than + ordinal. For decision trees, the maximum number of + categories per feature is 64, though the real-world limit + will be much lower because evaluating splits has + :math:`O(2^N)` time complexity, for :math:`N` + categories. Extra-randomized trees do not have this + limitation because they do not try to find the best + split. For these trees, the maximum number of categories + per feature is :math:`2^{31}`. + Attributes ---------- estimators_ : list of DecisionTreeClassifier @@ -1543,14 +1625,15 @@ def __init__(self, n_jobs=1, random_state=None, verbose=0, - warm_start=False): + warm_start=False, + categorical="none"): super(RandomTreesEmbedding, self).__init__( base_estimator=ExtraTreeRegressor(), n_estimators=n_estimators, estimator_params=("criterion", "max_depth", "min_samples_split", "min_samples_leaf", "min_weight_fraction_leaf", "max_features", "max_leaf_nodes", - "random_state"), + "random_state", "categorical"), bootstrap=False, oob_score=False, n_jobs=n_jobs, @@ -1566,6 +1649,7 @@ def __init__(self, self.max_features = 1 self.max_leaf_nodes = max_leaf_nodes self.sparse_output = sparse_output + self.categorical = categorical def _set_oob_score(self, X, y): raise NotImplementedError("OOB score not supported by tree embedding") diff --git a/sklearn/ensemble/gradient_boosting.py b/sklearn/ensemble/gradient_boosting.py index b17d726cb122a..66af4b3741f39 100644 --- a/sklearn/ensemble/gradient_boosting.py +++ b/sklearn/ensemble/gradient_boosting.py @@ -47,6 +47,8 @@ from time import time from ..tree.tree import DecisionTreeRegressor +from ..tree.tree import preproc_categorical +from ..tree.tree import validate_categorical from ..tree._tree import DTYPE from ..tree._tree import TREE_LEAF @@ -724,7 +726,7 @@ def __init__(self, loss, learning_rate, n_estimators, min_samples_split, min_samples_leaf, min_weight_fraction_leaf, max_depth, init, subsample, max_features, random_state, alpha=0.9, verbose=0, max_leaf_nodes=None, - warm_start=False, presort='auto'): + warm_start=False, presort='auto', categorical='none'): self.n_estimators = n_estimators self.learning_rate = learning_rate @@ -742,6 +744,8 @@ def __init__(self, loss, learning_rate, n_estimators, min_samples_split, self.max_leaf_nodes = max_leaf_nodes self.warm_start = warm_start self.presort = presort + self.categorical = categorical + self.category_map_ = None self.estimators_ = np.empty((0, 0), dtype=np.object) @@ -771,28 +775,21 @@ def _fit_stage(self, i, X, y, y_pred, sample_weight, sample_mask, max_features=self.max_features, max_leaf_nodes=self.max_leaf_nodes, random_state=random_state, - presort=self.presort) + presort=self.presort, + categorical=self.categorical) if self.subsample < 1.0: # no inplace multiplication! sample_weight = sample_weight * sample_mask.astype(np.float64) - if X_csc is not None: - tree.fit(X_csc, residual, sample_weight=sample_weight, - check_input=False, X_idx_sorted=X_idx_sorted) - else: - tree.fit(X, residual, sample_weight=sample_weight, - check_input=False, X_idx_sorted=X_idx_sorted) + tree.fit(X_csc if X_csc is not None else X, residual, + sample_weight=sample_weight, check_input=False, + X_idx_sorted=X_idx_sorted) # update tree leaves - if X_csr is not None: - loss.update_terminal_regions(tree.tree_, X_csr, y, residual, y_pred, - sample_weight, sample_mask, - self.learning_rate, k=k) - else: - loss.update_terminal_regions(tree.tree_, X, y, residual, y_pred, - sample_weight, sample_mask, - self.learning_rate, k=k) + loss.update_terminal_regions( + tree.tree_, X_csr if X_csr is not None else X, y, residual, + y_pred, sample_weight, sample_mask, self.learning_rate, k=k) # add tree to ensemble self.estimators_[i, k] = tree @@ -979,6 +976,10 @@ def fit(self, X, y, sample_weight=None, monitor=None): y = self._validate_y(y) + # Preprocess categorical variables + X, _, self.category_map_ = preproc_categorical( + X, self.categorical, check_input=True) + random_state = check_random_state(self.random_state) self._check_params() @@ -1077,8 +1078,8 @@ def _fit_stages(self, X, y, y_pred, sample_weight, random_state, # fit next stage of trees y_pred = self._fit_stage(i, X, y, y_pred, sample_weight, - sample_mask, random_state, X_idx_sorted, - X_csc, X_csr) + sample_mask, random_state, + X_idx_sorted, X_csc, X_csr) # track deviance (= loss) if do_oob: @@ -1140,9 +1141,10 @@ def decision_function(self, X): Regression and binary classification produce an array of shape [n_samples]. """ - self._check_initialized() X = self.estimators_[0, 0]._validate_X_predict(X, check_input=True) + X = validate_categorical(X, self.category_map_) + score = self._decision_function(X) if score.shape[1] == 1: return score.ravel() @@ -1168,6 +1170,8 @@ def _staged_decision_function(self, X): ``k == 1``, otherwise ``k==n_classes``. """ X = check_array(X, dtype=DTYPE, order="C") + X = validate_categorical(X, self.category_map_) + score = self._init_decision_function(X) for i in range(self.estimators_.shape[0]): predict_stage(self.estimators_, i, X, self.learning_rate, score) @@ -1379,6 +1383,15 @@ class GradientBoostingClassifier(BaseGradientBoosting, ClassifierMixin): .. versionadded:: 0.17 *presort* parameter. + categorical : array-like or str + Array of feature indices, boolean array of length + n_features, ``'all'``, or ``'none'``. Indicates which + features should be considered as categorical rather than + ordinal. The maximum number of categories per feature is + 64, though the real-world limit will be much lower because + evaluating splits has :math:`O(2^N)` time complexity, for + :math:`N` categories. + Attributes ---------- feature_importances_ : array, shape = [n_features] @@ -1431,7 +1444,7 @@ def __init__(self, loss='deviance', learning_rate=0.1, n_estimators=100, max_depth=3, init=None, random_state=None, max_features=None, verbose=0, max_leaf_nodes=None, warm_start=False, - presort='auto'): + presort='auto', categorical='none'): super(GradientBoostingClassifier, self).__init__( loss=loss, learning_rate=learning_rate, n_estimators=n_estimators, @@ -1442,7 +1455,7 @@ def __init__(self, loss='deviance', learning_rate=0.1, n_estimators=100, max_features=max_features, random_state=random_state, verbose=verbose, max_leaf_nodes=max_leaf_nodes, warm_start=warm_start, - presort=presort) + presort=presort, categorical=categorical) def _validate_y(self, y): check_classification_targets(y) @@ -1467,6 +1480,8 @@ def decision_function(self, X): [n_samples]. """ X = check_array(X, dtype=DTYPE, order="C") + X = validate_categorical(X, self.category_map_) + score = self._decision_function(X) if score.shape[1] == 1: return score.ravel() @@ -1728,6 +1743,15 @@ class GradientBoostingRegressor(BaseGradientBoosting, RegressorMixin): .. versionadded:: 0.17 optional parameter *presort*. + categorical : array-like or str + Array of feature indices, boolean array of length + n_features, ``'all'``, or ``'none'``. Indicates which + features should be considered as categorical rather than + ordinal. The maximum number of categories per feature is + 64, though the real-world limit will be much lower because + evaluating splits has :math:`O(2^N)` time complexity, for + :math:`N` categories. + Attributes ---------- feature_importances_ : array, shape = [n_features] @@ -1776,7 +1800,7 @@ def __init__(self, loss='ls', learning_rate=0.1, n_estimators=100, min_samples_leaf=1, min_weight_fraction_leaf=0., max_depth=3, init=None, random_state=None, max_features=None, alpha=0.9, verbose=0, max_leaf_nodes=None, - warm_start=False, presort='auto'): + warm_start=False, presort='auto', categorical='none'): super(GradientBoostingRegressor, self).__init__( loss=loss, learning_rate=learning_rate, n_estimators=n_estimators, @@ -1787,7 +1811,7 @@ def __init__(self, loss='ls', learning_rate=0.1, n_estimators=100, max_features=max_features, random_state=random_state, alpha=alpha, verbose=verbose, max_leaf_nodes=max_leaf_nodes, warm_start=warm_start, - presort=presort) + presort=presort, categorical=categorical) def predict(self, X): """Predict regression target for X. @@ -1803,6 +1827,8 @@ def predict(self, X): The predicted values. """ X = check_array(X, dtype=DTYPE, order="C") + X = validate_categorical(X, self.category_map_) + return self._decision_function(X).ravel() def staged_predict(self, X): diff --git a/sklearn/ensemble/tests/test_bagging.py b/sklearn/ensemble/tests/test_bagging.py index 5d85713a76210..67dde9bdf32e6 100644 --- a/sklearn/ensemble/tests/test_bagging.py +++ b/sklearn/ensemble/tests/test_bagging.py @@ -5,6 +5,8 @@ # Author: Gilles Louppe # License: BSD 3 clause +import sys + import numpy as np from sklearn.base import BaseEstimator @@ -419,40 +421,49 @@ def test_parallel_classification(): X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, random_state=rng) - + print('nocats checkpoint 1', file=sys.stderr) ensemble = BaggingClassifier(DecisionTreeClassifier(), n_jobs=3, random_state=0).fit(X_train, y_train) # predict_proba + print('nocats checkpoint 2', file=sys.stderr) ensemble.set_params(n_jobs=1) y1 = ensemble.predict_proba(X_test) + print('nocats checkpoint 3', file=sys.stderr) ensemble.set_params(n_jobs=2) y2 = ensemble.predict_proba(X_test) assert_array_almost_equal(y1, y2) + print('nocats checkpoint 4', file=sys.stderr) ensemble = BaggingClassifier(DecisionTreeClassifier(), n_jobs=1, random_state=0).fit(X_train, y_train) + print('nocats checkpoint 5', file=sys.stderr) y3 = ensemble.predict_proba(X_test) assert_array_almost_equal(y1, y3) # decision_function + print('nocats checkpoint 6', file=sys.stderr) ensemble = BaggingClassifier(SVC(decision_function_shape='ovr'), n_jobs=3, random_state=0).fit(X_train, y_train) + print('nocats checkpoint 7', file=sys.stderr) ensemble.set_params(n_jobs=1) decisions1 = ensemble.decision_function(X_test) + print('nocats checkpoint 8', file=sys.stderr) ensemble.set_params(n_jobs=2) decisions2 = ensemble.decision_function(X_test) assert_array_almost_equal(decisions1, decisions2) + print('nocats checkpoint 9', file=sys.stderr) ensemble = BaggingClassifier(SVC(decision_function_shape='ovr'), n_jobs=1, random_state=0).fit(X_train, y_train) + print('nocats checkpoint 10', file=sys.stderr) decisions3 = ensemble.decision_function(X_test) assert_array_almost_equal(decisions1, decisions3) diff --git a/sklearn/tree/__init__.py b/sklearn/tree/__init__.py index 1394bd914d27c..42bab93f089ab 100644 --- a/sklearn/tree/__init__.py +++ b/sklearn/tree/__init__.py @@ -7,6 +7,7 @@ from .tree import DecisionTreeRegressor from .tree import ExtraTreeClassifier from .tree import ExtraTreeRegressor +from .tree import preproc_categorical, validate_categorical from .export import export_graphviz __all__ = ["DecisionTreeClassifier", "DecisionTreeRegressor", diff --git a/sklearn/tree/_splitter.pxd b/sklearn/tree/_splitter.pxd index e716736e1cf91..df009c988f925 100644 --- a/sklearn/tree/_splitter.pxd +++ b/sklearn/tree/_splitter.pxd @@ -17,8 +17,30 @@ from ._criterion cimport Criterion ctypedef np.npy_float32 DTYPE_t # Type of X ctypedef np.npy_float64 DOUBLE_t # Type of y, sample_weight ctypedef np.npy_intp SIZE_t # Type for indices and counters +ctypedef np.npy_uint8 UINT8_t # Unsigned 8 bit integer ctypedef np.npy_int32 INT32_t # Signed 32 bit integer ctypedef np.npy_uint32 UINT32_t # Unsigned 32 bit integer +ctypedef np.npy_uint64 UINT64_t # Unsigned 64 bit integer + +ctypedef union SplitValue: + # Union type to generalize the concept of a threshold to + # categorical features. For non-categorical features, use the + # threshold member. It acts just as before, where feature values + # less than or equal to the threshold go left, and values greater + # than the threshold go right. + # + # For categorical features, use the cat_split member. It works in + # one of two ways, indicated by the value of its least significant + # bit (LSB). If the LSB is 0, then cat_split acts as a bitfield + # for up to 64 categories, sending samples left if the bit + # corresponding to their category is 1 or right if it is 0. If the + # LSB is 1, then the more significant 32 bits of cat_split is a + # random seed. To evaluate a sample, use the random seed to flip a + # coin (category_value + 1) times and send it left if the last + # flip gives 1; otherwise right. This second method allows up to + # 2**31 category values, but can only be used for RandomSplitter. + DOUBLE_t threshold + UINT64_t cat_split cdef struct SplitRecord: # Data to track sample split @@ -26,7 +48,8 @@ cdef struct SplitRecord: SIZE_t pos # Split samples array at the given position, # i.e. count of samples below threshold for feature. # pos is >= end if the node is a leaf. - double threshold # Threshold to split at. + SplitValue split_value # Generalized threshold for categorical and + # non-categorical features. double improvement # Impurity improvement given parent node. double impurity_left # Impurity of the left split. double impurity_right # Impurity of the right split. @@ -59,10 +82,15 @@ cdef class Splitter: cdef bint presort # Whether to use presorting, only # allowed on dense data + cdef bint shortcut # Whether decision trees are allowed to use the + # Breiman shortcut for categorical features cdef DOUBLE_t* y cdef SIZE_t y_stride cdef DOUBLE_t* sample_weight + cdef INT32_t* n_categories # (n_features) array giving number of + # categories (<0 for non-categorical) + cdef UINT8_t* _bit_cache # The samples vector `samples` is maintained by the Splitter object such # that the samples contained in a node are contiguous. With this setting, @@ -83,6 +111,7 @@ cdef class Splitter: # Methods cdef void init(self, object X, np.ndarray y, DOUBLE_t* sample_weight, + INT32_t* n_categories, np.ndarray X_idx_sorted=*) except * cdef void node_reset(self, SIZE_t start, SIZE_t end, @@ -95,4 +124,4 @@ cdef class Splitter: cdef void node_value(self, double* dest) nogil - cdef double node_impurity(self) nogil \ No newline at end of file + cdef double node_impurity(self) nogil diff --git a/sklearn/tree/_splitter.pyx b/sklearn/tree/_splitter.pyx index 4c59b6960e7a0..aebe988e15e6d 100644 --- a/sklearn/tree/_splitter.pyx +++ b/sklearn/tree/_splitter.pyx @@ -29,10 +29,13 @@ np.import_array() from scipy.sparse import csc_matrix from ._utils cimport log +from ._utils cimport our_rand_r from ._utils cimport rand_int from ._utils cimport rand_uniform from ._utils cimport RAND_R_MAX from ._utils cimport safe_realloc +from ._utils cimport make_bit_cache +from ._utils cimport goes_left cdef double INFINITY = np.inf @@ -48,7 +51,7 @@ cdef inline void _init_split(SplitRecord* self, SIZE_t start_pos) nogil: self.impurity_right = INFINITY self.pos = start_pos self.feature = 0 - self.threshold = 0. + self.split_value.threshold = 0. self.improvement = -INFINITY cdef class Splitter: @@ -60,7 +63,7 @@ cdef class Splitter: def __cinit__(self, Criterion criterion, SIZE_t max_features, SIZE_t min_samples_leaf, double min_weight_leaf, - object random_state, bint presort): + object random_state, bint presort, bint shortcut): """ Parameters ---------- @@ -95,12 +98,15 @@ cdef class Splitter: self.y = NULL self.y_stride = 0 self.sample_weight = NULL + self.n_categories = NULL + self._bit_cache = NULL self.max_features = max_features self.min_samples_leaf = min_samples_leaf self.min_weight_leaf = min_weight_leaf self.random_state = random_state self.presort = presort + self.shortcut = shortcut def __dealloc__(self): """Destructor.""" @@ -109,6 +115,8 @@ cdef class Splitter: free(self.features) free(self.constant_features) free(self.feature_values) + free(self.n_categories) + free(self._bit_cache) def __getstate__(self): return {} @@ -120,6 +128,7 @@ cdef class Splitter: object X, np.ndarray[DOUBLE_t, ndim=2, mode="c"] y, DOUBLE_t* sample_weight, + INT32_t* n_categories, np.ndarray X_idx_sorted=None) except *: """Initialize the splitter. @@ -181,6 +190,19 @@ cdef class Splitter: self.sample_weight = sample_weight + # Initialize the number of categories for each feature + # A value of -1 indicates a non-categorical feature + safe_realloc(&self.n_categories, n_features) + for i in range(n_features): + self.n_categories[i] = (-1 if n_categories == NULL + else n_categories[i]) + + # If needed, allocate cache space to hold split info + cdef INT32_t max_n_categories = max( + [self.n_categories[i] for i in range(n_features)]) + if max_n_categories > 0: + safe_realloc(&self._bit_cache, (max_n_categories + 7) // 8) + cdef void node_reset(self, SIZE_t start, SIZE_t end, double* weighted_n_node_samples) nogil: """Reset splitter on node samples[start:end]. @@ -242,7 +264,7 @@ cdef class BaseDenseSplitter(Splitter): def __cinit__(self, Criterion criterion, SIZE_t max_features, SIZE_t min_samples_leaf, double min_weight_leaf, - object random_state, bint presort): + object random_state, bint presort, bint shortcut): self.X = NULL self.X_sample_stride = 0 @@ -261,11 +283,12 @@ cdef class BaseDenseSplitter(Splitter): object X, np.ndarray[DOUBLE_t, ndim=2, mode="c"] y, DOUBLE_t* sample_weight, + INT32_t* n_categories, np.ndarray X_idx_sorted=None) except *: """Initialize the splitter.""" # Call parent init - Splitter.init(self, X, y, sample_weight) + Splitter.init(self, X, y, sample_weight, n_categories) # Initialize X cdef np.ndarray X_ndarray = X @@ -295,8 +318,51 @@ cdef class BestSplitter(BaseDenseSplitter): self.random_state, self.presort), self.__getstate__()) + + cdef void _shortcut_catlist(self, SIZE_t start, SIZE_t end, INT32_t ncat, + INT32_t ncat_present, const INT32_t *cat_offs, + SIZE_t *shortcut_cat) nogil: + """The Breiman shortcut for finding the best split involves a + preprocessing step wherein we sort the categories by + increasing (weighted) mean of the outcome y (whether 0/1 + binary for classification or quantitative for + regression). This function implements this preprocessing step + and produces a sorted list of category values. + + """ + cdef SIZE_t *samples = self.samples + cdef DTYPE_t *Xf = self.feature_values + cdef DOUBLE_t *y = self.y + cdef SIZE_t y_stride = self.y_stride + cdef DOUBLE_t *sample_weight = self.sample_weight + cdef DOUBLE_t w + cdef SIZE_t cat, localcat + cdef SIZE_t q, partition_end + cdef DTYPE_t sort_value[64] + cdef DTYPE_t sort_den[64] + + for cat in range(ncat): + sort_value[cat] = 0 + sort_den[cat] = 0 + + for q in range(start, end): + cat = Xf[q] + w = sample_weight[samples[q]] if sample_weight else 1.0 + sort_value[cat] += w * (y[y_stride * samples[q]]) + sort_den[cat] += w + + for localcat in range(ncat_present): + cat = localcat + cat_offs[localcat] + sort_value[localcat] = sort_value[cat] / sort_den[cat] + shortcut_cat[localcat] = cat + + # Second step: sort by decreasing impurity + sort(&sort_value[0], shortcut_cat, ncat_present) + + cdef void node_split(self, double impurity, SplitRecord* split, SIZE_t* n_constant_features) nogil: + """Find the best split on node samples[start:end].""" # Find the best split cdef SIZE_t* samples = self.samples @@ -325,8 +391,8 @@ cdef class BestSplitter(BaseDenseSplitter): cdef SIZE_t f_i = n_features cdef SIZE_t f_j - cdef SIZE_t tmp cdef SIZE_t p + cdef SIZE_t q cdef SIZE_t feature_idx_offset cdef SIZE_t feature_offset cdef SIZE_t i @@ -342,6 +408,12 @@ cdef class BestSplitter(BaseDenseSplitter): cdef SIZE_t n_total_constants = n_known_constants cdef DTYPE_t current_feature_value cdef SIZE_t partition_end + cdef bint is_categorical + cdef UINT64_t cat_split + cdef INT32_t ncat_present + cdef INT32_t cat_offs[64] + cdef SIZE_t shortcut_cat[64] + cdef bint shortcut = self.shortcut _init_split(&best, end) @@ -383,9 +455,8 @@ cdef class BestSplitter(BaseDenseSplitter): if f_j < n_known_constants: # f_j in the interval [n_drawn_constants, n_known_constants[ - tmp = features[f_j] - features[f_j] = features[n_drawn_constants] - features[n_drawn_constants] = tmp + features[f_j], features[n_drawn_constants] = ( + features[n_drawn_constants], features[f_j]) n_drawn_constants += 1 @@ -429,61 +500,124 @@ cdef class BestSplitter(BaseDenseSplitter): # Evaluate all splits self.criterion.reset() - p = start + is_categorical = self.n_categories[current.feature] > 0 + if is_categorical: + p = 0 + # Identify the subset of categories present (for performance reasons) + cat_split = 0 + ncat_present = 0 + for q in range(start, end): + cat_split |= 1 << (Xf[q]) + for q in range(self.n_categories[current.feature]): + if cat_split & (1 << q): + cat_offs[ncat_present] = q - ncat_present + ncat_present += 1 + shortcut = self.shortcut if ncat_present > 3 else 0 # No benefit for small N + if shortcut: + self._shortcut_catlist(start, end, self.n_categories[current.feature], + ncat_present, cat_offs, &shortcut_cat[0]) + else: + p = start + + while True: + if is_categorical: + if shortcut: + p += 1 + if p >= ncat_present: + break + cat_split = 0 + for q in range(p): + cat_split |= ( 1) << shortcut_cat[q] + if cat_split & 1: + cat_split = (~cat_split) & ( + (~( 0)) >> (64 - self.n_categories[current.feature])) + else: + if p > ((~( 0)) >> (64 - ncat_present)): + break + else: + p += 2 # LSB must always be 0 + + # Expand the bits of p out into cat_split + cat_split = 0 + for q in range(ncat_present): + cat_split |= (p & (( 1) << q)) << cat_offs[q] + + # Partition + q = start + partition_end = end + while q < partition_end: + if (cat_split >> Xf[q]) & 1: + q += 1 + else: + partition_end -= 1 + Xf[q], Xf[partition_end] = Xf[partition_end], Xf[q] + samples[q], samples[partition_end] = ( + samples[partition_end], samples[q]) + current.pos = q + + # Must reset criterion since we've reordered the samples + self.criterion.reset() + else: + # Non-categorical feature + while (p + 1 < end and + Xf[p + 1] <= Xf[p] + FEATURE_THRESHOLD): + p += 1 - while p < end: - while (p + 1 < end and - Xf[p + 1] <= Xf[p] + FEATURE_THRESHOLD): + # (p + 1 >= end) or (X[samples[p + 1], current.feature] > + # X[samples[p], current.feature]) p += 1 + # (p >= end) or (X[samples[p], current.feature] > + # X[samples[p - 1], current.feature]) - # (p + 1 >= end) or (X[samples[p + 1], current.feature] > - # X[samples[p], current.feature]) - p += 1 - # (p >= end) or (X[samples[p], current.feature] > - # X[samples[p - 1], current.feature]) + if p >= end: + break - if p < end: current.pos = p - # Reject if min_samples_leaf is not guaranteed - if (((current.pos - start) < min_samples_leaf) or - ((end - current.pos) < min_samples_leaf)): - continue - - self.criterion.update(current.pos) + # Reject if min_samples_leaf is not guaranteed + if (((current.pos - start) < min_samples_leaf) or + ((end - current.pos) < min_samples_leaf)): + continue - # Reject if min_weight_leaf is not satisfied - if ((self.criterion.weighted_n_left < min_weight_leaf) or - (self.criterion.weighted_n_right < min_weight_leaf)): - continue + self.criterion.update(current.pos) - current_proxy_improvement = self.criterion.proxy_impurity_improvement() + # Reject if min_weight_leaf is not satisfied + if ((self.criterion.weighted_n_left < min_weight_leaf) or + (self.criterion.weighted_n_right < min_weight_leaf)): + continue - if current_proxy_improvement > best_proxy_improvement: - best_proxy_improvement = current_proxy_improvement - current.threshold = (Xf[p - 1] + Xf[p]) / 2.0 + current_proxy_improvement = self.criterion.proxy_impurity_improvement() - if current.threshold == Xf[p]: - current.threshold = Xf[p - 1] + if current_proxy_improvement > best_proxy_improvement: + best_proxy_improvement = current_proxy_improvement + if is_categorical: + current.split_value.cat_split = cat_split + else: + current.split_value.threshold = (Xf[p - 1] + Xf[p]) / 2.0 + if current.split_value.threshold == Xf[p]: + current.split_value.threshold = Xf[p - 1] - best = current # copy + best = current # copy # Reorganize into samples[start:best.pos] + samples[best.pos:end] if best.pos < end: + make_bit_cache(best.split_value, self.n_categories[best.feature], + self._bit_cache) feature_offset = X_feature_stride * best.feature partition_end = end p = start while p < partition_end: - if X[X_sample_stride * samples[p] + feature_offset] <= best.threshold: + if goes_left(X[X_sample_stride * samples[p] + feature_offset], + best.split_value, self.n_categories[best.feature], + self._bit_cache): p += 1 else: partition_end -= 1 - tmp = samples[partition_end] - samples[partition_end] = samples[p] - samples[p] = tmp + samples[p], samples[partition_end] = ( + samples[partition_end], samples[p]) self.criterion.reset() self.criterion.update(best.pos) @@ -659,7 +793,7 @@ cdef class RandomSplitter(BaseDenseSplitter): cdef SIZE_t f_i = n_features cdef SIZE_t f_j cdef SIZE_t p - cdef SIZE_t tmp + cdef SIZE_t q cdef SIZE_t feature_stride # Number of features discovered to be constant during the split search cdef SIZE_t n_found_constants = 0 @@ -673,6 +807,8 @@ cdef class RandomSplitter(BaseDenseSplitter): cdef DTYPE_t max_feature_value cdef DTYPE_t current_feature_value cdef SIZE_t partition_end + cdef bint is_categorical + cdef UINT64_t split_seed _init_split(&best, end) @@ -709,9 +845,8 @@ cdef class RandomSplitter(BaseDenseSplitter): if f_j < n_known_constants: # f_j in the interval [n_drawn_constants, n_known_constants[ - tmp = features[f_j] - features[f_j] = features[n_drawn_constants] - features[n_drawn_constants] = tmp + features[f_j], features[n_drawn_constants] = ( + features[n_drawn_constants], features[f_j]) n_drawn_constants += 1 @@ -748,32 +883,44 @@ cdef class RandomSplitter(BaseDenseSplitter): f_i -= 1 features[f_i], features[f_j] = features[f_j], features[f_i] - # Draw a random threshold - current.threshold = rand_uniform(min_feature_value, - max_feature_value, - random_state) - - if current.threshold == max_feature_value: - current.threshold = min_feature_value - - # Partition - partition_end = end - p = start - while p < partition_end: - current_feature_value = Xf[p] - if current_feature_value <= current.threshold: - p += 1 + # Construct a random split + # Repeat up to 60 times if a trivial split is constructed + # (this can only happen with a categorical feature) + for q in range(60): + is_categorical = self.n_categories[current.feature] > 0 + if is_categorical: + split_seed = our_rand_r(random_state) + current.split_value.cat_split = (split_seed << 32) | 1 else: - partition_end -= 1 + current.split_value.threshold = rand_uniform( + min_feature_value, max_feature_value, random_state) + if current.split_value.threshold == max_feature_value: + current.split_value.threshold = min_feature_value + + # Partition + make_bit_cache(current.split_value, self.n_categories[current.feature], + self._bit_cache) + partition_end = end + p = start + while p < partition_end: + current_feature_value = Xf[p] + if goes_left(current_feature_value, current.split_value, + self.n_categories[current.feature], self._bit_cache): + p += 1 + else: + partition_end -= 1 + + Xf[p] = Xf[partition_end] + Xf[partition_end] = current_feature_value - Xf[p] = Xf[partition_end] - Xf[partition_end] = current_feature_value + samples[p], samples[partition_end] = ( + samples[partition_end], samples[p]) - tmp = samples[partition_end] - samples[partition_end] = samples[p] - samples[p] = tmp + current.pos = partition_end - current.pos = partition_end + # Break early if a non-trivial split was found + if partition_end != start and partition_end != end: + break # Reject if min_samples_leaf is not guaranteed if (((current.pos - start) < min_samples_leaf) or @@ -798,20 +945,23 @@ cdef class RandomSplitter(BaseDenseSplitter): # Reorganize into samples[start:best.pos] + samples[best.pos:end] feature_stride = X_feature_stride * best.feature if best.pos < end: + make_bit_cache(best.split_value, self.n_categories[best.feature], + self._bit_cache) if current.feature != best.feature: partition_end = end p = start while p < partition_end: - if X[X_sample_stride * samples[p] + feature_stride] <= best.threshold: + if goes_left(X[X_sample_stride * samples[p] + feature_stride], + best.split_value, self.n_categories[best.feature], + self._bit_cache): p += 1 else: partition_end -= 1 - tmp = samples[partition_end] - samples[partition_end] = samples[p] - samples[p] = tmp + samples[p], samples[partition_end] = ( + samples[partition_end], samples[p]) self.criterion.reset() @@ -848,7 +998,7 @@ cdef class BaseSparseSplitter(Splitter): def __cinit__(self, Criterion criterion, SIZE_t max_features, SIZE_t min_samples_leaf, double min_weight_leaf, - object random_state, bint presort): + object random_state, bint presort, bint shortcut): # Parent __cinit__ is automatically called self.X_data = NULL @@ -869,11 +1019,12 @@ cdef class BaseSparseSplitter(Splitter): object X, np.ndarray[DOUBLE_t, ndim=2, mode="c"] y, DOUBLE_t* sample_weight, + INT32_t* n_categories, np.ndarray X_idx_sorted=None) except *: """Initialize the splitter.""" # Call parent init - Splitter.init(self, X, y, sample_weight) + Splitter.init(self, X, y, sample_weight, n_categories) if not isinstance(X, csc_matrix): raise ValueError("X should be in csc format") @@ -1347,9 +1498,9 @@ cdef class BestSparseSplitter(BaseSparseSplitter): if current_proxy_improvement > best_proxy_improvement: best_proxy_improvement = current_proxy_improvement - current.threshold = (Xf[p_prev] + Xf[p]) / 2.0 - if current.threshold == Xf[p]: - current.threshold = Xf[p_prev] + current.split_value.threshold = (Xf[p_prev] + Xf[p]) / 2.0 + if current.split_value.threshold == Xf[p]: + current.split_value.threshold = Xf[p_prev] best = current @@ -1358,7 +1509,7 @@ cdef class BestSparseSplitter(BaseSparseSplitter): self.extract_nnz(best.feature, &end_negative, &start_positive, &is_samples_sorted) - self._partition(best.threshold, end_negative, start_positive, + self._partition(best.split_value.threshold, end_negative, start_positive, best.pos) self.criterion.reset() @@ -1542,15 +1693,14 @@ cdef class RandomSparseSplitter(BaseSparseSplitter): features[f_i], features[f_j] = features[f_j], features[f_i] # Draw a random threshold - current.threshold = rand_uniform(min_feature_value, - max_feature_value, - random_state) + current.split_value.threshold = rand_uniform( + min_feature_value, max_feature_value, random_state) - if current.threshold == max_feature_value: - current.threshold = min_feature_value + if current.split_value.threshold == max_feature_value: + current.split_value.threshold = min_feature_value # Partition - current.pos = self._partition(current.threshold, + current.pos = self._partition(current.split_value.threshold, end_negative, start_positive, start_positive + @@ -1586,7 +1736,7 @@ cdef class RandomSparseSplitter(BaseSparseSplitter): self.extract_nnz(best.feature, &end_negative, &start_positive, &is_samples_sorted) - self._partition(best.threshold, end_negative, start_positive, + self._partition(best.split_value.threshold, end_negative, start_positive, best.pos) self.criterion.reset() diff --git a/sklearn/tree/_tree.pxd b/sklearn/tree/_tree.pxd index 6b34e51a07c8e..e330848fb85a5 100644 --- a/sklearn/tree/_tree.pxd +++ b/sklearn/tree/_tree.pxd @@ -15,11 +15,14 @@ cimport numpy as np ctypedef np.npy_float32 DTYPE_t # Type of X ctypedef np.npy_float64 DOUBLE_t # Type of y, sample_weight ctypedef np.npy_intp SIZE_t # Type for indices and counters +ctypedef np.npy_uint8 UINT8_t # Unsigned 8 bit integer ctypedef np.npy_int32 INT32_t # Signed 32 bit integer ctypedef np.npy_uint32 UINT32_t # Unsigned 32 bit integer +ctypedef np.npy_uint64 UINT64_t # Unsigned 64 bit integer from ._splitter cimport Splitter from ._splitter cimport SplitRecord +from ._splitter cimport SplitValue cdef struct Node: # Base storage structure for the nodes in a Tree object @@ -27,10 +30,12 @@ cdef struct Node: SIZE_t left_child # id of the left child of the node SIZE_t right_child # id of the right child of the node SIZE_t feature # Feature used for splitting the node - DOUBLE_t threshold # Threshold value at the node + SplitValue split_value # Generalized threshold for categorical and + # non-categorical features. DOUBLE_t impurity # Impurity of the node (i.e., the value of the criterion) SIZE_t n_node_samples # Number of samples at the node DOUBLE_t weighted_n_node_samples # Weighted number of samples at the node + UINT8_t* _bit_cache cdef class Tree: @@ -52,17 +57,22 @@ cdef class Tree: cdef Node* nodes # Array of nodes cdef double* value # (capacity, n_outputs, max_n_classes) array of values cdef SIZE_t value_stride # = n_outputs * max_n_classes + cdef INT32_t* n_categories # (n_features) array giving number of + # categories (<0 for non-categorical) # Methods cdef SIZE_t _add_node(self, SIZE_t parent, bint is_left, bint is_leaf, - SIZE_t feature, double threshold, double impurity, - SIZE_t n_node_samples, + SIZE_t feature, SplitValue split_value, + double impurity, SIZE_t n_node_samples, double weighted_n_samples) nogil cdef void _resize(self, SIZE_t capacity) except * cdef int _resize_c(self, SIZE_t capacity=*) nogil cdef np.ndarray _get_value_ndarray(self) cdef np.ndarray _get_node_ndarray(self) + cdef np.ndarray _get_ncat_ndarray(self) + cdef void populate_bit_caches(self) + cdef void delete_bit_caches(self) cpdef np.ndarray predict(self, object X) @@ -98,5 +108,6 @@ cdef class TreeBuilder: cpdef build(self, Tree tree, object X, np.ndarray y, np.ndarray sample_weight=*, + np.ndarray n_categories=*, np.ndarray X_idx_sorted=*) cdef _check_input(self, object X, np.ndarray y, np.ndarray sample_weight) diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index 8bc02ef44d94d..401de169975a8 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -35,6 +35,8 @@ from ._utils cimport StackRecord from ._utils cimport PriorityHeap from ._utils cimport PriorityHeapRecord from ._utils cimport safe_realloc +from ._utils cimport make_bit_cache +from ._utils cimport goes_left from ._utils cimport sizet_ptr_to_ndarray cdef extern from "numpy/arrayobject.h": @@ -67,18 +69,20 @@ cdef DTYPE_t MIN_IMPURITY_SPLIT = 1e-7 # Repeat struct definition for numpy NODE_DTYPE = np.dtype({ - 'names': ['left_child', 'right_child', 'feature', 'threshold', 'impurity', - 'n_node_samples', 'weighted_n_node_samples'], - 'formats': [np.intp, np.intp, np.intp, np.float64, np.float64, np.intp, - np.float64], + 'names': ['left_child', 'right_child', 'feature', 'threshold', + 'impurity', 'n_node_samples', 'weighted_n_node_samples', + '_bit_cache'], + 'formats': [np.intp, np.intp, np.intp, np.float64, np.float64, + np.intp, np.float64, np.intp], 'offsets': [ &( NULL).left_child, &( NULL).right_child, &( NULL).feature, - &( NULL).threshold, + &( NULL).split_value, &( NULL).impurity, &( NULL).n_node_samples, - &( NULL).weighted_n_node_samples + &( NULL).weighted_n_node_samples, + &( NULL)._bit_cache ] }) @@ -91,6 +95,7 @@ cdef class TreeBuilder: cpdef build(self, Tree tree, object X, np.ndarray y, np.ndarray sample_weight=None, + np.ndarray n_categories=None, np.ndarray X_idx_sorted=None): """Build a decision tree from the training set (X, y).""" pass @@ -140,6 +145,7 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): cpdef build(self, Tree tree, object X, np.ndarray y, np.ndarray sample_weight=None, + np.ndarray n_categories=None, np.ndarray X_idx_sorted=None): """Build a decision tree from the training set (X, y).""" @@ -150,6 +156,14 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): if sample_weight is not None: sample_weight_ptr = sample_weight.data + cdef INT32_t *n_categories_ptr = NULL + if n_categories is not None: + if ((n_categories.dtype != np.int32) or + (not n_categories.flags.contiguous)): + n_categories = np.asarray(n_categories, + dtype=np.int32, order="C") + n_categories_ptr = n_categories.data + # Initial capacity cdef int init_capacity @@ -168,7 +182,7 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): cdef SIZE_t min_samples_split = self.min_samples_split # Recursive partition (without actual recursion) - splitter.init(X, y, sample_weight_ptr, X_idx_sorted) + splitter.init(X, y, sample_weight_ptr, n_categories_ptr, X_idx_sorted) cdef SIZE_t start cdef SIZE_t end @@ -181,7 +195,6 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): cdef SplitRecord split cdef SIZE_t node_id - cdef double threshold cdef double impurity = INFINITY cdef SIZE_t n_constant_features cdef bint is_leaf @@ -229,8 +242,8 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): is_leaf = is_leaf or (split.pos >= end) node_id = tree._add_node(parent, is_left, is_leaf, split.feature, - split.threshold, impurity, n_node_samples, - weighted_n_node_samples) + split.split_value, impurity, + n_node_samples, weighted_n_node_samples) if node_id == (-1): rc = -1 @@ -298,6 +311,7 @@ cdef class BestFirstTreeBuilder(TreeBuilder): cpdef build(self, Tree tree, object X, np.ndarray y, np.ndarray sample_weight=None, + np.ndarray n_categories=None, np.ndarray X_idx_sorted=None): """Build a decision tree from the training set (X, y).""" @@ -308,6 +322,14 @@ cdef class BestFirstTreeBuilder(TreeBuilder): if sample_weight is not None: sample_weight_ptr = sample_weight.data + cdef INT32_t *n_categories_ptr = NULL + if n_categories is not None: + if ((n_categories.dtype != np.int32) or + (not n_categories.flags.contiguous)): + n_categories = np.asarray(n_categories, + dtype=np.int32, order="C") + n_categories_ptr = n_categories.data + # Parameters cdef Splitter splitter = self.splitter cdef SIZE_t max_leaf_nodes = self.max_leaf_nodes @@ -316,7 +338,7 @@ cdef class BestFirstTreeBuilder(TreeBuilder): cdef SIZE_t min_samples_split = self.min_samples_split # Recursive partition (without actual recursion) - splitter.init(X, y, sample_weight_ptr, X_idx_sorted) + splitter.init(X, y, sample_weight_ptr, n_categories_ptr, X_idx_sorted) cdef PriorityHeap frontier = PriorityHeap(INITIAL_STACK_SIZE) cdef PriorityHeapRecord record @@ -356,7 +378,7 @@ cdef class BestFirstTreeBuilder(TreeBuilder): node.left_child = _TREE_LEAF node.right_child = _TREE_LEAF node.feature = _TREE_UNDEFINED - node.threshold = _TREE_UNDEFINED + node.split_value.threshold = _TREE_UNDEFINED else: # Node is expandable @@ -444,8 +466,8 @@ cdef class BestFirstTreeBuilder(TreeBuilder): if parent != NULL else _TREE_UNDEFINED, is_left, is_leaf, - split.feature, split.threshold, impurity, n_node_samples, - weighted_n_node_samples) + split.feature, split.split_value, + impurity, n_node_samples, weighted_n_node_samples) if node_id == (-1): return -1 @@ -524,6 +546,10 @@ cdef class Tree: value : array of double, shape [node_count, n_outputs, max_n_classes] Contains the constant prediction value of each node. + n_categories : array of int, shape [n_features] + Number of expected categorie values for categorical features, or + -1 for non-categorical features. + impurity : array of double, shape [node_count] impurity[i] holds the impurity (i.e., the value of the splitting criterion) at node i. @@ -576,6 +602,15 @@ cdef class Tree: def __get__(self): return self._get_value_ndarray()[:self.node_count] + property n_categories: + def __get__(self): + return self._get_ncat_ndarray()[:self.n_features] + + def __set__(self, np.ndarray[INT32_t, ndim=1] value): + cdef SIZE_t i + for i in range(self.n_features): + self.n_categories[i] = value[i] + def __cinit__(self, int n_features, np.ndarray[SIZE_t, ndim=1] n_classes, int n_outputs): """Constructor.""" @@ -598,13 +633,19 @@ cdef class Tree: self.capacity = 0 self.value = NULL self.nodes = NULL + self.n_categories = NULL + safe_realloc(&self.n_categories, n_features) + for k in range(n_features): + self.n_categories[k] = -1 def __dealloc__(self): """Destructor.""" # Free all inner structures + self.delete_bit_caches() free(self.n_classes) free(self.value) free(self.nodes) + free(self.n_categories) def __reduce__(self): """Reduce re-implementation, for pickling.""" @@ -620,6 +661,7 @@ cdef class Tree: d["node_count"] = self.node_count d["nodes"] = self._get_node_ndarray() d["values"] = self._get_value_ndarray() + d["n_categories"] = self._get_ncat_ndarray() return d def __setstate__(self, d): @@ -633,6 +675,7 @@ cdef class Tree: node_ndarray = d['nodes'] value_ndarray = d['values'] + ncat_ndarray = d['n_categories'] value_shape = (node_ndarray.shape[0], self.n_outputs, self.max_n_classes) @@ -641,7 +684,10 @@ cdef class Tree: not node_ndarray.flags.c_contiguous or value_ndarray.shape != value_shape or not value_ndarray.flags.c_contiguous or - value_ndarray.dtype != np.float64): + value_ndarray.dtype != np.float64 or + ncat_ndarray.shape != (self.n_features,) or + ncat_ndarray.dtype != np.int32 or + not ncat_ndarray.flags.c_contiguous): raise ValueError('Did not recognise loaded array layout') self.capacity = node_ndarray.shape[0] @@ -651,6 +697,8 @@ cdef class Tree: self.capacity * sizeof(Node)) value = memcpy(self.value, ( value_ndarray).data, self.capacity * self.value_stride * sizeof(double)) + ncat = memcpy(self.n_categories, ( ncat_ndarray).data, + self.n_features * sizeof(INT32_t)) cdef void _resize(self, SIZE_t capacity) except *: """Resize all inner arrays to `capacity`, if `capacity` == -1, then @@ -696,7 +744,7 @@ cdef class Tree: return 0 cdef SIZE_t _add_node(self, SIZE_t parent, bint is_left, bint is_leaf, - SIZE_t feature, double threshold, double impurity, + SIZE_t feature, SplitValue split_value, double impurity, SIZE_t n_node_samples, double weighted_n_node_samples) nogil: """Add a node to the tree. @@ -725,17 +773,46 @@ cdef class Tree: node.left_child = _TREE_LEAF node.right_child = _TREE_LEAF node.feature = _TREE_UNDEFINED - node.threshold = _TREE_UNDEFINED + node.split_value.threshold = _TREE_UNDEFINED else: # left_child and right_child will be set later node.feature = feature - node.threshold = threshold + node.split_value = split_value + + node._bit_cache = NULL self.node_count += 1 return node_id + cdef void populate_bit_caches(self): + """Allocates and populates bit caches for nodes that split on + categorical features. Should be run before every tree traversal.""" + cdef Node* node = self.nodes + cdef Node* end_node = self.nodes + self.node_count + cdef INT32_t n_categories = 0 + + while node != end_node: + if node.left_child != _TREE_LEAF: + n_categories = self.n_categories[node.feature] + if n_categories > 0: + safe_realloc(&node._bit_cache, (n_categories + 7) // 8) + make_bit_cache(node.split_value, n_categories, + node._bit_cache) + node += 1 + + cdef void delete_bit_caches(self): + """Deallocates the bit cache of each node in the tree. Should be run + after tree traversal.""" + cdef Node* node = self.nodes + cdef Node* end_node = self.nodes + self.node_count + + while node != end_node: + free(node._bit_cache) + node._bit_cache = NULL + node += 1 + cpdef np.ndarray predict(self, object X): """Predict target for X.""" out = self._get_value_ndarray().take(self.apply(X), axis=0, @@ -777,20 +854,24 @@ cdef class Tree: cdef Node* node = NULL cdef SIZE_t i = 0 + self.populate_bit_caches() + with nogil: for i in range(n_samples): node = self.nodes # While node not a leaf while node.left_child != _TREE_LEAF: # ... and node.right_child != _TREE_LEAF: - if X_ptr[X_sample_stride * i + - X_fx_stride * node.feature] <= node.threshold: + if goes_left(X_ptr[X_sample_stride * i + X_fx_stride * node.feature], + node.split_value, self.n_categories[node.feature], node._bit_cache): node = &self.nodes[node.left_child] else: node = &self.nodes[node.right_child] out_ptr[i] = (node - self.nodes) # node offset + self.delete_bit_caches() + return out cdef inline np.ndarray _apply_sparse_csr(self, object X): @@ -855,7 +936,8 @@ cdef class Tree: else: feature_value = 0. - if feature_value <= node.threshold: + if goes_left(feature_value, node.split_value, + self.n_categories[node.feature], node._bit_cache): node = &self.nodes[node.left_child] else: node = &self.nodes[node.right_child] @@ -918,7 +1000,7 @@ cdef class Tree: indptr_ptr[i + 1] += 1 if X_ptr[X_sample_stride * i + - X_fx_stride * node.feature] <= node.threshold: + X_fx_stride * node.feature] <= node.split_value.threshold: node = &self.nodes[node.left_child] else: node = &self.nodes[node.right_child] @@ -1006,7 +1088,7 @@ cdef class Tree: else: feature_value = 0. - if feature_value <= node.threshold: + if feature_value <= node.split_value.threshold: node = &self.nodes[node.left_child] else: node = &self.nodes[node.right_child] @@ -1082,6 +1164,20 @@ cdef class Tree: arr.base = self return arr + cdef np.ndarray _get_ncat_ndarray(self): + """Wraps n_categories as a 3-d Numpy array + + The array keeps a reference to this Tree, which manages the underlying + memory. + """ + cdef np.npy_intp shape[1] + shape[0] = self.n_features + cdef np.ndarray arr + arr = np.PyArray_SimpleNewFromData(1, shape, np.NPY_INT32, self.n_categories) + Py_INCREF(self) + arr.base = self + return arr + cdef np.ndarray _get_node_ndarray(self): """Wraps nodes as a NumPy struct array. diff --git a/sklearn/tree/_utils.pxd b/sklearn/tree/_utils.pxd index 8f659aa86374e..5d2bf007c3926 100644 --- a/sklearn/tree/_utils.pxd +++ b/sklearn/tree/_utils.pxd @@ -10,9 +10,12 @@ import numpy as np cimport numpy as np +from ._splitter cimport SplitValue + ctypedef np.npy_float32 DTYPE_t # Type of X ctypedef np.npy_float64 DOUBLE_t # Type of y, sample_weight ctypedef np.npy_intp SIZE_t # Type for indices and counters +ctypedef np.npy_uint8 UINT8_t # Unsigned 8 bit integer ctypedef np.npy_int32 INT32_t # Signed 32 bit integer ctypedef np.npy_uint32 UINT32_t # Unsigned 32 bit integer @@ -33,6 +36,7 @@ ctypedef fused realloc_ptr: (DTYPE_t*) (SIZE_t*) (unsigned char*) + (INT32_t*) cdef realloc_ptr safe_realloc(realloc_ptr* p, size_t nelems) except * @@ -40,6 +44,9 @@ cdef realloc_ptr safe_realloc(realloc_ptr* p, size_t nelems) except * cdef np.ndarray sizet_ptr_to_ndarray(SIZE_t* data, SIZE_t size) +cdef UINT32_t our_rand_r(UINT32_t* seed) nogil + + cdef SIZE_t rand_int(SIZE_t low, SIZE_t high, UINT32_t* random_state) nogil @@ -50,6 +57,14 @@ cdef double rand_uniform(double low, double high, cdef double log(double x) nogil + +# Functions for traversing a tree +cdef void make_bit_cache(SplitValue split, INT32_t n_categories, + UINT8_t* bit_cache) nogil + +cdef bint goes_left(DTYPE_t feature_value, SplitValue split, + INT32_t n_categories, UINT8_t* bit_cache) nogil + # ============================================================================= # Stack data structure # ============================================================================= diff --git a/sklearn/tree/_utils.pyx b/sklearn/tree/_utils.pyx index 70232e3851050..4f4e2488bc12e 100644 --- a/sklearn/tree/_utils.pyx +++ b/sklearn/tree/_utils.pyx @@ -82,6 +82,47 @@ cdef inline double rand_uniform(double low, double high, cdef inline double log(double x) nogil: return ln(x) / ln(2.0) +cdef inline void make_bit_cache(SplitValue split, INT32_t n_categories, + UINT8_t* bit_cache) nogil: + """Regenerate and store the random numbers for a split.""" + cdef UINT32_t rng_seed + cdef SIZE_t q + cdef UINT32_t val + + if (n_categories <= 0): + # Non-categorical feature; bit cache not used + return + + if (split.cat_split & 1 == 0): + # Bitfield model + for q in range((n_categories + 7) // 8): + bit_cache[q] = (split.cat_split >> (q * 8)) & 0xFF + else: + # Random model + for q in range((n_categories + 7) // 8): + bit_cache[q] = 0 + rng_seed = split.cat_split >> 32 + for q in range(n_categories): + val = rand_int(0, 2, &rng_seed) + bit_cache[q // 8] |= val << (q % 8) + +cdef inline bint goes_left(DTYPE_t feature_value, SplitValue split, + INT32_t n_categories, UINT8_t* bit_cache) nogil: + """Determine whether a sample goes to the left or right child node.""" + cdef SIZE_t idx, shift + + if n_categories < 1: + # Non-categorical feature + return feature_value <= split.threshold + else: + # Categorical feature, using bit cache + if ( feature_value) < n_categories: + idx = ( feature_value) // 8 + shift = ( feature_value) % 8 + return (bit_cache[idx] >> shift) & 1 + else: + return 0 + # ============================================================================= # Stack data structure diff --git a/sklearn/tree/log b/sklearn/tree/log new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sklearn/tree/tests/test_tree.py b/sklearn/tree/tests/test_tree.py index e4ca2be5e452a..c03420def7723 100644 --- a/sklearn/tree/tests/test_tree.py +++ b/sklearn/tree/tests/test_tree.py @@ -1443,3 +1443,51 @@ def test_no_sparse_y_support(): # Currently we don't support sparse y for name in ALL_TREES: yield (check_no_sparse_y_support, name) + + +def test_invalid_categorical_str(): + check = lambda nm: assert_raises(ValueError, ALL_TREES[nm], + categorical='example invalid string') + for name in ALL_TREES: + yield check, name + + +def test_invalid_categorical_bool(): + check = lambda nm: assert_raises( + ValueError, ALL_TREES[nm](categorical=[False, False, False]).fit, X, y) + for name in ALL_TREES: + yield check, name + + +def test_invalid_categorical_shape(): + check = lambda nm: assert_raises( + ValueError, ALL_TREES[nm], categorical=[[0]]) + for name in ALL_TREES: + yield check, name + + +def check_invalid_categorical_idx(name): + Tree = ALL_TREES[name] + bad_catvals = ([1, 2], [-3], [0, 0, 1]) + for catval in bad_catvals: + assert_raises(ValueError, Tree(categorical=catval).fit, X, y) + + +def test_invalid_categorical_idx(): + for name in ALL_TREES: + yield check_invalid_categorical_idx, name + + +def check_no_sparse_with_categorical(name): + X, y, X_sparse = [DATASETS['toy'][z] for z in + ['X', 'y', 'X_sparse']] + Tree = ALL_TREES[name] + assert_raises(NotImplementedError, Tree(categorical='all').fit, + X_sparse, y) + assert_raises(NotImplementedError, + Tree(categorical='all').fit(X, y).predict, X_sparse) + + +def test_no_sparse_with_categorical(): + for name in SPARSE_TREES: + yield check_no_sparse_with_categorical, name diff --git a/sklearn/tree/tree.py b/sklearn/tree/tree.py index d33f2fbadcb80..87a6919067633 100644 --- a/sklearn/tree/tree.py +++ b/sklearn/tree/tree.py @@ -20,6 +20,7 @@ from abc import ABCMeta from abc import abstractmethod from math import ceil +from itertools import count import numpy as np from scipy.sparse import issparse @@ -28,6 +29,7 @@ from ..base import ClassifierMixin from ..base import RegressorMixin from ..externals import six +from ..externals.six.moves import zip from ..feature_selection.from_model import _LearntSelectorMixin from ..utils import check_array from ..utils import check_random_state @@ -64,6 +66,119 @@ SPARSE_SPLITTERS = {"best": _splitter.BestSparseSplitter, "random": _splitter.RandomSparseSplitter} + +# ============================================================================= +# Support functions +# ============================================================================= +def preproc_categorical(X, categorical, check_input): + """Preprocess categorical features by mapping them to + range(n_categories). Used for fitting. + + Parameters + ---------- + X : array-like, shape=(n_samples, n_features) + Feature array + categorical : array-like or str + Specification of which features are categorical. See fit(). + check_input : bool + If False, bypass creation of category map and transformation + of X. Use only if you know what you are doing. + + Returns + ------- + X : array, shape=(n_samples, n_features) + Transformed copy of the feature array (or the original if + there are no categorical features) + n_categories : array, shape=(n_features,) + Number of categories for each feature (-1 if non-categorical) + category_map : list, length n_features + For each feature, a dictionary relating values to transformed + values, or an empty dictionary for non-categorical features + + """ + n_features = np.shape(X)[1] + if isinstance(categorical, str): + if categorical == 'none': + categorical = np.array([]) + elif categorical == 'all': + categorical = np.arange(n_features) + else: + # Should have been caught in the constructor, but just in case + raise ValueError("Invalid value for categorical: {}. Allowed" + " strings are 'all' or 'none'" + "".format(categorical)) + categorical = np.asarray(categorical) + if categorical.dtype == np.bool: + if categorical.shape != (n_features,): + raise ValueError("Shape of boolean parameter categorical must" + " be (n_features,)") + categorical = np.nonzero(categorical)[0] + if (len(categorical.shape) != 1 or + categorical.size > n_features or + (categorical.size > 0 and + (np.min(categorical) < 0 or + np.max(categorical) >= n_features))): + raise ValueError("Invalid shape or invalid feature index for" + " parameter categorical") + if issparse(X) and categorical.size > 0: + raise NotImplementedError("Categorical features not supported with" + " sparse inputs") + + n_categories = -np.ones(n_features, dtype=np.int32) + category_map = [{}] * n_features + if categorical.size > 0 and check_input: + X = np.copy(X) + for feature in categorical: + rounded = np.round(X[:, feature]).astype(np.int64) + unique_rounded = np.unique(rounded) + if check_input: + category_map[feature] = dict(zip(unique_rounded, count())) + X[:, feature] = np.array([category_map[feature][x] + for x in rounded], dtype=DTYPE) + n_categories[feature] = len(unique_rounded) + + return X, n_categories, category_map + + +def validate_categorical(X, category_map): + """Map categorical features onto sequential integers. Used for + predicting. + + Parameters + ---------- + X : array-like, shape=(n_samples, n_features) + Feature array + category_map : list, length n_features + For each feature, a dictionary relating values to transformed + values, or an empty dictionary for non-categorical features + + Returns + ------- + X : array, shape=(n_samples, n_features) + Transformed copy of the feature array (or the original if + there are no categorical features) + """ + if category_map is None: + return X + + n_categories = np.array([len(x) for x in category_map]) + categorical_features = np.nonzero(n_categories > 0)[0] + if categorical_features.size > 0: + if issparse(X): + raise NotImplementedError("Categorical features not supported" + " with sparse inputs") + X = np.copy(X) + for feature in categorical_features: + rounded = np.round(X[:, feature]).astype('int64') + new_cat = set(rounded) - set(category_map[feature]) + new_cat_map = dict(zip(new_cat, count(n_categories[feature]))) + X[:, feature] = np.array( + [category_map[feature].get(x, new_cat_map.get(x)) + for x in rounded]).astype(DTYPE) + + return X + + # ============================================================================= # Base decision tree # ============================================================================= @@ -89,7 +204,8 @@ def __init__(self, max_leaf_nodes, random_state, class_weight=None, - presort=False): + presort=False, + categorical='none'): self.criterion = criterion self.splitter = splitter self.max_depth = max_depth @@ -101,15 +217,26 @@ def __init__(self, self.max_leaf_nodes = max_leaf_nodes self.class_weight = class_weight self.presort = presort + self.categorical = categorical self.n_features_ = None self.n_outputs_ = None self.classes_ = None self.n_classes_ = None + self.category_map_ = None self.tree_ = None self.max_features_ = None + # Input validation for parameter categorical + if isinstance(self.categorical, str): + if categorical not in ('all', 'none'): + raise ValueError("Invalid value for categorical: {}. Allowed" + " strings are 'all' or 'none'" + "".format(categorical)) + elif len(np.shape(categorical)) != 1: + raise ValueError("Invalid shape for parameter categorical") + def fit(self, X, y, sample_weight=None, check_input=True, X_idx_sorted=None): """Build a decision tree from the training set (X, y). @@ -135,7 +262,7 @@ def fit(self, X, y, sample_weight=None, check_input=True, check_input : boolean, (default=True) Allow to bypass several input checking. - Don't use this parameter unless you know what you do. + Don't use this parameter unless you know what you are doing. X_idx_sorted : array-like, shape = [n_samples, n_features], optional The indexes of the sorted training input samples. If many tree @@ -296,6 +423,10 @@ def fit(self, X, y, sample_weight=None, check_input=True, else: sample_weight = expanded_class_weight + # Do preprocessing of categorical variables + X, n_categories, self.category_map_ = preproc_categorical( + X, self.categorical, check_input) + # Set min_weight_leaf from min_weight_fraction_leaf if self.min_weight_fraction_leaf != 0. and sample_weight is not None: min_weight_leaf = (self.min_weight_fraction_leaf * @@ -338,19 +469,33 @@ def fit(self, X, y, sample_weight=None, check_input=True, self.n_classes_) else: criterion = CRITERIA_REG[self.criterion](self.n_outputs_) + if is_classification: + use_shortcut = (self.n_classes_.tolist() == [2] and + (isinstance(criterion, _criterion.Gini) or + isinstance(criterion, _criterion.Entropy))) + else: + use_shortcut = isinstance(criterion, _criterion.MSE) SPLITTERS = SPARSE_SPLITTERS if issparse(X) else DENSE_SPLITTERS splitter = self.splitter - if not isinstance(self.splitter, Splitter): + if not isinstance(splitter, Splitter): splitter = SPLITTERS[self.splitter](criterion, self.max_features_, min_samples_leaf, min_weight_leaf, random_state, - self.presort) + self.presort, + use_shortcut) + + if (not isinstance(splitter, _splitter.RandomSplitter) and + np.max(n_categories) > 64): + raise ValueError('A feature with {} categories was detected; to' + ' use more than 64, use ExtraTree rather than' + ' DecisionTree.'.format(np.max(n_categories))) self.tree_ = Tree(self.n_features_, self.n_classes_, self.n_outputs_) + self.tree_.n_categories = n_categories # Use BestFirst if max_leaf_nodes given; use DepthFirst otherwise if max_leaf_nodes < 0: @@ -365,7 +510,8 @@ def fit(self, X, y, sample_weight=None, check_input=True, max_depth, max_leaf_nodes) - builder.build(self.tree_, X, y, sample_weight, X_idx_sorted) + builder.build(self.tree_, X, y, sample_weight, n_categories, + X_idx_sorted) if self.n_outputs_ == 1: self.n_classes_ = self.n_classes_[0] @@ -411,7 +557,7 @@ def predict(self, X, check_input=True): check_input : boolean, (default=True) Allow to bypass several input checking. - Don't use this parameter unless you know what you do. + Don't use this parameter unless you know what you are doing. Returns ------- @@ -420,6 +566,9 @@ def predict(self, X, check_input=True): """ X = self._validate_X_predict(X, check_input) + if check_input: + X = validate_categorical(X, self.category_map_) + proba = self.tree_.predict(X) n_samples = X.shape[0] @@ -461,7 +610,7 @@ def apply(self, X, check_input=True): check_input : boolean, (default=True) Allow to bypass several input checking. - Don't use this parameter unless you know what you do. + Don't use this parameter unless you know what you are doing. Returns ------- @@ -472,6 +621,9 @@ def apply(self, X, check_input=True): numbering. """ X = self._validate_X_predict(X, check_input) + if check_input: + X = validate_categorical(X, self.category_map_) + return self.tree_.apply(X) def decision_path(self, X, check_input=True): @@ -613,6 +765,19 @@ class DecisionTreeClassifier(BaseDecisionTree, ClassifierMixin): When using either a smaller dataset or a restricted depth, this may speed up the training. + categorical : array-like or str + Array of feature indices, boolean array of length + n_features, ``'all'``, or ``'none'``. Indicates which + features should be considered as categorical rather than + ordinal. For decision trees, the maximum number of + categories per feature is 64, though the real-world limit + will be much lower because evaluating splits has + :math:`O(2^N)` time complexity, for :math:`N` + categories. Extra-randomized trees do not have this + limitation because they do not try to find the best + split. For these trees, the maximum number of categories + per feature is :math:`2^{31}`. + Attributes ---------- classes_ : array of shape = [n_classes] or a list of such arrays @@ -684,7 +849,8 @@ def __init__(self, random_state=None, max_leaf_nodes=None, class_weight=None, - presort=False): + presort=False, + categorical='none'): super(DecisionTreeClassifier, self).__init__( criterion=criterion, splitter=splitter, @@ -696,7 +862,8 @@ def __init__(self, max_leaf_nodes=max_leaf_nodes, class_weight=class_weight, random_state=random_state, - presort=presort) + presort=presort, + categorical=categorical) def predict_proba(self, X, check_input=True): """Predict class probabilities of the input samples X. @@ -706,7 +873,7 @@ class in a leaf. check_input : boolean, (default=True) Allow to bypass several input checking. - Don't use this parameter unless you know what you do. + Don't use this parameter unless you know what you are doing. Parameters ---------- @@ -723,6 +890,9 @@ class in a leaf. classes corresponds to that in the attribute `classes_`. """ X = self._validate_X_predict(X, check_input) + if check_input: + X = validate_categorical(X, self.category_map_) + proba = self.tree_.predict(X) if self.n_outputs_ == 1: @@ -852,6 +1022,19 @@ class DecisionTreeRegressor(BaseDecisionTree, RegressorMixin): When using either a smaller dataset or a restricted depth, this may speed up the training. + categorical : array-like or str + Array of feature indices, boolean array of length + n_features, ``'all'``, or ``'none'``. Indicates which + features should be considered as categorical rather than + ordinal. For decision trees, the maximum number of + categories per feature is 64, though the real-world limit + will be much lower because evaluating splits has + :math:`O(2^N)` time complexity, for :math:`N` + categories. Extra-randomized trees do not have this + limitation because they do not try to find the best + split. For these trees, the maximum number of categories + per feature is :math:`2^{31}`. + Attributes ---------- feature_importances_ : array of shape = [n_features] @@ -914,7 +1097,8 @@ def __init__(self, max_features=None, random_state=None, max_leaf_nodes=None, - presort=False): + presort=False, + categorical='none'): super(DecisionTreeRegressor, self).__init__( criterion=criterion, splitter=splitter, @@ -925,7 +1109,8 @@ def __init__(self, max_features=max_features, max_leaf_nodes=max_leaf_nodes, random_state=random_state, - presort=presort) + presort=presort, + categorical=categorical) class ExtraTreeClassifier(DecisionTreeClassifier): @@ -962,7 +1147,8 @@ def __init__(self, max_features="auto", random_state=None, max_leaf_nodes=None, - class_weight=None): + class_weight=None, + categorical='none'): super(ExtraTreeClassifier, self).__init__( criterion=criterion, splitter=splitter, @@ -973,7 +1159,8 @@ def __init__(self, max_features=max_features, max_leaf_nodes=max_leaf_nodes, class_weight=class_weight, - random_state=random_state) + random_state=random_state, + categorical=categorical) class ExtraTreeRegressor(DecisionTreeRegressor): @@ -1009,7 +1196,8 @@ def __init__(self, min_weight_fraction_leaf=0., max_features="auto", random_state=None, - max_leaf_nodes=None): + max_leaf_nodes=None, + categorical='none'): super(ExtraTreeRegressor, self).__init__( criterion=criterion, splitter=splitter, @@ -1019,4 +1207,5 @@ def __init__(self, min_weight_fraction_leaf=min_weight_fraction_leaf, max_features=max_features, max_leaf_nodes=max_leaf_nodes, - random_state=random_state) + random_state=random_state, + categorical=categorical)