diff --git a/benchmarks/bench_binary_tree_partition.py b/benchmarks/bench_binary_tree_partition.py new file mode 100644 index 0000000000000..b10e9b65c5657 --- /dev/null +++ b/benchmarks/bench_binary_tree_partition.py @@ -0,0 +1,23 @@ +from sklearn.neighbors import KDTree +import numpy as np +from time import time + + +def main(): + test_cases = [ + ("random", np.random.rand(50000)), + ("ordered", np.arange(50000, dtype=float)), + ("reverse ordered", np.arange(50000, 0, -1, dtype=float)), + ("duplicated", np.zeros([50000], dtype=float)) + ] + for name, case in test_cases: + expanded_case = np.expand_dims(case, -1) + begin = time() + tree = KDTree(expanded_case, leaf_size=1) + end = time() + del tree + print("{name}: {time}s".format(name=name, time=end - begin)) + + +if __name__ == "__main__": + main() diff --git a/sklearn/neighbors/_binary_tree.pxi b/sklearn/neighbors/_binary_tree.pxi index ef6a2a2d5d330..0ff0f94dfa563 100755 --- a/sklearn/neighbors/_binary_tree.pxi +++ b/sklearn/neighbors/_binary_tree.pxi @@ -159,6 +159,8 @@ from ._typedefs import DTYPE, ITYPE from ._dist_metrics cimport (DistanceMetric, euclidean_dist, euclidean_rdist, euclidean_dist_to_rdist, euclidean_rdist_to_dist) +from ._nth_element cimport partition_node_indices + cdef extern from "numpy/arrayobject.h": void PyArray_ENABLEFLAGS(np.ndarray arr, int flags) @@ -773,73 +775,6 @@ cdef ITYPE_t find_node_split_dim(DTYPE_t* data, return j_max -cdef int partition_node_indices(DTYPE_t* data, - ITYPE_t* node_indices, - ITYPE_t split_dim, - ITYPE_t split_index, - ITYPE_t n_features, - ITYPE_t n_points) except -1: - """Partition points in the node into two equal-sized groups. - - Upon return, the values in node_indices will be rearranged such that - (assuming numpy-style indexing): - - data[node_indices[0:split_index], split_dim] - <= data[node_indices[split_index], split_dim] - - and - - data[node_indices[split_index], split_dim] - <= data[node_indices[split_index:n_points], split_dim] - - The algorithm is essentially a partial in-place quicksort around a - set pivot. - - Parameters - ---------- - data : double pointer - Pointer to a 2D array of the training data, of shape [N, n_features]. - N must be greater than any of the values in node_indices. - node_indices : int pointer - Pointer to a 1D array of length n_points. This lists the indices of - each of the points within the current node. This will be modified - in-place. - split_dim : int - the dimension on which to split. This will usually be computed via - the routine ``find_node_split_dim`` - split_index : int - the index within node_indices around which to split the points. - - Returns - ------- - status : int - integer exit status. On return, the contents of node_indices are - modified as noted above. - """ - cdef ITYPE_t left, right, midindex, i - cdef DTYPE_t d1, d2 - left = 0 - right = n_points - 1 - - while True: - midindex = left - for i in range(left, right): - d1 = data[node_indices[i] * n_features + split_dim] - d2 = data[node_indices[right] * n_features + split_dim] - if d1 < d2: - swap(node_indices, i, midindex) - midindex += 1 - swap(node_indices, midindex, right) - if midindex == split_index: - break - elif midindex < split_index: - left = midindex + 1 - else: - right = midindex - 1 - - return 0 - - ###################################################################### # NodeHeap : min-heap used to keep track of nodes during # breadth-first query diff --git a/sklearn/neighbors/_nth_element.pxd b/sklearn/neighbors/_nth_element.pxd new file mode 100644 index 0000000000000..522e826632824 --- /dev/null +++ b/sklearn/neighbors/_nth_element.pxd @@ -0,0 +1,9 @@ +from ._typedefs cimport DTYPE_t, ITYPE_t + +cdef int partition_node_indices( + DTYPE_t *data, + ITYPE_t *node_indices, + ITYPE_t split_dim, + ITYPE_t split_index, + ITYPE_t n_features, + ITYPE_t n_points) except -1 diff --git a/sklearn/neighbors/_nth_element.pyx b/sklearn/neighbors/_nth_element.pyx new file mode 100644 index 0000000000000..58cf9563c515c --- /dev/null +++ b/sklearn/neighbors/_nth_element.pyx @@ -0,0 +1,25 @@ +cdef extern from "_nth_element_inner.h": + void partition_node_indices_inner[D, I]( + D *data, + I *node_indices, + I split_dim, + I split_index, + I n_features, + I n_points) except + + + +cdef int partition_node_indices( + DTYPE_t *data, + ITYPE_t *node_indices, + ITYPE_t split_dim, + ITYPE_t split_index, + ITYPE_t n_features, + ITYPE_t n_points) except -1: + partition_node_indices_inner( + data, + node_indices, + split_dim, + split_index, + n_features, + n_points) + return 0 diff --git a/sklearn/neighbors/_nth_element_inner.h b/sklearn/neighbors/_nth_element_inner.h new file mode 100644 index 0000000000000..816addcf6b3db --- /dev/null +++ b/sklearn/neighbors/_nth_element_inner.h @@ -0,0 +1,33 @@ +#include + +template +class IndexComparator { +private: + const D *data; + I split_dim, n_features; +public: + IndexComparator(const D *data, const I &split_dim, const I &n_features): + data(data), split_dim(split_dim), n_features(n_features) {} + + bool operator()(const I &a, const I &b) const { + D a_value = data[a * n_features + split_dim]; + D b_value = data[b * n_features + split_dim]; + return a_value == b_value ? a < b : a_value < b_value; + } +}; + +template +void partition_node_indices_inner( + const D *data, + I *node_indices, + const I &split_dim, + const I &split_index, + const I &n_features, + const I &n_points) { + IndexComparator index_comparator(data, split_dim, n_features); + std::nth_element( + node_indices, + node_indices + split_index, + node_indices + n_points, + index_comparator); +} diff --git a/sklearn/neighbors/setup.py b/sklearn/neighbors/setup.py index 9264044678193..fac1d63a75eb1 100644 --- a/sklearn/neighbors/setup.py +++ b/sklearn/neighbors/setup.py @@ -20,6 +20,12 @@ def configuration(parent_package='', top_path=None): include_dirs=[numpy.get_include()], libraries=libraries) + config.add_extension('_nth_element', + sources=['_nth_element.pyx'], + include_dirs=[numpy.get_include()], + language="c++", + libraries=libraries) + config.add_extension('_dist_metrics', sources=['_dist_metrics.pyx'], include_dirs=[numpy.get_include(),