Skip to content

Fix Kd-tree time complexity problem. #11103

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions benchmarks/bench_binary_tree_partition.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from sklearn.neighbors import KDTree
import numpy as np
from time import time


def main():
test_cases = [
("random", np.random.rand(50000)),
("ordered", np.arange(50000, dtype=float)),
("reverse ordered", np.arange(50000, 0, -1, dtype=float)),
("duplicated", np.zeros([50000], dtype=float))
]
for name, case in test_cases:
expanded_case = np.expand_dims(case, -1)
begin = time()
tree = KDTree(expanded_case, leaf_size=1)
end = time()
del tree
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

garbage collector will take care of this, so, I think, we can remove it.

print("{name}: {time}s".format(name=name, time=end - begin))


if __name__ == "__main__":
main()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would just write the benchmark script directly without defining the main function

69 changes: 2 additions & 67 deletions sklearn/neighbors/_binary_tree.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,8 @@ from ._typedefs import DTYPE, ITYPE
from ._dist_metrics cimport (DistanceMetric, euclidean_dist, euclidean_rdist,
euclidean_dist_to_rdist, euclidean_rdist_to_dist)

from ._nth_element cimport partition_node_indices

cdef extern from "numpy/arrayobject.h":
void PyArray_ENABLEFLAGS(np.ndarray arr, int flags)

Expand Down Expand Up @@ -773,73 +775,6 @@ cdef ITYPE_t find_node_split_dim(DTYPE_t* data,
return j_max


cdef int partition_node_indices(DTYPE_t* data,
ITYPE_t* node_indices,
ITYPE_t split_dim,
ITYPE_t split_index,
ITYPE_t n_features,
ITYPE_t n_points) except -1:
"""Partition points in the node into two equal-sized groups.

Upon return, the values in node_indices will be rearranged such that
(assuming numpy-style indexing):

data[node_indices[0:split_index], split_dim]
<= data[node_indices[split_index], split_dim]

and

data[node_indices[split_index], split_dim]
<= data[node_indices[split_index:n_points], split_dim]

The algorithm is essentially a partial in-place quicksort around a
set pivot.

Parameters
----------
data : double pointer
Pointer to a 2D array of the training data, of shape [N, n_features].
N must be greater than any of the values in node_indices.
node_indices : int pointer
Pointer to a 1D array of length n_points. This lists the indices of
each of the points within the current node. This will be modified
in-place.
split_dim : int
the dimension on which to split. This will usually be computed via
the routine ``find_node_split_dim``
split_index : int
the index within node_indices around which to split the points.

Returns
-------
status : int
integer exit status. On return, the contents of node_indices are
modified as noted above.
"""
cdef ITYPE_t left, right, midindex, i
cdef DTYPE_t d1, d2
left = 0
right = n_points - 1

while True:
midindex = left
for i in range(left, right):
d1 = data[node_indices[i] * n_features + split_dim]
d2 = data[node_indices[right] * n_features + split_dim]
if d1 < d2:
swap(node_indices, i, midindex)
midindex += 1
swap(node_indices, midindex, right)
if midindex == split_index:
break
elif midindex < split_index:
left = midindex + 1
else:
right = midindex - 1

return 0


######################################################################
# NodeHeap : min-heap used to keep track of nodes during
# breadth-first query
Expand Down
9 changes: 9 additions & 0 deletions sklearn/neighbors/_nth_element.pxd
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from ._typedefs cimport DTYPE_t, ITYPE_t

cdef int partition_node_indices(
DTYPE_t *data,
ITYPE_t *node_indices,
ITYPE_t split_dim,
ITYPE_t split_index,
ITYPE_t n_features,
ITYPE_t n_points) except -1
25 changes: 25 additions & 0 deletions sklearn/neighbors/_nth_element.pyx
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
cdef extern from "_nth_element_inner.h":
void partition_node_indices_inner[D, I](
D *data,
I *node_indices,
I split_dim,
I split_index,
I n_features,
I n_points) except +


cdef int partition_node_indices(
DTYPE_t *data,
ITYPE_t *node_indices,
ITYPE_t split_dim,
ITYPE_t split_index,
ITYPE_t n_features,
ITYPE_t n_points) except -1:
partition_node_indices_inner(
data,
node_indices,
split_dim,
split_index,
n_features,
n_points)
return 0
33 changes: 33 additions & 0 deletions sklearn/neighbors/_nth_element_inner.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#include <algorithm>

template<class D, class I>
class IndexComparator {
private:
const D *data;
I split_dim, n_features;
public:
IndexComparator(const D *data, const I &split_dim, const I &n_features):
data(data), split_dim(split_dim), n_features(n_features) {}

bool operator()(const I &a, const I &b) const {
D a_value = data[a * n_features + split_dim];
D b_value = data[b * n_features + split_dim];
return a_value == b_value ? a < b : a_value < b_value;
}
};

template<class D, class I>
void partition_node_indices_inner(
const D *data,
I *node_indices,
const I &split_dim,
const I &split_index,
const I &n_features,
const I &n_points) {
IndexComparator<D, I> index_comparator(data, split_dim, n_features);
std::nth_element(
node_indices,
node_indices + split_index,
node_indices + n_points,
index_comparator);
}
6 changes: 6 additions & 0 deletions sklearn/neighbors/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@ def configuration(parent_package='', top_path=None):
include_dirs=[numpy.get_include()],
libraries=libraries)

config.add_extension('_nth_element',
sources=['_nth_element.pyx'],
include_dirs=[numpy.get_include()],
language="c++",
libraries=libraries)

config.add_extension('_dist_metrics',
sources=['_dist_metrics.pyx'],
include_dirs=[numpy.get_include(),
Expand Down