Skip to content

MAINT Use newest NumPy C API in tree._criterion #25615

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@
"sklearn.svm._libsvm",
"sklearn.svm._libsvm_sparse",
"sklearn.svm._newrand",
"sklearn.tree._criterion",
"sklearn.tree._splitter",
"sklearn.tree._tree",
"sklearn.tree._utils",
Expand Down
28 changes: 15 additions & 13 deletions sklearn/tree/_criterion.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -889,6 +889,8 @@ cdef class MAE(RegressionCriterion):

cdef cnp.ndarray left_child
cdef cnp.ndarray right_child
cdef void** left_child_ptr
cdef void** right_child_ptr
cdef DOUBLE_t[::1] node_medians

def __cinit__(self, SIZE_t n_outputs, SIZE_t n_samples):
Expand Down Expand Up @@ -923,6 +925,9 @@ cdef class MAE(RegressionCriterion):
self.left_child[k] = WeightedMedianCalculator(n_samples)
self.right_child[k] = WeightedMedianCalculator(n_samples)

self.left_child_ptr = <void**> cnp.PyArray_DATA(self.left_child)
self.right_child_ptr = <void**> cnp.PyArray_DATA(self.right_child)

cdef int init(
self,
const DOUBLE_t[:, ::1] y,
Expand Down Expand Up @@ -950,11 +955,8 @@ cdef class MAE(RegressionCriterion):
self.weighted_n_samples = weighted_n_samples
self.weighted_n_node_samples = 0.

cdef void** left_child
cdef void** right_child

left_child = <void**> self.left_child.data
right_child = <void**> self.right_child.data
cdef void** left_child = self.left_child_ptr
cdef void** right_child = self.right_child_ptr

for k in range(self.n_outputs):
(<WeightedMedianCalculator> left_child[k]).reset()
Expand Down Expand Up @@ -991,8 +993,8 @@ cdef class MAE(RegressionCriterion):
cdef DOUBLE_t value
cdef DOUBLE_t weight

cdef void** left_child = <void**> self.left_child.data
cdef void** right_child = <void**> self.right_child.data
cdef void** left_child = self.left_child_ptr
cdef void** right_child = self.right_child_ptr

self.weighted_n_left = 0.0
self.weighted_n_right = self.weighted_n_node_samples
Expand Down Expand Up @@ -1024,8 +1026,8 @@ cdef class MAE(RegressionCriterion):

cdef DOUBLE_t value
cdef DOUBLE_t weight
cdef void** left_child = <void**> self.left_child.data
cdef void** right_child = <void**> self.right_child.data
cdef void** left_child = self.left_child_ptr
cdef void** right_child = self.right_child_ptr

# reverse reset the WeightedMedianCalculators, right should have no
# elements and left should have all elements.
Expand All @@ -1049,8 +1051,8 @@ cdef class MAE(RegressionCriterion):
cdef const DOUBLE_t[:] sample_weight = self.sample_weight
cdef const SIZE_t[:] sample_indices = self.sample_indices

cdef void** left_child = <void**> self.left_child.data
cdef void** right_child = <void**> self.right_child.data
cdef void** left_child = self.left_child_ptr
cdef void** right_child = self.right_child_ptr

cdef SIZE_t pos = self.pos
cdef SIZE_t end = self.end
Expand Down Expand Up @@ -1147,8 +1149,8 @@ cdef class MAE(RegressionCriterion):
cdef DOUBLE_t impurity_left = 0.0
cdef DOUBLE_t impurity_right = 0.0

cdef void** left_child = <void**> self.left_child.data
cdef void** right_child = <void**> self.right_child.data
cdef void** left_child = self.left_child_ptr
cdef void** right_child = self.right_child_ptr

for k in range(self.n_outputs):
median = (<WeightedMedianCalculator> left_child[k]).get_median()
Expand Down