Skip to content

FEA Categorical split support for DecisionTree*, ExtraTree*, RandomForest* and `ExtraTrees* #29437

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 32 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
5477137
WIP
adam2392 Jul 8, 2024
4eb5ab5
Still working unit tests on the trees while compiling works
adam2392 Jul 8, 2024
df790d9
Working commits and basic skeleton is laid out now
adam2392 Jul 8, 2024
c35bcbf
Working compilation
adam2392 Jul 8, 2024
4cb02f1
Isort the cimport statements
adam2392 Jul 8, 2024
6a3943a
Working prototype without categorical splitting
adam2392 Jul 9, 2024
a60c212
Adding now breiman shortcut
adam2392 Jul 9, 2024
85b3c40
Merge branch 'main' into nocats
adam2392 Jul 10, 2024
ea6ea00
Fix lint
adam2392 Jul 10, 2024
42b6427
Maintainencne show
adam2392 Jul 10, 2024
d533205
Include in setup
adam2392 Jul 10, 2024
454c535
Fix lint
adam2392 Jul 10, 2024
b2fe94e
Try some asv stuff
adam2392 Jul 10, 2024
bee1876
Update benchmarking
adam2392 Jul 11, 2024
7352549
Merging main
adam2392 Jul 12, 2024
1d764ac
Merge branch 'main' into partition
adam2392 Jul 13, 2024
dec2017
Add file headers
adam2392 Jul 15, 2024
7aa1b1e
Merge branch 'main' into partition
adam2392 Jul 15, 2024
3c8f08b
Specialize the interface
adam2392 Jul 16, 2024
a7504f1
Merge branch 'partition' of https://github.com/adam2392/scikit-learn …
adam2392 Jul 16, 2024
2ac5277
Merge branch 'main' into partition
adam2392 Jul 16, 2024
fedc609
Merge branch 'main' into partition
adam2392 Jul 16, 2024
929035e
Reverse diff
adam2392 Jul 16, 2024
75158f2
Reverse diff
adam2392 Jul 16, 2024
7df18bf
Merging
adam2392 Jul 16, 2024
af49fc0
Re-order
adam2392 Jul 17, 2024
38bb5f2
Working refactor
adam2392 Jul 17, 2024
f4a4a10
Merge branch 'main' into nocats
adam2392 Jul 17, 2024
af9be58
Significant regression on sparse... wip
adam2392 Jul 17, 2024
575ee89
Finish merge
adam2392 Jul 21, 2024
fcebee9
Merging main
adam2392 Aug 21, 2024
83aed27
Merging -s
adam2392 Mar 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions asv_benchmarks/asv.conf.json
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

// List of branches to benchmark. If not provided, defaults to "master
// (for git) or "default" (for mercurial).
"branches": ["main"],
// "branches": ["main"],
// "branches": ["default"], // for mercurial

// The DVCS being used. If not set, it will be automatically
Expand All @@ -40,7 +40,7 @@
// If missing or the empty string, the tool will be automatically
// determined by looking for tools on the PATH environment
// variable.
"environment_type": "conda",
"environment_type": "mamba",

// timeout in seconds for installing any dependencies in environment
// defaults to 10 min
Expand Down
11 changes: 11 additions & 0 deletions doc/whats_new/v1.6.rst
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,12 @@ Python and CPython ecosystem, for example :user:`Nathan Goldbaum <ngoldbaum>`,
`clone(frozenestimator)` returns the same estimator instead of an unfitted clone.
:pr:`29705` By `Adrin Jalali`_ :pr:`29705`

- |MajorFeature| :class:`ensemble.RandomForestClassifier`, :class:`ensemble.RandomForestRegressor`,
:class:`ensemble.ExtraTreesClassifier`, and :class:`ensemble.ExtraTreesRegressor`
now support splitting on categorical features using the `categorical_features`
parameter.
:pr:`29437` by :user:`Adam Li <adam2392>`.

:mod:`sklearn.impute`
---------------------

Expand Down Expand Up @@ -675,6 +681,11 @@ Python and CPython ecosystem, for example :user:`Nathan Goldbaum <ngoldbaum>`,
format.
By :user:`Santiago M. Mola <smola>`. :pr:`17575`

- |Enhancement| :class:`tree.DecisionTreeClassifier`, :class:`tree.DecisionTreeRegressor`,
:class:`tree.ExtraTreeClassifier`, and :class:`tree.ExtraTreeRegressor` now support
splitting on categorical features using the `categorical_features` parameter.
:pr:`29437` by :user:`Adam Li <adam2392>`.

:mod:`sklearn.utils`
--------------------

Expand Down
172 changes: 164 additions & 8 deletions sklearn/tree/_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from abc import ABCMeta, abstractmethod
from math import ceil
from numbers import Integral, Real
from warnings import warn

import numpy as np
from scipy.sparse import issparse
Expand Down Expand Up @@ -76,7 +77,10 @@
"poisson": _criterion.Poisson,
}

DENSE_SPLITTERS = {"best": _splitter.BestSplitter, "random": _splitter.RandomSplitter}
DENSE_SPLITTERS = {
"best": _splitter.BestSplitter,
"random": _splitter.RandomSplitter,
}

SPARSE_SPLITTERS = {
"best": _splitter.BestSparseSplitter,
Expand Down Expand Up @@ -122,6 +126,11 @@ class BaseDecisionTree(MultiOutputMixin, BaseEstimator, metaclass=ABCMeta):
"min_impurity_decrease": [Interval(Real, 0.0, None, closed="left")],
"ccp_alpha": [Interval(Real, 0.0, None, closed="left")],
"monotonic_cst": ["array-like", None],
"categorical_features": [
"array-like",
StrOptions({"from_dtype"}),
None,
],
}

@abstractmethod
Expand All @@ -141,6 +150,7 @@ def __init__(
class_weight=None,
ccp_alpha=0.0,
monotonic_cst=None,
categorical_features=None,
):
self.criterion = criterion
self.splitter = splitter
Expand All @@ -155,6 +165,7 @@ def __init__(
self.class_weight = class_weight
self.ccp_alpha = ccp_alpha
self.monotonic_cst = monotonic_cst
self.categorical_features = categorical_features

def get_depth(self):
"""Return the depth of the decision tree.
Expand Down Expand Up @@ -256,13 +267,19 @@ def _fit(
missing_values_in_feature_mask = (
self._compute_missing_values_in_feature_mask(X)
)
if issparse(X):

is_sparse_X = issparse(X)
if is_sparse_X:
X.sort_indices()

if X.indices.dtype != np.intc or X.indptr.dtype != np.intc:
raise ValueError(
"No support for np.int64 index based sparse matrices"
)
if is_sparse_X and self.categorical_features is not None:
raise NotImplementedError(
"Categorical features not supported" " with sparse inputs"
)

if self.criterion == "poisson":
if np.any(y < 0):
Expand Down Expand Up @@ -366,6 +383,78 @@ def _fit(
else:
sample_weight = expanded_class_weight

# Validate categorical features
if isinstance(self.categorical_features, str):
if self.categorical_features == "all":
categorical = np.arange(self.n_features_in_)
else:
raise ValueError(
"Invalid value for categorical: {}. Allowed"
" strings are 'all' or 'none'"
"".format(self.categorical_features)
)
elif self.categorical_features is None:
categorical = np.array([], dtype=np.intp)
else:
categorical = np.atleast_1d(self.categorical_features).flatten()
print(categorical)
if categorical.dtype == np.bool:
if categorical.size != self.n_features_:
raise ValueError(
"Invalid value for categorical: Shape of "
"boolean parameter categorical must "
"be (n_features,)"
)
categorical = np.nonzero(categorical)[0]
print(categorical)
if np.size(categorical) > self.n_features_in_ or (
categorical.size > 0
and (categorical.min() < 0 or categorical.max() >= self.n_features_in_)
):
raise ValueError(
"Invalid value for categorical: Invalid shape or "
"feature index for parameter categorical "
"invalid."
)
if issparse(X):
if categorical.size > 0:
raise NotImplementedError(
"Categorical features not supported" " with sparse inputs"
)
else:
if np.any(X[:, categorical].astype(np.intp) < 0):
raise ValueError(
"Invalid value for categorical: given values "
"for categorical features must be "
"non-negative."
)

# Calculate n_categories and verify they are all at least 1% populated
n_categories = np.array(
[
np.intp(X[:, i].max()) + 1 if i in categorical else -1
for i in range(self.n_features_in_)
],
dtype=np.int32,
)
n_cat_present = np.array(
[
np.unique(X[:, i].astype(np.intp)).size if i in categorical else -1
for i in range(self.n_features_in_)
],
dtype=np.int32,
)
if np.any((n_cat_present < 0.01 * n_cat_present)[categorical]):
warn(
"At least one categorical feature has less than 1%"
" of its categories present in the sample. Runtime"
" and memory usage will be much smaller if you"
" represent the categories as sequential integers.",
UserWarning,
)

self.n_categories_ = n_categories

# Set min_weight_leaf from min_weight_fraction_leaf
if sample_weight is None:
min_weight_leaf = self.min_weight_fraction_leaf * n_samples
Expand All @@ -386,9 +475,17 @@ def _fit(
# might be shared and modified concurrently during parallel fitting
criterion = copy.deepcopy(criterion)

SPLITTERS = SPARSE_SPLITTERS if issparse(X) else DENSE_SPLITTERS
# We automatically use the Breiman shortcut for categorical splits
# if the criterion is Gini or Entropy and the tree is a binary classifier,
# or if the criterion is MSE and the tree is a regression tree.
if is_classification:
breiman_shortcut = self.n_classes_.tolist() == [2] and (
isinstance(criterion, _criterion.Gini)
or isinstance(criterion, _criterion.Entropy)
)
else:
breiman_shortcut = isinstance(criterion, _criterion.MSE)

splitter = self.splitter
if self.monotonic_cst is None:
monotonic_cst = None
else:
Expand Down Expand Up @@ -428,24 +525,41 @@ def _fit(
# *positive class*, all signs must be flipped.
monotonic_cst *= -1

SPLITTERS = SPARSE_SPLITTERS if issparse(X) else DENSE_SPLITTERS
splitter = self.splitter
if not isinstance(self.splitter, Splitter):
# Note: random splitter does not use breiman shortcut
splitter = SPLITTERS[self.splitter](
criterion,
self.max_features_,
min_samples_leaf,
min_weight_leaf,
random_state,
monotonic_cst,
breiman_shortcut,
)

if (
not isinstance(splitter, _splitter.RandomSplitter)
and np.max(n_categories) > 64
):
raise ValueError(
"Categorical features with greater than 64"
" categories not supported with DecisionTree;"
" try ExtraTree."
)

if is_classifier(self):
self.tree_ = Tree(self.n_features_in_, self.n_classes_, self.n_outputs_)
self.tree_ = Tree(
self.n_features_in_, self.n_classes_, self.n_outputs_, n_categories
)
else:
self.tree_ = Tree(
self.n_features_in_,
# TODO: tree shouldn't need this in this case
np.array([1] * self.n_outputs_, dtype=np.intp),
self.n_outputs_,
n_categories,
)

# Use BestFirst if max_leaf_nodes given; use DepthFirst otherwise
Expand All @@ -469,7 +583,14 @@ def _fit(
self.min_impurity_decrease,
)

builder.build(self.tree_, X, y, sample_weight, missing_values_in_feature_mask)
builder.build(
self.tree_,
X,
y,
sample_weight,
missing_values_in_feature_mask,
n_categories,
)

if self.n_outputs_ == 1 and is_classifier(self):
self.n_classes_ = self.n_classes_[0]
Expand All @@ -494,10 +615,15 @@ def _validate_X_predict(self, X, check_input):
reset=False,
ensure_all_finite=ensure_all_finite,
)
if issparse(X) and (
is_sparse_X = issparse(X)
if is_sparse_X and (
X.indices.dtype != np.intc or X.indptr.dtype != np.intc
):
raise ValueError("No support for np.int64 index based sparse matrices")
if is_sparse_X and np.any(self.n_categories_ > 0):
raise NotImplementedError(
"Categorical features not supported" " with sparse inputs"
)
else:
# The number of features is checked regardless of `check_input`
_check_n_features(self, X, reset=False)
Expand Down Expand Up @@ -617,13 +743,16 @@ def _prune_tree(self):
# build pruned tree
if is_classifier(self):
n_classes = np.atleast_1d(self.n_classes_)
pruned_tree = Tree(self.n_features_in_, n_classes, self.n_outputs_)
pruned_tree = Tree(
self.n_features_in_, n_classes, self.n_outputs_, self.n_categories_
)
else:
pruned_tree = Tree(
self.n_features_in_,
# TODO: the tree shouldn't need this param
np.array([1] * self.n_outputs_, dtype=np.intp),
self.n_outputs_,
self.n_categories_,
)
_build_pruned_tree_ccp(pruned_tree, self.tree_, self.ccp_alpha)

Expand Down Expand Up @@ -856,6 +985,25 @@ class DecisionTreeClassifier(ClassifierMixin, BaseDecisionTree):

.. versionadded:: 1.4

categorical_features : array-like of bool of shape (n_features) or \
array-like of int of shape (n_categorical_features), \
default=None
Array of feature indices, boolean array of length n_features,
``'all'`` or ``'none'``. Indicates which features should be
considered as categorical rather than ordinal. For decision trees,
the maximum number of categories is 64. In practice, the limit will
often be lower because the process of searching for the best possible
split grows exponentially with the number of categories. However, a
shortcut due to Breiman (1984) is used when fitting data with binary
labels using the ``Gini`` or ``Entropy`` criteria. In this case,
the runtime is linear in the number of categories. Extra-random trees
have an upper limit of :math:`2^{31}` categories, and runtimes
linear in the number of categories.

Read more in the :ref:`User Guide <categorical_support_gbdt>`.

.. versionadded:: 1.6

Attributes
----------
classes_ : ndarray of shape (n_classes,) or list of ndarray
Expand Down Expand Up @@ -973,6 +1121,7 @@ def __init__(
class_weight=None,
ccp_alpha=0.0,
monotonic_cst=None,
categorical_features=None,
):
super().__init__(
criterion=criterion,
Expand All @@ -988,6 +1137,7 @@ def __init__(
min_impurity_decrease=min_impurity_decrease,
monotonic_cst=monotonic_cst,
ccp_alpha=ccp_alpha,
categorical_features=categorical_features,
)

@_fit_context(prefer_skip_nested_validation=True)
Expand Down Expand Up @@ -1355,6 +1505,7 @@ def __init__(
min_impurity_decrease=0.0,
ccp_alpha=0.0,
monotonic_cst=None,
categorical_features=None,
):
super().__init__(
criterion=criterion,
Expand All @@ -1369,6 +1520,7 @@ def __init__(
min_impurity_decrease=min_impurity_decrease,
ccp_alpha=ccp_alpha,
monotonic_cst=monotonic_cst,
categorical_features=categorical_features,
)

@_fit_context(prefer_skip_nested_validation=True)
Expand Down Expand Up @@ -1708,6 +1860,7 @@ def __init__(
class_weight=None,
ccp_alpha=0.0,
monotonic_cst=None,
categorical_features=None,
):
super().__init__(
criterion=criterion,
Expand All @@ -1723,6 +1876,7 @@ def __init__(
random_state=random_state,
ccp_alpha=ccp_alpha,
monotonic_cst=monotonic_cst,
categorical_features=categorical_features,
)

def __sklearn_tags__(self):
Expand Down Expand Up @@ -1968,6 +2122,7 @@ def __init__(
max_leaf_nodes=None,
ccp_alpha=0.0,
monotonic_cst=None,
categorical_features=None,
):
super().__init__(
criterion=criterion,
Expand All @@ -1982,6 +2137,7 @@ def __init__(
random_state=random_state,
ccp_alpha=ccp_alpha,
monotonic_cst=monotonic_cst,
categorical_features=categorical_features,
)

def __sklearn_tags__(self):
Expand Down
Loading
Loading