diff --git a/sklearn/tree/_criterion.pxd b/sklearn/tree/_criterion.pxd index 57dacc0726c0a..229a6bc2874e3 100644 --- a/sklearn/tree/_criterion.pxd +++ b/sklearn/tree/_criterion.pxd @@ -53,12 +53,12 @@ cdef class Criterion: # statistics correspond to samples[start:pos] and samples[pos:end]. # Methods - cdef void init(self, DOUBLE_t* y, SIZE_t y_stride, DOUBLE_t* sample_weight, - double weighted_n_samples, SIZE_t* samples, SIZE_t start, - SIZE_t end) nogil - cdef void reset(self) nogil - cdef void reverse_reset(self) nogil - cdef void update(self, SIZE_t new_pos) nogil + cdef int init(self, DOUBLE_t* y, SIZE_t y_stride, DOUBLE_t* sample_weight, + double weighted_n_samples, SIZE_t* samples, SIZE_t start, + SIZE_t end) nogil except -1 + cdef int reset(self) nogil except -1 + cdef int reverse_reset(self) nogil except -1 + cdef int update(self, SIZE_t new_pos) nogil except -1 cdef double node_impurity(self) nogil cdef void children_impurity(self, double* impurity_left, double* impurity_right) nogil diff --git a/sklearn/tree/_criterion.pyx b/sklearn/tree/_criterion.pyx index 3d71818846b9f..5187a5066bb2e 100644 --- a/sklearn/tree/_criterion.pyx +++ b/sklearn/tree/_criterion.pyx @@ -51,11 +51,14 @@ cdef class Criterion: def __setstate__(self, d): pass - cdef void init(self, DOUBLE_t* y, SIZE_t y_stride, DOUBLE_t* sample_weight, - double weighted_n_samples, SIZE_t* samples, SIZE_t start, - SIZE_t end) nogil: + cdef int init(self, DOUBLE_t* y, SIZE_t y_stride, DOUBLE_t* sample_weight, + double weighted_n_samples, SIZE_t* samples, SIZE_t start, + SIZE_t end) nogil except -1: """Placeholder for a method which will initialize the criterion. + Returns -1 in case of failure to allocate memory (and raise MemoryError) + or 0 otherwise. + Parameters ---------- y : array-like, dtype=DOUBLE_t @@ -79,7 +82,7 @@ cdef class Criterion: pass - cdef void reset(self) nogil: + cdef int reset(self) nogil except -1: """Reset the criterion at pos=start. This method must be implemented by the subclass. @@ -87,14 +90,14 @@ cdef class Criterion: pass - cdef void reverse_reset(self) nogil: + cdef int reverse_reset(self) nogil except -1: """Reset the criterion at pos=end. This method must be implemented by the subclass. """ pass - cdef void update(self, SIZE_t new_pos) nogil: + cdef int update(self, SIZE_t new_pos) nogil except -1: """Updated statistics by moving samples[pos:new_pos] to the left child. This updates the collected statistics by moving samples[pos:new_pos] @@ -281,12 +284,15 @@ cdef class ClassificationCriterion(Criterion): sizet_ptr_to_ndarray(self.n_classes, self.n_outputs)), self.__getstate__()) - cdef void init(self, DOUBLE_t* y, SIZE_t y_stride, - DOUBLE_t* sample_weight, double weighted_n_samples, - SIZE_t* samples, SIZE_t start, SIZE_t end) nogil: + cdef int init(self, DOUBLE_t* y, SIZE_t y_stride, + DOUBLE_t* sample_weight, double weighted_n_samples, + SIZE_t* samples, SIZE_t start, SIZE_t end) nogil except -1: """Initialize the criterion at node samples[start:end] and children samples[start:start] and samples[start:end]. + Returns -1 in case of failure to allocate memory (and raise MemoryError) + or 0 otherwise. + Parameters ---------- y : array-like, dtype=DOUBLE_t @@ -347,10 +353,14 @@ cdef class ClassificationCriterion(Criterion): # Reset to pos=start self.reset() + return 0 - cdef void reset(self) nogil: - """Reset the criterion at pos=start.""" + cdef int reset(self) nogil except -1: + """Reset the criterion at pos=start + Returns -1 in case of failure to allocate memory (and raise MemoryError) + or 0 otherwise. + """ self.pos = self.start self.weighted_n_left = 0.0 @@ -370,9 +380,14 @@ cdef class ClassificationCriterion(Criterion): sum_total += self.sum_stride sum_left += self.sum_stride sum_right += self.sum_stride + return 0 - cdef void reverse_reset(self) nogil: - """Reset the criterion at pos=end.""" + cdef int reverse_reset(self) nogil except -1: + """Reset the criterion at pos=end + + Returns -1 in case of failure to allocate memory (and raise MemoryError) + or 0 otherwise. + """ self.pos = self.end self.weighted_n_left = self.weighted_n_node_samples @@ -392,10 +407,14 @@ cdef class ClassificationCriterion(Criterion): sum_total += self.sum_stride sum_left += self.sum_stride sum_right += self.sum_stride + return 0 - cdef void update(self, SIZE_t new_pos) nogil: + cdef int update(self, SIZE_t new_pos) nogil except -1: """Updated statistics by moving samples[pos:new_pos] to the left child. + Returns -1 in case of failure to allocate memory (and raise MemoryError) + or 0 otherwise. + Parameters ---------- new_pos : SIZE_t @@ -470,6 +489,7 @@ cdef class ClassificationCriterion(Criterion): sum_total += self.sum_stride self.pos = new_pos + return 0 cdef double node_impurity(self) nogil: pass @@ -736,9 +756,9 @@ cdef class RegressionCriterion(Criterion): def __reduce__(self): return (type(self), (self.n_outputs, self.n_samples), self.__getstate__()) - cdef void init(self, DOUBLE_t* y, SIZE_t y_stride, DOUBLE_t* sample_weight, - double weighted_n_samples, SIZE_t* samples, SIZE_t start, - SIZE_t end) nogil: + cdef int init(self, DOUBLE_t* y, SIZE_t y_stride, DOUBLE_t* sample_weight, + double weighted_n_samples, SIZE_t* samples, SIZE_t start, + SIZE_t end) nogil except -1: """Initialize the criterion at node samples[start:end] and children samples[start:start] and samples[start:end].""" # Initialize fields @@ -778,8 +798,9 @@ cdef class RegressionCriterion(Criterion): # Reset to pos=start self.reset() + return 0 - cdef void reset(self) nogil: + cdef int reset(self) nogil except -1: """Reset the criterion at pos=start.""" cdef SIZE_t n_bytes = self.n_outputs * sizeof(double) memset(self.sum_left, 0, n_bytes) @@ -788,8 +809,9 @@ cdef class RegressionCriterion(Criterion): self.weighted_n_left = 0.0 self.weighted_n_right = self.weighted_n_node_samples self.pos = self.start + return 0 - cdef void reverse_reset(self) nogil: + cdef int reverse_reset(self) nogil except -1: """Reset the criterion at pos=end.""" cdef SIZE_t n_bytes = self.n_outputs * sizeof(double) memset(self.sum_right, 0, n_bytes) @@ -798,8 +820,9 @@ cdef class RegressionCriterion(Criterion): self.weighted_n_right = 0.0 self.weighted_n_left = self.weighted_n_node_samples self.pos = self.end + return 0 - cdef void update(self, SIZE_t new_pos) nogil: + cdef int update(self, SIZE_t new_pos) nogil except -1: """Updated statistics by moving samples[pos:new_pos] to the left.""" cdef double* sum_left = self.sum_left @@ -859,6 +882,7 @@ cdef class RegressionCriterion(Criterion): sum_right[k] = sum_total[k] - sum_left[k] self.pos = new_pos + return 0 cdef double node_impurity(self) nogil: pass @@ -1018,9 +1042,6 @@ cdef class MAE(RegressionCriterion): # Allocate memory for the accumulators safe_realloc(&self.node_medians, n_outputs) - if (self.node_medians == NULL): - raise MemoryError() - self.left_child = np.empty(n_outputs, dtype='object') self.right_child = np.empty(n_outputs, dtype='object') # initialize WeightedMedianCalculators @@ -1028,9 +1049,9 @@ cdef class MAE(RegressionCriterion): self.left_child[k] = WeightedMedianCalculator(n_samples) self.right_child[k] = WeightedMedianCalculator(n_samples) - cdef void init(self, DOUBLE_t* y, SIZE_t y_stride, DOUBLE_t* sample_weight, - double weighted_n_samples, SIZE_t* samples, SIZE_t start, - SIZE_t end) nogil: + cdef int init(self, DOUBLE_t* y, SIZE_t y_stride, DOUBLE_t* sample_weight, + double weighted_n_samples, SIZE_t* samples, SIZE_t start, + SIZE_t end) nogil except -1: """Initialize the criterion at node samples[start:end] and children samples[start:start] and samples[start:end].""" @@ -1068,6 +1089,7 @@ cdef class MAE(RegressionCriterion): for k in range(self.n_outputs): y_ik = y[i * y_stride + k] + # push method ends up calling safe_realloc, hence `except -1` # push all values to the right side, # since pos = start initially anyway ( right_child[k]).push(y_ik, w) @@ -1079,9 +1101,14 @@ cdef class MAE(RegressionCriterion): # Reset to pos=start self.reset() + return 0 - cdef void reset(self) nogil: - """Reset the criterion at pos=start.""" + cdef int reset(self) nogil except -1: + """Reset the criterion at pos=start + + Returns -1 in case of failure to allocate memory (and raise MemoryError) + or 0 otherwise. + """ cdef SIZE_t i, k cdef DOUBLE_t value @@ -1103,11 +1130,17 @@ cdef class MAE(RegressionCriterion): # remove everything from left and put it into right ( left_child[k]).pop(&value, &weight) + # push method ends up calling safe_realloc, hence `except -1` ( right_child[k]).push(value, weight) + return 0 - cdef void reverse_reset(self) nogil: - """Reset the criterion at pos=end.""" + cdef int reverse_reset(self) nogil except -1: + """Reset the criterion at pos=end + + Returns -1 in case of failure to allocate memory (and raise MemoryError) + or 0 otherwise. + """ self.weighted_n_right = 0.0 self.weighted_n_left = self.weighted_n_node_samples @@ -1126,11 +1159,17 @@ cdef class MAE(RegressionCriterion): # remove everything from right and put it into left ( right_child[k]).pop(&value, &weight) + # push method ends up calling safe_realloc, hence `except -1` ( left_child[k]).push(value, weight) + return 0 - cdef void update(self, SIZE_t new_pos) nogil: - """Updated statistics by moving samples[pos:new_pos] to the left.""" + cdef int update(self, SIZE_t new_pos) nogil except -1: + """Updated statistics by moving samples[pos:new_pos] to the left + + Returns -1 in case of failure to allocate memory (and raise MemoryError) + or 0 otherwise. + """ cdef DOUBLE_t* sample_weight = self.sample_weight cdef SIZE_t* samples = self.samples @@ -1162,6 +1201,7 @@ cdef class MAE(RegressionCriterion): y_ik = y[i * self.y_stride + k] # remove y_ik and its weight w from right and add to left ( right_child[k]).remove(y_ik, w) + # push method ends up calling safe_realloc, hence except -1 ( left_child[k]).push(y_ik, w) self.weighted_n_left += w @@ -1185,6 +1225,7 @@ cdef class MAE(RegressionCriterion): self.weighted_n_right = (self.weighted_n_node_samples - self.weighted_n_left) self.pos = new_pos + return 0 cdef void node_value(self, double* dest) nogil: """Computes the node value of samples[start:end] into dest.""" diff --git a/sklearn/tree/_splitter.pxd b/sklearn/tree/_splitter.pxd index 4b97d01614b9f..4d5c5ae46bceb 100644 --- a/sklearn/tree/_splitter.pxd +++ b/sklearn/tree/_splitter.pxd @@ -81,18 +81,18 @@ cdef class Splitter: # This allows optimization with depth-based tree building. # Methods - cdef void init(self, object X, np.ndarray y, - DOUBLE_t* sample_weight, - np.ndarray X_idx_sorted=*) except * + cdef int init(self, object X, np.ndarray y, + DOUBLE_t* sample_weight, + np.ndarray X_idx_sorted=*) except -1 - cdef void node_reset(self, SIZE_t start, SIZE_t end, - double* weighted_n_node_samples) nogil + cdef int node_reset(self, SIZE_t start, SIZE_t end, + double* weighted_n_node_samples) nogil except -1 - cdef void node_split(self, - double impurity, # Impurity of the node - SplitRecord* split, - SIZE_t* n_constant_features) nogil + cdef int node_split(self, + double impurity, # Impurity of the node + SplitRecord* split, + SIZE_t* n_constant_features) nogil except -1 cdef void node_value(self, double* dest) nogil - cdef double node_impurity(self) nogil \ No newline at end of file + cdef double node_impurity(self) nogil diff --git a/sklearn/tree/_splitter.pyx b/sklearn/tree/_splitter.pyx index 5fa7ee553fe2d..06dfab587493c 100644 --- a/sklearn/tree/_splitter.pyx +++ b/sklearn/tree/_splitter.pyx @@ -116,15 +116,18 @@ cdef class Splitter: def __setstate__(self, d): pass - cdef void init(self, + cdef int init(self, object X, np.ndarray[DOUBLE_t, ndim=2, mode="c"] y, DOUBLE_t* sample_weight, - np.ndarray X_idx_sorted=None) except *: + np.ndarray X_idx_sorted=None) except -1: """Initialize the splitter. Take in the input data X, the target Y, and optional sample weights. + Returns -1 in case of failure to allocate memory (and raise MemoryError) + or 0 otherwise. + Parameters ---------- X : object @@ -180,11 +183,15 @@ cdef class Splitter: self.y_stride = y.strides[0] / y.itemsize self.sample_weight = sample_weight + return 0 - cdef void node_reset(self, SIZE_t start, SIZE_t end, - double* weighted_n_node_samples) nogil: + cdef int node_reset(self, SIZE_t start, SIZE_t end, + double* weighted_n_node_samples) nogil except -1: """Reset splitter on node samples[start:end]. + Returns -1 in case of failure to allocate memory (and raise MemoryError) + or 0 otherwise. + Parameters ---------- start : SIZE_t @@ -207,13 +214,16 @@ cdef class Splitter: end) weighted_n_node_samples[0] = self.criterion.weighted_n_node_samples + return 0 - cdef void node_split(self, double impurity, SplitRecord* split, - SIZE_t* n_constant_features) nogil: + cdef int node_split(self, double impurity, SplitRecord* split, + SIZE_t* n_constant_features) nogil except -1: """Find the best split on node samples[start:end]. This is a placeholder method. The majority of computation will be done here. + + It should return -1 upon errors. """ pass @@ -257,12 +267,16 @@ cdef class BaseDenseSplitter(Splitter): if self.presort == 1: free(self.sample_mask) - cdef void init(self, - object X, - np.ndarray[DOUBLE_t, ndim=2, mode="c"] y, - DOUBLE_t* sample_weight, - np.ndarray X_idx_sorted=None) except *: - """Initialize the splitter.""" + cdef int init(self, + object X, + np.ndarray[DOUBLE_t, ndim=2, mode="c"] y, + DOUBLE_t* sample_weight, + np.ndarray X_idx_sorted=None) except -1: + """Initialize the splitter + + Returns -1 in case of failure to allocate memory (and raise MemoryError) + or 0 otherwise. + """ # Call parent init Splitter.init(self, X, y, sample_weight) @@ -284,6 +298,8 @@ cdef class BaseDenseSplitter(Splitter): safe_realloc(&self.sample_mask, self.n_total_samples) memset(self.sample_mask, 0, self.n_total_samples*sizeof(SIZE_t)) + return 0 + cdef class BestSplitter(BaseDenseSplitter): """Splitter for finding the best split.""" @@ -295,9 +311,13 @@ cdef class BestSplitter(BaseDenseSplitter): self.random_state, self.presort), self.__getstate__()) - cdef void node_split(self, double impurity, SplitRecord* split, - SIZE_t* n_constant_features) nogil: - """Find the best split on node samples[start:end].""" + cdef int node_split(self, double impurity, SplitRecord* split, + SIZE_t* n_constant_features) nogil except -1: + """Find the best split on node samples[start:end] + + Returns -1 in case of failure to allocate memory (and raise MemoryError) + or 0 otherwise. + """ # Find the best split cdef SIZE_t* samples = self.samples cdef SIZE_t start = self.start @@ -509,6 +529,7 @@ cdef class BestSplitter(BaseDenseSplitter): # Return values split[0] = best n_constant_features[0] = n_total_constants + return 0 # Sort n-element arrays pointed to by Xf and samples, simultaneously, @@ -518,7 +539,8 @@ cdef inline void sort(DTYPE_t* Xf, SIZE_t* samples, SIZE_t n) nogil: introsort(Xf, samples, n, maxd) -cdef inline void swap(DTYPE_t* Xf, SIZE_t* samples, SIZE_t i, SIZE_t j) nogil: +cdef inline void swap(DTYPE_t* Xf, SIZE_t* samples, + SIZE_t i, SIZE_t j) nogil: # Helper for sort Xf[i], Xf[j] = Xf[j], Xf[i] samples[i], samples[j] = samples[j], samples[i] @@ -546,7 +568,8 @@ cdef inline DTYPE_t median3(DTYPE_t* Xf, SIZE_t n) nogil: # Introsort with median of 3 pivot selection and 3-way partition function # (robust to repeated elements, e.g. lots of zero features). -cdef void introsort(DTYPE_t* Xf, SIZE_t *samples, SIZE_t n, int maxd) nogil: +cdef void introsort(DTYPE_t* Xf, SIZE_t *samples, + SIZE_t n, int maxd) nogil: cdef DTYPE_t pivot cdef SIZE_t i, l, r @@ -631,9 +654,13 @@ cdef class RandomSplitter(BaseDenseSplitter): self.random_state, self.presort), self.__getstate__()) - cdef void node_split(self, double impurity, SplitRecord* split, - SIZE_t* n_constant_features) nogil: - """Find the best random split on node samples[start:end].""" + cdef int node_split(self, double impurity, SplitRecord* split, + SIZE_t* n_constant_features) nogil except -1: + """Find the best random split on node samples[start:end] + + Returns -1 in case of failure to allocate memory (and raise MemoryError) + or 0 otherwise. + """ # Draw random splits and pick the best cdef SIZE_t* samples = self.samples cdef SIZE_t start = self.start @@ -833,6 +860,7 @@ cdef class RandomSplitter(BaseDenseSplitter): # Return values split[0] = best n_constant_features[0] = n_total_constants + return 0 cdef class BaseSparseSplitter(Splitter): @@ -865,13 +893,16 @@ cdef class BaseSparseSplitter(Splitter): free(self.index_to_samples) free(self.sorted_samples) - cdef void init(self, - object X, - np.ndarray[DOUBLE_t, ndim=2, mode="c"] y, - DOUBLE_t* sample_weight, - np.ndarray X_idx_sorted=None) except *: - """Initialize the splitter.""" + cdef int init(self, + object X, + np.ndarray[DOUBLE_t, ndim=2, mode="c"] y, + DOUBLE_t* sample_weight, + np.ndarray X_idx_sorted=None) except -1: + """Initialize the splitter + Returns -1 in case of failure to allocate memory (and raise MemoryError) + or 0 otherwise. + """ # Call parent init Splitter.init(self, X, y, sample_weight) @@ -903,6 +934,7 @@ cdef class BaseSparseSplitter(Splitter): for p in range(n_samples): index_to_samples[samples[p]] = p + return 0 cdef inline SIZE_t _partition(self, double threshold, SIZE_t end_negative, SIZE_t start_positive, @@ -1148,7 +1180,7 @@ cdef inline void extract_nnz_binary_search(INT32_t* X_indices, cdef inline void sparse_swap(SIZE_t* index_to_samples, SIZE_t* samples, - SIZE_t pos_1, SIZE_t pos_2) nogil : + SIZE_t pos_1, SIZE_t pos_2) nogil: """Swap sample pos_1 and pos_2 preserving sparse invariant.""" samples[pos_1], samples[pos_2] = samples[pos_2], samples[pos_1] index_to_samples[samples[pos_1]] = pos_1 @@ -1166,10 +1198,12 @@ cdef class BestSparseSplitter(BaseSparseSplitter): self.random_state, self.presort), self.__getstate__()) - cdef void node_split(self, double impurity, SplitRecord* split, - SIZE_t* n_constant_features) nogil: - """Find the best split on node samples[start:end], using sparse - features. + cdef int node_split(self, double impurity, SplitRecord* split, + SIZE_t* n_constant_features) nogil except -1: + """Find the best split on node samples[start:end], using sparse features + + Returns -1 in case of failure to allocate memory (and raise MemoryError) + or 0 otherwise. """ # Find the best split cdef SIZE_t* samples = self.samples @@ -1380,6 +1414,7 @@ cdef class BestSparseSplitter(BaseSparseSplitter): # Return values split[0] = best n_constant_features[0] = n_total_constants + return 0 cdef class RandomSparseSplitter(BaseSparseSplitter): @@ -1393,10 +1428,12 @@ cdef class RandomSparseSplitter(BaseSparseSplitter): self.random_state, self.presort), self.__getstate__()) - cdef void node_split(self, double impurity, SplitRecord* split, - SIZE_t* n_constant_features) nogil: - """Find a random split on node samples[start:end], using sparse - features. + cdef int node_split(self, double impurity, SplitRecord* split, + SIZE_t* n_constant_features) nogil except -1: + """Find a random split on node samples[start:end], using sparse features + + Returns -1 in case of failure to allocate memory (and raise MemoryError) + or 0 otherwise. """ # Find the best split cdef SIZE_t* samples = self.samples @@ -1608,3 +1645,4 @@ cdef class RandomSparseSplitter(BaseSparseSplitter): # Return values split[0] = best n_constant_features[0] = n_total_constants + return 0 diff --git a/sklearn/tree/_tree.pxd b/sklearn/tree/_tree.pxd index dbf0545b1e1d5..4f9f359725646 100644 --- a/sklearn/tree/_tree.pxd +++ b/sklearn/tree/_tree.pxd @@ -58,9 +58,9 @@ cdef class Tree: cdef SIZE_t _add_node(self, SIZE_t parent, bint is_left, bint is_leaf, SIZE_t feature, double threshold, double impurity, SIZE_t n_node_samples, - double weighted_n_samples) nogil - cdef void _resize(self, SIZE_t capacity) except * - cdef int _resize_c(self, SIZE_t capacity=*) nogil + double weighted_n_samples) nogil except -1 + cdef int _resize(self, SIZE_t capacity) nogil except -1 + cdef int _resize_c(self, SIZE_t capacity=*) nogil except -1 cdef np.ndarray _get_value_ndarray(self) cdef np.ndarray _get_node_ndarray(self) diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index f8632ab1640d8..69ab8572d2ae5 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -19,7 +19,6 @@ from cpython cimport Py_INCREF, PyObject from libc.stdlib cimport free -from libc.stdlib cimport realloc from libc.string cimport memcpy from libc.string cimport memset @@ -272,9 +271,12 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): # Best first builder ---------------------------------------------------------- cdef inline int _add_to_frontier(PriorityHeapRecord* rec, - PriorityHeap frontier) nogil: - """Adds record ``rec`` to the priority queue ``frontier``; returns -1 - on memory-error. """ + PriorityHeap frontier) nogil except -1: + """Adds record ``rec`` to the priority queue ``frontier`` + + Returns -1 in case of failure to allocate memory (and raise MemoryError) + or 0 otherwise. + """ return frontier.push(rec.node_id, rec.start, rec.end, rec.pos, rec.depth, rec.is_leaf, rec.improvement, rec.impurity, rec.impurity_left, rec.impurity_right) @@ -417,7 +419,7 @@ cdef class BestFirstTreeBuilder(TreeBuilder): SIZE_t start, SIZE_t end, double impurity, bint is_first, bint is_left, Node* parent, SIZE_t depth, - PriorityHeapRecord* res) nogil: + PriorityHeapRecord* res) nogil except -1: """Adds node w/ partition ``[start, end)`` to the frontier. """ cdef SplitRecord split cdef SIZE_t node_id @@ -657,16 +659,26 @@ cdef class Tree: value = memcpy(self.value, ( value_ndarray).data, self.capacity * self.value_stride * sizeof(double)) - cdef void _resize(self, SIZE_t capacity) except *: + cdef int _resize(self, SIZE_t capacity) nogil except -1: """Resize all inner arrays to `capacity`, if `capacity` == -1, then - double the size of the inner arrays.""" + double the size of the inner arrays. + + Returns -1 in case of failure to allocate memory (and raise MemoryError) + or 0 otherwise. + """ if self._resize_c(capacity) != 0: - raise MemoryError() + # Acquire gil only if we need to raise + with gil: + raise MemoryError() # XXX using (size_t)(-1) is ugly, but SIZE_MAX is not available in C89 # (i.e., older MSVC). - cdef int _resize_c(self, SIZE_t capacity=(-1)) nogil: - """Guts of _resize. Returns 0 for success, -1 for error.""" + cdef int _resize_c(self, SIZE_t capacity=(-1)) nogil except -1: + """Guts of _resize + + Returns -1 in case of failure to allocate memory (and raise MemoryError) + or 0 otherwise. + """ if capacity == self.capacity and self.nodes != NULL: return 0 @@ -676,16 +688,8 @@ cdef class Tree: else: capacity = 2 * self.capacity - # XXX no safe_realloc here because we need to grab the GIL - cdef void* ptr = realloc(self.nodes, capacity * sizeof(Node)) - if ptr == NULL: - return -1 - self.nodes = ptr - ptr = realloc(self.value, - capacity * self.value_stride * sizeof(double)) - if ptr == NULL: - return -1 - self.value = ptr + safe_realloc(&self.nodes, capacity) + safe_realloc(&self.value, capacity * self.value_stride) # value memory is initialised to 0 to enable classifier argmax if capacity > self.capacity: @@ -702,7 +706,8 @@ cdef class Tree: cdef SIZE_t _add_node(self, SIZE_t parent, bint is_left, bint is_leaf, SIZE_t feature, double threshold, double impurity, - SIZE_t n_node_samples, double weighted_n_node_samples) nogil: + SIZE_t n_node_samples, + double weighted_n_node_samples) nogil except -1: """Add a node to the tree. The new node registers itself as the child of its parent. diff --git a/sklearn/tree/_utils.pxd b/sklearn/tree/_utils.pxd index cc9649030ed65..017888ab41db7 100644 --- a/sklearn/tree/_utils.pxd +++ b/sklearn/tree/_utils.pxd @@ -40,19 +40,21 @@ ctypedef fused realloc_ptr: (DOUBLE_t**) (Node*) (Node**) + (StackRecord*) + (PriorityHeapRecord*) -cdef realloc_ptr safe_realloc(realloc_ptr* p, size_t nelems) except * +cdef realloc_ptr safe_realloc(realloc_ptr* p, size_t nelems) nogil except * cdef np.ndarray sizet_ptr_to_ndarray(SIZE_t* data, SIZE_t size) cdef SIZE_t rand_int(SIZE_t low, SIZE_t high, - UINT32_t* random_state) nogil + UINT32_t* random_state) nogil cdef double rand_uniform(double low, double high, - UINT32_t* random_state) nogil + UINT32_t* random_state) nogil cdef double log(double x) nogil @@ -79,7 +81,7 @@ cdef class Stack: cdef bint is_empty(self) nogil cdef int push(self, SIZE_t start, SIZE_t end, SIZE_t depth, SIZE_t parent, bint is_left, double impurity, - SIZE_t n_constant_features) nogil + SIZE_t n_constant_features) nogil except -1 cdef int pop(self, StackRecord* res) nogil @@ -111,7 +113,7 @@ cdef class PriorityHeap: cdef int push(self, SIZE_t node_id, SIZE_t start, SIZE_t end, SIZE_t pos, SIZE_t depth, bint is_leaf, double improvement, double impurity, double impurity_left, - double impurity_right) nogil + double impurity_right) nogil except -1 cdef int pop(self, PriorityHeapRecord* res) nogil # ============================================================================= @@ -129,9 +131,9 @@ cdef class WeightedPQueue: cdef WeightedPQueueRecord* array_ cdef bint is_empty(self) nogil - cdef void reset(self) nogil + cdef int reset(self) nogil except -1 cdef SIZE_t size(self) nogil - cdef int push(self, DOUBLE_t data, DOUBLE_t weight) nogil + cdef int push(self, DOUBLE_t data, DOUBLE_t weight) nogil except -1 cdef int remove(self, DOUBLE_t data, DOUBLE_t weight) nogil cdef int pop(self, DOUBLE_t* data, DOUBLE_t* weight) nogil cdef int peek(self, DOUBLE_t* data, DOUBLE_t* weight) nogil @@ -152,14 +154,14 @@ cdef class WeightedMedianCalculator: # = w[0] + w[1] + ... + w[k-1] cdef SIZE_t size(self) nogil - cdef int push(self, DOUBLE_t data, DOUBLE_t weight) nogil - cdef void reset(self) nogil - cdef int update_median_parameters_post_push(self, DOUBLE_t data, - DOUBLE_t weight, - DOUBLE_t original_median) nogil + cdef int push(self, DOUBLE_t data, DOUBLE_t weight) nogil except -1 + cdef int reset(self) nogil except -1 + cdef int update_median_parameters_post_push( + self, DOUBLE_t data, DOUBLE_t weight, + DOUBLE_t original_median) nogil cdef int remove(self, DOUBLE_t data, DOUBLE_t weight) nogil cdef int pop(self, DOUBLE_t* data, DOUBLE_t* weight) nogil - cdef int update_median_parameters_post_remove(self, DOUBLE_t data, - DOUBLE_t weight, - DOUBLE_t original_median) nogil + cdef int update_median_parameters_post_remove( + self, DOUBLE_t data, DOUBLE_t weight, + DOUBLE_t original_median) nogil cdef DOUBLE_t get_median(self) nogil diff --git a/sklearn/tree/_utils.pyx b/sklearn/tree/_utils.pyx index 465afebc99ffa..faf2e5b777448 100644 --- a/sklearn/tree/_utils.pyx +++ b/sklearn/tree/_utils.pyx @@ -13,7 +13,6 @@ from libc.stdlib cimport free from libc.stdlib cimport malloc -from libc.stdlib cimport calloc from libc.stdlib cimport realloc from libc.math cimport log as ln @@ -25,17 +24,19 @@ np.import_array() # Helper functions # ============================================================================= -cdef realloc_ptr safe_realloc(realloc_ptr* p, size_t nelems) except *: +cdef realloc_ptr safe_realloc(realloc_ptr* p, size_t nelems) nogil except *: # sizeof(realloc_ptr[0]) would be more like idiomatic C, but causes Cython # 0.20.1 to crash. cdef size_t nbytes = nelems * sizeof(p[0][0]) if nbytes / sizeof(p[0][0]) != nelems: # Overflow in the multiplication - raise MemoryError("could not allocate (%d * %d) bytes" - % (nelems, sizeof(p[0][0]))) + with gil: + raise MemoryError("could not allocate (%d * %d) bytes" + % (nelems, sizeof(p[0][0]))) cdef realloc_ptr tmp = realloc(p[0], nbytes) if tmp == NULL: - raise MemoryError("could not allocate %d bytes" % nbytes) + with gil: + raise MemoryError("could not allocate %d bytes" % nbytes) p[0] = tmp return tmp # for convenience @@ -109,8 +110,6 @@ cdef class Stack: self.capacity = capacity self.top = 0 self.stack_ = malloc(capacity * sizeof(StackRecord)) - if self.stack_ == NULL: - raise MemoryError() def __dealloc__(self): free(self.stack_) @@ -120,10 +119,11 @@ cdef class Stack: cdef int push(self, SIZE_t start, SIZE_t end, SIZE_t depth, SIZE_t parent, bint is_left, double impurity, - SIZE_t n_constant_features) nogil: + SIZE_t n_constant_features) nogil except -1: """Push a new element onto the stack. - Returns 0 if successful; -1 on out of memory error. + Return -1 in case of failure to allocate memory (and raise MemoryError) + or 0 otherwise. """ cdef SIZE_t top = self.top cdef StackRecord* stack = NULL @@ -131,12 +131,8 @@ cdef class Stack: # Resize if capacity not sufficient if top >= self.capacity: self.capacity *= 2 - stack = realloc(self.stack_, - self.capacity * sizeof(StackRecord)) - if stack == NULL: - # no free; __dealloc__ handles that - return -1 - self.stack_ = stack + # Since safe_realloc can raise MemoryError, use `except -1` + safe_realloc(&self.stack_, self.capacity) stack = self.stack_ stack[top].start = start @@ -196,9 +192,7 @@ cdef class PriorityHeap: def __cinit__(self, SIZE_t capacity): self.capacity = capacity self.heap_ptr = 0 - self.heap_ = malloc(capacity * sizeof(PriorityHeapRecord)) - if self.heap_ == NULL: - raise MemoryError() + safe_realloc(&self.heap_, capacity) def __dealloc__(self): free(self.heap_) @@ -241,10 +235,11 @@ cdef class PriorityHeap: cdef int push(self, SIZE_t node_id, SIZE_t start, SIZE_t end, SIZE_t pos, SIZE_t depth, bint is_leaf, double improvement, double impurity, double impurity_left, - double impurity_right) nogil: + double impurity_right) nogil except -1: """Push record on the priority heap. - Returns 0 if successful; -1 on out of memory error. + Return -1 in case of failure to allocate memory (and raise MemoryError) + or 0 otherwise. """ cdef SIZE_t heap_ptr = self.heap_ptr cdef PriorityHeapRecord* heap = NULL @@ -252,13 +247,8 @@ cdef class PriorityHeap: # Resize if capacity not sufficient if heap_ptr >= self.capacity: self.capacity *= 2 - heap = realloc(self.heap_, - self.capacity * - sizeof(PriorityHeapRecord)) - if heap == NULL: - # no free; __dealloc__ handles that - return -1 - self.heap_ = heap + # Since safe_realloc can raise MemoryError, use `except -1` + safe_realloc(&self.heap_, self.capacity) # Put element as last element of heap heap = self.heap_ @@ -330,17 +320,19 @@ cdef class WeightedPQueue: self.array_ptr = 0 safe_realloc(&self.array_, capacity) - if self.array_ == NULL: - raise MemoryError() - def __dealloc__(self): free(self.array_) - cdef void reset(self) nogil: - """Reset the WeightedPQueue to its state at construction""" + cdef int reset(self) nogil except -1: + """Reset the WeightedPQueue to its state at construction + + Return -1 in case of failure to allocate memory (and raise MemoryError) + or 0 otherwise. + """ self.array_ptr = 0 - self.array_ = calloc(self.capacity, - sizeof(WeightedPQueueRecord)) + # Since safe_realloc can raise MemoryError, use `except *` + safe_realloc(&self.array_, self.capacity) + return 0 cdef bint is_empty(self) nogil: return self.array_ptr <= 0 @@ -348,9 +340,11 @@ cdef class WeightedPQueue: cdef SIZE_t size(self) nogil: return self.array_ptr - cdef int push(self, DOUBLE_t data, DOUBLE_t weight) nogil: + cdef int push(self, DOUBLE_t data, DOUBLE_t weight) nogil except -1: """Push record on the array. - Returns 0 if successful; -1 on out of memory error. + + Return -1 in case of failure to allocate memory (and raise MemoryError) + or 0 otherwise. """ cdef SIZE_t array_ptr = self.array_ptr cdef WeightedPQueueRecord* array = NULL @@ -359,14 +353,8 @@ cdef class WeightedPQueue: # Resize if capacity not sufficient if array_ptr >= self.capacity: self.capacity *= 2 - array = realloc(self.array_, - self.capacity * - sizeof(WeightedPQueueRecord)) - - if array == NULL: - # no free; __dealloc__ handles that - return -1 - self.array_ = array + # Since safe_realloc can raise MemoryError, use `except -1` + safe_realloc(&self.array_, self.capacity) # Put element as last element of array array = self.array_ @@ -510,31 +498,40 @@ cdef class WeightedMedianCalculator: WeightedMedianCalculator""" return self.samples.size() - cdef void reset(self) nogil: - """Reset the WeightedMedianCalculator to its state at construction""" + cdef int reset(self) nogil except -1: + """Reset the WeightedMedianCalculator to its state at construction + + Return -1 in case of failure to allocate memory (and raise MemoryError) + or 0 otherwise. + """ + # samples.reset (WeightedPQueue.reset) uses safe_realloc, hence + # except -1 self.samples.reset() self.total_weight = 0 self.k = 0 self.sum_w_0_k = 0 + return 0 + + cdef int push(self, DOUBLE_t data, DOUBLE_t weight) nogil except -1: + """Push a value and its associated weight to the WeightedMedianCalculator - cdef int push(self, DOUBLE_t data, DOUBLE_t weight) nogil: - """Push a value and its associated weight - to the WeightedMedianCalculator to be considered - in the median calculation. + Return -1 in case of failure to allocate memory (and raise MemoryError) + or 0 otherwise. """ cdef int return_value cdef DOUBLE_t original_median if self.size() != 0: original_median = self.get_median() + # samples.push (WeightedPQueue.push) uses safe_realloc, hence except -1 return_value = self.samples.push(data, weight) self.update_median_parameters_post_push(data, weight, original_median) return return_value - cdef int update_median_parameters_post_push(self, DOUBLE_t data, - DOUBLE_t weight, - DOUBLE_t original_median) nogil: + cdef int update_median_parameters_post_push( + self, DOUBLE_t data, DOUBLE_t weight, + DOUBLE_t original_median) nogil: """Update the parameters used in the median calculation, namely `k` and `sum_w_0_k` after an insertion""" @@ -609,9 +606,9 @@ cdef class WeightedMedianCalculator: original_median) return return_value - cdef int update_median_parameters_post_remove(self, DOUBLE_t data, - DOUBLE_t weight, - double original_median) nogil: + cdef int update_median_parameters_post_remove( + self, DOUBLE_t data, DOUBLE_t weight, + double original_median) nogil: """Update the parameters used in the median calculation, namely `k` and `sum_w_0_k` after a removal""" # reset parameters because it there are no elements