Skip to content

Commit cc3eb48

Browse files
Added categorical_features argument to trees
1 parent f31eec6 commit cc3eb48

File tree

1 file changed

+62
-4
lines changed

1 file changed

+62
-4
lines changed

sklearn/tree/tree.py

Lines changed: 62 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ def __init__(self,
7373
min_weight_fraction_leaf,
7474
max_features,
7575
max_leaf_nodes,
76+
categorical_features,
7677
random_state):
7778
self.criterion = criterion
7879
self.splitter = splitter
@@ -83,11 +84,13 @@ def __init__(self,
8384
self.max_features = max_features
8485
self.random_state = random_state
8586
self.max_leaf_nodes = max_leaf_nodes
87+
self.categorical_features = categorical_features
8688

8789
self.n_features_ = None
8890
self.n_outputs_ = None
8991
self.classes_ = None
9092
self.n_classes_ = None
93+
self.categorical_dicts = None
9194

9295
self.tree_ = None
9396
self.max_features_ = None
@@ -237,6 +240,43 @@ def fit(self, X, y, sample_mask=None, X_argsorted=None, check_input=True,
237240
"number of samples=%d" %
238241
(len(sample_weight), n_samples))
239242

243+
n_features = X.shape[1]
244+
# We parse the argument categorical_features to a mask
245+
if not self.categorical_features or self.categorical_features == "None":
246+
categorical_mask = np.zeros(n_features, dtype=bool)
247+
has_categorical = True
248+
elif self.categorical_features == "all":
249+
categorical_mask = np.ones(n_features, dtype=bool)
250+
has_categorical = True
251+
else:
252+
try:
253+
self.categorical_features = list(self.categorical_features)
254+
except TypeError:
255+
raise ValueError("categorical_features not recognized. Must "
256+
"be 'None', 'all', a mask or a list")
257+
if len(self.categorical_features) == n_features:
258+
categorical_mask = self.categorical_features
259+
has_categorical = sum(self.categorical_features) > 0
260+
else:
261+
categorical_mask = np.zeros(n_features, dtype=bool)
262+
categorical_mask[np.asarray(self.categorical_features)] = True
263+
has_categorical = len(categorical_mask) > 0
264+
# We transform the categorical features to 0...n
265+
self.categorical_dicts = [
266+
dict((e, i) for (i, e) in enumerate(set(X[:, feature])))
267+
if categorical_mask[feature] else None
268+
for feature in xrange(n_features) ]
269+
if has_categorical:
270+
X = np.copy(X)
271+
for feature in xrange(n_features):
272+
if categorical_mask[feature]:
273+
hashing = self.categorical_dicts[feature]
274+
if len(hashing) > 32:
275+
raise ValueError(
276+
"Too many factors for feature {}. 32 maximum, "
277+
"found {}".format(feature, len(hashing)))
278+
X[:, feature] = [hashing[e] for e in X[:, feature]]
279+
240280
# Set min_weight_leaf from min_weight_fraction_leaf
241281
if self.min_weight_fraction_leaf != 0. and sample_weight is not None:
242282
min_weight_leaf = (self.min_weight_fraction_leaf *
@@ -418,6 +458,13 @@ class DecisionTreeClassifier(BaseDecisionTree, ClassifierMixin):
418458
If None then unlimited number of leaf nodes.
419459
If not None then ``max_depth`` will be ignored.
420460
461+
categorical_features: array of indices or mask
462+
Specify what features are treated as categorical.
463+
- 'None' (default): All features are treated as not categorical.
464+
- 'all': All features are treated as categorical.
465+
- array of indices: Array of categorical feature indices.
466+
- mask: Array of length n_features and with dtype=bool.
467+
421468
random_state : int, RandomState instance or None, optional (default=None)
422469
If int, random_state is the seed used by the random number generator;
423470
If RandomState instance, random_state is the random number generator;
@@ -489,7 +536,8 @@ def __init__(self,
489536
random_state=None,
490537
min_density=None,
491538
compute_importances=None,
492-
max_leaf_nodes=None):
539+
max_leaf_nodes=None,
540+
categorical_features=None):
493541
super(DecisionTreeClassifier, self).__init__(
494542
criterion=criterion,
495543
splitter=splitter,
@@ -499,7 +547,8 @@ def __init__(self,
499547
min_weight_fraction_leaf=min_weight_fraction_leaf,
500548
max_features=max_features,
501549
max_leaf_nodes=max_leaf_nodes,
502-
random_state=random_state)
550+
random_state=random_state,
551+
categorical_features=categorical_features)
503552

504553
if min_density is not None:
505554
warn("The min_density parameter is deprecated as of version 0.14 "
@@ -641,6 +690,13 @@ class DecisionTreeRegressor(BaseDecisionTree, RegressorMixin):
641690
If None then unlimited number of leaf nodes.
642691
If not None then ``max_depth`` will be ignored.
643692
693+
categorical_features: array of indices or mask
694+
Specify what features are treated as categorical.
695+
- 'None' (default): All features are treated as not categorical.
696+
- 'all': All features are treated as categorical.
697+
- array of indices: Array of categorical feature indices.
698+
- mask: Array of length n_features and with dtype=bool.
699+
644700
random_state : int, RandomState instance or None, optional (default=None)
645701
If int, random_state is the seed used by the random number generator;
646702
If RandomState instance, random_state is the random number generator;
@@ -704,7 +760,8 @@ def __init__(self,
704760
random_state=None,
705761
min_density=None,
706762
compute_importances=None,
707-
max_leaf_nodes=None):
763+
max_leaf_nodes=None,
764+
categorical_features=None):
708765
super(DecisionTreeRegressor, self).__init__(
709766
criterion=criterion,
710767
splitter=splitter,
@@ -714,7 +771,8 @@ def __init__(self,
714771
min_weight_fraction_leaf=min_weight_fraction_leaf,
715772
max_features=max_features,
716773
max_leaf_nodes=max_leaf_nodes,
717-
random_state=random_state)
774+
random_state=random_state,
775+
categorical_features=categorical_features)
718776

719777
if min_density is not None:
720778
warn("The min_density parameter is deprecated as of version 0.14 "

0 commit comments

Comments
 (0)