diff --git a/sklearn/neighbors/_quad_tree.pxd b/sklearn/neighbors/_quad_tree.pxd index 71c4c3071344c..9ed033e747314 100644 --- a/sklearn/neighbors/_quad_tree.pxd +++ b/sklearn/neighbors/_quad_tree.pxd @@ -4,11 +4,7 @@ # See quad_tree.pyx for details. cimport numpy as cnp - -ctypedef cnp.npy_float32 DTYPE_t # Type of X -ctypedef cnp.npy_intp SIZE_t # Type for indices and counters -ctypedef cnp.npy_int32 INT32_t # Signed 32 bit integer -ctypedef cnp.npy_uint32 UINT32_t # Unsigned 32 bit integer +from ..utils._typedefs cimport float32_t, intp_t # This is effectively an ifdef statement in Cython # It allows us to write printf debugging lines @@ -25,26 +21,26 @@ cdef struct Cell: # Base storage structure for cells in a QuadTree object # Tree structure - SIZE_t parent # Parent cell of this cell - SIZE_t[8] children # Array pointing to children of this cell + intp_t parent # Parent cell of this cell + intp_t[8] children # Array pointing to children of this cell # Cell description - SIZE_t cell_id # Id of the cell in the cells array in the Tree - SIZE_t point_index # Index of the point at this cell (only defined - # # in non empty leaf) - bint is_leaf # Does this cell have children? - DTYPE_t squared_max_width # Squared value of the maximum width w - SIZE_t depth # Depth of the cell in the tree - SIZE_t cumulative_size # Number of points included in the subtree with - # # this cell as a root. + intp_t cell_id # Id of the cell in the cells array in the Tree + intp_t point_index # Index of the point at this cell (only defined + # # in non empty leaf) + bint is_leaf # Does this cell have children? + float32_t squared_max_width # Squared value of the maximum width w + intp_t depth # Depth of the cell in the tree + intp_t cumulative_size # Number of points included in the subtree with + # # this cell as a root. # Internal constants - DTYPE_t[3] center # Store the center for quick split of cells - DTYPE_t[3] barycenter # Keep track of the center of mass of the cell + float32_t[3] center # Store the center for quick split of cells + float32_t[3] barycenter # Keep track of the center of mass of the cell # Cell boundaries - DTYPE_t[3] min_bounds # Inferior boundaries of this cell (inclusive) - DTYPE_t[3] max_bounds # Superior boundaries of this cell (exclusive) + float32_t[3] min_bounds # Inferior boundaries of this cell (inclusive) + float32_t[3] max_bounds # Superior boundaries of this cell (exclusive) cdef class _QuadTree: @@ -57,40 +53,40 @@ cdef class _QuadTree: # Parameters of the tree cdef public int n_dimensions # Number of dimensions in X cdef public int verbose # Verbosity of the output - cdef SIZE_t n_cells_per_cell # Number of children per node. (2 ** n_dimension) + cdef intp_t n_cells_per_cell # Number of children per node. (2 ** n_dimension) # Tree inner structure - cdef public SIZE_t max_depth # Max depth of the tree - cdef public SIZE_t cell_count # Counter for node IDs - cdef public SIZE_t capacity # Capacity of tree, in terms of nodes - cdef public SIZE_t n_points # Total number of points + cdef public intp_t max_depth # Max depth of the tree + cdef public intp_t cell_count # Counter for node IDs + cdef public intp_t capacity # Capacity of tree, in terms of nodes + cdef public intp_t n_points # Total number of points cdef Cell* cells # Array of nodes # Point insertion methods - cdef int insert_point(self, DTYPE_t[3] point, SIZE_t point_index, - SIZE_t cell_id=*) except -1 nogil - cdef SIZE_t _insert_point_in_new_child(self, DTYPE_t[3] point, Cell* cell, - SIZE_t point_index, SIZE_t size=* + cdef int insert_point(self, float32_t[3] point, intp_t point_index, + intp_t cell_id=*) except -1 nogil + cdef intp_t _insert_point_in_new_child(self, float32_t[3] point, Cell* cell, + intp_t point_index, intp_t size=* ) noexcept nogil - cdef SIZE_t _select_child(self, DTYPE_t[3] point, Cell* cell) noexcept nogil - cdef bint _is_duplicate(self, DTYPE_t[3] point1, DTYPE_t[3] point2) noexcept nogil + cdef intp_t _select_child(self, float32_t[3] point, Cell* cell) noexcept nogil + cdef bint _is_duplicate(self, float32_t[3] point1, float32_t[3] point2) noexcept nogil # Create a summary of the Tree compare to a query point - cdef long summarize(self, DTYPE_t[3] point, DTYPE_t* results, - float squared_theta=*, SIZE_t cell_id=*, long idx=* + cdef long summarize(self, float32_t[3] point, float32_t* results, + float squared_theta=*, intp_t cell_id=*, long idx=* ) noexcept nogil # Internal cell initialization methods - cdef void _init_cell(self, Cell* cell, SIZE_t parent, SIZE_t depth) noexcept nogil - cdef void _init_root(self, DTYPE_t[3] min_bounds, DTYPE_t[3] max_bounds + cdef void _init_cell(self, Cell* cell, intp_t parent, intp_t depth) noexcept nogil + cdef void _init_root(self, float32_t[3] min_bounds, float32_t[3] max_bounds ) noexcept nogil # Private methods - cdef int _check_point_in_cell(self, DTYPE_t[3] point, Cell* cell + cdef int _check_point_in_cell(self, float32_t[3] point, Cell* cell ) except -1 nogil # Private array manipulation to manage the ``cells`` array - cdef int _resize(self, SIZE_t capacity) except -1 nogil - cdef int _resize_c(self, SIZE_t capacity=*) except -1 nogil - cdef int _get_cell(self, DTYPE_t[3] point, SIZE_t cell_id=*) except -1 nogil + cdef int _resize(self, intp_t capacity) except -1 nogil + cdef int _resize_c(self, intp_t capacity=*) except -1 nogil + cdef int _get_cell(self, float32_t[3] point, intp_t cell_id=*) except -1 nogil cdef Cell[:] _get_cell_ndarray(self) diff --git a/sklearn/neighbors/_quad_tree.pyx b/sklearn/neighbors/_quad_tree.pyx index 1da59c9f29206..e481e41ca65e4 100644 --- a/sklearn/neighbors/_quad_tree.pyx +++ b/sklearn/neighbors/_quad_tree.pyx @@ -80,11 +80,11 @@ cdef class _QuadTree: """Build a tree from an array of points X.""" cdef: int i - DTYPE_t[3] pt - DTYPE_t[3] min_bounds, max_bounds + float32_t[3] pt + float32_t[3] min_bounds, max_bounds # validate X and prepare for query - # X = check_array(X, dtype=DTYPE_t, order='C') + # X = check_array(X, dtype=float32_t, order='C') n_samples = X.shape[0] capacity = 100 @@ -113,13 +113,13 @@ cdef class _QuadTree: # Shrink the cells array to reduce memory usage self._resize(capacity=self.cell_count) - cdef int insert_point(self, DTYPE_t[3] point, SIZE_t point_index, - SIZE_t cell_id=0) except -1 nogil: + cdef int insert_point(self, float32_t[3] point, intp_t point_index, + intp_t cell_id=0) except -1 nogil: """Insert a point in the QuadTree.""" cdef int ax - cdef SIZE_t selected_child + cdef intp_t selected_child cdef Cell* cell = &self.cells[cell_id] - cdef SIZE_t n_point = cell.cumulative_size + cdef intp_t n_point = cell.cumulative_size if self.verbose > 10: printf("[QuadTree] Inserting depth %li\n", cell.depth) @@ -177,16 +177,16 @@ cdef class _QuadTree: return self.insert_point(point, point_index, cell_id) # XXX: This operation is not Thread safe - cdef SIZE_t _insert_point_in_new_child( - self, DTYPE_t[3] point, Cell* cell, SIZE_t point_index, SIZE_t size=1 + cdef intp_t _insert_point_in_new_child( + self, float32_t[3] point, Cell* cell, intp_t point_index, intp_t size=1 ) noexcept nogil: """Create a child of cell which will contain point.""" # Local variable definition cdef: - SIZE_t cell_id, cell_child_id, parent_id - DTYPE_t[3] save_point - DTYPE_t width + intp_t cell_id, cell_child_id, parent_id + float32_t[3] save_point + float32_t width Cell* child int i @@ -247,7 +247,7 @@ cdef class _QuadTree: return cell_id - cdef bint _is_duplicate(self, DTYPE_t[3] point1, DTYPE_t[3] point2) noexcept nogil: + cdef bint _is_duplicate(self, float32_t[3] point1, float32_t[3] point2) noexcept nogil: """Check if the two given points are equals.""" cdef int i cdef bint res = True @@ -256,11 +256,11 @@ cdef class _QuadTree: res &= fabsf(point1[i] - point2[i]) <= EPSILON return res - cdef SIZE_t _select_child(self, DTYPE_t[3] point, Cell* cell) noexcept nogil: + cdef intp_t _select_child(self, float32_t[3] point, Cell* cell) noexcept nogil: """Select the child of cell which contains the given query point.""" cdef: int i - SIZE_t selected_child = 0 + intp_t selected_child = 0 for i in range(self.n_dimensions): # Select the correct child cell to insert the point by comparing @@ -270,7 +270,7 @@ cdef class _QuadTree: selected_child += 1 return cell.children[selected_child] - cdef void _init_cell(self, Cell* cell, SIZE_t parent, SIZE_t depth) noexcept nogil: + cdef void _init_cell(self, Cell* cell, intp_t parent, intp_t depth) noexcept nogil: """Initialize a cell structure with some constants.""" cell.parent = parent cell.is_leaf = True @@ -280,12 +280,12 @@ cdef class _QuadTree: for i in range(self.n_cells_per_cell): cell.children[i] = SIZE_MAX - cdef void _init_root(self, DTYPE_t[3] min_bounds, DTYPE_t[3] max_bounds + cdef void _init_root(self, float32_t[3] min_bounds, float32_t[3] max_bounds ) noexcept nogil: """Initialize the root node with the given space boundaries""" cdef: int i - DTYPE_t width + float32_t width Cell* root = &self.cells[0] self._init_cell(root, -1, 0) @@ -299,7 +299,7 @@ cdef class _QuadTree: self.cell_count += 1 - cdef int _check_point_in_cell(self, DTYPE_t[3] point, Cell* cell + cdef int _check_point_in_cell(self, float32_t[3] point, Cell* cell ) except -1 nogil: """Check that the given point is in the cell boundaries.""" @@ -366,8 +366,8 @@ cdef class _QuadTree: "in children." .format(self.n_points, self.cells[0].cumulative_size)) - cdef long summarize(self, DTYPE_t[3] point, DTYPE_t* results, - float squared_theta=.5, SIZE_t cell_id=0, long idx=0 + cdef long summarize(self, float32_t[3] point, float32_t* results, + float squared_theta=.5, intp_t cell_id=0, long idx=0 ) noexcept nogil: """Summarize the tree compared to a query point. @@ -429,7 +429,7 @@ cdef class _QuadTree: # Otherwise, we go a higher level of resolution and into the leaves. if cell.is_leaf or ( (cell.squared_max_width / results[idx_d]) < squared_theta): - results[idx_d + 1] = cell.cumulative_size + results[idx_d + 1] = cell.cumulative_size return idx + self.n_dimensions + 2 else: @@ -446,7 +446,7 @@ cdef class _QuadTree: """return the id of the cell containing the query point or raise ValueError if the point is not in the tree """ - cdef DTYPE_t[3] query_pt + cdef float32_t[3] query_pt cdef int i assert len(point) == self.n_dimensions, ( @@ -458,14 +458,14 @@ cdef class _QuadTree: return self._get_cell(query_pt, 0) - cdef int _get_cell(self, DTYPE_t[3] point, SIZE_t cell_id=0 + cdef int _get_cell(self, float32_t[3] point, intp_t cell_id=0 ) except -1 nogil: """guts of get_cell. Return the id of the cell containing the query point or raise ValueError if the point is not in the tree""" cdef: - SIZE_t selected_child + intp_t selected_child Cell* cell = &self.cells[cell_id] if cell.is_leaf: @@ -562,7 +562,7 @@ cdef class _QuadTree: raise ValueError("Can't initialize array!") return arr - cdef int _resize(self, SIZE_t capacity) except -1 nogil: + cdef int _resize(self, intp_t capacity) except -1 nogil: """Resize all inner arrays to `capacity`, if `capacity` == -1, then double the size of the inner arrays. @@ -574,7 +574,7 @@ cdef class _QuadTree: with gil: raise MemoryError() - cdef int _resize_c(self, SIZE_t capacity=SIZE_MAX) except -1 nogil: + cdef int _resize_c(self, intp_t capacity=SIZE_MAX) except -1 nogil: """Guts of _resize Returns -1 in case of failure to allocate memory (and raise MemoryError) @@ -598,10 +598,10 @@ cdef class _QuadTree: self.capacity = capacity return 0 - def _py_summarize(self, DTYPE_t[:] query_pt, DTYPE_t[:, :] X, float angle): + def _py_summarize(self, float32_t[:] query_pt, float32_t[:, :] X, float angle): # Used for testing summarize cdef: - DTYPE_t[:] summary + float32_t[:] summary int n_samples n_samples = X.shape[0]