Skip to content

Commit 9784fec

Browse files
committed
The final stash
1 parent 921f725 commit 9784fec

File tree

6 files changed

+213
-122
lines changed

6 files changed

+213
-122
lines changed

sklearn/tree/_criterion.pxd

+6-1
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,13 @@ cdef class Criterion:
2929
cdef DOUBLE_t* sample_weight # Sample weights
3030

3131
cdef SIZE_t* samples # Sample indices in X, y
32+
cdef SIZE_t* missing_samples # Indices of missing-valued samples
3233
cdef SIZE_t start # samples[start:pos] are the samples in the left node
3334
cdef SIZE_t pos # samples[pos:end] are the samples in the right node
3435
cdef SIZE_t end
36+
cdef SIZE_t missing_direction # 0 Consider missing to be a part of left side of the split
37+
# 1 Consider missing to be a part of right side of the split
38+
# 2 Ingore missing value-ed samples.
3539

3640
cdef SIZE_t n_outputs # Number of outputs
3741
cdef SIZE_t n_node_samples # Number of samples in the node (end-start)
@@ -57,7 +61,8 @@ cdef class Criterion:
5761
SIZE_t end) nogil
5862
cdef void reset(self) nogil
5963
cdef void reverse_reset(self) nogil
60-
cdef void update(self, SIZE_t new_pos) nogil
64+
cdef void update_split_pos(self, SIZE_t new_pos) nogil
65+
cdef void update_missing_direction(self, SIZE_t missing_direction) nogil
6166
cdef double node_impurity(self) nogil
6267
cdef void children_impurity(self, double* impurity_left,
6368
double* impurity_right) nogil

sklearn/tree/_criterion.pyx

+20-19
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,8 @@ cdef class Criterion:
5353

5454
cdef void init(self, DOUBLE_t* y, SIZE_t y_stride, DOUBLE_t* sample_weight,
5555
double weighted_n_samples, SIZE_t* samples,
56-
SIZE_t start, SIZE_t end,
57-
SIZE_t start_missing, SIZE_t end_missing) nogil:
56+
SIZE_t* missing_samples, SIZE_t start, SIZE_t end,
57+
SIZE_t missing_direction) nogil:
5858
"""Placeholder for a method which will initialize the criterion.
5959
6060
Parameters
@@ -77,10 +77,12 @@ cdef class Criterion:
7777
The first non-missing-valued sample to be used on this node
7878
end: SIZE_t
7979
The last non-missing-valued sample used on this node
80-
start_missing: SIZE_t
81-
The first missing-valued sample to be used on this node
82-
end_missing: SIZE_t
83-
The last missing-valued sample used on this node
80+
missing_direction: SIZE_t
81+
The initial value of the direction in which the missing values
82+
must be sent to
83+
0 - To send the missing values left
84+
1 - To send the missing values right
85+
2 - To ignore the missing values
8486
8587
"""
8688

@@ -101,7 +103,7 @@ cdef class Criterion:
101103
"""
102104
pass
103105

104-
cdef void update(self, SIZE_t new_pos) nogil:
106+
cdef void update_split_pos(self, SIZE_t new_pos) nogil:
105107
"""Updated statistics by moving samples[pos:new_pos] to the left child.
106108
107109
This updates the collected statistics by moving samples[pos:new_pos]
@@ -116,18 +118,19 @@ cdef class Criterion:
116118

117119
pass
118120

119-
cdef void move_missing(self, bint direction) nogil:
121+
cdef void update_missing_direction(self, SIZE_t missing_direction) nogil:
120122
"""Updated statistics by moving the missing-valued samples to l/r.
121123
122124
This updates the collected statistics by moving the missing-valued
123-
samples (samples[start_missing:end_nonmissing]) to the direction as
124-
specified.
125+
samples to the direction as specified.
125126
126127
Parameters
127128
----------
128-
direction: bint
129-
0 (false) to move the missing-valued samples left.
130-
1 (true) to move the missing-valued samples right.
129+
missing_direction: SIZE_t
130+
Direction in which the missing values must be sent to
131+
0 - To send the missing values left
132+
1 - To send the missing values right
133+
2 - To ignore the missing values
131134
132135
"""
133136

@@ -254,12 +257,10 @@ cdef class ClassificationCriterion(Criterion):
254257
self.sample_weight = NULL
255258

256259
self.samples = NULL
257-
self.start_nonmissing = 0
258-
self.pos_nonmissing = 0
259-
self.end_nonmissing = 0
260-
261-
self.start_missing = 0
262-
self.end_missing = 0
260+
self.start = 0
261+
self.pos = 0
262+
self.end = 0
263+
self.missing_direction = 2 # Ignore the missing values by default
263264

264265
self.n_outputs = n_outputs
265266

sklearn/tree/_splitter.pxd

+6-8
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,11 @@ cdef struct SplitRecord:
2727
# i.e. count of samples below threshold for feature.
2828
# pos is >= end if the node is a leaf.
2929
double threshold # Threshold to split at.
30+
SIZE_t missing_direction # Direction in which the missing values should
31+
# be sent to
3032
double improvement # Impurity improvement given parent node.
3133
double impurity_left # Impurity of the left split.
3234
double impurity_right # Impurity of the right split.
33-
double send_missing_left # Whether to send the missing values left/right
3435

3536
cdef class Splitter:
3637
# The splitter searches in the input space for a feature and a threshold
@@ -48,11 +49,10 @@ cdef class Splitter:
4849
cdef UINT32_t rand_r_state # sklearn_rand_r random number state
4950

5051
cdef SIZE_t* samples # Sample indices in X, y
52+
cdef bint allow_missing # Whether to include the missing vals
5153
cdef SIZE_t* missing_samples # Sample indices with missing values
5254

5355
cdef SIZE_t n_samples # X.shape[0]
54-
# TODO selfnote we need n_missing?
55-
5656
cdef double weighted_n_samples # Weighted number of samples
5757

5858
cdef SIZE_t* features # Feature indices in X
@@ -65,11 +65,8 @@ cdef class Splitter:
6565
# for the non missing values
6666
cdef SIZE_t end # End pos for the current node
6767
# for the non missing values
68-
69-
cdef SIZE_t start_missing # Start position for the current node
70-
# for the missing values
71-
cdef SIZE_t end_missing # End pos for the current node
72-
# for the missing values
68+
cdef SIZE_t missing_direction # Direction in which the missing
69+
# values must be sent to
7370

7471
cdef bint presort # Whether to use presorting, only
7572
# allowed on dense data
@@ -104,6 +101,7 @@ cdef class Splitter:
104101
object missing_samples) except *
105102

106103
cdef void node_reset(self, SIZE_t start, SIZE_t end,
104+
SIZE_t missing_direction,
107105
double* weighted_n_node_samples) nogil
108106

109107
cdef void node_split(self,

0 commit comments

Comments
 (0)