Skip to content

ENH Adds missing value support for trees #23595

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 109 commits into from
May 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
109 commits
Select commit Hold shift + click to select a range
604ad0b
ENH Adds missing value support for trees
thomasjpfan Jun 12, 2022
f7cb581
DOC Adds whats new number
thomasjpfan Jun 12, 2022
db242b0
REV Reduce diff
thomasjpfan Jun 12, 2022
6bffbd0
DOC Adds whats new number
thomasjpfan Jun 12, 2022
fb90bcb
DOC Fixes docstring
thomasjpfan Jun 12, 2022
05799f7
CLN Less diff
thomasjpfan Jun 23, 2022
aaafafb
Merge remote-tracking branch 'upstream/main' into missing_values_tree…
thomasjpfan Jun 23, 2022
986cf4a
Merge remote-tracking branch 'upstream/main' into missing_values_tree…
thomasjpfan Sep 28, 2022
a3edade
CLN Refactor code into a shared function
thomasjpfan Sep 28, 2022
824c191
CLN Remove unneeded code
thomasjpfan Sep 28, 2022
0d3ea36
DOC Adds more comments about order
thomasjpfan Sep 28, 2022
f062fc0
Merge remote-tracking branch 'upstream/main' into missing_values_tree…
thomasjpfan Sep 28, 2022
576fdd3
FIX Fixes bagging tests
thomasjpfan Sep 28, 2022
ae8cce2
DOC Adds more comments to sorting code
thomasjpfan Sep 30, 2022
30084aa
DOC Update comments to be more clear
thomasjpfan Sep 30, 2022
13cd441
DOC Adds more comments explaining the logic
thomasjpfan Sep 30, 2022
ca9e27e
DOC Adds comments for the second time
thomasjpfan Sep 30, 2022
2eac79f
DOC Spelling mistake
thomasjpfan Sep 30, 2022
5a07c2c
DOC Clarify comment
thomasjpfan Sep 30, 2022
a42a2f6
DOC Clarify comment again
thomasjpfan Sep 30, 2022
3a05ce5
CLN Use current_end instead of j
thomasjpfan Oct 9, 2022
0167d32
Update doc/modules/tree.rst
thomasjpfan Oct 25, 2022
7c4cea2
DOC More details about splitter
thomasjpfan Oct 25, 2022
7db3826
Merge remote-tracking branch 'upstream/main' into missing_values_tree…
thomasjpfan Oct 25, 2022
dc65406
STY Word wrapping
thomasjpfan Oct 25, 2022
a420edc
Merge remote-tracking branch 'upstream/main' into missing_values_tree…
thomasjpfan Dec 7, 2022
cc1ca9f
FIX Fixes merge with main
thomasjpfan Dec 7, 2022
a9c9ead
DOC Adds more comments
thomasjpfan Dec 7, 2022
2560acc
DOC Use methods
thomasjpfan Dec 7, 2022
74b506d
Merge remote-tracking branch 'upstream/main' into missing_values_tree…
thomasjpfan Dec 9, 2022
ce9cce4
Merge remote-tracking branch 'upstream/main' into missing_values_tree…
thomasjpfan Dec 13, 2022
c607c5b
ENH Use memset directly
thomasjpfan Dec 13, 2022
282f4c4
ENH Init sum_missing only when required
thomasjpfan Dec 13, 2022
bccc133
CLN Address comments
thomasjpfan Dec 22, 2022
ffbaab6
CLN Remove mention of mae
thomasjpfan Dec 22, 2022
177ef51
DOC Update comment
thomasjpfan Dec 22, 2022
ed963f0
Merge remote-tracking branch 'upstream/main' into missing_values_tree…
thomasjpfan Jan 17, 2023
c530e61
ENH Simplify split
thomasjpfan Jan 17, 2023
e31563f
CLN Improves API for n_missing
thomasjpfan Jan 17, 2023
f5adadf
CLN Shift left
thomasjpfan Jan 17, 2023
0a8572f
DOC Adds comments on how splitting works
thomasjpfan Jan 17, 2023
d581246
CLN Simplify move sums
thomasjpfan Jan 17, 2023
94bc4b7
CLN Simplifying
thomasjpfan Jan 17, 2023
cf4bbdb
ENH Uses Cython to accerlation any_isnan
thomasjpfan Jan 17, 2023
865d5b1
Merge remote-tracking branch 'upstream/main' into missing_values_tree…
thomasjpfan Jan 17, 2023
1f91bf3
DOC Move to 1.3
thomasjpfan Jan 17, 2023
9a287b4
CLN Update docstrings
thomasjpfan Jan 17, 2023
22172a9
TST Use array
thomasjpfan Jan 17, 2023
0f6c09e
DOC Adds docstring regarding bagging estimator
thomasjpfan Jan 17, 2023
1a26458
CLN Improves variable names
thomasjpfan Jan 17, 2023
ad55632
CLN Better variables
thomasjpfan Jan 17, 2023
23164dc
CLN Move shifting missing values to a function call
thomasjpfan Jan 23, 2023
779a625
CLN Use variable for has_missing
thomasjpfan Jan 23, 2023
f822375
CLN Simplify loop
thomasjpfan Jan 23, 2023
93ce545
CLN Remove unused variable
thomasjpfan Jan 24, 2023
df7c2a1
Merge remote-tracking branch 'upstream/main' into missing_values_tree…
thomasjpfan Jan 30, 2023
2df3d95
CLN Improves logic for final partitioner
thomasjpfan Jan 30, 2023
6007a3a
TST Adds new check with missing values
thomasjpfan Jan 30, 2023
1a8616e
CI Fix tests
thomasjpfan Jan 30, 2023
2161ab5
Merge remote-tracking branch 'upstream/main' into missing_values_tree…
thomasjpfan Feb 6, 2023
c8cb9c6
TST Adds test on the predictive performance of missing values
thomasjpfan Feb 6, 2023
748d1da
Update doc/whats_new/v1.3.rst
thomasjpfan Feb 9, 2023
d025dcc
CLN Address comments
thomasjpfan Feb 9, 2023
72c9c80
CLN Address comments
thomasjpfan Feb 9, 2023
bbc7c86
Merge remote-tracking branch 'upstream/main' into missing_values_tree…
thomasjpfan Feb 9, 2023
c3512f3
FIX Fixes build
thomasjpfan Feb 11, 2023
12317f0
Merge remote-tracking branch 'upstream/main' into missing_values_tree…
thomasjpfan Feb 24, 2023
a0ae427
FIX Fixes merge bugs
thomasjpfan Feb 24, 2023
0ad1837
FIX Fix regression
thomasjpfan Feb 24, 2023
1b121b1
FIX Fix regression
thomasjpfan Feb 24, 2023
941b557
FIX Fixes merge issue
thomasjpfan Feb 25, 2023
ad5e198
FIX Fixes missing value algo
thomasjpfan Feb 25, 2023
0c01aeb
CLN Fix regression with criterion
thomasjpfan Mar 24, 2023
b390e25
Merge remote-tracking branch 'upstream/main' into missing_values_tree…
thomasjpfan Mar 24, 2023
93ecff0
CLN Place edge case into it's own if statement
thomasjpfan Mar 24, 2023
e2e4169
CLN Compute missing mask and pass around
thomasjpfan Mar 24, 2023
cbe6185
CLN Simplify code
thomasjpfan Mar 25, 2023
1a84a61
CLN Reduce diff
thomasjpfan Mar 25, 2023
2cb8f73
CLN Reduce diff
thomasjpfan Mar 25, 2023
f38fe48
CLN Address comments
thomasjpfan Mar 28, 2023
20c3f51
CLN Adjust tags
thomasjpfan Mar 28, 2023
76a5f5e
CLN Remove isanynan
thomasjpfan Mar 28, 2023
1bc7e46
Apply suggestions from code review
thomasjpfan Mar 28, 2023
7f399ed
DOC Update user guide for ties
thomasjpfan Mar 28, 2023
947aae5
CLN Rename to missing_mask_by_feature
thomasjpfan Mar 28, 2023
3d84774
CLN Remove unneeded comment
thomasjpfan Mar 28, 2023
91c58a2
DOC Update comment
thomasjpfan Mar 28, 2023
a23e2d7
Apply suggestions from code review
thomasjpfan Mar 28, 2023
23dea83
FIX Include missing import
thomasjpfan Mar 28, 2023
4f421a5
Merge remote-tracking branch 'upstream/main' into missing_values_tree…
thomasjpfan Mar 28, 2023
6aca4d0
CLN Address comments
thomasjpfan Mar 28, 2023
0e393e9
DOC Add doctests in the Missing Values Support section
jjerphan Mar 29, 2023
103cad2
fixup! DOC Add doctests in the Missing Values Support section
jjerphan Mar 29, 2023
3361c25
Update sklearn/tree/_classes.py
thomasjpfan Mar 29, 2023
57a6379
Apply suggestions from code review
thomasjpfan Mar 29, 2023
d94d650
STY Linting
thomasjpfan Mar 29, 2023
61d2fc7
CLN Strict check
thomasjpfan Mar 29, 2023
af4264f
DOC Adds docstring about missing values
thomasjpfan Mar 29, 2023
d12f336
CLN Improve logic for final split
thomasjpfan Mar 29, 2023
f39d63a
DOC Improve docstring
thomasjpfan Mar 29, 2023
fffc16c
Merge remote-tracking branch 'upstream/main' into missing_values_tree…
thomasjpfan Apr 7, 2023
4f1532e
Merge remote-tracking branch 'upstream/main' into missing_values_tree…
thomasjpfan Apr 21, 2023
40a39e4
CLN Removes variable name
thomasjpfan Apr 21, 2023
649f8c2
STY Cython lint
thomasjpfan Apr 23, 2023
98a1fe4
STY Consistent comment spacing
thomasjpfan Apr 26, 2023
e047412
Merge remote-tracking branch 'upstream/main' into missing_values_tree…
thomasjpfan Apr 26, 2023
042e86a
CLN Address comments
thomasjpfan May 3, 2023
5691677
DOC Adds more comments
thomasjpfan May 3, 2023
8c5755a
Merge remote-tracking branch 'upstream/main' into missing_values_tree…
thomasjpfan May 3, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions doc/modules/tree.rst
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,65 @@ Mean Absolute Error:

Note that it fits much slower than the MSE criterion.

.. _tree_missing_value_support:

Missing Values Support
======================

:class:`~tree.DecisionTreeClassifier` and :class:`~tree.DecisionTreeRegressor`
have built-in support for missing values when `splitter='best'` and criterion is
`'gini'`, `'entropy`', or `'log_loss'`, for classification or
`'squared_error'`, `'friedman_mse'`, or `'poisson'` for regression.

For each potential threshold on the non-missing data, the splitter will evaluate
the split with all the missing values going to the left node or the right node.

Decisions are made as follows:

- By default when predicting, the samples with missing values are classified
with the class used in the split found during training::

>>> from sklearn.tree import DecisionTreeClassifier
>>> import numpy as np

>>> X = np.array([0, 1, 6, np.nan]).reshape(-1, 1)
>>> y = [0, 0, 1, 1]

>>> tree = DecisionTreeClassifier(random_state=0).fit(X, y)
>>> tree.predict(X)
array([0, 0, 1, 1])

- If the the criterion evaluation is the same for both nodes,
then the tie for missing value at predict time is broken by going to the
right node. The splitter also checks the split where all the missing
values go to one child and non-missing values go to the other::

>>> from sklearn.tree import DecisionTreeClassifier
>>> import numpy as np

>>> X = np.array([np.nan, -1, np.nan, 1]).reshape(-1, 1)
>>> y = [0, 0, 1, 1]

>>> tree = DecisionTreeClassifier(random_state=0).fit(X, y)

>>> X_test = np.array([np.nan]).reshape(-1, 1)
>>> tree.predict(X_test)
array([1])

- If no missing values are seen during training for a given feature, then during
prediction missing values are mapped to the child with the most samples::

>>> from sklearn.tree import DecisionTreeClassifier
>>> import numpy as np

>>> X = np.array([0, 1, 2, 3]).reshape(-1, 1)
>>> y = [0, 1, 1, 1]

>>> tree = DecisionTreeClassifier(random_state=0).fit(X, y)

>>> X_test = np.array([np.nan]).reshape(-1, 1)
>>> tree.predict(X_test)
array([1])

.. _minimal_cost_complexity_pruning:

Expand Down
6 changes: 6 additions & 0 deletions doc/whats_new/v1.3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,12 @@ Changelog
:mod:`sklearn.tree`
...................

- |MajorFeature| :class:`tree.DecisionTreeRegressor` and
:class:`tree.DecisionTreeClassifier` support missing values when
`splitter='best'` and criterion is `gini`, `entropy`, or `log_loss`,
for classification or `squared_error`, `friedman_mse`, or `poisson`
for regression. :pr:`23595` by `Thomas Fan`_.

- |Enhancement| Adds a `class_names` parameter to
:func:`tree.export_text`. This allows specifying the parameter `class_names`
for each target class in ascending numerical order.
Expand Down
92 changes: 85 additions & 7 deletions sklearn/tree/_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
from ..utils import Bunch
from ..utils import check_random_state
from ..utils.validation import _check_sample_weight
from ..utils.validation import assert_all_finite
from ..utils.validation import _assert_all_finite_element_wise
from ..utils import compute_sample_weight
from ..utils.multiclass import check_classification_targets
from ..utils.validation import check_is_fitted
Expand All @@ -48,6 +50,7 @@
from ._tree import _build_pruned_tree_ccp
from ._tree import ccp_pruning_path
from . import _tree, _splitter, _criterion
from ._utils import _any_isnan_axis0

__all__ = [
"DecisionTreeClassifier",
Expand Down Expand Up @@ -174,19 +177,67 @@ def get_n_leaves(self):
check_is_fitted(self)
return self.tree_.n_leaves

def fit(self, X, y, sample_weight=None, check_input=True):
def _support_missing_values(self, X):
return not issparse(X) and self._get_tags()["allow_nan"]

def _compute_feature_has_missing(self, X):
"""Return boolean mask denoting if there are missing values for each feature.

This method also ensures that X is finite.

Parameter
---------
X : array-like of shape (n_samples, n_features), dtype=DOUBLE
Input data.

Returns
-------
feature_has_missing : ndarray of shape (n_features,), or None
Missing value mask. If missing values are not supported or there
are no missing values, return None.
"""
common_kwargs = dict(estimator_name=self.__class__.__name__, input_name="X")

if not self._support_missing_values(X):
assert_all_finite(X, **common_kwargs)
return None

with np.errstate(over="ignore"):
overall_sum = np.sum(X)

if not np.isfinite(overall_sum):
# Raise a ValueError in case of the presence of an infinite element.
_assert_all_finite_element_wise(X, xp=np, allow_nan=True, **common_kwargs)

# If the sum is not nan, then there are no missing values
if not np.isnan(overall_sum):
return None

feature_has_missing = _any_isnan_axis0(X)
return feature_has_missing

def _fit(
self, X, y, sample_weight=None, check_input=True, feature_has_missing=None
):
self._validate_params()
random_state = check_random_state(self.random_state)

if check_input:
# Need to validate separately here.
# We can't pass multi_output=True because that would allow y to be
# csr.
check_X_params = dict(dtype=DTYPE, accept_sparse="csc")

# _compute_feature_has_missing will check for finite values and
# compute the missing mask if the tree supports missing values
check_X_params = dict(
dtype=DTYPE, accept_sparse="csc", force_all_finite=False
)
check_y_params = dict(ensure_2d=False, dtype=None)
X, y = self._validate_data(
X, y, validate_separately=(check_X_params, check_y_params)
)

feature_has_missing = self._compute_feature_has_missing(X)
if issparse(X):
X.sort_indices()

Expand Down Expand Up @@ -381,7 +432,7 @@ def fit(self, X, y, sample_weight=None, check_input=True):
self.min_impurity_decrease,
)

builder.build(self.tree_, X, y, sample_weight)
builder.build(self.tree_, X, y, sample_weight, feature_has_missing)

if self.n_outputs_ == 1 and is_classifier(self):
self.n_classes_ = self.n_classes_[0]
Expand All @@ -394,7 +445,17 @@ def fit(self, X, y, sample_weight=None, check_input=True):
def _validate_X_predict(self, X, check_input):
"""Validate the training data on predict (probabilities)."""
if check_input:
X = self._validate_data(X, dtype=DTYPE, accept_sparse="csr", reset=False)
if self._support_missing_values(X):
force_all_finite = "allow-nan"
else:
force_all_finite = True
X = self._validate_data(
X,
dtype=DTYPE,
accept_sparse="csr",
reset=False,
force_all_finite=force_all_finite,
)
if issparse(X) and (
X.indices.dtype != np.intc or X.indptr.dtype != np.intc
):
Expand Down Expand Up @@ -886,7 +947,7 @@ def fit(self, X, y, sample_weight=None, check_input=True):
Fitted estimator.
"""

super().fit(
super()._fit(
X,
y,
sample_weight=sample_weight,
Expand Down Expand Up @@ -971,7 +1032,14 @@ def predict_log_proba(self, X):
return proba

def _more_tags(self):
return {"multilabel": True}
# XXX: nan is only support for dense arrays, but we set this for common test to
# pass, specifically: check_estimators_nan_inf
allow_nan = self.splitter == "best" and self.criterion in {
"gini",
"log_loss",
"entropy",
}
return {"multilabel": True, "allow_nan": allow_nan}


class DecisionTreeRegressor(RegressorMixin, BaseDecisionTree):
Expand Down Expand Up @@ -1239,7 +1307,7 @@ def fit(self, X, y, sample_weight=None, check_input=True):
Fitted estimator.
"""

super().fit(
super()._fit(
X,
y,
sample_weight=sample_weight,
Expand Down Expand Up @@ -1274,6 +1342,16 @@ def _compute_partial_dependence_recursion(self, grid, target_features):
)
return averaged_predictions

def _more_tags(self):
# XXX: nan is only support for dense arrays, but we set this for common test to
# pass, specifically: check_estimators_nan_inf
allow_nan = self.splitter == "best" and self.criterion in {
"squared_error",
"friedman_mse",
"poisson",
}
return {"allow_nan": allow_nan}


class ExtraTreeClassifier(DecisionTreeClassifier):
"""An extremely randomized tree classifier.
Expand Down
19 changes: 13 additions & 6 deletions sklearn/tree/_criterion.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ cdef class Criterion:
cdef SIZE_t start # samples[start:pos] are the samples in the left node
cdef SIZE_t pos # samples[pos:end] are the samples in the right node
cdef SIZE_t end
cdef SIZE_t n_missing # Number of missing values for the feature being evaluated
cdef bint missing_go_to_left # Whether missing values go to the left node

cdef SIZE_t n_outputs # Number of outputs
cdef SIZE_t n_samples # Number of samples
Expand All @@ -36,6 +38,7 @@ cdef class Criterion:
cdef double weighted_n_node_samples # Weighted number of samples in the node
cdef double weighted_n_left # Weighted number of samples in the left node
cdef double weighted_n_right # Weighted number of samples in the right node
cdef double weighted_n_missing # Weighted number of samples that are missing

# The criterion object is maintained such that left and right collected
# statistics correspond to samples[start:pos] and samples[pos:end].
Expand All @@ -50,6 +53,8 @@ cdef class Criterion:
SIZE_t start,
SIZE_t end
) except -1 nogil
cdef void init_sum_missing(self)
cdef void init_missing(self, SIZE_t n_missing) noexcept nogil
cdef int reset(self) except -1 nogil
cdef int reverse_reset(self) except -1 nogil
cdef int update(self, SIZE_t new_pos) except -1 nogil
Expand Down Expand Up @@ -77,15 +82,17 @@ cdef class ClassificationCriterion(Criterion):
cdef SIZE_t[::1] n_classes
cdef SIZE_t max_n_classes

cdef double[:, ::1] sum_total # The sum of the weighted count of each label.
cdef double[:, ::1] sum_left # Same as above, but for the left side of the split
cdef double[:, ::1] sum_right # Same as above, but for the right side of the split
cdef double[:, ::1] sum_total # The sum of the weighted count of each label.
cdef double[:, ::1] sum_left # Same as above, but for the left side of the split
cdef double[:, ::1] sum_right # Same as above, but for the right side of the split
cdef double[:, ::1] sum_missing # Same as above, but for missing values in X

cdef class RegressionCriterion(Criterion):
"""Abstract regression criterion."""

cdef double sq_sum_total

cdef double[::1] sum_total # The sum of w*y.
cdef double[::1] sum_left # Same as above, but for the left side of the split
cdef double[::1] sum_right # Same as above, but for the right side of the split
cdef double[::1] sum_total # The sum of w*y.
cdef double[::1] sum_left # Same as above, but for the left side of the split
cdef double[::1] sum_right # Same as above, but for the right side of the split
cdef double[::1] sum_missing # Same as above, but for missing values in X
Loading