Skip to content

[MAINT] Improve extensibility of the Tree/Splitter code #22756

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions sklearn/tree/_splitter.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,42 @@ cdef class Splitter:
cdef void node_value(self, double* dest) nogil

cdef double node_impurity(self) nogil

cdef inline void sort(DTYPE_t* Xf, SIZE_t* samples, SIZE_t n) nogil
cdef inline void swap(DTYPE_t* Xf, SIZE_t* samples, SIZE_t i, SIZE_t j) nogil
cdef inline DTYPE_t median3(DTYPE_t* Xf, SIZE_t n) nogil
cdef void introsort(DTYPE_t* Xf, SIZE_t *samples, SIZE_t n, int maxd) nogil
cdef inline void sift_down(DTYPE_t* Xf, SIZE_t* samples,
SIZE_t start, SIZE_t end) nogil
cdef void heapsort(DTYPE_t* Xf, SIZE_t* samples, SIZE_t n) nogil
cdef int compare_SIZE_t(const void* a, const void* b) nogil
cdef inline void binary_search(INT32_t* sorted_array,
INT32_t start, INT32_t end,
SIZE_t value, SIZE_t* index,
INT32_t* new_start) nogil
cdef inline void extract_nnz_index_to_samples(INT32_t* X_indices,
DTYPE_t* X_data,
INT32_t indptr_start,
INT32_t indptr_end,
SIZE_t* samples,
SIZE_t start,
SIZE_t end,
SIZE_t* index_to_samples,
DTYPE_t* Xf,
SIZE_t* end_negative,
SIZE_t* start_positive) nogil
cdef inline void extract_nnz_binary_search(INT32_t* X_indices,
DTYPE_t* X_data,
INT32_t indptr_start,
INT32_t indptr_end,
SIZE_t* samples,
SIZE_t start,
SIZE_t end,
SIZE_t* index_to_samples,
DTYPE_t* Xf,
SIZE_t* end_negative,
SIZE_t* start_positive,
SIZE_t* sorted_samples,
bint* is_samples_sorted) nogil
cdef inline void sparse_swap(SIZE_t* index_to_samples, SIZE_t* samples,
SIZE_t pos_1, SIZE_t pos_2) nogil
6 changes: 5 additions & 1 deletion sklearn/tree/_tree.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,13 @@ cdef class Tree:

# Methods
cdef SIZE_t _add_node(self, SIZE_t parent, bint is_left, bint is_leaf,
SIZE_t feature, double threshold, double impurity,
SplitRecord split_node, double impurity,
SIZE_t n_node_samples,
double weighted_n_node_samples) nogil except -1
cdef int _set_node_values(self, SplitRecord split_node,
Node *node) nogil except -1
cdef DTYPE_t _compute_feature(self, const DTYPE_t[:] X_ndarray,
Node *node) nogil
cdef int _resize(self, SIZE_t capacity) nogil except -1
cdef int _resize_c(self, SIZE_t capacity=*) nogil except -1

Expand Down
50 changes: 41 additions & 9 deletions sklearn/tree/_tree.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -245,8 +245,8 @@ cdef class DepthFirstTreeBuilder(TreeBuilder):
(split.improvement + EPSILON <
min_impurity_decrease))

node_id = tree._add_node(parent, is_left, is_leaf, split.feature,
split.threshold, impurity, n_node_samples,
node_id = tree._add_node(parent, is_left, is_leaf, split,
impurity, n_node_samples,
weighted_n_node_samples)

if node_id == SIZE_MAX:
Expand Down Expand Up @@ -487,7 +487,7 @@ cdef class BestFirstTreeBuilder(TreeBuilder):
if parent != NULL
else _TREE_UNDEFINED,
is_left, is_leaf,
split.feature, split.threshold, impurity, n_node_samples,
split, impurity, n_node_samples,
weighted_n_node_samples)
if node_id == SIZE_MAX:
return -1
Expand Down Expand Up @@ -747,8 +747,27 @@ cdef class Tree:
self.capacity = capacity
return 0

cdef int _set_node_values(self, SplitRecord split_node,
Node *node) nogil except -1:
"""Set node data.
"""
node.feature = split_node.feature
node.threshold = split_node.threshold
return 1

cdef DTYPE_t _compute_feature(self, const DTYPE_t[:] X_ndarray,
Node *node) nogil:
"""Compute feature from a given data matrix, X.

In axis-aligned trees, this is simply the value in the column of X
for this specific feature.
"""
# the feature index
cdef DTYPE_t feature = X_ndarray[node.feature]
return feature

cdef SIZE_t _add_node(self, SIZE_t parent, bint is_left, bint is_leaf,
SIZE_t feature, double threshold, double impurity,
SplitRecord split_node, double impurity,
SIZE_t n_node_samples,
double weighted_n_node_samples) nogil except -1:
"""Add a node to the tree.
Expand Down Expand Up @@ -782,8 +801,9 @@ cdef class Tree:

else:
# left_child and right_child will be set later
node.feature = feature
node.threshold = threshold
if self._set_node_values(split_node, node) != 1:
with gil:
raise RuntimeError

self.node_count += 1

Expand Down Expand Up @@ -817,6 +837,7 @@ cdef class Tree:

# Extract input
cdef const DTYPE_t[:, :] X_ndarray = X
cdef const DTYPE_t[:] X_vector
cdef SIZE_t n_samples = X.shape[0]

# Initialize output
Expand All @@ -827,13 +848,19 @@ cdef class Tree:
cdef Node* node = NULL
cdef SIZE_t i = 0

# the feature index
cdef DOUBLE_t feature_value

with nogil:
for i in range(n_samples):
node = self.nodes
# While node not a leaf
while node.left_child != _TREE_LEAF:
# ... and node.right_child != _TREE_LEAF:
if X_ndarray[i, node.feature] <= node.threshold:
# compute the feature value to compare against threshold
X_vector = X_ndarray[i, :]
feature_value = self._compute_feature(X_vector, node)
if feature_value <= node.threshold:
node = &self.nodes[node.left_child]
else:
node = &self.nodes[node.right_child]
Expand Down Expand Up @@ -900,7 +927,6 @@ cdef class Tree:
# ... and node.right_child != _TREE_LEAF:
if feature_to_sample[node.feature] == i:
feature_value = X_sample[node.feature]

else:
feature_value = 0.

Expand Down Expand Up @@ -1741,6 +1767,8 @@ cdef _build_pruned_tree(
stack[BuildPrunedRecord] prune_stack
BuildPrunedRecord stack_record

SplitRecord split

with nogil:
# push root node onto stack
prune_stack.push({"start": 0, "depth": 0, "parent": _TREE_UNDEFINED, "is_left": 0})
Expand All @@ -1757,8 +1785,12 @@ cdef _build_pruned_tree(
is_leaf = leaves_in_subtree[orig_node_id]
node = &orig_tree.nodes[orig_node_id]

# redefine to a SplitRecord to pass into _add_node
split.feature = node.feature
split.threshold = node.threshold

new_node_id = tree._add_node(
parent, is_left, is_leaf, node.feature, node.threshold,
parent, is_left, is_leaf, split,
node.impurity, node.n_node_samples,
node.weighted_n_node_samples)

Expand Down