Skip to content

Commit ce2376e

Browse files
lestevejeremiedbb
authored andcommitted
FIX Fix ExtraTreeRegressor missing data handling (#30318)
1 parent cce5062 commit ce2376e

File tree

4 files changed

+23
-6
lines changed

4 files changed

+23
-6
lines changed

doc/whats_new/upcoming_changes/sklearn.tree/27966.feature.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
support missing-values in the data matrix ``X``. Missing-values are handled by
33
randomly moving all of the samples to the left, or right child node as the tree is
44
traversed.
5-
By :user:`Adam Li <adam2392>`
5+
By :user:`Adam Li <adam2392>` and :user:`Loïc Estève <lesteve>`
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
- :class:`tree.ExtraTreeClassifier` and :class:`tree.ExtraTreeRegressor` now
2+
support missing-values in the data matrix ``X``. Missing-values are handled by
3+
randomly moving all of the samples to the left, or right child node as the tree is
4+
traversed.
5+
By :user:`Adam Li <adam2392>` and :user:`Loïc Estève <lesteve>`

sklearn/tree/_partitioner.pyx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ cdef class DensePartitioner:
194194
"""Partition samples for feature_values at the current_threshold."""
195195
cdef:
196196
intp_t p = self.start
197-
intp_t partition_end = self.end
197+
intp_t partition_end = self.end - self.n_missing
198198
intp_t[::1] samples = self.samples
199199
float32_t[::1] feature_values = self.feature_values
200200

sklearn/tree/tests/test_tree.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2689,10 +2689,8 @@ def test_regression_tree_missing_values_toy(Tree, X, criterion):
26892689
impurity = tree.tree_.impurity
26902690
assert all(impurity >= 0), impurity.min() # MSE should always be positive
26912691

2692-
# Note: the impurity matches after the first split only on greedy trees
2693-
if Tree is DecisionTreeRegressor:
2694-
# Check the impurity match after the first split
2695-
assert_allclose(tree.tree_.impurity[:2], tree_ref.tree_.impurity[:2])
2692+
# Check the impurity match after the first split
2693+
assert_allclose(tree.tree_.impurity[:2], tree_ref.tree_.impurity[:2])
26962694

26972695
# Find the leaves with a single sample where the MSE should be 0
26982696
leaves_idx = np.flatnonzero(
@@ -2701,6 +2699,20 @@ def test_regression_tree_missing_values_toy(Tree, X, criterion):
27012699
assert_allclose(tree.tree_.impurity[leaves_idx], 0.0)
27022700

27032701

2702+
def test_regression_extra_tree_missing_values_toy(global_random_seed):
2703+
rng = np.random.RandomState(global_random_seed)
2704+
n_samples = 100
2705+
X = np.arange(n_samples, dtype=np.float64).reshape(-1, 1)
2706+
X[-20:, :] = np.nan
2707+
rng.shuffle(X)
2708+
y = np.arange(n_samples)
2709+
2710+
tree = ExtraTreeRegressor(random_state=global_random_seed, max_depth=5).fit(X, y)
2711+
2712+
impurity = tree.tree_.impurity
2713+
assert all(impurity >= 0), impurity # MSE should always be positive
2714+
2715+
27042716
def test_classification_tree_missing_values_toy():
27052717
"""Check that we properly handle missing values in clasification trees using a toy
27062718
dataset.

0 commit comments

Comments
 (0)