From 8d2332b89f637d1e06ae1c0a5caff5ddd47d4991 Mon Sep 17 00:00:00 2001 From: Adam Li Date: Tue, 12 Sep 2023 15:05:28 -0400 Subject: [PATCH 1/2] Fix ctypedef types in tree submodule Signed-off-by: Adam Li --- sklearn/ensemble/_gradient_boosting.pyx | 42 +-- sklearn/tree/_criterion.pxd | 45 ++- sklearn/tree/_criterion.pyx | 332 ++++++++--------- sklearn/tree/_splitter.pxd | 49 ++- sklearn/tree/_splitter.pyx | 468 ++++++++++++------------ sklearn/tree/_tree.pxd | 60 ++- sklearn/tree/_tree.pyx | 322 ++++++++-------- sklearn/tree/_utils.pxd | 74 ++-- sklearn/tree/_utils.pyx | 76 ++-- 9 files changed, 726 insertions(+), 742 deletions(-) diff --git a/sklearn/ensemble/_gradient_boosting.pyx b/sklearn/ensemble/_gradient_boosting.pyx index b7846cc38f4af..c6b1819bf012d 100644 --- a/sklearn/ensemble/_gradient_boosting.pyx +++ b/sklearn/ensemble/_gradient_boosting.pyx @@ -11,11 +11,9 @@ cnp.import_array() from scipy.sparse import issparse +from ..utils._typedefs cimport float32_t, float64_t, intp_t, int32_t from ..tree._tree cimport Node from ..tree._tree cimport Tree -from ..tree._tree cimport DTYPE_t -from ..tree._tree cimport SIZE_t -from ..tree._tree cimport INT32_t from ..tree._utils cimport safe_realloc @@ -24,10 +22,10 @@ from numpy import zeros as np_zeros # constant to mark tree leafs -cdef SIZE_t TREE_LEAF = -1 +cdef intp_t TREE_LEAF = -1 cdef void _predict_regression_tree_inplace_fast_dense( - const DTYPE_t[:, ::1] X, + const float32_t[:, ::1] X, Node* root_node, double *value, double scale, @@ -45,7 +43,7 @@ cdef void _predict_regression_tree_inplace_fast_dense( Parameters ---------- - X : DTYPE_t 2d memory view + X : float32_t 2d memory view The memory view on the data ndarray of the input ``X``. Assumes that the array is c-continuous. root_node : tree Node pointer @@ -63,7 +61,7 @@ cdef void _predict_regression_tree_inplace_fast_dense( ``out`` is assumed to be a two-dimensional array of shape ``(n_samples, K)``. """ - cdef SIZE_t n_samples = X.shape[0] + cdef intp_t n_samples = X.shape[0] cdef Py_ssize_t i cdef Node *node for i in range(n_samples): @@ -87,20 +85,20 @@ def _predict_regression_tree_stages_sparse( The function assumes that the ndarray that wraps ``X`` is csr_matrix. """ - cdef const DTYPE_t[::1] X_data = X.data - cdef const INT32_t[::1] X_indices = X.indices - cdef const INT32_t[::1] X_indptr = X.indptr + cdef const float32_t[::1] X_data = X.data + cdef const int32_t[::1] X_indices = X.indices + cdef const int32_t[::1] X_indptr = X.indptr - cdef SIZE_t n_samples = X.shape[0] - cdef SIZE_t n_features = X.shape[1] - cdef SIZE_t n_stages = estimators.shape[0] - cdef SIZE_t n_outputs = estimators.shape[1] + cdef intp_t n_samples = X.shape[0] + cdef intp_t n_features = X.shape[1] + cdef intp_t n_stages = estimators.shape[0] + cdef intp_t n_outputs = estimators.shape[1] # Indices and temporary variables - cdef SIZE_t sample_i - cdef SIZE_t feature_i - cdef SIZE_t stage_i - cdef SIZE_t output_i + cdef intp_t sample_i + cdef intp_t feature_i + cdef intp_t stage_i + cdef intp_t output_i cdef Node *root_node = NULL cdef Node *node = NULL cdef double *value = NULL @@ -117,18 +115,18 @@ def _predict_regression_tree_stages_sparse( values[stage_i * n_outputs + output_i] = tree.value # Initialize auxiliary data-structure - cdef DTYPE_t feature_value = 0. - cdef DTYPE_t* X_sample = NULL + cdef float32_t feature_value = 0. + cdef float32_t* X_sample = NULL # feature_to_sample as a data structure records the last seen sample # for each feature; functionally, it is an efficient way to identify # which features are nonzero in the present sample. - cdef SIZE_t* feature_to_sample = NULL + cdef intp_t* feature_to_sample = NULL safe_realloc(&X_sample, n_features) safe_realloc(&feature_to_sample, n_features) - memset(feature_to_sample, -1, n_features * sizeof(SIZE_t)) + memset(feature_to_sample, -1, n_features * sizeof(intp_t)) # Cycle through all samples for sample_i in range(n_samples): diff --git a/sklearn/tree/_criterion.pxd b/sklearn/tree/_criterion.pxd index b765d324bebb9..16d8f35d44013 100644 --- a/sklearn/tree/_criterion.pxd +++ b/sklearn/tree/_criterion.pxd @@ -10,11 +10,8 @@ # See _criterion.pyx for implementation details. cimport numpy as cnp -from ._tree cimport DTYPE_t # Type of X -from ._tree cimport DOUBLE_t # Type of y, sample_weight -from ._tree cimport SIZE_t # Type for indices and counters -from ._tree cimport INT32_t # Signed 32 bit integer -from ._tree cimport UINT32_t # Unsigned 32 bit integer +from ..utils._typedefs cimport float64_t, intp_t + cdef class Criterion: # The criterion computes the impurity of a node and the reduction of @@ -22,19 +19,19 @@ cdef class Criterion: # such as the mean in regression and class probabilities in classification. # Internal structures - cdef const DOUBLE_t[:, ::1] y # Values of y - cdef const DOUBLE_t[:] sample_weight # Sample weights + cdef const float64_t[:, ::1] y # Values of y + cdef const float64_t[:] sample_weight # Sample weights - cdef const SIZE_t[:] sample_indices # Sample indices in X, y - cdef SIZE_t start # samples[start:pos] are the samples in the left node - cdef SIZE_t pos # samples[pos:end] are the samples in the right node - cdef SIZE_t end - cdef SIZE_t n_missing # Number of missing values for the feature being evaluated + cdef const intp_t[:] sample_indices # Sample indices in X, y + cdef intp_t start # samples[start:pos] are the samples in the left node + cdef intp_t pos # samples[pos:end] are the samples in the right node + cdef intp_t end + cdef intp_t n_missing # Number of missing values for the feature being evaluated cdef bint missing_go_to_left # Whether missing values go to the left node - cdef SIZE_t n_outputs # Number of outputs - cdef SIZE_t n_samples # Number of samples - cdef SIZE_t n_node_samples # Number of samples in the node (end-start) + cdef intp_t n_outputs # Number of outputs + cdef intp_t n_samples # Number of samples + cdef intp_t n_node_samples # Number of samples in the node (end-start) cdef double weighted_n_samples # Weighted number of samples (in total) cdef double weighted_n_node_samples # Weighted number of samples in the node cdef double weighted_n_left # Weighted number of samples in the left node @@ -47,18 +44,18 @@ cdef class Criterion: # Methods cdef int init( self, - const DOUBLE_t[:, ::1] y, - const DOUBLE_t[:] sample_weight, + const float64_t[:, ::1] y, + const float64_t[:] sample_weight, double weighted_n_samples, - const SIZE_t[:] sample_indices, - SIZE_t start, - SIZE_t end + const intp_t[:] sample_indices, + intp_t start, + intp_t end ) except -1 nogil cdef void init_sum_missing(self) - cdef void init_missing(self, SIZE_t n_missing) noexcept nogil + cdef void init_missing(self, intp_t n_missing) noexcept nogil cdef int reset(self) except -1 nogil cdef int reverse_reset(self) except -1 nogil - cdef int update(self, SIZE_t new_pos) except -1 nogil + cdef int update(self, intp_t new_pos) except -1 nogil cdef double node_impurity(self) noexcept nogil cdef void children_impurity( self, @@ -101,8 +98,8 @@ cdef class Criterion: cdef class ClassificationCriterion(Criterion): """Abstract criterion for classification.""" - cdef SIZE_t[::1] n_classes - cdef SIZE_t max_n_classes + cdef intp_t[::1] n_classes + cdef intp_t max_n_classes cdef double[:, ::1] sum_total # The sum of the weighted count of each label. cdef double[:, ::1] sum_left # Same as above, but for the left side of the split diff --git a/sklearn/tree/_criterion.pyx b/sklearn/tree/_criterion.pyx index ed8a12065554e..13d18888db79f 100644 --- a/sklearn/tree/_criterion.pyx +++ b/sklearn/tree/_criterion.pyx @@ -42,12 +42,12 @@ cdef class Criterion: cdef int init( self, - const DOUBLE_t[:, ::1] y, - const DOUBLE_t[:] sample_weight, + const float64_t[:, ::1] y, + const float64_t[:] sample_weight, double weighted_n_samples, - const SIZE_t[:] sample_indices, - SIZE_t start, - SIZE_t end, + const intp_t[:] sample_indices, + intp_t start, + intp_t end, ) except -1 nogil: """Placeholder for a method which will initialize the criterion. @@ -56,25 +56,25 @@ cdef class Criterion: Parameters ---------- - y : ndarray, dtype=DOUBLE_t + y : ndarray, dtype=float64_t y is a buffer that can store values for n_outputs target variables stored as a Cython memoryview. - sample_weight : ndarray, dtype=DOUBLE_t + sample_weight : ndarray, dtype=float64_t The weight of each sample stored as a Cython memoryview. weighted_n_samples : double The total weight of the samples being considered - sample_indices : ndarray, dtype=SIZE_t + sample_indices : ndarray, dtype=intp_t A mask on the samples. Indices of the samples in X and y we want to use, where sample_indices[start:end] correspond to the samples in this node. - start : SIZE_t + start : intp_t The first sample to be used on this node - end : SIZE_t + end : intp_t The last sample used on this node """ pass - cdef void init_missing(self, SIZE_t n_missing) noexcept nogil: + cdef void init_missing(self, intp_t n_missing) noexcept nogil: """Initialize sum_missing if there are missing values. This method assumes that caller placed the missing samples in @@ -82,7 +82,7 @@ cdef class Criterion: Parameters ---------- - n_missing: SIZE_t + n_missing: intp_t Number of missing values for specific feature. """ pass @@ -101,7 +101,7 @@ cdef class Criterion: """ pass - cdef int update(self, SIZE_t new_pos) except -1 nogil: + cdef int update(self, intp_t new_pos) except -1 nogil: """Updated statistics by moving sample_indices[pos:new_pos] to the left child. This updates the collected statistics by moving sample_indices[pos:new_pos] @@ -110,7 +110,7 @@ cdef class Criterion: Parameters ---------- - new_pos : SIZE_t + new_pos : intp_t New starting index position of the sample_indices in the right child """ pass @@ -273,7 +273,7 @@ cdef inline void _move_sums_classification( sum_1 = 0 sum_2 = sum_total """ - cdef SIZE_t k, c, n_bytes + cdef intp_t k, c, n_bytes if criterion.n_missing != 0 and put_missing_in_1: for k in range(criterion.n_outputs): n_bytes = criterion.n_classes[k] * sizeof(double) @@ -299,15 +299,15 @@ cdef inline void _move_sums_classification( cdef class ClassificationCriterion(Criterion): """Abstract criterion for classification.""" - def __cinit__(self, SIZE_t n_outputs, - cnp.ndarray[SIZE_t, ndim=1] n_classes): + def __cinit__(self, intp_t n_outputs, + cnp.ndarray[intp_t, ndim=1] n_classes): """Initialize attributes for this criterion. Parameters ---------- - n_outputs : SIZE_t + n_outputs : intp_t The number of targets, the dimensionality of the prediction - n_classes : numpy.ndarray, dtype=SIZE_t + n_classes : numpy.ndarray, dtype=intp_t The number of unique classes in each target """ self.start = 0 @@ -325,8 +325,8 @@ cdef class ClassificationCriterion(Criterion): self.n_classes = np.empty(n_outputs, dtype=np.intp) - cdef SIZE_t k = 0 - cdef SIZE_t max_n_classes = 0 + cdef intp_t k = 0 + cdef intp_t max_n_classes = 0 # For each target, set the number of unique classes in that target, # and also compute the maximal stride of all targets @@ -349,12 +349,12 @@ cdef class ClassificationCriterion(Criterion): cdef int init( self, - const DOUBLE_t[:, ::1] y, - const DOUBLE_t[:] sample_weight, + const float64_t[:, ::1] y, + const float64_t[:] sample_weight, double weighted_n_samples, - const SIZE_t[:] sample_indices, - SIZE_t start, - SIZE_t end + const intp_t[:] sample_indices, + intp_t start, + intp_t end ) except -1 nogil: """Initialize the criterion. @@ -366,18 +366,18 @@ cdef class ClassificationCriterion(Criterion): Parameters ---------- - y : ndarray, dtype=DOUBLE_t + y : ndarray, dtype=float64_t The target stored as a buffer for memory efficiency. - sample_weight : ndarray, dtype=DOUBLE_t + sample_weight : ndarray, dtype=float64_t The weight of each sample stored as a Cython memoryview. weighted_n_samples : double The total weight of all samples - sample_indices : ndarray, dtype=SIZE_t + sample_indices : ndarray, dtype=intp_t A mask on the samples. Indices of the samples in X and y we want to use, where sample_indices[start:end] correspond to the samples in this node. - start : SIZE_t + start : intp_t The first sample to use in the mask - end : SIZE_t + end : intp_t The last sample to use in the mask """ self.y = y @@ -389,11 +389,11 @@ cdef class ClassificationCriterion(Criterion): self.weighted_n_samples = weighted_n_samples self.weighted_n_node_samples = 0.0 - cdef SIZE_t i - cdef SIZE_t p - cdef SIZE_t k - cdef SIZE_t c - cdef DOUBLE_t w = 1.0 + cdef intp_t i + cdef intp_t p + cdef intp_t k + cdef intp_t c + cdef float64_t w = 1.0 for k in range(self.n_outputs): memset(&self.sum_total[k, 0], 0, self.n_classes[k] * sizeof(double)) @@ -408,7 +408,7 @@ cdef class ClassificationCriterion(Criterion): # Count weighted class frequency for each target for k in range(self.n_outputs): - c = self.y[i, k] + c = self.y[i, k] self.sum_total[k, c] += w self.weighted_n_node_samples += w @@ -421,14 +421,14 @@ cdef class ClassificationCriterion(Criterion): """Init sum_missing to hold sums for missing values.""" self.sum_missing = np.zeros((self.n_outputs, self.max_n_classes), dtype=np.float64) - cdef void init_missing(self, SIZE_t n_missing) noexcept nogil: + cdef void init_missing(self, intp_t n_missing) noexcept nogil: """Initialize sum_missing if there are missing values. This method assumes that caller placed the missing samples in self.sample_indices[-n_missing:] """ - cdef SIZE_t i, p, k, c - cdef DOUBLE_t w = 1.0 + cdef intp_t i, p, k, c + cdef float64_t w = 1.0 self.n_missing = n_missing if n_missing == 0: @@ -445,7 +445,7 @@ cdef class ClassificationCriterion(Criterion): w = self.sample_weight[i] for k in range(self.n_outputs): - c = self.y[i, k] + c = self.y[i, k] self.sum_missing[k, c] += w self.weighted_n_missing += w @@ -484,7 +484,7 @@ cdef class ClassificationCriterion(Criterion): ) return 0 - cdef int update(self, SIZE_t new_pos) except -1 nogil: + cdef int update(self, intp_t new_pos) except -1 nogil: """Updated statistics by moving sample_indices[pos:new_pos] to the left child. Returns -1 in case of failure to allocate memory (and raise MemoryError) @@ -492,24 +492,24 @@ cdef class ClassificationCriterion(Criterion): Parameters ---------- - new_pos : SIZE_t + new_pos : intp_t The new ending position for which to move sample_indices from the right child to the left child. """ - cdef SIZE_t pos = self.pos + cdef intp_t pos = self.pos # The missing samples are assumed to be in # self.sample_indices[-self.n_missing:] that is # self.sample_indices[end_non_missing:self.end]. - cdef SIZE_t end_non_missing = self.end - self.n_missing + cdef intp_t end_non_missing = self.end - self.n_missing - cdef const SIZE_t[:] sample_indices = self.sample_indices - cdef const DOUBLE_t[:] sample_weight = self.sample_weight + cdef const intp_t[:] sample_indices = self.sample_indices + cdef const float64_t[:] sample_weight = self.sample_weight - cdef SIZE_t i - cdef SIZE_t p - cdef SIZE_t k - cdef SIZE_t c - cdef DOUBLE_t w = 1.0 + cdef intp_t i + cdef intp_t p + cdef intp_t k + cdef intp_t c + cdef float64_t w = 1.0 # Update statistics up to new_pos # @@ -526,7 +526,7 @@ cdef class ClassificationCriterion(Criterion): w = sample_weight[i] for k in range(self.n_outputs): - self.sum_left[k, self.y[i, k]] += w + self.sum_left[k, self.y[i, k]] += w self.weighted_n_left += w @@ -540,7 +540,7 @@ cdef class ClassificationCriterion(Criterion): w = sample_weight[i] for k in range(self.n_outputs): - self.sum_left[k, self.y[i, k]] -= w + self.sum_left[k, self.y[i, k]] -= w self.weighted_n_left -= w @@ -568,7 +568,7 @@ cdef class ClassificationCriterion(Criterion): dest : double pointer The memory address which we will save the node value into. """ - cdef SIZE_t k + cdef intp_t k for k in range(self.n_outputs): memcpy(dest, &self.sum_total[k, 0], self.n_classes[k] * sizeof(double)) @@ -641,8 +641,8 @@ cdef class Entropy(ClassificationCriterion): """ cdef double entropy = 0.0 cdef double count_k - cdef SIZE_t k - cdef SIZE_t c + cdef intp_t k + cdef intp_t c for k in range(self.n_outputs): for c in range(self.n_classes[k]): @@ -670,8 +670,8 @@ cdef class Entropy(ClassificationCriterion): cdef double entropy_left = 0.0 cdef double entropy_right = 0.0 cdef double count_k - cdef SIZE_t k - cdef SIZE_t c + cdef intp_t k + cdef intp_t c for k in range(self.n_outputs): for c in range(self.n_classes[k]): @@ -716,8 +716,8 @@ cdef class Gini(ClassificationCriterion): cdef double gini = 0.0 cdef double sq_count cdef double count_k - cdef SIZE_t k - cdef SIZE_t c + cdef intp_t k + cdef intp_t c for k in range(self.n_outputs): sq_count = 0.0 @@ -750,8 +750,8 @@ cdef class Gini(ClassificationCriterion): cdef double sq_count_left cdef double sq_count_right cdef double count_k - cdef SIZE_t k - cdef SIZE_t c + cdef intp_t k + cdef intp_t c for k in range(self.n_outputs): sq_count_left = 0.0 @@ -794,8 +794,8 @@ cdef inline void _move_sums_regression( sum_2 = sum_total """ cdef: - SIZE_t i - SIZE_t n_bytes = criterion.n_outputs * sizeof(double) + intp_t i + intp_t n_bytes = criterion.n_outputs * sizeof(double) bint has_missing = criterion.n_missing != 0 if has_missing and put_missing_in_1: @@ -824,15 +824,15 @@ cdef class RegressionCriterion(Criterion): = (\sum_i^n y_i ** 2) - n_samples * y_bar ** 2 """ - def __cinit__(self, SIZE_t n_outputs, SIZE_t n_samples): + def __cinit__(self, intp_t n_outputs, intp_t n_samples): """Initialize parameters for this criterion. Parameters ---------- - n_outputs : SIZE_t + n_outputs : intp_t The number of targets to be predicted - n_samples : SIZE_t + n_samples : intp_t The total number of samples to fit on """ # Default values @@ -859,12 +859,12 @@ cdef class RegressionCriterion(Criterion): cdef int init( self, - const DOUBLE_t[:, ::1] y, - const DOUBLE_t[:] sample_weight, + const float64_t[:, ::1] y, + const float64_t[:] sample_weight, double weighted_n_samples, - const SIZE_t[:] sample_indices, - SIZE_t start, - SIZE_t end, + const intp_t[:] sample_indices, + intp_t start, + intp_t end, ) except -1 nogil: """Initialize the criterion. @@ -881,12 +881,12 @@ cdef class RegressionCriterion(Criterion): self.weighted_n_samples = weighted_n_samples self.weighted_n_node_samples = 0. - cdef SIZE_t i - cdef SIZE_t p - cdef SIZE_t k - cdef DOUBLE_t y_ik - cdef DOUBLE_t w_y_ik - cdef DOUBLE_t w = 1.0 + cdef intp_t i + cdef intp_t p + cdef intp_t k + cdef float64_t y_ik + cdef float64_t w_y_ik + cdef float64_t w = 1.0 self.sq_sum_total = 0.0 memset(&self.sum_total[0], 0, self.n_outputs * sizeof(double)) @@ -912,16 +912,16 @@ cdef class RegressionCriterion(Criterion): """Init sum_missing to hold sums for missing values.""" self.sum_missing = np.zeros(self.n_outputs, dtype=np.float64) - cdef void init_missing(self, SIZE_t n_missing) noexcept nogil: + cdef void init_missing(self, intp_t n_missing) noexcept nogil: """Initialize sum_missing if there are missing values. This method assumes that caller placed the missing samples in self.sample_indices[-n_missing:] """ - cdef SIZE_t i, p, k - cdef DOUBLE_t y_ik - cdef DOUBLE_t w_y_ik - cdef DOUBLE_t w = 1.0 + cdef intp_t i, p, k + cdef float64_t y_ik + cdef float64_t w_y_ik + cdef float64_t w = 1.0 self.n_missing = n_missing if n_missing == 0: @@ -970,21 +970,21 @@ cdef class RegressionCriterion(Criterion): ) return 0 - cdef int update(self, SIZE_t new_pos) except -1 nogil: + cdef int update(self, intp_t new_pos) except -1 nogil: """Updated statistics by moving sample_indices[pos:new_pos] to the left.""" - cdef const DOUBLE_t[:] sample_weight = self.sample_weight - cdef const SIZE_t[:] sample_indices = self.sample_indices + cdef const float64_t[:] sample_weight = self.sample_weight + cdef const intp_t[:] sample_indices = self.sample_indices - cdef SIZE_t pos = self.pos + cdef intp_t pos = self.pos # The missing samples are assumed to be in # self.sample_indices[-self.n_missing:] that is # self.sample_indices[end_non_missing:self.end]. - cdef SIZE_t end_non_missing = self.end - self.n_missing - cdef SIZE_t i - cdef SIZE_t p - cdef SIZE_t k - cdef DOUBLE_t w = 1.0 + cdef intp_t end_non_missing = self.end - self.n_missing + cdef intp_t i + cdef intp_t p + cdef intp_t k + cdef float64_t w = 1.0 # Update statistics up to new_pos # @@ -1035,7 +1035,7 @@ cdef class RegressionCriterion(Criterion): cdef void node_value(self, double* dest) noexcept nogil: """Compute the node value of sample_indices[start:end] into dest.""" - cdef SIZE_t k + cdef intp_t k for k in range(self.n_outputs): dest[k] = self.sum_total[k] / self.weighted_n_node_samples @@ -1086,7 +1086,7 @@ cdef class MSE(RegressionCriterion): better. """ cdef double impurity - cdef SIZE_t k + cdef intp_t k impurity = self.sq_sum_total / self.weighted_n_node_samples for k in range(self.n_outputs): @@ -1114,7 +1114,7 @@ cdef class MSE(RegressionCriterion): - 1/n_L * sum_{i left}(y_i)^2 - 1/n_R * sum_{i right}(y_i)^2 """ - cdef SIZE_t k + cdef intp_t k cdef double proxy_impurity_left = 0.0 cdef double proxy_impurity_right = 0.0 @@ -1132,20 +1132,20 @@ cdef class MSE(RegressionCriterion): i.e. the impurity of the left child (sample_indices[start:pos]) and the impurity the right child (sample_indices[pos:end]). """ - cdef const DOUBLE_t[:] sample_weight = self.sample_weight - cdef const SIZE_t[:] sample_indices = self.sample_indices - cdef SIZE_t pos = self.pos - cdef SIZE_t start = self.start + cdef const float64_t[:] sample_weight = self.sample_weight + cdef const intp_t[:] sample_indices = self.sample_indices + cdef intp_t pos = self.pos + cdef intp_t start = self.start - cdef DOUBLE_t y_ik + cdef float64_t y_ik cdef double sq_sum_left = 0.0 cdef double sq_sum_right - cdef SIZE_t i - cdef SIZE_t p - cdef SIZE_t k - cdef DOUBLE_t w = 1.0 + cdef intp_t i + cdef intp_t p + cdef intp_t k + cdef float64_t w = 1.0 for p in range(start, pos): i = sample_indices[p] @@ -1180,17 +1180,17 @@ cdef class MAE(RegressionCriterion): cdef cnp.ndarray right_child cdef void** left_child_ptr cdef void** right_child_ptr - cdef DOUBLE_t[::1] node_medians + cdef float64_t[::1] node_medians - def __cinit__(self, SIZE_t n_outputs, SIZE_t n_samples): + def __cinit__(self, intp_t n_outputs, intp_t n_samples): """Initialize parameters for this criterion. Parameters ---------- - n_outputs : SIZE_t + n_outputs : intp_t The number of targets to be predicted - n_samples : SIZE_t + n_samples : intp_t The total number of samples to fit on """ # Default values @@ -1219,20 +1219,20 @@ cdef class MAE(RegressionCriterion): cdef int init( self, - const DOUBLE_t[:, ::1] y, - const DOUBLE_t[:] sample_weight, + const float64_t[:, ::1] y, + const float64_t[:] sample_weight, double weighted_n_samples, - const SIZE_t[:] sample_indices, - SIZE_t start, - SIZE_t end, + const intp_t[:] sample_indices, + intp_t start, + intp_t end, ) except -1 nogil: """Initialize the criterion. This initializes the criterion at node sample_indices[start:end] and children sample_indices[start:start] and sample_indices[start:end]. """ - cdef SIZE_t i, p, k - cdef DOUBLE_t w = 1.0 + cdef intp_t i, p, k + cdef float64_t w = 1.0 # Initialize fields self.y = y @@ -1272,7 +1272,7 @@ cdef class MAE(RegressionCriterion): self.reset() return 0 - cdef void init_missing(self, SIZE_t n_missing) noexcept nogil: + cdef void init_missing(self, intp_t n_missing) noexcept nogil: """Raise error if n_missing != 0.""" if n_missing == 0: return @@ -1285,9 +1285,9 @@ cdef class MAE(RegressionCriterion): Returns -1 in case of failure to allocate memory (and raise MemoryError) or 0 otherwise. """ - cdef SIZE_t i, k - cdef DOUBLE_t value - cdef DOUBLE_t weight + cdef intp_t i, k + cdef float64_t value + cdef float64_t weight cdef void** left_child = self.left_child_ptr cdef void** right_child = self.right_child_ptr @@ -1320,8 +1320,8 @@ cdef class MAE(RegressionCriterion): self.weighted_n_left = self.weighted_n_node_samples self.pos = self.end - cdef DOUBLE_t value - cdef DOUBLE_t weight + cdef float64_t value + cdef float64_t weight cdef void** left_child = self.left_child_ptr cdef void** right_child = self.right_child_ptr @@ -1338,22 +1338,22 @@ cdef class MAE(RegressionCriterion): weight) return 0 - cdef int update(self, SIZE_t new_pos) except -1 nogil: + cdef int update(self, intp_t new_pos) except -1 nogil: """Updated statistics by moving sample_indices[pos:new_pos] to the left. Returns -1 in case of failure to allocate memory (and raise MemoryError) or 0 otherwise. """ - cdef const DOUBLE_t[:] sample_weight = self.sample_weight - cdef const SIZE_t[:] sample_indices = self.sample_indices + cdef const float64_t[:] sample_weight = self.sample_weight + cdef const intp_t[:] sample_indices = self.sample_indices cdef void** left_child = self.left_child_ptr cdef void** right_child = self.right_child_ptr - cdef SIZE_t pos = self.pos - cdef SIZE_t end = self.end - cdef SIZE_t i, p, k - cdef DOUBLE_t w = 1.0 + cdef intp_t pos = self.pos + cdef intp_t end = self.end + cdef intp_t i, p, k + cdef float64_t w = 1.0 # Update statistics up to new_pos # @@ -1397,7 +1397,7 @@ cdef class MAE(RegressionCriterion): cdef void node_value(self, double* dest) noexcept nogil: """Computes the node value of sample_indices[start:end] into dest.""" - cdef SIZE_t k + cdef intp_t k for k in range(self.n_outputs): dest[k] = self.node_medians[k] @@ -1433,11 +1433,11 @@ cdef class MAE(RegressionCriterion): i.e. the impurity of sample_indices[start:end]. The smaller the impurity the better. """ - cdef const DOUBLE_t[:] sample_weight = self.sample_weight - cdef const SIZE_t[:] sample_indices = self.sample_indices - cdef SIZE_t i, p, k - cdef DOUBLE_t w = 1.0 - cdef DOUBLE_t impurity = 0.0 + cdef const float64_t[:] sample_weight = self.sample_weight + cdef const intp_t[:] sample_indices = self.sample_indices + cdef intp_t i, p, k + cdef float64_t w = 1.0 + cdef float64_t impurity = 0.0 for k in range(self.n_outputs): for p in range(self.start, self.end): @@ -1457,18 +1457,18 @@ cdef class MAE(RegressionCriterion): i.e. the impurity of the left child (sample_indices[start:pos]) and the impurity the right child (sample_indices[pos:end]). """ - cdef const DOUBLE_t[:] sample_weight = self.sample_weight - cdef const SIZE_t[:] sample_indices = self.sample_indices + cdef const float64_t[:] sample_weight = self.sample_weight + cdef const intp_t[:] sample_indices = self.sample_indices - cdef SIZE_t start = self.start - cdef SIZE_t pos = self.pos - cdef SIZE_t end = self.end + cdef intp_t start = self.start + cdef intp_t pos = self.pos + cdef intp_t end = self.end - cdef SIZE_t i, p, k - cdef DOUBLE_t median - cdef DOUBLE_t w = 1.0 - cdef DOUBLE_t impurity_left = 0.0 - cdef DOUBLE_t impurity_right = 0.0 + cdef intp_t i, p, k + cdef float64_t median + cdef float64_t w = 1.0 + cdef float64_t impurity_left = 0.0 + cdef float64_t impurity_right = 0.0 cdef void** left_child = self.left_child_ptr cdef void** right_child = self.right_child_ptr @@ -1521,7 +1521,7 @@ cdef class FriedmanMSE(MSE): cdef double total_sum_left = 0.0 cdef double total_sum_right = 0.0 - cdef SIZE_t k + cdef intp_t k cdef double diff = 0.0 for k in range(self.n_outputs): @@ -1539,7 +1539,7 @@ cdef class FriedmanMSE(MSE): cdef double total_sum_left = 0.0 cdef double total_sum_right = 0.0 - cdef SIZE_t k + cdef intp_t k cdef double diff = 0.0 for k in range(self.n_outputs): @@ -1607,7 +1607,7 @@ cdef class Poisson(RegressionCriterion): - sum{i left }(y_i) * log(mean{i left}(y_i)) - sum{i right}(y_i) * log(mean{i right}(y_i)) """ - cdef SIZE_t k + cdef intp_t k cdef double proxy_impurity_left = 0.0 cdef double proxy_impurity_right = 0.0 cdef double y_mean_left = 0. @@ -1637,9 +1637,9 @@ cdef class Poisson(RegressionCriterion): i.e. the impurity of the left child (sample_indices[start:pos]) and the impurity of the right child (sample_indices[pos:end]) for Poisson. """ - cdef SIZE_t start = self.start - cdef SIZE_t pos = self.pos - cdef SIZE_t end = self.end + cdef intp_t start = self.start + cdef intp_t pos = self.pos + cdef intp_t end = self.end impurity_left[0] = self.poisson_loss(start, pos, self.sum_left, self.weighted_n_left) @@ -1647,22 +1647,24 @@ cdef class Poisson(RegressionCriterion): impurity_right[0] = self.poisson_loss(pos, end, self.sum_right, self.weighted_n_right) - cdef inline DOUBLE_t poisson_loss(self, - SIZE_t start, - SIZE_t end, - const double[::1] y_sum, - DOUBLE_t weight_sum) noexcept nogil: + cdef inline float64_t poisson_loss( + self, + intp_t start, + intp_t end, + const double[::1] y_sum, + float64_t weight_sum + ) noexcept nogil: """Helper function to compute Poisson loss (~deviance) of a given node. """ - cdef const DOUBLE_t[:, ::1] y = self.y - cdef const DOUBLE_t[:] sample_weight = self.sample_weight - cdef const SIZE_t[:] sample_indices = self.sample_indices - - cdef DOUBLE_t y_mean = 0. - cdef DOUBLE_t poisson_loss = 0. - cdef DOUBLE_t w = 1.0 - cdef SIZE_t i, k, p - cdef SIZE_t n_outputs = self.n_outputs + cdef const float64_t[:, ::1] y = self.y + cdef const float64_t[:] sample_weight = self.sample_weight + cdef const intp_t[:] sample_indices = self.sample_indices + + cdef float64_t y_mean = 0. + cdef float64_t poisson_loss = 0. + cdef float64_t w = 1.0 + cdef intp_t i, k, p + cdef intp_t n_outputs = self.n_outputs for k in range(n_outputs): if y_sum[k] <= EPSILON: diff --git a/sklearn/tree/_splitter.pxd b/sklearn/tree/_splitter.pxd index 2547e14b324df..981ac25f98ad6 100644 --- a/sklearn/tree/_splitter.pxd +++ b/sklearn/tree/_splitter.pxd @@ -12,16 +12,13 @@ cimport numpy as cnp from ._criterion cimport Criterion -from ._tree cimport DTYPE_t # Type of X -from ._tree cimport DOUBLE_t # Type of y, sample_weight -from ._tree cimport SIZE_t # Type for indices and counters -from ._tree cimport INT32_t # Signed 32 bit integer -from ._tree cimport UINT32_t # Unsigned 32 bit integer +from ..utils._typedefs cimport float32_t, float64_t, intp_t, int32_t, uint32_t + cdef struct SplitRecord: # Data to track sample split - SIZE_t feature # Which feature to split on. - SIZE_t pos # Split samples array at the given position, + intp_t feature # Which feature to split on. + intp_t pos # Split samples array at the given position, # # i.e. count of samples below threshold for feature. # # pos is >= end if the node is a leaf. double threshold # Threshold to split at. @@ -31,7 +28,7 @@ cdef struct SplitRecord: double lower_bound # Lower bound on value of both children for monotonicity double upper_bound # Upper bound on value of both children for monotonicity unsigned char missing_go_to_left # Controls if missing values go to the left node. - SIZE_t n_missing # Number of missing values for the feature being split on + intp_t n_missing # Number of missing values for the feature being split on cdef class Splitter: # The splitter searches in the input space for a feature and a threshold @@ -41,25 +38,25 @@ cdef class Splitter: # Internal structures cdef public Criterion criterion # Impurity criterion - cdef public SIZE_t max_features # Number of features to test - cdef public SIZE_t min_samples_leaf # Min samples in a leaf + cdef public intp_t max_features # Number of features to test + cdef public intp_t min_samples_leaf # Min samples in a leaf cdef public double min_weight_leaf # Minimum weight in a leaf cdef object random_state # Random state - cdef UINT32_t rand_r_state # sklearn_rand_r random number state + cdef uint32_t rand_r_state # sklearn_rand_r random number state - cdef SIZE_t[::1] samples # Sample indices in X, y - cdef SIZE_t n_samples # X.shape[0] + cdef intp_t[::1] samples # Sample indices in X, y + cdef intp_t n_samples # X.shape[0] cdef double weighted_n_samples # Weighted number of samples - cdef SIZE_t[::1] features # Feature indices in X - cdef SIZE_t[::1] constant_features # Constant features indices - cdef SIZE_t n_features # X.shape[1] - cdef DTYPE_t[::1] feature_values # temp. array holding feature values + cdef intp_t[::1] features # Feature indices in X + cdef intp_t[::1] constant_features # Constant features indices + cdef intp_t n_features # X.shape[1] + cdef float32_t[::1] feature_values # temp. array holding feature values - cdef SIZE_t start # Start position for the current node - cdef SIZE_t end # End position for the current node + cdef intp_t start # Start position for the current node + cdef intp_t end # End position for the current node - cdef const DOUBLE_t[:, ::1] y + cdef const float64_t[:, ::1] y # Monotonicity constraints for each feature. # The encoding is as follows: # -1: monotonic decrease @@ -67,7 +64,7 @@ cdef class Splitter: # +1: monotonic increase cdef const cnp.int8_t[:] monotonic_cst cdef bint with_monotonic_cst - cdef const DOUBLE_t[:] sample_weight + cdef const float64_t[:] sample_weight # The samples vector `samples` is maintained by the Splitter object such # that the samples contained in a node are contiguous. With this setting, @@ -89,15 +86,15 @@ cdef class Splitter: cdef int init( self, object X, - const DOUBLE_t[:, ::1] y, - const DOUBLE_t[:] sample_weight, + const float64_t[:, ::1] y, + const float64_t[:] sample_weight, const unsigned char[::1] missing_values_in_feature_mask, ) except -1 cdef int node_reset( self, - SIZE_t start, - SIZE_t end, + intp_t start, + intp_t end, double* weighted_n_node_samples ) except -1 nogil @@ -105,7 +102,7 @@ cdef class Splitter: self, double impurity, # Impurity of the node SplitRecord* split, - SIZE_t* n_constant_features, + intp_t* n_constant_features, double lower_bound, double upper_bound, ) except -1 nogil diff --git a/sklearn/tree/_splitter.pyx b/sklearn/tree/_splitter.pyx index 5c30ba315a90a..f4e6a3419627a 100644 --- a/sklearn/tree/_splitter.pyx +++ b/sklearn/tree/_splitter.pyx @@ -32,13 +32,13 @@ from ._utils cimport RAND_R_MAX cdef double INFINITY = np.inf # Mitigate precision differences between 32 bit and 64 bit -cdef DTYPE_t FEATURE_THRESHOLD = 1e-7 +cdef float32_t FEATURE_THRESHOLD = 1e-7 # Constant to switch between algorithm non zero value extract algorithm # in SparsePartitioner -cdef DTYPE_t EXTRACT_NNZ_SWITCH = 0.1 +cdef float32_t EXTRACT_NNZ_SWITCH = 0.1 -cdef inline void _init_split(SplitRecord* self, SIZE_t start_pos) noexcept nogil: +cdef inline void _init_split(SplitRecord* self, intp_t start_pos) noexcept nogil: self.impurity_left = INFINITY self.impurity_right = INFINITY self.pos = start_pos @@ -58,8 +58,8 @@ cdef class Splitter: def __cinit__( self, Criterion criterion, - SIZE_t max_features, - SIZE_t min_samples_leaf, + intp_t max_features, + intp_t min_samples_leaf, double min_weight_leaf, object random_state, const cnp.int8_t[:] monotonic_cst, @@ -70,11 +70,11 @@ cdef class Splitter: criterion : Criterion The criterion to measure the quality of a split. - max_features : SIZE_t + max_features : intp_t The maximal number of randomly selected features which can be considered for a split. - min_samples_leaf : SIZE_t + min_samples_leaf : intp_t The minimal number of samples each leaf can have, where splits which would result in having less samples in a leaf are not considered. @@ -120,8 +120,8 @@ cdef class Splitter: cdef int init( self, object X, - const DOUBLE_t[:, ::1] y, - const DOUBLE_t[:] sample_weight, + const float64_t[:, ::1] y, + const float64_t[:] sample_weight, const unsigned char[::1] missing_values_in_feature_mask, ) except -1: """Initialize the splitter. @@ -136,11 +136,11 @@ cdef class Splitter: X : object This contains the inputs. Usually it is a 2d numpy array. - y : ndarray, dtype=DOUBLE_t + y : ndarray, dtype=float64_t This is the vector of targets, or true labels, for the samples represented as a Cython memoryview. - sample_weight : ndarray, dtype=DOUBLE_t + sample_weight : ndarray, dtype=float64_t The weights of the samples, where higher weighted samples are fit closer than lower weight samples. If not provided, all samples are assumed to have uniform weight. This is represented @@ -151,14 +151,14 @@ cdef class Splitter: """ self.rand_r_state = self.random_state.randint(0, RAND_R_MAX) - cdef SIZE_t n_samples = X.shape[0] + cdef intp_t n_samples = X.shape[0] # Create a new array which will be used to store nonzero # samples from the feature of interest self.samples = np.empty(n_samples, dtype=np.intp) - cdef SIZE_t[::1] samples = self.samples + cdef intp_t[::1] samples = self.samples - cdef SIZE_t i, j + cdef intp_t i, j cdef double weighted_n_samples = 0.0 j = 0 @@ -177,7 +177,7 @@ cdef class Splitter: self.n_samples = j self.weighted_n_samples = weighted_n_samples - cdef SIZE_t n_features = X.shape[1] + cdef intp_t n_features = X.shape[1] self.features = np.arange(n_features, dtype=np.intp) self.n_features = n_features @@ -191,7 +191,7 @@ cdef class Splitter: self.criterion.init_sum_missing() return 0 - cdef int node_reset(self, SIZE_t start, SIZE_t end, + cdef int node_reset(self, intp_t start, intp_t end, double* weighted_n_node_samples) except -1 nogil: """Reset splitter on node samples[start:end]. @@ -200,9 +200,9 @@ cdef class Splitter: Parameters ---------- - start : SIZE_t + start : intp_t The index of the first sample to consider - end : SIZE_t + end : intp_t The index of the last sample to consider weighted_n_node_samples : ndarray, dtype=double pointer The total weight of those samples @@ -227,7 +227,7 @@ cdef class Splitter: self, double impurity, SplitRecord* split, - SIZE_t* n_constant_features, + intp_t* n_constant_features, double lower_bound, double upper_bound, ) except -1 nogil: @@ -259,10 +259,10 @@ cdef class Splitter: cdef inline void shift_missing_values_to_left_if_required( SplitRecord* best, - SIZE_t[::1] samples, - SIZE_t end, + intp_t[::1] samples, + intp_t end, ) nogil: - cdef SIZE_t i, p, current_end + cdef intp_t i, p, current_end # The partitioner partitions the data such that the missing values are in # samples[-n_missing:] for the criterion to consume. If the missing values # are going to the right node, then the missing values are already in the @@ -290,7 +290,7 @@ cdef inline int node_split_best( Criterion criterion, double impurity, SplitRecord* split, - SIZE_t* n_constant_features, + intp_t* n_constant_features, bint with_monotonic_cst, const cnp.int8_t[:] monotonic_cst, double lower_bound, @@ -302,43 +302,43 @@ cdef inline int node_split_best( or 0 otherwise. """ # Find the best split - cdef SIZE_t start = splitter.start - cdef SIZE_t end = splitter.end - cdef SIZE_t end_non_missing - cdef SIZE_t n_missing = 0 + cdef intp_t start = splitter.start + cdef intp_t end = splitter.end + cdef intp_t end_non_missing + cdef intp_t n_missing = 0 cdef bint has_missing = 0 - cdef SIZE_t n_searches - cdef SIZE_t n_left, n_right + cdef intp_t n_searches + cdef intp_t n_left, n_right cdef bint missing_go_to_left - cdef SIZE_t[::1] samples = splitter.samples - cdef SIZE_t[::1] features = splitter.features - cdef SIZE_t[::1] constant_features = splitter.constant_features - cdef SIZE_t n_features = splitter.n_features + cdef intp_t[::1] samples = splitter.samples + cdef intp_t[::1] features = splitter.features + cdef intp_t[::1] constant_features = splitter.constant_features + cdef intp_t n_features = splitter.n_features - cdef DTYPE_t[::1] feature_values = splitter.feature_values - cdef SIZE_t max_features = splitter.max_features - cdef SIZE_t min_samples_leaf = splitter.min_samples_leaf + cdef float32_t[::1] feature_values = splitter.feature_values + cdef intp_t max_features = splitter.max_features + cdef intp_t min_samples_leaf = splitter.min_samples_leaf cdef double min_weight_leaf = splitter.min_weight_leaf - cdef UINT32_t* random_state = &splitter.rand_r_state + cdef uint32_t* random_state = &splitter.rand_r_state cdef SplitRecord best_split, current_split cdef double current_proxy_improvement = -INFINITY cdef double best_proxy_improvement = -INFINITY - cdef SIZE_t f_i = n_features - cdef SIZE_t f_j - cdef SIZE_t p - cdef SIZE_t p_prev + cdef intp_t f_i = n_features + cdef intp_t f_j + cdef intp_t p + cdef intp_t p_prev - cdef SIZE_t n_visited_features = 0 + cdef intp_t n_visited_features = 0 # Number of features discovered to be constant during the split search - cdef SIZE_t n_found_constants = 0 + cdef intp_t n_found_constants = 0 # Number of features known to be constant and drawn without replacement - cdef SIZE_t n_drawn_constants = 0 - cdef SIZE_t n_known_constants = n_constant_features[0] + cdef intp_t n_drawn_constants = 0 + cdef intp_t n_known_constants = n_constant_features[0] # n_total_constants = n_known_constants + n_found_constants - cdef SIZE_t n_total_constants = n_known_constants + cdef intp_t n_total_constants = n_known_constants _init_split(&best_split, end) @@ -539,12 +539,12 @@ cdef inline int node_split_best( # Respect invariant for constant features: the original order of # element in features[:n_known_constants] must be preserved for sibling # and child nodes - memcpy(&features[0], &constant_features[0], sizeof(SIZE_t) * n_known_constants) + memcpy(&features[0], &constant_features[0], sizeof(intp_t) * n_known_constants) # Copy newly found constant features memcpy(&constant_features[n_known_constants], &features[n_known_constants], - sizeof(SIZE_t) * n_found_constants) + sizeof(intp_t) * n_found_constants) # Return values split[0] = best_split @@ -554,24 +554,24 @@ cdef inline int node_split_best( # Sort n-element arrays pointed to by feature_values and samples, simultaneously, # by the values in feature_values. Algorithm: Introsort (Musser, SP&E, 1997). -cdef inline void sort(DTYPE_t* feature_values, SIZE_t* samples, SIZE_t n) noexcept nogil: +cdef inline void sort(float32_t* feature_values, intp_t* samples, intp_t n) noexcept nogil: if n == 0: return cdef int maxd = 2 * log(n) introsort(feature_values, samples, n, maxd) -cdef inline void swap(DTYPE_t* feature_values, SIZE_t* samples, - SIZE_t i, SIZE_t j) noexcept nogil: +cdef inline void swap(float32_t* feature_values, intp_t* samples, + intp_t i, intp_t j) noexcept nogil: # Helper for sort feature_values[i], feature_values[j] = feature_values[j], feature_values[i] samples[i], samples[j] = samples[j], samples[i] -cdef inline DTYPE_t median3(DTYPE_t* feature_values, SIZE_t n) noexcept nogil: +cdef inline float32_t median3(float32_t* feature_values, intp_t n) noexcept nogil: # Median of three pivot selection, after Bentley and McIlroy (1993). # Engineering a sort function. SP&E. Requires 8/3 comparisons on average. - cdef DTYPE_t a = feature_values[0], b = feature_values[n / 2], c = feature_values[n - 1] + cdef float32_t a = feature_values[0], b = feature_values[n / 2], c = feature_values[n - 1] if a < b: if b < c: return b @@ -590,10 +590,10 @@ cdef inline DTYPE_t median3(DTYPE_t* feature_values, SIZE_t n) noexcept nogil: # Introsort with median of 3 pivot selection and 3-way partition function # (robust to repeated elements, e.g. lots of zero features). -cdef void introsort(DTYPE_t* feature_values, SIZE_t *samples, - SIZE_t n, int maxd) noexcept nogil: - cdef DTYPE_t pivot - cdef SIZE_t i, l, r +cdef void introsort(float32_t* feature_values, intp_t *samples, + intp_t n, int maxd) noexcept nogil: + cdef float32_t pivot + cdef intp_t i, l, r while n > 1: if maxd <= 0: # max depth limit exceeded ("gone quadratic") @@ -623,10 +623,10 @@ cdef void introsort(DTYPE_t* feature_values, SIZE_t *samples, n -= r -cdef inline void sift_down(DTYPE_t* feature_values, SIZE_t* samples, - SIZE_t start, SIZE_t end) noexcept nogil: +cdef inline void sift_down(float32_t* feature_values, intp_t* samples, + intp_t start, intp_t end) noexcept nogil: # Restore heap order in feature_values[start:end] by moving the max element to start. - cdef SIZE_t child, maxind, root + cdef intp_t child, maxind, root root = start while True: @@ -646,8 +646,8 @@ cdef inline void sift_down(DTYPE_t* feature_values, SIZE_t* samples, root = maxind -cdef void heapsort(DTYPE_t* feature_values, SIZE_t* samples, SIZE_t n) noexcept nogil: - cdef SIZE_t start, end +cdef void heapsort(float32_t* feature_values, intp_t* samples, intp_t n) noexcept nogil: + cdef intp_t start, end # heapify start = (n - 2) / 2 @@ -671,7 +671,7 @@ cdef inline int node_split_random( Criterion criterion, double impurity, SplitRecord* split, - SIZE_t* n_constant_features, + intp_t* n_constant_features, bint with_monotonic_cst, const cnp.int8_t[:] monotonic_cst, double lower_bound, @@ -683,34 +683,34 @@ cdef inline int node_split_random( or 0 otherwise. """ # Draw random splits and pick the best - cdef SIZE_t start = splitter.start - cdef SIZE_t end = splitter.end + cdef intp_t start = splitter.start + cdef intp_t end = splitter.end - cdef SIZE_t[::1] features = splitter.features - cdef SIZE_t[::1] constant_features = splitter.constant_features - cdef SIZE_t n_features = splitter.n_features + cdef intp_t[::1] features = splitter.features + cdef intp_t[::1] constant_features = splitter.constant_features + cdef intp_t n_features = splitter.n_features - cdef SIZE_t max_features = splitter.max_features - cdef SIZE_t min_samples_leaf = splitter.min_samples_leaf + cdef intp_t max_features = splitter.max_features + cdef intp_t min_samples_leaf = splitter.min_samples_leaf cdef double min_weight_leaf = splitter.min_weight_leaf - cdef UINT32_t* random_state = &splitter.rand_r_state + cdef uint32_t* random_state = &splitter.rand_r_state cdef SplitRecord best_split, current_split cdef double current_proxy_improvement = - INFINITY cdef double best_proxy_improvement = - INFINITY - cdef SIZE_t f_i = n_features - cdef SIZE_t f_j + cdef intp_t f_i = n_features + cdef intp_t f_j # Number of features discovered to be constant during the split search - cdef SIZE_t n_found_constants = 0 + cdef intp_t n_found_constants = 0 # Number of features known to be constant and drawn without replacement - cdef SIZE_t n_drawn_constants = 0 - cdef SIZE_t n_known_constants = n_constant_features[0] + cdef intp_t n_drawn_constants = 0 + cdef intp_t n_known_constants = n_constant_features[0] # n_total_constants = n_known_constants + n_found_constants - cdef SIZE_t n_total_constants = n_known_constants - cdef SIZE_t n_visited_features = 0 - cdef DTYPE_t min_feature_value - cdef DTYPE_t max_feature_value + cdef intp_t n_total_constants = n_known_constants + cdef intp_t n_visited_features = 0 + cdef float32_t min_feature_value + cdef float32_t max_feature_value _init_split(&best_split, end) @@ -841,12 +841,12 @@ cdef inline int node_split_random( # Respect invariant for constant features: the original order of # element in features[:n_known_constants] must be preserved for sibling # and child nodes - memcpy(&features[0], &constant_features[0], sizeof(SIZE_t) * n_known_constants) + memcpy(&features[0], &constant_features[0], sizeof(intp_t) * n_known_constants) # Copy newly found constant features memcpy(&constant_features[n_known_constants], &features[n_known_constants], - sizeof(SIZE_t) * n_found_constants) + sizeof(intp_t) * n_found_constants) # Return values split[0] = best_split @@ -861,19 +861,19 @@ cdef class DensePartitioner: Note that this partitioner is agnostic to the splitting strategy (best vs. random). """ cdef: - const DTYPE_t[:, :] X - cdef SIZE_t[::1] samples - cdef DTYPE_t[::1] feature_values - cdef SIZE_t start - cdef SIZE_t end - cdef SIZE_t n_missing + const float32_t[:, :] X + cdef intp_t[::1] samples + cdef float32_t[::1] feature_values + cdef intp_t start + cdef intp_t end + cdef intp_t n_missing cdef const unsigned char[::1] missing_values_in_feature_mask def __init__( self, - const DTYPE_t[:, :] X, - SIZE_t[::1] samples, - DTYPE_t[::1] feature_values, + const float32_t[:, :] X, + intp_t[::1] samples, + float32_t[::1] feature_values, const unsigned char[::1] missing_values_in_feature_mask, ): self.X = X @@ -881,14 +881,14 @@ cdef class DensePartitioner: self.feature_values = feature_values self.missing_values_in_feature_mask = missing_values_in_feature_mask - cdef inline void init_node_split(self, SIZE_t start, SIZE_t end) noexcept nogil: + cdef inline void init_node_split(self, intp_t start, intp_t end) noexcept nogil: """Initialize splitter at the beginning of node_split.""" self.start = start self.end = end self.n_missing = 0 cdef inline void sort_samples_and_feature_values( - self, SIZE_t current_feature + self, intp_t current_feature ) noexcept nogil: """Simultaneously sort based on the feature_values. @@ -897,11 +897,11 @@ cdef class DensePartitioner: in self.n_missing. """ cdef: - SIZE_t i, current_end - DTYPE_t[::1] feature_values = self.feature_values - const DTYPE_t[:, :] X = self.X - SIZE_t[::1] samples = self.samples - SIZE_t n_missing = 0 + intp_t i, current_end + float32_t[::1] feature_values = self.feature_values + const float32_t[:, :] X = self.X + intp_t[::1] samples = self.samples + intp_t n_missing = 0 const unsigned char[::1] missing_values_in_feature_mask = self.missing_values_in_feature_mask # Sort samples along that feature; by @@ -938,19 +938,19 @@ cdef class DensePartitioner: cdef inline void find_min_max( self, - SIZE_t current_feature, - DTYPE_t* min_feature_value_out, - DTYPE_t* max_feature_value_out, + intp_t current_feature, + float32_t* min_feature_value_out, + float32_t* max_feature_value_out, ) noexcept nogil: """Find the minimum and maximum value for current_feature.""" cdef: - SIZE_t p - DTYPE_t current_feature_value - const DTYPE_t[:, :] X = self.X - SIZE_t[::1] samples = self.samples - DTYPE_t min_feature_value = X[samples[self.start], current_feature] - DTYPE_t max_feature_value = min_feature_value - DTYPE_t[::1] feature_values = self.feature_values + intp_t p + float32_t current_feature_value + const float32_t[:, :] X = self.X + intp_t[::1] samples = self.samples + float32_t min_feature_value = X[samples[self.start], current_feature] + float32_t max_feature_value = min_feature_value + float32_t[::1] feature_values = self.feature_values feature_values[self.start] = min_feature_value @@ -966,14 +966,14 @@ cdef class DensePartitioner: min_feature_value_out[0] = min_feature_value max_feature_value_out[0] = max_feature_value - cdef inline void next_p(self, SIZE_t* p_prev, SIZE_t* p) noexcept nogil: + cdef inline void next_p(self, intp_t* p_prev, intp_t* p) noexcept nogil: """Compute the next p_prev and p for iteratiing over feature values. The missing values are not included when iterating through the feature values. """ cdef: - DTYPE_t[::1] feature_values = self.feature_values - SIZE_t end_non_missing = self.end - self.n_missing + float32_t[::1] feature_values = self.feature_values + intp_t end_non_missing = self.end - self.n_missing while ( p[0] + 1 < end_non_missing and @@ -987,13 +987,13 @@ cdef class DensePartitioner: # (feature_values[p] >= end) or (feature_values[p] > feature_values[p - 1]) p[0] += 1 - cdef inline SIZE_t partition_samples(self, double current_threshold) noexcept nogil: + cdef inline intp_t partition_samples(self, double current_threshold) noexcept nogil: """Partition samples for feature_values at the current_threshold.""" cdef: - SIZE_t p = self.start - SIZE_t partition_end = self.end - SIZE_t[::1] samples = self.samples - DTYPE_t[::1] feature_values = self.feature_values + intp_t p = self.start + intp_t partition_end = self.end + intp_t[::1] samples = self.samples + float32_t[::1] feature_values = self.feature_values while p < partition_end: if feature_values[p] <= current_threshold: @@ -1010,10 +1010,10 @@ cdef class DensePartitioner: cdef inline void partition_samples_final( self, - SIZE_t best_pos, + intp_t best_pos, double best_threshold, - SIZE_t best_feature, - SIZE_t best_n_missing, + intp_t best_feature, + intp_t best_n_missing, ) noexcept nogil: """Partition samples for X at the best_threshold and best_feature. @@ -1023,13 +1023,13 @@ cdef class DensePartitioner: """ cdef: # Local invariance: start <= p <= partition_end <= end - SIZE_t start = self.start - SIZE_t p = start - SIZE_t end = self.end - 1 - SIZE_t partition_end = end - best_n_missing - SIZE_t[::1] samples = self.samples - const DTYPE_t[:, :] X = self.X - DTYPE_t current_value + intp_t start = self.start + intp_t p = start + intp_t end = self.end - 1 + intp_t partition_end = end - best_n_missing + intp_t[::1] samples = self.samples + const float32_t[:, :] X = self.X + float32_t current_value if best_n_missing != 0: # Move samples with missing values to the end while partitioning the @@ -1072,32 +1072,32 @@ cdef class SparsePartitioner: Note that this partitioner is agnostic to the splitting strategy (best vs. random). """ - cdef SIZE_t[::1] samples - cdef DTYPE_t[::1] feature_values - cdef SIZE_t start - cdef SIZE_t end - cdef SIZE_t n_missing + cdef intp_t[::1] samples + cdef float32_t[::1] feature_values + cdef intp_t start + cdef intp_t end + cdef intp_t n_missing cdef const unsigned char[::1] missing_values_in_feature_mask - cdef const DTYPE_t[::1] X_data - cdef const INT32_t[::1] X_indices - cdef const INT32_t[::1] X_indptr + cdef const float32_t[::1] X_data + cdef const int32_t[::1] X_indices + cdef const int32_t[::1] X_indptr - cdef SIZE_t n_total_samples + cdef intp_t n_total_samples - cdef SIZE_t[::1] index_to_samples - cdef SIZE_t[::1] sorted_samples + cdef intp_t[::1] index_to_samples + cdef intp_t[::1] sorted_samples - cdef SIZE_t start_positive - cdef SIZE_t end_negative + cdef intp_t start_positive + cdef intp_t end_negative cdef bint is_samples_sorted def __init__( self, object X, - SIZE_t[::1] samples, - SIZE_t n_samples, - DTYPE_t[::1] feature_values, + intp_t[::1] samples, + intp_t n_samples, + float32_t[::1] feature_values, const unsigned char[::1] missing_values_in_feature_mask, ): if not (issparse(X) and X.format == "csc"): @@ -1107,7 +1107,7 @@ cdef class SparsePartitioner: self.feature_values = feature_values # Initialize X - cdef SIZE_t n_total_samples = X.shape[0] + cdef intp_t n_total_samples = X.shape[0] self.X_data = X.data self.X_indices = X.indices @@ -1118,13 +1118,13 @@ cdef class SparsePartitioner: self.index_to_samples = np.full(n_total_samples, fill_value=-1, dtype=np.intp) self.sorted_samples = np.empty(n_samples, dtype=np.intp) - cdef SIZE_t p + cdef intp_t p for p in range(n_samples): self.index_to_samples[samples[p]] = p self.missing_values_in_feature_mask = missing_values_in_feature_mask - cdef inline void init_node_split(self, SIZE_t start, SIZE_t end) noexcept nogil: + cdef inline void init_node_split(self, intp_t start, intp_t end) noexcept nogil: """Initialize splitter at the beginning of node_split.""" self.start = start self.end = end @@ -1132,13 +1132,13 @@ cdef class SparsePartitioner: self.n_missing = 0 cdef inline void sort_samples_and_feature_values( - self, SIZE_t current_feature + self, intp_t current_feature ) noexcept nogil: """Simultaneously sort based on the feature_values.""" cdef: - DTYPE_t[::1] feature_values = self.feature_values - SIZE_t[::1] index_to_samples = self.index_to_samples - SIZE_t[::1] samples = self.samples + float32_t[::1] feature_values = self.feature_values + intp_t[::1] index_to_samples = self.index_to_samples + intp_t[::1] samples = self.samples self.extract_nnz(current_feature) # Sort the positive and negative parts of `feature_values` @@ -1171,15 +1171,15 @@ cdef class SparsePartitioner: cdef inline void find_min_max( self, - SIZE_t current_feature, - DTYPE_t* min_feature_value_out, - DTYPE_t* max_feature_value_out, + intp_t current_feature, + float32_t* min_feature_value_out, + float32_t* max_feature_value_out, ) noexcept nogil: """Find the minimum and maximum value for current_feature.""" cdef: - SIZE_t p - DTYPE_t current_feature_value, min_feature_value, max_feature_value - DTYPE_t[::1] feature_values = self.feature_values + intp_t p + float32_t current_feature_value, min_feature_value, max_feature_value + float32_t[::1] feature_values = self.feature_values self.extract_nnz(current_feature) @@ -1212,11 +1212,11 @@ cdef class SparsePartitioner: min_feature_value_out[0] = min_feature_value max_feature_value_out[0] = max_feature_value - cdef inline void next_p(self, SIZE_t* p_prev, SIZE_t* p) noexcept nogil: + cdef inline void next_p(self, intp_t* p_prev, intp_t* p) noexcept nogil: """Compute the next p_prev and p for iteratiing over feature values.""" cdef: - SIZE_t p_next - DTYPE_t[::1] feature_values = self.feature_values + intp_t p_next + float32_t[::1] feature_values = self.feature_values if p[0] + 1 != self.end_negative: p_next = p[0] + 1 @@ -1234,28 +1234,28 @@ cdef class SparsePartitioner: p_prev[0] = p[0] p[0] = p_next - cdef inline SIZE_t partition_samples(self, double current_threshold) noexcept nogil: + cdef inline intp_t partition_samples(self, double current_threshold) noexcept nogil: """Partition samples for feature_values at the current_threshold.""" return self._partition(current_threshold, self.start_positive) cdef inline void partition_samples_final( self, - SIZE_t best_pos, + intp_t best_pos, double best_threshold, - SIZE_t best_feature, - SIZE_t n_missing, + intp_t best_feature, + intp_t n_missing, ) noexcept nogil: """Partition samples for X at the best_threshold and best_feature.""" self.extract_nnz(best_feature) self._partition(best_threshold, best_pos) - cdef inline SIZE_t _partition(self, double threshold, SIZE_t zero_pos) noexcept nogil: + cdef inline intp_t _partition(self, double threshold, intp_t zero_pos) noexcept nogil: """Partition samples[start:end] based on threshold.""" cdef: - SIZE_t p, partition_end - SIZE_t[::1] index_to_samples = self.index_to_samples - DTYPE_t[::1] feature_values = self.feature_values - SIZE_t[::1] samples = self.samples + intp_t p, partition_end + intp_t[::1] index_to_samples = self.index_to_samples + float32_t[::1] feature_values = self.feature_values + intp_t[::1] samples = self.samples if threshold < 0.: p = self.start @@ -1281,7 +1281,7 @@ cdef class SparsePartitioner: return partition_end - cdef inline void extract_nnz(self, SIZE_t feature) noexcept nogil: + cdef inline void extract_nnz(self, intp_t feature) noexcept nogil: """Extract and partition values for a given feature. The extracted values are partitioned between negative values @@ -1297,19 +1297,19 @@ cdef class SparsePartitioner: Parameters ---------- - feature : SIZE_t, + feature : intp_t, Index of the feature we want to extract non zero value. """ - cdef SIZE_t[::1] samples = self.samples - cdef DTYPE_t[::1] feature_values = self.feature_values - cdef SIZE_t indptr_start = self.X_indptr[feature], - cdef SIZE_t indptr_end = self.X_indptr[feature + 1] - cdef SIZE_t n_indices = (indptr_end - indptr_start) - cdef SIZE_t n_samples = self.end - self.start - cdef SIZE_t[::1] index_to_samples = self.index_to_samples - cdef SIZE_t[::1] sorted_samples = self.sorted_samples - cdef const INT32_t[::1] X_indices = self.X_indices - cdef const DTYPE_t[::1] X_data = self.X_data + cdef intp_t[::1] samples = self.samples + cdef float32_t[::1] feature_values = self.feature_values + cdef intp_t indptr_start = self.X_indptr[feature], + cdef intp_t indptr_end = self.X_indptr[feature + 1] + cdef intp_t n_indices = (indptr_end - indptr_start) + cdef intp_t n_samples = self.end - self.start + cdef intp_t[::1] index_to_samples = self.index_to_samples + cdef intp_t[::1] sorted_samples = self.sorted_samples + cdef const int32_t[::1] X_indices = self.X_indices + cdef const float32_t[::1] X_data = self.X_data # Use binary search if n_samples * log(n_indices) < # n_indices and index_to_samples approach otherwise. @@ -1339,18 +1339,18 @@ cdef class SparsePartitioner: cdef int compare_SIZE_t(const void* a, const void* b) noexcept nogil: """Comparison function for sort.""" - return ((a)[0] - (b)[0]) + return ((a)[0] - (b)[0]) -cdef inline void binary_search(const INT32_t[::1] sorted_array, - INT32_t start, INT32_t end, - SIZE_t value, SIZE_t* index, - INT32_t* new_start) noexcept nogil: +cdef inline void binary_search(const int32_t[::1] sorted_array, + int32_t start, int32_t end, + intp_t value, intp_t* index, + int32_t* new_start) noexcept nogil: """Return the index of value in the sorted array. If not found, return -1. new_start is the last pivot + 1 """ - cdef INT32_t pivot + cdef int32_t pivot index[0] = -1 while start < end: pivot = start + (end - start) / 2 @@ -1367,25 +1367,25 @@ cdef inline void binary_search(const INT32_t[::1] sorted_array, new_start[0] = start -cdef inline void extract_nnz_index_to_samples(const INT32_t[::1] X_indices, - const DTYPE_t[::1] X_data, - INT32_t indptr_start, - INT32_t indptr_end, - SIZE_t[::1] samples, - SIZE_t start, - SIZE_t end, - SIZE_t[::1] index_to_samples, - DTYPE_t[::1] feature_values, - SIZE_t* end_negative, - SIZE_t* start_positive) noexcept nogil: +cdef inline void extract_nnz_index_to_samples(const int32_t[::1] X_indices, + const float32_t[::1] X_data, + int32_t indptr_start, + int32_t indptr_end, + intp_t[::1] samples, + intp_t start, + intp_t end, + intp_t[::1] index_to_samples, + float32_t[::1] feature_values, + intp_t* end_negative, + intp_t* start_positive) noexcept nogil: """Extract and partition values for a feature using index_to_samples. Complexity is O(indptr_end - indptr_start). """ - cdef INT32_t k - cdef SIZE_t index - cdef SIZE_t end_negative_ = start - cdef SIZE_t start_positive_ = end + cdef int32_t k + cdef intp_t index + cdef intp_t end_negative_ = start + cdef intp_t start_positive_ = end for k in range(indptr_start, indptr_end): if start <= index_to_samples[X_indices[k]] < end: @@ -1406,18 +1406,18 @@ cdef inline void extract_nnz_index_to_samples(const INT32_t[::1] X_indices, start_positive[0] = start_positive_ -cdef inline void extract_nnz_binary_search(const INT32_t[::1] X_indices, - const DTYPE_t[::1] X_data, - INT32_t indptr_start, - INT32_t indptr_end, - SIZE_t[::1] samples, - SIZE_t start, - SIZE_t end, - SIZE_t[::1] index_to_samples, - DTYPE_t[::1] feature_values, - SIZE_t* end_negative, - SIZE_t* start_positive, - SIZE_t[::1] sorted_samples, +cdef inline void extract_nnz_binary_search(const int32_t[::1] X_indices, + const float32_t[::1] X_data, + int32_t indptr_start, + int32_t indptr_end, + intp_t[::1] samples, + intp_t start, + intp_t end, + intp_t[::1] index_to_samples, + float32_t[::1] feature_values, + intp_t* end_negative, + intp_t* start_positive, + intp_t[::1] sorted_samples, bint* is_samples_sorted) noexcept nogil: """Extract and partition values for a given feature using binary search. @@ -1427,13 +1427,13 @@ cdef inline void extract_nnz_binary_search(const INT32_t[::1] X_indices, O((1 - is_samples_sorted[0]) * n_samples * log(n_samples) + n_samples * log(n_indices)). """ - cdef SIZE_t n_samples + cdef intp_t n_samples if not is_samples_sorted[0]: n_samples = end - start memcpy(&sorted_samples[start], &samples[start], - n_samples * sizeof(SIZE_t)) - qsort(&sorted_samples[start], n_samples, sizeof(SIZE_t), + n_samples * sizeof(intp_t)) + qsort(&sorted_samples[start], n_samples, sizeof(intp_t), compare_SIZE_t) is_samples_sorted[0] = 1 @@ -1445,11 +1445,11 @@ cdef inline void extract_nnz_binary_search(const INT32_t[::1] X_indices, sorted_samples[end - 1] < X_indices[indptr_end - 1]): indptr_end -= 1 - cdef SIZE_t p = start - cdef SIZE_t index - cdef SIZE_t k - cdef SIZE_t end_negative_ = start - cdef SIZE_t start_positive_ = end + cdef intp_t p = start + cdef intp_t index + cdef intp_t k + cdef intp_t end_negative_ = start + cdef intp_t start_positive_ = end while (p < end and indptr_start < indptr_end): # Find index of sorted_samples[p] in X_indices @@ -1477,8 +1477,8 @@ cdef inline void extract_nnz_binary_search(const INT32_t[::1] X_indices, start_positive[0] = start_positive_ -cdef inline void sparse_swap(SIZE_t[::1] index_to_samples, SIZE_t[::1] samples, - SIZE_t pos_1, SIZE_t pos_2) noexcept nogil: +cdef inline void sparse_swap(intp_t[::1] index_to_samples, intp_t[::1] samples, + intp_t pos_1, intp_t pos_2) noexcept nogil: """Swap sample pos_1 and pos_2 preserving sparse invariant.""" samples[pos_1], samples[pos_2] = samples[pos_2], samples[pos_1] index_to_samples[samples[pos_1]] = pos_1 @@ -1491,8 +1491,8 @@ cdef class BestSplitter(Splitter): cdef int init( self, object X, - const DOUBLE_t[:, ::1] y, - const DOUBLE_t[:] sample_weight, + const float64_t[:, ::1] y, + const float64_t[:] sample_weight, const unsigned char[::1] missing_values_in_feature_mask, ) except -1: Splitter.init(self, X, y, sample_weight, missing_values_in_feature_mask) @@ -1504,7 +1504,7 @@ cdef class BestSplitter(Splitter): self, double impurity, SplitRecord* split, - SIZE_t* n_constant_features, + intp_t* n_constant_features, double lower_bound, double upper_bound ) except -1 nogil: @@ -1527,8 +1527,8 @@ cdef class BestSparseSplitter(Splitter): cdef int init( self, object X, - const DOUBLE_t[:, ::1] y, - const DOUBLE_t[:] sample_weight, + const float64_t[:, ::1] y, + const float64_t[:] sample_weight, const unsigned char[::1] missing_values_in_feature_mask, ) except -1: Splitter.init(self, X, y, sample_weight, missing_values_in_feature_mask) @@ -1540,7 +1540,7 @@ cdef class BestSparseSplitter(Splitter): self, double impurity, SplitRecord* split, - SIZE_t* n_constant_features, + intp_t* n_constant_features, double lower_bound, double upper_bound ) except -1 nogil: @@ -1563,8 +1563,8 @@ cdef class RandomSplitter(Splitter): cdef int init( self, object X, - const DOUBLE_t[:, ::1] y, - const DOUBLE_t[:] sample_weight, + const float64_t[:, ::1] y, + const float64_t[:] sample_weight, const unsigned char[::1] missing_values_in_feature_mask, ) except -1: Splitter.init(self, X, y, sample_weight, missing_values_in_feature_mask) @@ -1576,7 +1576,7 @@ cdef class RandomSplitter(Splitter): self, double impurity, SplitRecord* split, - SIZE_t* n_constant_features, + intp_t* n_constant_features, double lower_bound, double upper_bound ) except -1 nogil: @@ -1599,8 +1599,8 @@ cdef class RandomSparseSplitter(Splitter): cdef int init( self, object X, - const DOUBLE_t[:, ::1] y, - const DOUBLE_t[:] sample_weight, + const float64_t[:, ::1] y, + const float64_t[:] sample_weight, const unsigned char[::1] missing_values_in_feature_mask, ) except -1: Splitter.init(self, X, y, sample_weight, missing_values_in_feature_mask) @@ -1611,7 +1611,7 @@ cdef class RandomSparseSplitter(Splitter): self, double impurity, SplitRecord* split, - SIZE_t* n_constant_features, + intp_t* n_constant_features, double lower_bound, double upper_bound ) except -1 nogil: diff --git a/sklearn/tree/_tree.pxd b/sklearn/tree/_tree.pxd index b99f44c0472a2..cab8cef45a474 100644 --- a/sklearn/tree/_tree.pxd +++ b/sklearn/tree/_tree.pxd @@ -13,11 +13,7 @@ import numpy as np cimport numpy as cnp -ctypedef cnp.npy_float32 DTYPE_t # Type of X -ctypedef cnp.npy_float64 DOUBLE_t # Type of y, sample_weight -ctypedef cnp.npy_intp SIZE_t # Type for indices and counters -ctypedef cnp.npy_int32 INT32_t # Signed 32 bit integer -ctypedef cnp.npy_uint32 UINT32_t # Unsigned 32 bit integer +from ..utils._typedefs cimport float32_t, float64_t, intp_t, int32_t, uint32_t from ._splitter cimport Splitter from ._splitter cimport SplitRecord @@ -25,13 +21,13 @@ from ._splitter cimport SplitRecord cdef struct Node: # Base storage structure for the nodes in a Tree object - SIZE_t left_child # id of the left child of the node - SIZE_t right_child # id of the right child of the node - SIZE_t feature # Feature used for splitting the node - DOUBLE_t threshold # Threshold value at the node - DOUBLE_t impurity # Impurity of the node (i.e., the value of the criterion) - SIZE_t n_node_samples # Number of samples at the node - DOUBLE_t weighted_n_node_samples # Weighted number of samples at the node + intp_t left_child # id of the left child of the node + intp_t right_child # id of the right child of the node + intp_t feature # Feature used for splitting the node + float64_t threshold # Threshold value at the node + float64_t impurity # Impurity of the node (i.e., the value of the criterion) + intp_t n_node_samples # Number of samples at the node + float64_t weighted_n_node_samples # Weighted number of samples at the node unsigned char missing_go_to_left # Whether features have missing values @@ -41,28 +37,28 @@ cdef class Tree: # feature importances. # Input/Output layout - cdef public SIZE_t n_features # Number of features in X - cdef SIZE_t* n_classes # Number of classes in y[:, k] - cdef public SIZE_t n_outputs # Number of outputs in y - cdef public SIZE_t max_n_classes # max(n_classes) + cdef public intp_t n_features # Number of features in X + cdef intp_t* n_classes # Number of classes in y[:, k] + cdef public intp_t n_outputs # Number of outputs in y + cdef public intp_t max_n_classes # max(n_classes) # Inner structures: values are stored separately from node structure, # since size is determined at runtime. - cdef public SIZE_t max_depth # Max depth of the tree - cdef public SIZE_t node_count # Counter for node IDs - cdef public SIZE_t capacity # Capacity of tree, in terms of nodes + cdef public intp_t max_depth # Max depth of the tree + cdef public intp_t node_count # Counter for node IDs + cdef public intp_t capacity # Capacity of tree, in terms of nodes cdef Node* nodes # Array of nodes cdef double* value # (capacity, n_outputs, max_n_classes) array of values - cdef SIZE_t value_stride # = n_outputs * max_n_classes + cdef intp_t value_stride # = n_outputs * max_n_classes # Methods - cdef SIZE_t _add_node(self, SIZE_t parent, bint is_left, bint is_leaf, - SIZE_t feature, double threshold, double impurity, - SIZE_t n_node_samples, + cdef intp_t _add_node(self, intp_t parent, bint is_left, bint is_leaf, + intp_t feature, double threshold, double impurity, + intp_t n_node_samples, double weighted_n_node_samples, unsigned char missing_go_to_left) except -1 nogil - cdef int _resize(self, SIZE_t capacity) except -1 nogil - cdef int _resize_c(self, SIZE_t capacity=*) except -1 nogil + cdef int _resize(self, intp_t capacity) except -1 nogil + cdef int _resize_c(self, intp_t capacity=*) except -1 nogil cdef cnp.ndarray _get_value_ndarray(self) cdef cnp.ndarray _get_node_ndarray(self) @@ -95,24 +91,24 @@ cdef class TreeBuilder: cdef Splitter splitter # Splitting algorithm - cdef SIZE_t min_samples_split # Minimum number of samples in an internal node - cdef SIZE_t min_samples_leaf # Minimum number of samples in a leaf + cdef intp_t min_samples_split # Minimum number of samples in an internal node + cdef intp_t min_samples_leaf # Minimum number of samples in a leaf cdef double min_weight_leaf # Minimum weight in a leaf - cdef SIZE_t max_depth # Maximal tree depth + cdef intp_t max_depth # Maximal tree depth cdef double min_impurity_decrease # Impurity threshold for early stopping cpdef build( self, Tree tree, object X, - const DOUBLE_t[:, ::1] y, - const DOUBLE_t[:] sample_weight=*, + const float64_t[:, ::1] y, + const float64_t[:] sample_weight=*, const unsigned char[::1] missing_values_in_feature_mask=*, ) cdef _check_input( self, object X, - const DOUBLE_t[:, ::1] y, - const DOUBLE_t[:] sample_weight, + const float64_t[:, ::1] y, + const float64_t[:] sample_weight, ) diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index ce998e80a9d0a..eef59dfb3ec43 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -70,8 +70,8 @@ cdef int IS_NOT_LEFT = 0 TREE_LEAF = -1 TREE_UNDEFINED = -2 -cdef SIZE_t _TREE_LEAF = TREE_LEAF -cdef SIZE_t _TREE_UNDEFINED = TREE_UNDEFINED +cdef intp_t _TREE_LEAF = TREE_LEAF +cdef intp_t _TREE_UNDEFINED = TREE_UNDEFINED # Build the corresponding numpy dtype for Node. # This works by casting `dummy` to an array of Node of length 1, which numpy @@ -91,8 +91,8 @@ cdef class TreeBuilder: self, Tree tree, object X, - const DOUBLE_t[:, ::1] y, - const DOUBLE_t[:] sample_weight=None, + const float64_t[:, ::1] y, + const float64_t[:] sample_weight=None, const unsigned char[::1] missing_values_in_feature_mask=None, ): """Build a decision tree from the training set (X, y).""" @@ -101,8 +101,8 @@ cdef class TreeBuilder: cdef inline _check_input( self, object X, - const DOUBLE_t[:, ::1] y, - const DOUBLE_t[:] sample_weight, + const float64_t[:, ::1] y, + const float64_t[:] sample_weight, ): """Check input dtype, layout and format""" if issparse(X): @@ -140,22 +140,22 @@ cdef class TreeBuilder: # Depth first builder --------------------------------------------------------- # A record on the stack for depth-first tree growing cdef struct StackRecord: - SIZE_t start - SIZE_t end - SIZE_t depth - SIZE_t parent + intp_t start + intp_t end + intp_t depth + intp_t parent bint is_left double impurity - SIZE_t n_constant_features + intp_t n_constant_features double lower_bound double upper_bound cdef class DepthFirstTreeBuilder(TreeBuilder): """Build a decision tree in depth-first fashion.""" - def __cinit__(self, Splitter splitter, SIZE_t min_samples_split, - SIZE_t min_samples_leaf, double min_weight_leaf, - SIZE_t max_depth, double min_impurity_decrease): + def __cinit__(self, Splitter splitter, intp_t min_samples_split, + intp_t min_samples_leaf, double min_weight_leaf, + intp_t max_depth, double min_impurity_decrease): self.splitter = splitter self.min_samples_split = min_samples_split self.min_samples_leaf = min_samples_leaf @@ -167,8 +167,8 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): self, Tree tree, object X, - const DOUBLE_t[:, ::1] y, - const DOUBLE_t[:] sample_weight=None, + const float64_t[:, ::1] y, + const float64_t[:] sample_weight=None, const unsigned char[::1] missing_values_in_feature_mask=None, ): """Build a decision tree from the training set (X, y).""" @@ -188,33 +188,33 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): # Parameters cdef Splitter splitter = self.splitter - cdef SIZE_t max_depth = self.max_depth - cdef SIZE_t min_samples_leaf = self.min_samples_leaf + cdef intp_t max_depth = self.max_depth + cdef intp_t min_samples_leaf = self.min_samples_leaf cdef double min_weight_leaf = self.min_weight_leaf - cdef SIZE_t min_samples_split = self.min_samples_split + cdef intp_t min_samples_split = self.min_samples_split cdef double min_impurity_decrease = self.min_impurity_decrease # Recursive partition (without actual recursion) splitter.init(X, y, sample_weight, missing_values_in_feature_mask) - cdef SIZE_t start - cdef SIZE_t end - cdef SIZE_t depth - cdef SIZE_t parent + cdef intp_t start + cdef intp_t end + cdef intp_t depth + cdef intp_t parent cdef bint is_left - cdef SIZE_t n_node_samples = splitter.n_samples + cdef intp_t n_node_samples = splitter.n_samples cdef double weighted_n_node_samples cdef SplitRecord split - cdef SIZE_t node_id + cdef intp_t node_id cdef double impurity = INFINITY cdef double lower_bound cdef double upper_bound cdef double middle_value - cdef SIZE_t n_constant_features + cdef intp_t n_constant_features cdef bint is_leaf cdef bint first = 1 - cdef SIZE_t max_depth_seen = -1 + cdef intp_t max_depth_seen = -1 cdef int rc = 0 cdef stack[StackRecord] builder_stack @@ -369,11 +369,11 @@ cdef struct FrontierRecord: # Record of information of a Node, the frontier for a split. Those records are # maintained in a heap to access the Node with the best improvement in impurity, # allowing growing trees greedily on this improvement. - SIZE_t node_id - SIZE_t start - SIZE_t end - SIZE_t pos - SIZE_t depth + intp_t node_id + intp_t start + intp_t end + intp_t pos + intp_t depth bint is_leaf double impurity double impurity_left @@ -404,11 +404,11 @@ cdef class BestFirstTreeBuilder(TreeBuilder): The best node to expand is given by the node at the frontier that has the highest impurity improvement. """ - cdef SIZE_t max_leaf_nodes + cdef intp_t max_leaf_nodes - def __cinit__(self, Splitter splitter, SIZE_t min_samples_split, - SIZE_t min_samples_leaf, min_weight_leaf, - SIZE_t max_depth, SIZE_t max_leaf_nodes, + def __cinit__(self, Splitter splitter, intp_t min_samples_split, + intp_t min_samples_leaf, min_weight_leaf, + intp_t max_depth, intp_t max_leaf_nodes, double min_impurity_decrease): self.splitter = splitter self.min_samples_split = min_samples_split @@ -422,8 +422,8 @@ cdef class BestFirstTreeBuilder(TreeBuilder): self, Tree tree, object X, - const DOUBLE_t[:, ::1] y, - const DOUBLE_t[:] sample_weight=None, + const float64_t[:, ::1] y, + const float64_t[:] sample_weight=None, const unsigned char[::1] missing_values_in_feature_mask=None, ): """Build a decision tree from the training set (X, y).""" @@ -433,7 +433,7 @@ cdef class BestFirstTreeBuilder(TreeBuilder): # Parameters cdef Splitter splitter = self.splitter - cdef SIZE_t max_leaf_nodes = self.max_leaf_nodes + cdef intp_t max_leaf_nodes = self.max_leaf_nodes # Recursive partition (without actual recursion) splitter.init(X, y, sample_weight, missing_values_in_feature_mask) @@ -447,15 +447,15 @@ cdef class BestFirstTreeBuilder(TreeBuilder): cdef double right_child_min cdef double right_child_max - cdef SIZE_t n_node_samples = splitter.n_samples - cdef SIZE_t max_split_nodes = max_leaf_nodes - 1 + cdef intp_t n_node_samples = splitter.n_samples + cdef intp_t max_split_nodes = max_leaf_nodes - 1 cdef bint is_leaf - cdef SIZE_t max_depth_seen = -1 + cdef intp_t max_depth_seen = -1 cdef int rc = 0 cdef Node* node # Initial capacity - cdef SIZE_t init_capacity = max_split_nodes + max_leaf_nodes + cdef intp_t init_capacity = max_split_nodes + max_leaf_nodes tree._resize(init_capacity) with nogil: @@ -587,22 +587,22 @@ cdef class BestFirstTreeBuilder(TreeBuilder): self, Splitter splitter, Tree tree, - SIZE_t start, - SIZE_t end, + intp_t start, + intp_t end, double impurity, bint is_first, bint is_left, Node* parent, - SIZE_t depth, + intp_t depth, double lower_bound, double upper_bound, FrontierRecord* res ) except -1 nogil: """Adds node w/ partition ``[start, end)`` to the frontier. """ cdef SplitRecord split - cdef SIZE_t node_id - cdef SIZE_t n_node_samples - cdef SIZE_t n_constant_features = 0 + cdef intp_t node_id + cdef intp_t n_node_samples + cdef intp_t n_constant_features = 0 cdef double min_impurity_decrease = self.min_impurity_decrease cdef double weighted_n_node_samples cdef bint is_leaf @@ -795,7 +795,7 @@ cdef class Tree: # https://github.com/cython/cython/issues/5243 is fixed def __cinit__(self, int n_features, cnp.ndarray n_classes, int n_outputs): """Constructor.""" - cdef SIZE_t dummy = 0 + cdef intp_t dummy = 0 size_t_dtype = np.array(dummy).dtype n_classes = _check_n_classes(n_classes, size_t_dtype) @@ -809,7 +809,7 @@ cdef class Tree: self.max_n_classes = np.max(n_classes) self.value_stride = n_outputs * self.max_n_classes - cdef SIZE_t k + cdef intp_t k for k in range(n_outputs): self.n_classes[k] = n_classes[k] @@ -874,7 +874,7 @@ cdef class Tree: memcpy(self.value, cnp.PyArray_DATA(value_ndarray), self.capacity * self.value_stride * sizeof(double)) - cdef int _resize(self, SIZE_t capacity) except -1 nogil: + cdef int _resize(self, intp_t capacity) except -1 nogil: """Resize all inner arrays to `capacity`, if `capacity` == -1, then double the size of the inner arrays. @@ -886,7 +886,7 @@ cdef class Tree: with gil: raise MemoryError() - cdef int _resize_c(self, SIZE_t capacity=INTPTR_MAX) except -1 nogil: + cdef int _resize_c(self, intp_t capacity=INTPTR_MAX) except -1 nogil: """Guts of _resize Returns -1 in case of failure to allocate memory (and raise MemoryError) @@ -917,9 +917,9 @@ cdef class Tree: self.capacity = capacity return 0 - cdef SIZE_t _add_node(self, SIZE_t parent, bint is_left, bint is_leaf, - SIZE_t feature, double threshold, double impurity, - SIZE_t n_node_samples, + cdef intp_t _add_node(self, intp_t parent, bint is_left, bint is_leaf, + intp_t feature, double threshold, double impurity, + intp_t n_node_samples, double weighted_n_node_samples, unsigned char missing_go_to_left) except -1 nogil: """Add a node to the tree. @@ -928,7 +928,7 @@ cdef class Tree: Returns (size_t)(-1) on error. """ - cdef SIZE_t node_id = self.node_count + cdef intp_t node_id = self.node_count if node_id >= self.capacity: if self._resize_c() != 0: @@ -988,16 +988,16 @@ cdef class Tree: raise ValueError("X.dtype should be np.float32, got %s" % X.dtype) # Extract input - cdef const DTYPE_t[:, :] X_ndarray = X - cdef SIZE_t n_samples = X.shape[0] - cdef DTYPE_t X_i_node_feature + cdef const float32_t[:, :] X_ndarray = X + cdef intp_t n_samples = X.shape[0] + cdef float32_t X_i_node_feature # Initialize output - cdef SIZE_t[:] out = np.zeros(n_samples, dtype=np.intp) + cdef intp_t[:] out = np.zeros(n_samples, dtype=np.intp) # Initialize auxiliary data-structure cdef Node* node = NULL - cdef SIZE_t i = 0 + cdef intp_t i = 0 with nogil: for i in range(n_samples): @@ -1016,7 +1016,7 @@ cdef class Tree: else: node = &self.nodes[node.right_child] - out[i] = (node - self.nodes) # node offset + out[i] = (node - self.nodes) # node offset return np.asarray(out) @@ -1032,33 +1032,33 @@ cdef class Tree: raise ValueError("X.dtype should be np.float32, got %s" % X.dtype) # Extract input - cdef const DTYPE_t[:] X_data = X.data - cdef const INT32_t[:] X_indices = X.indices - cdef const INT32_t[:] X_indptr = X.indptr + cdef const float32_t[:] X_data = X.data + cdef const int32_t[:] X_indices = X.indices + cdef const int32_t[:] X_indptr = X.indptr - cdef SIZE_t n_samples = X.shape[0] - cdef SIZE_t n_features = X.shape[1] + cdef intp_t n_samples = X.shape[0] + cdef intp_t n_features = X.shape[1] # Initialize output - cdef SIZE_t[:] out = np.zeros(n_samples, dtype=np.intp) + cdef intp_t[:] out = np.zeros(n_samples, dtype=np.intp) # Initialize auxiliary data-structure - cdef DTYPE_t feature_value = 0. + cdef float32_t feature_value = 0. cdef Node* node = NULL - cdef DTYPE_t* X_sample = NULL - cdef SIZE_t i = 0 - cdef INT32_t k = 0 + cdef float32_t* X_sample = NULL + cdef intp_t i = 0 + cdef int32_t k = 0 # feature_to_sample as a data structure records the last seen sample # for each feature; functionally, it is an efficient way to identify # which features are nonzero in the present sample. - cdef SIZE_t* feature_to_sample = NULL + cdef intp_t* feature_to_sample = NULL safe_realloc(&X_sample, n_features) safe_realloc(&feature_to_sample, n_features) with nogil: - memset(feature_to_sample, -1, n_features * sizeof(SIZE_t)) + memset(feature_to_sample, -1, n_features * sizeof(intp_t)) for i in range(n_samples): node = self.nodes @@ -1081,7 +1081,7 @@ cdef class Tree: else: node = &self.nodes[node.right_child] - out[i] = (node - self.nodes) # node offset + out[i] = (node - self.nodes) # node offset # Free auxiliary arrays free(X_sample) @@ -1108,18 +1108,18 @@ cdef class Tree: raise ValueError("X.dtype should be np.float32, got %s" % X.dtype) # Extract input - cdef const DTYPE_t[:, :] X_ndarray = X - cdef SIZE_t n_samples = X.shape[0] + cdef const float32_t[:, :] X_ndarray = X + cdef intp_t n_samples = X.shape[0] # Initialize output - cdef SIZE_t[:] indptr = np.zeros(n_samples + 1, dtype=np.intp) - cdef SIZE_t[:] indices = np.zeros( + cdef intp_t[:] indptr = np.zeros(n_samples + 1, dtype=np.intp) + cdef intp_t[:] indices = np.zeros( n_samples * (1 + self.max_depth), dtype=np.intp ) # Initialize auxiliary data-structure cdef Node* node = NULL - cdef SIZE_t i = 0 + cdef intp_t i = 0 with nogil: for i in range(n_samples): @@ -1129,7 +1129,7 @@ cdef class Tree: # Add all external nodes while node.left_child != _TREE_LEAF: # ... and node.right_child != _TREE_LEAF: - indices[indptr[i + 1]] = (node - self.nodes) + indices[indptr[i + 1]] = (node - self.nodes) indptr[i + 1] += 1 if X_ndarray[i, node.feature] <= node.threshold: @@ -1138,11 +1138,11 @@ cdef class Tree: node = &self.nodes[node.right_child] # Add the leave node - indices[indptr[i + 1]] = (node - self.nodes) + indices[indptr[i + 1]] = (node - self.nodes) indptr[i + 1] += 1 indices = indices[:indptr[n_samples]] - cdef SIZE_t[:] data = np.ones(shape=len(indices), dtype=np.intp) + cdef intp_t[:] data = np.ones(shape=len(indices), dtype=np.intp) out = csr_matrix((data, indices, indptr), shape=(n_samples, self.node_count)) @@ -1160,36 +1160,36 @@ cdef class Tree: raise ValueError("X.dtype should be np.float32, got %s" % X.dtype) # Extract input - cdef const DTYPE_t[:] X_data = X.data - cdef const INT32_t[:] X_indices = X.indices - cdef const INT32_t[:] X_indptr = X.indptr + cdef const float32_t[:] X_data = X.data + cdef const int32_t[:] X_indices = X.indices + cdef const int32_t[:] X_indptr = X.indptr - cdef SIZE_t n_samples = X.shape[0] - cdef SIZE_t n_features = X.shape[1] + cdef intp_t n_samples = X.shape[0] + cdef intp_t n_features = X.shape[1] # Initialize output - cdef SIZE_t[:] indptr = np.zeros(n_samples + 1, dtype=np.intp) - cdef SIZE_t[:] indices = np.zeros( + cdef intp_t[:] indptr = np.zeros(n_samples + 1, dtype=np.intp) + cdef intp_t[:] indices = np.zeros( n_samples * (1 + self.max_depth), dtype=np.intp ) # Initialize auxiliary data-structure - cdef DTYPE_t feature_value = 0. + cdef float32_t feature_value = 0. cdef Node* node = NULL - cdef DTYPE_t* X_sample = NULL - cdef SIZE_t i = 0 - cdef INT32_t k = 0 + cdef float32_t* X_sample = NULL + cdef intp_t i = 0 + cdef int32_t k = 0 # feature_to_sample as a data structure records the last seen sample # for each feature; functionally, it is an efficient way to identify # which features are nonzero in the present sample. - cdef SIZE_t* feature_to_sample = NULL + cdef intp_t* feature_to_sample = NULL safe_realloc(&X_sample, n_features) safe_realloc(&feature_to_sample, n_features) with nogil: - memset(feature_to_sample, -1, n_features * sizeof(SIZE_t)) + memset(feature_to_sample, -1, n_features * sizeof(intp_t)) for i in range(n_samples): node = self.nodes @@ -1203,7 +1203,7 @@ cdef class Tree: while node.left_child != _TREE_LEAF: # ... and node.right_child != _TREE_LEAF: - indices[indptr[i + 1]] = (node - self.nodes) + indices[indptr[i + 1]] = (node - self.nodes) indptr[i + 1] += 1 if feature_to_sample[node.feature] == i: @@ -1218,7 +1218,7 @@ cdef class Tree: node = &self.nodes[node.right_child] # Add the leave node - indices[indptr[i + 1]] = (node - self.nodes) + indices[indptr[i + 1]] = (node - self.nodes) indptr[i + 1] += 1 # Free auxiliary arrays @@ -1226,7 +1226,7 @@ cdef class Tree: free(feature_to_sample) indices = indices[:indptr[n_samples]] - cdef SIZE_t[:] data = np.ones(shape=len(indices), dtype=np.intp) + cdef intp_t[:] data = np.ones(shape=len(indices), dtype=np.intp) out = csr_matrix((data, indices, indptr), shape=(n_samples, self.node_count)) @@ -1336,7 +1336,7 @@ cdef class Tree: raise ValueError("Can't initialize array.") return arr - def compute_partial_dependence(self, DTYPE_t[:, ::1] X, + def compute_partial_dependence(self, float32_t[:, ::1] X, int[::1] target_features, double[::1] out): """Partial dependence of the response on the ``target_feature`` set. @@ -1369,18 +1369,18 @@ cdef class Tree: cdef: double[::1] weight_stack = np.zeros(self.node_count, dtype=np.float64) - SIZE_t[::1] node_idx_stack = np.zeros(self.node_count, + intp_t[::1] node_idx_stack = np.zeros(self.node_count, dtype=np.intp) - SIZE_t sample_idx - SIZE_t feature_idx + intp_t sample_idx + intp_t feature_idx int stack_size double left_sample_frac double current_weight double total_weight # used for sanity check only Node *current_node # use a pointer to avoid copying attributes - SIZE_t current_node_idx + intp_t current_node_idx bint is_target_feature - SIZE_t _TREE_LEAF = TREE_LEAF # to avoid python interactions + intp_t _TREE_LEAF = TREE_LEAF # to avoid python interactions for sample_idx in range(X.shape[0]): # init stacks for current sample @@ -1495,7 +1495,7 @@ def _dtype_to_dict(dtype): def _dtype_dict_with_modified_bitness(dtype_dict): - # field names in Node struct with SIZE_t types (see sklearn/tree/_tree.pxd) + # field names in Node struct with intp_t types (see sklearn/tree/_tree.pxd) indexing_field_names = ["left_child", "right_child", "feature", "n_node_samples"] expected_dtype_size = str(struct.calcsize("P")) @@ -1511,7 +1511,7 @@ def _dtype_dict_with_modified_bitness(dtype_dict): def _all_compatible_dtype_dicts(dtype): - # The Cython code for decision trees uses platform-specific SIZE_t + # The Cython code for decision trees uses platform-specific intp_t # typed indexing fields that correspond to either i4 or i8 dtypes for # the matching fields in the numpy array depending on the bitness of # the platform (32 bit or 64 bit respectively). @@ -1580,12 +1580,12 @@ cdef class _CCPPruneController: """Base class used by build_pruned_tree_ccp and ccp_pruning_path to control pruning. """ - cdef bint stop_pruning(self, DOUBLE_t effective_alpha) noexcept nogil: + cdef bint stop_pruning(self, float64_t effective_alpha) noexcept nogil: """Return 1 to stop pruning and 0 to continue pruning""" return 0 - cdef void save_metrics(self, DOUBLE_t effective_alpha, - DOUBLE_t subtree_impurities) noexcept nogil: + cdef void save_metrics(self, float64_t effective_alpha, + float64_t subtree_impurities) noexcept nogil: """Save metrics when pruning""" pass @@ -1596,14 +1596,14 @@ cdef class _CCPPruneController: cdef class _AlphaPruner(_CCPPruneController): """Use alpha to control when to stop pruning.""" - cdef DOUBLE_t ccp_alpha - cdef SIZE_t capacity + cdef float64_t ccp_alpha + cdef intp_t capacity - def __cinit__(self, DOUBLE_t ccp_alpha): + def __cinit__(self, float64_t ccp_alpha): self.ccp_alpha = ccp_alpha self.capacity = 0 - cdef bint stop_pruning(self, DOUBLE_t effective_alpha) noexcept nogil: + cdef bint stop_pruning(self, float64_t effective_alpha) noexcept nogil: # The subtree on the previous iteration has the greatest ccp_alpha # less than or equal to self.ccp_alpha return self.ccp_alpha < effective_alpha @@ -1617,9 +1617,9 @@ cdef class _AlphaPruner(_CCPPruneController): cdef class _PathFinder(_CCPPruneController): """Record metrics used to return the cost complexity path.""" - cdef DOUBLE_t[:] ccp_alphas - cdef DOUBLE_t[:] impurities - cdef UINT32_t count + cdef float64_t[:] ccp_alphas + cdef float64_t[:] impurities + cdef uint32_t count def __cinit__(self, int node_count): self.ccp_alphas = np.zeros(shape=(node_count), dtype=np.float64) @@ -1627,16 +1627,16 @@ cdef class _PathFinder(_CCPPruneController): self.count = 0 cdef void save_metrics(self, - DOUBLE_t effective_alpha, - DOUBLE_t subtree_impurities) noexcept nogil: + float64_t effective_alpha, + float64_t subtree_impurities) noexcept nogil: self.ccp_alphas[self.count] = effective_alpha self.impurities[self.count] = subtree_impurities self.count += 1 cdef struct CostComplexityPruningRecord: - SIZE_t node_idx - SIZE_t parent + intp_t node_idx + intp_t parent cdef _cost_complexity_prune(unsigned char[:] leaves_in_subtree, # OUT Tree orig_tree, @@ -1660,41 +1660,41 @@ cdef _cost_complexity_prune(unsigned char[:] leaves_in_subtree, # OUT """ cdef: - SIZE_t i - SIZE_t n_nodes = orig_tree.node_count + intp_t i + intp_t n_nodes = orig_tree.node_count # prior probability using weighted samples - DOUBLE_t[:] weighted_n_node_samples = orig_tree.weighted_n_node_samples - DOUBLE_t total_sum_weights = weighted_n_node_samples[0] - DOUBLE_t[:] impurity = orig_tree.impurity + float64_t[:] weighted_n_node_samples = orig_tree.weighted_n_node_samples + float64_t total_sum_weights = weighted_n_node_samples[0] + float64_t[:] impurity = orig_tree.impurity # weighted impurity of each node - DOUBLE_t[:] r_node = np.empty(shape=n_nodes, dtype=np.float64) + float64_t[:] r_node = np.empty(shape=n_nodes, dtype=np.float64) - SIZE_t[:] child_l = orig_tree.children_left - SIZE_t[:] child_r = orig_tree.children_right - SIZE_t[:] parent = np.zeros(shape=n_nodes, dtype=np.intp) + intp_t[:] child_l = orig_tree.children_left + intp_t[:] child_r = orig_tree.children_right + intp_t[:] parent = np.zeros(shape=n_nodes, dtype=np.intp) stack[CostComplexityPruningRecord] ccp_stack CostComplexityPruningRecord stack_record - SIZE_t node_idx - stack[SIZE_t] node_indices_stack + intp_t node_idx + stack[intp_t] node_indices_stack - SIZE_t[:] n_leaves = np.zeros(shape=n_nodes, dtype=np.intp) - DOUBLE_t[:] r_branch = np.zeros(shape=n_nodes, dtype=np.float64) - DOUBLE_t current_r - SIZE_t leaf_idx - SIZE_t parent_idx + intp_t[:] n_leaves = np.zeros(shape=n_nodes, dtype=np.intp) + float64_t[:] r_branch = np.zeros(shape=n_nodes, dtype=np.float64) + float64_t current_r + intp_t leaf_idx + intp_t parent_idx # candidate nodes that can be pruned unsigned char[:] candidate_nodes = np.zeros(shape=n_nodes, dtype=np.uint8) # nodes in subtree unsigned char[:] in_subtree = np.ones(shape=n_nodes, dtype=np.uint8) - SIZE_t pruned_branch_node_idx - DOUBLE_t subtree_alpha - DOUBLE_t effective_alpha - SIZE_t n_pruned_leaves - DOUBLE_t r_diff - DOUBLE_t max_float64 = np.finfo(np.float64).max + intp_t pruned_branch_node_idx + float64_t subtree_alpha + float64_t effective_alpha + intp_t n_pruned_leaves + float64_t r_diff + float64_t max_float64 = np.finfo(np.float64).max # find parent node ids and leaves with nogil: @@ -1800,7 +1800,7 @@ cdef _cost_complexity_prune(unsigned char[:] leaves_in_subtree, # OUT def _build_pruned_tree_ccp( Tree tree, # OUT Tree orig_tree, - DOUBLE_t ccp_alpha + float64_t ccp_alpha ): """Build a pruned tree from the original tree using cost complexity pruning. @@ -1821,7 +1821,7 @@ def _build_pruned_tree_ccp( """ cdef: - SIZE_t n_nodes = orig_tree.node_count + intp_t n_nodes = orig_tree.node_count unsigned char[:] leaves_in_subtree = np.zeros( shape=n_nodes, dtype=np.uint8) @@ -1862,10 +1862,10 @@ def ccp_pruning_path(Tree orig_tree): _cost_complexity_prune(leaves_in_subtree, orig_tree, path_finder) cdef: - UINT32_t total_items = path_finder.count - DOUBLE_t[:] ccp_alphas = np.empty(shape=total_items, dtype=np.float64) - DOUBLE_t[:] impurities = np.empty(shape=total_items, dtype=np.float64) - UINT32_t count = 0 + uint32_t total_items = path_finder.count + float64_t[:] ccp_alphas = np.empty(shape=total_items, dtype=np.float64) + float64_t[:] impurities = np.empty(shape=total_items, dtype=np.float64) + uint32_t count = 0 while count < total_items: ccp_alphas[count] = path_finder.ccp_alphas[count] @@ -1879,16 +1879,16 @@ def ccp_pruning_path(Tree orig_tree): cdef struct BuildPrunedRecord: - SIZE_t start - SIZE_t depth - SIZE_t parent + intp_t start + intp_t depth + intp_t parent bint is_left cdef _build_pruned_tree( Tree tree, # OUT Tree orig_tree, const unsigned char[:] leaves_in_subtree, - SIZE_t capacity + intp_t capacity ): """Build a pruned tree. @@ -1903,22 +1903,22 @@ cdef _build_pruned_tree( Original tree leaves_in_subtree : unsigned char memoryview, shape=(node_count, ) Boolean mask for leaves to include in subtree - capacity : SIZE_t + capacity : intp_t Number of nodes to initially allocate in pruned tree """ tree._resize(capacity) cdef: - SIZE_t orig_node_id - SIZE_t new_node_id - SIZE_t depth - SIZE_t parent + intp_t orig_node_id + intp_t new_node_id + intp_t depth + intp_t parent bint is_left bint is_leaf # value_stride for original tree and new tree are the same - SIZE_t value_stride = orig_tree.value_stride - SIZE_t max_depth_seen = -1 + intp_t value_stride = orig_tree.value_stride + intp_t max_depth_seen = -1 int rc = 0 Node* node double* orig_value_ptr diff --git a/sklearn/tree/_utils.pxd b/sklearn/tree/_utils.pxd index 4b953af2d9b2b..cd7a77cc1bbc9 100644 --- a/sklearn/tree/_utils.pxd +++ b/sklearn/tree/_utils.pxd @@ -11,13 +11,7 @@ cimport numpy as cnp from ._tree cimport Node from ..neighbors._quad_tree cimport Cell - -ctypedef cnp.npy_float32 DTYPE_t # Type of X -ctypedef cnp.npy_float64 DOUBLE_t # Type of y, sample_weight -ctypedef cnp.npy_intp SIZE_t # Type for indices and counters -ctypedef cnp.npy_int32 INT32_t # Signed 32 bit integer -ctypedef cnp.npy_uint32 UINT32_t # Unsigned 32 bit integer - +from ..utils._typedefs cimport float32_t, float64_t, intp_t, int32_t, uint32_t cdef enum: # Max value for our rand_r replacement (near the bottom). @@ -30,17 +24,17 @@ cdef enum: # safe_realloc(&p, n) resizes the allocation of p to n * sizeof(*p) bytes or # raises a MemoryError. It never calls free, since that's __dealloc__'s job. -# cdef DTYPE_t *p = NULL +# cdef float32_t *p = NULL # safe_realloc(&p, n) # is equivalent to p = malloc(n * sizeof(*p)) with error checking. ctypedef fused realloc_ptr: # Add pointer types here as needed. - (DTYPE_t*) - (SIZE_t*) + (float32_t*) + (intp_t*) (unsigned char*) (WeightedPQueueRecord*) - (DOUBLE_t*) - (DOUBLE_t**) + (float64_t*) + (float64_t**) (Node*) (Cell*) (Node**) @@ -48,15 +42,15 @@ ctypedef fused realloc_ptr: cdef realloc_ptr safe_realloc(realloc_ptr* p, size_t nelems) except * nogil -cdef cnp.ndarray sizet_ptr_to_ndarray(SIZE_t* data, SIZE_t size) +cdef cnp.ndarray sizet_ptr_to_ndarray(intp_t* data, intp_t size) -cdef SIZE_t rand_int(SIZE_t low, SIZE_t high, - UINT32_t* random_state) noexcept nogil +cdef intp_t rand_int(intp_t low, intp_t high, + uint32_t* random_state) noexcept nogil cdef double rand_uniform(double low, double high, - UINT32_t* random_state) noexcept nogil + uint32_t* random_state) noexcept nogil cdef double log(double x) noexcept nogil @@ -67,23 +61,23 @@ cdef double log(double x) noexcept nogil # A record stored in the WeightedPQueue cdef struct WeightedPQueueRecord: - DOUBLE_t data - DOUBLE_t weight + float64_t data + float64_t weight cdef class WeightedPQueue: - cdef SIZE_t capacity - cdef SIZE_t array_ptr + cdef intp_t capacity + cdef intp_t array_ptr cdef WeightedPQueueRecord* array_ cdef bint is_empty(self) noexcept nogil cdef int reset(self) except -1 nogil - cdef SIZE_t size(self) noexcept nogil - cdef int push(self, DOUBLE_t data, DOUBLE_t weight) except -1 nogil - cdef int remove(self, DOUBLE_t data, DOUBLE_t weight) noexcept nogil - cdef int pop(self, DOUBLE_t* data, DOUBLE_t* weight) noexcept nogil - cdef int peek(self, DOUBLE_t* data, DOUBLE_t* weight) noexcept nogil - cdef DOUBLE_t get_weight_from_index(self, SIZE_t index) noexcept nogil - cdef DOUBLE_t get_value_from_index(self, SIZE_t index) noexcept nogil + cdef intp_t size(self) noexcept nogil + cdef int push(self, float64_t data, float64_t weight) except -1 nogil + cdef int remove(self, float64_t data, float64_t weight) noexcept nogil + cdef int pop(self, float64_t* data, float64_t* weight) noexcept nogil + cdef int peek(self, float64_t* data, float64_t* weight) noexcept nogil + cdef float64_t get_weight_from_index(self, intp_t index) noexcept nogil + cdef float64_t get_value_from_index(self, intp_t index) noexcept nogil # ============================================================================= @@ -91,20 +85,20 @@ cdef class WeightedPQueue: # ============================================================================= cdef class WeightedMedianCalculator: - cdef SIZE_t initial_capacity + cdef intp_t initial_capacity cdef WeightedPQueue samples - cdef DOUBLE_t total_weight - cdef SIZE_t k - cdef DOUBLE_t sum_w_0_k # represents sum(weights[0:k]) = w[0] + w[1] + ... + w[k-1] - cdef SIZE_t size(self) noexcept nogil - cdef int push(self, DOUBLE_t data, DOUBLE_t weight) except -1 nogil + cdef float64_t total_weight + cdef intp_t k + cdef float64_t sum_w_0_k # represents sum(weights[0:k]) = w[0] + w[1] + ... + w[k-1] + cdef intp_t size(self) noexcept nogil + cdef int push(self, float64_t data, float64_t weight) except -1 nogil cdef int reset(self) except -1 nogil cdef int update_median_parameters_post_push( - self, DOUBLE_t data, DOUBLE_t weight, - DOUBLE_t original_median) noexcept nogil - cdef int remove(self, DOUBLE_t data, DOUBLE_t weight) noexcept nogil - cdef int pop(self, DOUBLE_t* data, DOUBLE_t* weight) noexcept nogil + self, float64_t data, float64_t weight, + float64_t original_median) noexcept nogil + cdef int remove(self, float64_t data, float64_t weight) noexcept nogil + cdef int pop(self, float64_t* data, float64_t* weight) noexcept nogil cdef int update_median_parameters_post_remove( - self, DOUBLE_t data, DOUBLE_t weight, - DOUBLE_t original_median) noexcept nogil - cdef DOUBLE_t get_median(self) noexcept nogil + self, float64_t data, float64_t weight, + float64_t original_median) noexcept nogil + cdef float64_t get_median(self) noexcept nogil diff --git a/sklearn/tree/_utils.pyx b/sklearn/tree/_utils.pyx index 669d69409fdc3..98a8249928b6f 100644 --- a/sklearn/tree/_utils.pyx +++ b/sklearn/tree/_utils.pyx @@ -43,28 +43,28 @@ cdef realloc_ptr safe_realloc(realloc_ptr* p, size_t nelems) except * nogil: def _realloc_test(): # Helper for tests. Tries to allocate (-1) / 2 * sizeof(size_t) # bytes, which will always overflow. - cdef SIZE_t* p = NULL + cdef intp_t* p = NULL safe_realloc(&p, (-1) / 2) if p != NULL: free(p) assert False -cdef inline cnp.ndarray sizet_ptr_to_ndarray(SIZE_t* data, SIZE_t size): +cdef inline cnp.ndarray sizet_ptr_to_ndarray(intp_t* data, intp_t size): """Return copied data as 1D numpy array of intp's.""" cdef cnp.npy_intp shape[1] shape[0] = size return cnp.PyArray_SimpleNewFromData(1, shape, cnp.NPY_INTP, data).copy() -cdef inline SIZE_t rand_int(SIZE_t low, SIZE_t high, - UINT32_t* random_state) noexcept nogil: +cdef inline intp_t rand_int(intp_t low, intp_t high, + uint32_t* random_state) noexcept nogil: """Generate a random integer in [low; end).""" return low + our_rand_r(random_state) % (high - low) cdef inline double rand_uniform(double low, double high, - UINT32_t* random_state) noexcept nogil: + uint32_t* random_state) noexcept nogil: """Generate a random double in [low; high).""" return ((high - low) * our_rand_r(random_state) / RAND_R_MAX) + low @@ -82,10 +82,10 @@ cdef class WeightedPQueue: Attributes ---------- - capacity : SIZE_t + capacity : intp_t The capacity of the priority queue. - array_ptr : SIZE_t + array_ptr : intp_t The water mark of the priority queue; the priority queue grows from left to right in the array ``array_``. ``array_ptr`` is always less than ``capacity``. @@ -96,7 +96,7 @@ cdef class WeightedPQueue: ``array_ptr-1``. """ - def __cinit__(self, SIZE_t capacity): + def __cinit__(self, intp_t capacity): self.capacity = capacity self.array_ptr = 0 safe_realloc(&self.array_, capacity) @@ -118,18 +118,18 @@ cdef class WeightedPQueue: cdef bint is_empty(self) noexcept nogil: return self.array_ptr <= 0 - cdef SIZE_t size(self) noexcept nogil: + cdef intp_t size(self) noexcept nogil: return self.array_ptr - cdef int push(self, DOUBLE_t data, DOUBLE_t weight) except -1 nogil: + cdef int push(self, float64_t data, float64_t weight) except -1 nogil: """Push record on the array. Return -1 in case of failure to allocate memory (and raise MemoryError) or 0 otherwise. """ - cdef SIZE_t array_ptr = self.array_ptr + cdef intp_t array_ptr = self.array_ptr cdef WeightedPQueueRecord* array = NULL - cdef SIZE_t i + cdef intp_t i # Resize if capacity not sufficient if array_ptr >= self.capacity: @@ -153,13 +153,13 @@ cdef class WeightedPQueue: self.array_ptr = array_ptr + 1 return 0 - cdef int remove(self, DOUBLE_t data, DOUBLE_t weight) noexcept nogil: + cdef int remove(self, float64_t data, float64_t weight) noexcept nogil: """Remove a specific value/weight record from the array. Returns 0 if successful, -1 if record not found.""" - cdef SIZE_t array_ptr = self.array_ptr + cdef intp_t array_ptr = self.array_ptr cdef WeightedPQueueRecord* array = self.array_ - cdef SIZE_t idx_to_remove = -1 - cdef SIZE_t i + cdef intp_t idx_to_remove = -1 + cdef intp_t i if array_ptr <= 0: return -1 @@ -181,12 +181,12 @@ cdef class WeightedPQueue: self.array_ptr = array_ptr - 1 return 0 - cdef int pop(self, DOUBLE_t* data, DOUBLE_t* weight) noexcept nogil: + cdef int pop(self, float64_t* data, float64_t* weight) noexcept nogil: """Remove the top (minimum) element from array. Returns 0 if successful, -1 if nothing to remove.""" - cdef SIZE_t array_ptr = self.array_ptr + cdef intp_t array_ptr = self.array_ptr cdef WeightedPQueueRecord* array = self.array_ - cdef SIZE_t i + cdef intp_t i if array_ptr <= 0: return -1 @@ -202,7 +202,7 @@ cdef class WeightedPQueue: self.array_ptr = array_ptr - 1 return 0 - cdef int peek(self, DOUBLE_t* data, DOUBLE_t* weight) noexcept nogil: + cdef int peek(self, float64_t* data, float64_t* weight) noexcept nogil: """Write the top element from array to a pointer. Returns 0 if successful, -1 if nothing to write.""" cdef WeightedPQueueRecord* array = self.array_ @@ -213,7 +213,7 @@ cdef class WeightedPQueue: weight[0] = array[0].weight return 0 - cdef DOUBLE_t get_weight_from_index(self, SIZE_t index) noexcept nogil: + cdef float64_t get_weight_from_index(self, intp_t index) noexcept nogil: """Given an index between [0,self.current_capacity], access the appropriate heap and return the requested weight""" cdef WeightedPQueueRecord* array = self.array_ @@ -221,7 +221,7 @@ cdef class WeightedPQueue: # get weight at index return array[index].weight - cdef DOUBLE_t get_value_from_index(self, SIZE_t index) noexcept nogil: + cdef float64_t get_value_from_index(self, intp_t index) noexcept nogil: """Given an index between [0,self.current_capacity], access the appropriate heap and return the requested value""" cdef WeightedPQueueRecord* array = self.array_ @@ -245,21 +245,21 @@ cdef class WeightedMedianCalculator: Attributes ---------- - initial_capacity : SIZE_t + initial_capacity : intp_t The initial capacity of the WeightedMedianCalculator. samples : WeightedPQueue Holds the samples (consisting of values and their weights) used in the weighted median calculation. - total_weight : DOUBLE_t + total_weight : float64_t The sum of the weights of items in ``samples``. Represents the total weight of all samples used in the median calculation. - k : SIZE_t + k : intp_t Index used to calculate the median. - sum_w_0_k : DOUBLE_t + sum_w_0_k : float64_t The sum of the weights from samples[0:k]. Used in the weighted median calculation; minimizing the value of ``k`` such that ``sum_w_0_k`` >= ``total_weight / 2`` provides a mechanism for @@ -267,14 +267,14 @@ cdef class WeightedMedianCalculator: """ - def __cinit__(self, SIZE_t initial_capacity): + def __cinit__(self, intp_t initial_capacity): self.initial_capacity = initial_capacity self.samples = WeightedPQueue(initial_capacity) self.total_weight = 0 self.k = 0 self.sum_w_0_k = 0 - cdef SIZE_t size(self) noexcept nogil: + cdef intp_t size(self) noexcept nogil: """Return the number of samples in the WeightedMedianCalculator""" return self.samples.size() @@ -293,14 +293,14 @@ cdef class WeightedMedianCalculator: self.sum_w_0_k = 0 return 0 - cdef int push(self, DOUBLE_t data, DOUBLE_t weight) except -1 nogil: + cdef int push(self, float64_t data, float64_t weight) except -1 nogil: """Push a value and its associated weight to the WeightedMedianCalculator Return -1 in case of failure to allocate memory (and raise MemoryError) or 0 otherwise. """ cdef int return_value - cdef DOUBLE_t original_median = 0.0 + cdef float64_t original_median = 0.0 if self.size() != 0: original_median = self.get_median() @@ -311,8 +311,8 @@ cdef class WeightedMedianCalculator: return return_value cdef int update_median_parameters_post_push( - self, DOUBLE_t data, DOUBLE_t weight, - DOUBLE_t original_median) noexcept nogil: + self, float64_t data, float64_t weight, + float64_t original_median) noexcept nogil: """Update the parameters used in the median calculation, namely `k` and `sum_w_0_k` after an insertion""" @@ -352,12 +352,12 @@ cdef class WeightedMedianCalculator: self.sum_w_0_k += self.samples.get_weight_from_index(self.k-1) return 0 - cdef int remove(self, DOUBLE_t data, DOUBLE_t weight) noexcept nogil: + cdef int remove(self, float64_t data, float64_t weight) noexcept nogil: """Remove a value from the MedianHeap, removing it from consideration in the median calculation """ cdef int return_value - cdef DOUBLE_t original_median = 0.0 + cdef float64_t original_median = 0.0 if self.size() != 0: original_median = self.get_median() @@ -367,7 +367,7 @@ cdef class WeightedMedianCalculator: original_median) return return_value - cdef int pop(self, DOUBLE_t* data, DOUBLE_t* weight) noexcept nogil: + cdef int pop(self, float64_t* data, float64_t* weight) noexcept nogil: """Pop a value from the MedianHeap, starting from the left and moving to the right. """ @@ -388,7 +388,7 @@ cdef class WeightedMedianCalculator: return return_value cdef int update_median_parameters_post_remove( - self, DOUBLE_t data, DOUBLE_t weight, + self, float64_t data, float64_t weight, double original_median) noexcept nogil: """Update the parameters used in the median calculation, namely `k` and `sum_w_0_k` after a removal""" @@ -437,7 +437,7 @@ cdef class WeightedMedianCalculator: self.sum_w_0_k -= self.samples.get_weight_from_index(self.k) return 0 - cdef DOUBLE_t get_median(self) noexcept nogil: + cdef float64_t get_median(self) noexcept nogil: """Write the median to a pointer, taking into account sample weights.""" if self.sum_w_0_k == (self.total_weight / 2.0): @@ -449,7 +449,7 @@ cdef class WeightedMedianCalculator: return self.samples.get_value_from_index(self.k-1) -def _any_isnan_axis0(const DTYPE_t[:, :] X): +def _any_isnan_axis0(const float32_t[:, :] X): """Same as np.any(np.isnan(X), axis=0)""" cdef: int i, j From 74ce8b9f7e23c007634d4280b2f5e6fc6fab071f Mon Sep 17 00:00:00 2001 From: Adam Li Date: Tue, 12 Sep 2023 15:11:39 -0400 Subject: [PATCH 2/2] Fix spacing in in-line comments Signed-off-by: Adam Li --- sklearn/tree/_criterion.pxd | 26 +++++++++++++------------- sklearn/tree/_splitter.pxd | 2 +- sklearn/tree/_tree.pxd | 6 +++--- 3 files changed, 17 insertions(+), 17 deletions(-) diff --git a/sklearn/tree/_criterion.pxd b/sklearn/tree/_criterion.pxd index 16d8f35d44013..d5341c4db3be1 100644 --- a/sklearn/tree/_criterion.pxd +++ b/sklearn/tree/_criterion.pxd @@ -22,21 +22,21 @@ cdef class Criterion: cdef const float64_t[:, ::1] y # Values of y cdef const float64_t[:] sample_weight # Sample weights - cdef const intp_t[:] sample_indices # Sample indices in X, y - cdef intp_t start # samples[start:pos] are the samples in the left node - cdef intp_t pos # samples[pos:end] are the samples in the right node + cdef const intp_t[:] sample_indices # Sample indices in X, y + cdef intp_t start # samples[start:pos] are the samples in the left node + cdef intp_t pos # samples[pos:end] are the samples in the right node cdef intp_t end - cdef intp_t n_missing # Number of missing values for the feature being evaluated - cdef bint missing_go_to_left # Whether missing values go to the left node + cdef intp_t n_missing # Number of missing values for the feature being evaluated + cdef bint missing_go_to_left # Whether missing values go to the left node - cdef intp_t n_outputs # Number of outputs - cdef intp_t n_samples # Number of samples - cdef intp_t n_node_samples # Number of samples in the node (end-start) - cdef double weighted_n_samples # Weighted number of samples (in total) - cdef double weighted_n_node_samples # Weighted number of samples in the node - cdef double weighted_n_left # Weighted number of samples in the left node - cdef double weighted_n_right # Weighted number of samples in the right node - cdef double weighted_n_missing # Weighted number of samples that are missing + cdef intp_t n_outputs # Number of outputs + cdef intp_t n_samples # Number of samples + cdef intp_t n_node_samples # Number of samples in the node (end-start) + cdef double weighted_n_samples # Weighted number of samples (in total) + cdef double weighted_n_node_samples # Weighted number of samples in the node + cdef double weighted_n_left # Weighted number of samples in the left node + cdef double weighted_n_right # Weighted number of samples in the right node + cdef double weighted_n_missing # Weighted number of samples that are missing # The criterion object is maintained such that left and right collected # statistics correspond to samples[start:pos] and samples[pos:end]. diff --git a/sklearn/tree/_splitter.pxd b/sklearn/tree/_splitter.pxd index 981ac25f98ad6..a096014804847 100644 --- a/sklearn/tree/_splitter.pxd +++ b/sklearn/tree/_splitter.pxd @@ -51,7 +51,7 @@ cdef class Splitter: cdef intp_t[::1] features # Feature indices in X cdef intp_t[::1] constant_features # Constant features indices cdef intp_t n_features # X.shape[1] - cdef float32_t[::1] feature_values # temp. array holding feature values + cdef float32_t[::1] feature_values # temp. array holding feature values cdef intp_t start # Start position for the current node cdef intp_t end # End position for the current node diff --git a/sklearn/tree/_tree.pxd b/sklearn/tree/_tree.pxd index cab8cef45a474..97634748c3f42 100644 --- a/sklearn/tree/_tree.pxd +++ b/sklearn/tree/_tree.pxd @@ -24,10 +24,10 @@ cdef struct Node: intp_t left_child # id of the left child of the node intp_t right_child # id of the right child of the node intp_t feature # Feature used for splitting the node - float64_t threshold # Threshold value at the node - float64_t impurity # Impurity of the node (i.e., the value of the criterion) + float64_t threshold # Threshold value at the node + float64_t impurity # Impurity of the node (i.e., the value of the criterion) intp_t n_node_samples # Number of samples at the node - float64_t weighted_n_node_samples # Weighted number of samples at the node + float64_t weighted_n_node_samples # Weighted number of samples at the node unsigned char missing_go_to_left # Whether features have missing values