@@ -166,6 +166,7 @@ from ..utils._typedefs import DTYPE, ITYPE
166
166
from ..utils ._heap cimport heap_push
167
167
from ..utils ._sorting cimport simultaneous_sort as _simultaneous_sort
168
168
169
+ # TODO: use cnp.PyArray_ENABLEFLAGS when Cython>=3.0 is used.
169
170
cdef extern from "numpy/arrayobject.h" :
170
171
void PyArray_ENABLEFLAGS (cnp .ndarray arr , int flags )
171
172
@@ -511,8 +512,8 @@ cdef class NeighborsHeap:
511
512
n_nbrs : int
512
513
the size of each heap.
513
514
"""
514
- cdef cnp . ndarray distances_arr
515
- cdef cnp . ndarray indices_arr
515
+ cdef DTYPE_t [:, :: 1 ] distances_arr
516
+ cdef ITYPE_t [:, :: 1 ] indices_arr
516
517
517
518
cdef DTYPE_t [:, ::1 ] distances
518
519
cdef ITYPE_t [:, ::1 ] indices
@@ -538,7 +539,7 @@ cdef class NeighborsHeap:
538
539
"""
539
540
if sort :
540
541
self ._sort ()
541
- return self .distances_arr , self .indices_arr
542
+ return self .distances_arr . base , self .indices_arr . base
542
543
543
544
cdef inline DTYPE_t largest (self , ITYPE_t row ) nogil except - 1 :
544
545
"""Return the largest distance in the given row"""
@@ -643,8 +644,8 @@ cdef class NodeHeap:
643
644
644
645
heap[i].val < min(heap[2 * i + 1].val, heap[2 * i + 2].val)
645
646
"""
646
- cdef cnp . ndarray data_arr
647
- cdef NodeHeapData_t [:: 1 ] data
647
+ cdef NodeHeapData_t [:] data_arr
648
+ cdef NodeHeapData_t [:] data
648
649
cdef ITYPE_t n
649
650
650
651
def __cinit__ (self ):
@@ -660,13 +661,16 @@ cdef class NodeHeap:
660
661
661
662
cdef int resize (self , ITYPE_t new_size ) except - 1 :
662
663
"""Resize the heap to be either larger or smaller"""
663
- cdef NodeHeapData_t * data_ptr
664
- cdef NodeHeapData_t * new_data_ptr
665
- cdef ITYPE_t i
666
- cdef ITYPE_t size = self .data .shape [0 ]
667
- cdef cnp .ndarray new_data_arr = np .zeros (new_size ,
668
- dtype = NodeHeapData )
669
- cdef NodeHeapData_t [::1 ] new_data = new_data_arr
664
+ cdef :
665
+ NodeHeapData_t * data_ptr
666
+ NodeHeapData_t * new_data_ptr
667
+ ITYPE_t i
668
+ ITYPE_t size = self .data .shape [0 ]
669
+ NodeHeapData_t [:] new_data_arr = np .zeros (
670
+ new_size ,
671
+ dtype = NodeHeapData ,
672
+ )
673
+ NodeHeapData_t [:] new_data = new_data_arr
670
674
671
675
if size > 0 and new_size > 0 :
672
676
data_ptr = & self .data [0 ]
@@ -769,11 +773,11 @@ VALID_METRIC_IDS = get_valid_metric_ids(VALID_METRICS)
769
773
# Binary Tree class
770
774
cdef class BinaryTree :
771
775
772
- cdef cnp . ndarray data_arr
773
- cdef cnp . ndarray sample_weight_arr
774
- cdef cnp . ndarray idx_array_arr
775
- cdef cnp . ndarray node_data_arr
776
- cdef cnp . ndarray node_bounds_arr
776
+ cdef const DTYPE_t [:, :: 1 ] data_arr
777
+ cdef const DTYPE_t [:: 1 ] sample_weight_arr
778
+ cdef const ITYPE_t [:: 1 ] idx_array_arr
779
+ cdef const NodeData_t [:: 1 ] node_data_arr
780
+ cdef const DTYPE_t [:, :, :: 1 ] node_bounds_arr
777
781
778
782
cdef readonly const DTYPE_t [:, ::1 ] data
779
783
cdef readonly const DTYPE_t [::1 ] sample_weight
@@ -869,7 +873,7 @@ cdef class BinaryTree:
869
873
# Allocate tree-specific data
870
874
allocate_data (self , self .n_nodes , n_features )
871
875
self ._recursive_build (
872
- node_data = self .node_data_arr ,
876
+ node_data = self .node_data_arr . base ,
873
877
i_node = 0 ,
874
878
idx_start = 0 ,
875
879
idx_end = n_samples
@@ -905,15 +909,15 @@ cdef class BinaryTree:
905
909
"""
906
910
if self .sample_weight is not None :
907
911
# pass the numpy array
908
- sample_weight_arr = self .sample_weight_arr
912
+ sample_weight_arr = self .sample_weight_arr . base
909
913
else :
910
914
# pass None to avoid confusion with the empty place holder
911
915
# of size 1 from __cinit__
912
916
sample_weight_arr = None
913
- return (self .data_arr ,
914
- self .idx_array_arr ,
915
- self .node_data_arr ,
916
- self .node_bounds_arr ,
917
+ return (self .data_arr . base ,
918
+ self .idx_array_arr . base ,
919
+ self .node_data_arr . base ,
920
+ self .node_bounds_arr . base ,
917
921
int (self .leaf_size ),
918
922
int (self .n_levels ),
919
923
int (self .n_nodes ),
@@ -993,8 +997,12 @@ cdef class BinaryTree:
993
997
arrays: tuple of array
994
998
Arrays for storing tree data, index, node data and node bounds.
995
999
"""
996
- return (self .data_arr , self .idx_array_arr ,
997
- self .node_data_arr , self .node_bounds_arr )
1000
+ return (
1001
+ self .data_arr .base ,
1002
+ self .idx_array_arr .base ,
1003
+ self .node_data_arr .base ,
1004
+ self .node_bounds_arr .base ,
1005
+ )
998
1006
999
1007
cdef inline DTYPE_t dist (self , DTYPE_t * x1 , DTYPE_t * x2 ,
1000
1008
ITYPE_t size ) nogil except - 1 :
@@ -1340,14 +1348,14 @@ cdef class BinaryTree:
1340
1348
# make a new numpy array that wraps the existing data
1341
1349
indices_npy [i ] = cnp .PyArray_SimpleNewFromData (1 , & counts [i ], cnp .NPY_INTP , indices [i ])
1342
1350
# make sure the data will be freed when the numpy array is garbage collected
1343
- PyArray_ENABLEFLAGS (indices_npy [i ], cnp .NPY_OWNDATA )
1351
+ PyArray_ENABLEFLAGS (indices_npy [i ], cnp .NPY_ARRAY_OWNDATA )
1344
1352
# make sure the data is not freed twice
1345
1353
indices [i ] = NULL
1346
1354
1347
1355
# make a new numpy array that wraps the existing data
1348
1356
distances_npy [i ] = cnp .PyArray_SimpleNewFromData (1 , & counts [i ], cnp .NPY_DOUBLE , distances [i ])
1349
1357
# make sure the data will be freed when the numpy array is garbage collected
1350
- PyArray_ENABLEFLAGS (distances_npy [i ], cnp .NPY_OWNDATA )
1358
+ PyArray_ENABLEFLAGS (distances_npy [i ], cnp .NPY_ARRAY_OWNDATA )
1351
1359
# make sure the data is not freed twice
1352
1360
distances [i ] = NULL
1353
1361
@@ -1360,7 +1368,7 @@ cdef class BinaryTree:
1360
1368
# make a new numpy array that wraps the existing data
1361
1369
indices_npy [i ] = cnp .PyArray_SimpleNewFromData (1 , & counts [i ], cnp .NPY_INTP , indices [i ])
1362
1370
# make sure the data will be freed when the numpy array is garbage collected
1363
- PyArray_ENABLEFLAGS (indices_npy [i ], cnp .NPY_OWNDATA )
1371
+ PyArray_ENABLEFLAGS (indices_npy [i ], cnp .NPY_ARRAY_OWNDATA )
1364
1372
# make sure the data is not freed twice
1365
1373
indices [i ] = NULL
1366
1374
0 commit comments