Skip to content

MAINT cython typedefs in _quad_tree #27351

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 34 additions & 38 deletions sklearn/neighbors/_quad_tree.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,7 @@
# See quad_tree.pyx for details.

cimport numpy as cnp

ctypedef cnp.npy_float32 DTYPE_t # Type of X
ctypedef cnp.npy_intp SIZE_t # Type for indices and counters
ctypedef cnp.npy_int32 INT32_t # Signed 32 bit integer
ctypedef cnp.npy_uint32 UINT32_t # Unsigned 32 bit integer
from ..utils._typedefs cimport float32_t, intp_t

# This is effectively an ifdef statement in Cython
# It allows us to write printf debugging lines
Expand All @@ -25,26 +21,26 @@ cdef struct Cell:
# Base storage structure for cells in a QuadTree object

# Tree structure
SIZE_t parent # Parent cell of this cell
SIZE_t[8] children # Array pointing to children of this cell
intp_t parent # Parent cell of this cell
intp_t[8] children # Array pointing to children of this cell

# Cell description
SIZE_t cell_id # Id of the cell in the cells array in the Tree
SIZE_t point_index # Index of the point at this cell (only defined
# # in non empty leaf)
bint is_leaf # Does this cell have children?
DTYPE_t squared_max_width # Squared value of the maximum width w
SIZE_t depth # Depth of the cell in the tree
SIZE_t cumulative_size # Number of points included in the subtree with
# # this cell as a root.
intp_t cell_id # Id of the cell in the cells array in the Tree
intp_t point_index # Index of the point at this cell (only defined
# # in non empty leaf)
bint is_leaf # Does this cell have children?
float32_t squared_max_width # Squared value of the maximum width w
intp_t depth # Depth of the cell in the tree
intp_t cumulative_size # Number of points included in the subtree with
# # this cell as a root.

# Internal constants
DTYPE_t[3] center # Store the center for quick split of cells
DTYPE_t[3] barycenter # Keep track of the center of mass of the cell
float32_t[3] center # Store the center for quick split of cells
float32_t[3] barycenter # Keep track of the center of mass of the cell

# Cell boundaries
DTYPE_t[3] min_bounds # Inferior boundaries of this cell (inclusive)
DTYPE_t[3] max_bounds # Superior boundaries of this cell (exclusive)
float32_t[3] min_bounds # Inferior boundaries of this cell (inclusive)
float32_t[3] max_bounds # Superior boundaries of this cell (exclusive)


cdef class _QuadTree:
Expand All @@ -57,40 +53,40 @@ cdef class _QuadTree:
# Parameters of the tree
cdef public int n_dimensions # Number of dimensions in X
cdef public int verbose # Verbosity of the output
cdef SIZE_t n_cells_per_cell # Number of children per node. (2 ** n_dimension)
cdef intp_t n_cells_per_cell # Number of children per node. (2 ** n_dimension)

# Tree inner structure
cdef public SIZE_t max_depth # Max depth of the tree
cdef public SIZE_t cell_count # Counter for node IDs
cdef public SIZE_t capacity # Capacity of tree, in terms of nodes
cdef public SIZE_t n_points # Total number of points
cdef public intp_t max_depth # Max depth of the tree
cdef public intp_t cell_count # Counter for node IDs
cdef public intp_t capacity # Capacity of tree, in terms of nodes
cdef public intp_t n_points # Total number of points
cdef Cell* cells # Array of nodes

# Point insertion methods
cdef int insert_point(self, DTYPE_t[3] point, SIZE_t point_index,
SIZE_t cell_id=*) except -1 nogil
cdef SIZE_t _insert_point_in_new_child(self, DTYPE_t[3] point, Cell* cell,
SIZE_t point_index, SIZE_t size=*
cdef int insert_point(self, float32_t[3] point, intp_t point_index,
intp_t cell_id=*) except -1 nogil
cdef intp_t _insert_point_in_new_child(self, float32_t[3] point, Cell* cell,
intp_t point_index, intp_t size=*
) noexcept nogil
cdef SIZE_t _select_child(self, DTYPE_t[3] point, Cell* cell) noexcept nogil
cdef bint _is_duplicate(self, DTYPE_t[3] point1, DTYPE_t[3] point2) noexcept nogil
cdef intp_t _select_child(self, float32_t[3] point, Cell* cell) noexcept nogil
cdef bint _is_duplicate(self, float32_t[3] point1, float32_t[3] point2) noexcept nogil

# Create a summary of the Tree compare to a query point
cdef long summarize(self, DTYPE_t[3] point, DTYPE_t* results,
float squared_theta=*, SIZE_t cell_id=*, long idx=*
cdef long summarize(self, float32_t[3] point, float32_t* results,
float squared_theta=*, intp_t cell_id=*, long idx=*
) noexcept nogil

# Internal cell initialization methods
cdef void _init_cell(self, Cell* cell, SIZE_t parent, SIZE_t depth) noexcept nogil
cdef void _init_root(self, DTYPE_t[3] min_bounds, DTYPE_t[3] max_bounds
cdef void _init_cell(self, Cell* cell, intp_t parent, intp_t depth) noexcept nogil
cdef void _init_root(self, float32_t[3] min_bounds, float32_t[3] max_bounds
) noexcept nogil

# Private methods
cdef int _check_point_in_cell(self, DTYPE_t[3] point, Cell* cell
cdef int _check_point_in_cell(self, float32_t[3] point, Cell* cell
) except -1 nogil

# Private array manipulation to manage the ``cells`` array
cdef int _resize(self, SIZE_t capacity) except -1 nogil
cdef int _resize_c(self, SIZE_t capacity=*) except -1 nogil
cdef int _get_cell(self, DTYPE_t[3] point, SIZE_t cell_id=*) except -1 nogil
cdef int _resize(self, intp_t capacity) except -1 nogil
cdef int _resize_c(self, intp_t capacity=*) except -1 nogil
cdef int _get_cell(self, float32_t[3] point, intp_t cell_id=*) except -1 nogil
cdef Cell[:] _get_cell_ndarray(self)
58 changes: 29 additions & 29 deletions sklearn/neighbors/_quad_tree.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,11 @@ cdef class _QuadTree:
"""Build a tree from an array of points X."""
cdef:
int i
DTYPE_t[3] pt
DTYPE_t[3] min_bounds, max_bounds
float32_t[3] pt
float32_t[3] min_bounds, max_bounds

# validate X and prepare for query
# X = check_array(X, dtype=DTYPE_t, order='C')
# X = check_array(X, dtype=float32_t, order='C')
n_samples = X.shape[0]

capacity = 100
Expand Down Expand Up @@ -113,13 +113,13 @@ cdef class _QuadTree:
# Shrink the cells array to reduce memory usage
self._resize(capacity=self.cell_count)

cdef int insert_point(self, DTYPE_t[3] point, SIZE_t point_index,
SIZE_t cell_id=0) except -1 nogil:
cdef int insert_point(self, float32_t[3] point, intp_t point_index,
intp_t cell_id=0) except -1 nogil:
"""Insert a point in the QuadTree."""
cdef int ax
cdef SIZE_t selected_child
cdef intp_t selected_child
cdef Cell* cell = &self.cells[cell_id]
cdef SIZE_t n_point = cell.cumulative_size
cdef intp_t n_point = cell.cumulative_size

if self.verbose > 10:
printf("[QuadTree] Inserting depth %li\n", cell.depth)
Expand Down Expand Up @@ -177,16 +177,16 @@ cdef class _QuadTree:
return self.insert_point(point, point_index, cell_id)

# XXX: This operation is not Thread safe
cdef SIZE_t _insert_point_in_new_child(
self, DTYPE_t[3] point, Cell* cell, SIZE_t point_index, SIZE_t size=1
cdef intp_t _insert_point_in_new_child(
self, float32_t[3] point, Cell* cell, intp_t point_index, intp_t size=1
) noexcept nogil:
"""Create a child of cell which will contain point."""

# Local variable definition
cdef:
SIZE_t cell_id, cell_child_id, parent_id
DTYPE_t[3] save_point
DTYPE_t width
intp_t cell_id, cell_child_id, parent_id
float32_t[3] save_point
float32_t width
Cell* child
int i

Expand Down Expand Up @@ -247,7 +247,7 @@ cdef class _QuadTree:

return cell_id

cdef bint _is_duplicate(self, DTYPE_t[3] point1, DTYPE_t[3] point2) noexcept nogil:
cdef bint _is_duplicate(self, float32_t[3] point1, float32_t[3] point2) noexcept nogil:
"""Check if the two given points are equals."""
cdef int i
cdef bint res = True
Expand All @@ -256,11 +256,11 @@ cdef class _QuadTree:
res &= fabsf(point1[i] - point2[i]) <= EPSILON
return res

cdef SIZE_t _select_child(self, DTYPE_t[3] point, Cell* cell) noexcept nogil:
cdef intp_t _select_child(self, float32_t[3] point, Cell* cell) noexcept nogil:
"""Select the child of cell which contains the given query point."""
cdef:
int i
SIZE_t selected_child = 0
intp_t selected_child = 0

for i in range(self.n_dimensions):
# Select the correct child cell to insert the point by comparing
Expand All @@ -270,7 +270,7 @@ cdef class _QuadTree:
selected_child += 1
return cell.children[selected_child]

cdef void _init_cell(self, Cell* cell, SIZE_t parent, SIZE_t depth) noexcept nogil:
cdef void _init_cell(self, Cell* cell, intp_t parent, intp_t depth) noexcept nogil:
"""Initialize a cell structure with some constants."""
cell.parent = parent
cell.is_leaf = True
Expand All @@ -280,12 +280,12 @@ cdef class _QuadTree:
for i in range(self.n_cells_per_cell):
cell.children[i] = SIZE_MAX

cdef void _init_root(self, DTYPE_t[3] min_bounds, DTYPE_t[3] max_bounds
cdef void _init_root(self, float32_t[3] min_bounds, float32_t[3] max_bounds
) noexcept nogil:
"""Initialize the root node with the given space boundaries"""
cdef:
int i
DTYPE_t width
float32_t width
Cell* root = &self.cells[0]

self._init_cell(root, -1, 0)
Expand All @@ -299,7 +299,7 @@ cdef class _QuadTree:

self.cell_count += 1

cdef int _check_point_in_cell(self, DTYPE_t[3] point, Cell* cell
cdef int _check_point_in_cell(self, float32_t[3] point, Cell* cell
) except -1 nogil:
"""Check that the given point is in the cell boundaries."""

Expand Down Expand Up @@ -366,8 +366,8 @@ cdef class _QuadTree:
"in children."
.format(self.n_points, self.cells[0].cumulative_size))

cdef long summarize(self, DTYPE_t[3] point, DTYPE_t* results,
float squared_theta=.5, SIZE_t cell_id=0, long idx=0
cdef long summarize(self, float32_t[3] point, float32_t* results,
float squared_theta=.5, intp_t cell_id=0, long idx=0
) noexcept nogil:
"""Summarize the tree compared to a query point.

Expand Down Expand Up @@ -429,7 +429,7 @@ cdef class _QuadTree:
# Otherwise, we go a higher level of resolution and into the leaves.
if cell.is_leaf or (
(cell.squared_max_width / results[idx_d]) < squared_theta):
results[idx_d + 1] = <DTYPE_t> cell.cumulative_size
results[idx_d + 1] = <float32_t> cell.cumulative_size
return idx + self.n_dimensions + 2

else:
Expand All @@ -446,7 +446,7 @@ cdef class _QuadTree:
"""return the id of the cell containing the query point or raise
ValueError if the point is not in the tree
"""
cdef DTYPE_t[3] query_pt
cdef float32_t[3] query_pt
cdef int i

assert len(point) == self.n_dimensions, (
Expand All @@ -458,14 +458,14 @@ cdef class _QuadTree:

return self._get_cell(query_pt, 0)

cdef int _get_cell(self, DTYPE_t[3] point, SIZE_t cell_id=0
cdef int _get_cell(self, float32_t[3] point, intp_t cell_id=0
) except -1 nogil:
"""guts of get_cell.

Return the id of the cell containing the query point or raise ValueError
if the point is not in the tree"""
cdef:
SIZE_t selected_child
intp_t selected_child
Cell* cell = &self.cells[cell_id]

if cell.is_leaf:
Expand Down Expand Up @@ -562,7 +562,7 @@ cdef class _QuadTree:
raise ValueError("Can't initialize array!")
return arr

cdef int _resize(self, SIZE_t capacity) except -1 nogil:
cdef int _resize(self, intp_t capacity) except -1 nogil:
"""Resize all inner arrays to `capacity`, if `capacity` == -1, then
double the size of the inner arrays.

Expand All @@ -574,7 +574,7 @@ cdef class _QuadTree:
with gil:
raise MemoryError()

cdef int _resize_c(self, SIZE_t capacity=SIZE_MAX) except -1 nogil:
cdef int _resize_c(self, intp_t capacity=SIZE_MAX) except -1 nogil:
"""Guts of _resize

Returns -1 in case of failure to allocate memory (and raise MemoryError)
Expand All @@ -598,10 +598,10 @@ cdef class _QuadTree:
self.capacity = capacity
return 0

def _py_summarize(self, DTYPE_t[:] query_pt, DTYPE_t[:, :] X, float angle):
def _py_summarize(self, float32_t[:] query_pt, float32_t[:, :] X, float angle):
# Used for testing summarize
cdef:
DTYPE_t[:] summary
float32_t[:] summary
int n_samples

n_samples = X.shape[0]
Expand Down