From 82eb949edf3ec5cd2e18c9abf13e504d3abe3f6e Mon Sep 17 00:00:00 2001 From: jiefangxuanyan <505745416@qq.com> Date: Thu, 3 May 2018 16:06:24 +0800 Subject: [PATCH 01/14] Use std::nth_element from C++ standard library. --- sklearn/neighbors/binary_tree.pxi | 69 +-------------------------- sklearn/neighbors/nth_element.pxd | 8 ++++ sklearn/neighbors/nth_element.pyx | 9 ++++ sklearn/neighbors/setup.py | 7 +++ sklearn/neighbors/src/nth_element.cpp | 31 ++++++++++++ sklearn/neighbors/src/nth_element.h | 9 ++++ 6 files changed, 66 insertions(+), 67 deletions(-) create mode 100644 sklearn/neighbors/nth_element.pxd create mode 100644 sklearn/neighbors/nth_element.pyx create mode 100644 sklearn/neighbors/src/nth_element.cpp create mode 100644 sklearn/neighbors/src/nth_element.h diff --git a/sklearn/neighbors/binary_tree.pxi b/sklearn/neighbors/binary_tree.pxi index edf78257c9b23..1764d0b07583e 100755 --- a/sklearn/neighbors/binary_tree.pxi +++ b/sklearn/neighbors/binary_tree.pxi @@ -161,6 +161,8 @@ from dist_metrics cimport (DistanceMetric, euclidean_dist, euclidean_rdist, cdef extern from "numpy/arrayobject.h": void PyArray_ENABLEFLAGS(np.ndarray arr, int flags) +from nth_element cimport partition_node_indices + np.import_array() # some handy constants @@ -791,73 +793,6 @@ cdef ITYPE_t find_node_split_dim(DTYPE_t* data, return j_max -cdef int partition_node_indices(DTYPE_t* data, - ITYPE_t* node_indices, - ITYPE_t split_dim, - ITYPE_t split_index, - ITYPE_t n_features, - ITYPE_t n_points) except -1: - """Partition points in the node into two equal-sized groups. - - Upon return, the values in node_indices will be rearranged such that - (assuming numpy-style indexing): - - data[node_indices[0:split_index], split_dim] - <= data[node_indices[split_index], split_dim] - - and - - data[node_indices[split_index], split_dim] - <= data[node_indices[split_index:n_points], split_dim] - - The algorithm is essentially a partial in-place quicksort around a - set pivot. - - Parameters - ---------- - data : double pointer - Pointer to a 2D array of the training data, of shape [N, n_features]. - N must be greater than any of the values in node_indices. - node_indices : int pointer - Pointer to a 1D array of length n_points. This lists the indices of - each of the points within the current node. This will be modified - in-place. - split_dim : int - the dimension on which to split. This will usually be computed via - the routine ``find_node_split_dim`` - split_index : int - the index within node_indices around which to split the points. - - Returns - ------- - status : int - integer exit status. On return, the contents of node_indices are - modified as noted above. - """ - cdef ITYPE_t left, right, midindex, i - cdef DTYPE_t d1, d2 - left = 0 - right = n_points - 1 - - while True: - midindex = left - for i in range(left, right): - d1 = data[node_indices[i] * n_features + split_dim] - d2 = data[node_indices[right] * n_features + split_dim] - if d1 < d2: - swap(node_indices, i, midindex) - midindex += 1 - swap(node_indices, midindex, right) - if midindex == split_index: - break - elif midindex < split_index: - left = midindex + 1 - else: - right = midindex - 1 - - return 0 - - ###################################################################### # NodeHeap : min-heap used to keep track of nodes during # breadth-first query diff --git a/sklearn/neighbors/nth_element.pxd b/sklearn/neighbors/nth_element.pxd new file mode 100644 index 0000000000000..89683c112bcbb --- /dev/null +++ b/sklearn/neighbors/nth_element.pxd @@ -0,0 +1,8 @@ +from typedefs cimport DTYPE_t, ITYPE_t + +cdef int partition_node_indices(DTYPE_t *data, + ITYPE_t *node_indices, + ITYPE_t split_dim, + ITYPE_t split_index, + ITYPE_t n_features, + ITYPE_t n_points) diff --git a/sklearn/neighbors/nth_element.pyx b/sklearn/neighbors/nth_element.pyx new file mode 100644 index 0000000000000..a897de30ede35 --- /dev/null +++ b/sklearn/neighbors/nth_element.pyx @@ -0,0 +1,9 @@ +from typedefs cimport DTYPE_t, ITYPE_t + +cdef extern from "nth_element.h": + int partition_node_indices(DTYPE_t*data, + ITYPE_t*node_indices, + ITYPE_t split_dim, + ITYPE_t split_index, + ITYPE_t n_features, + ITYPE_t n_points) diff --git a/sklearn/neighbors/setup.py b/sklearn/neighbors/setup.py index 8b1ad7bac9fab..1191234183523 100644 --- a/sklearn/neighbors/setup.py +++ b/sklearn/neighbors/setup.py @@ -20,6 +20,13 @@ def configuration(parent_package='', top_path=None): include_dirs=[numpy.get_include()], libraries=libraries) + config.add_extension('nth_element', + sources=['nth_element.pyx', + os.path.join('src', 'nth_element.cpp')], + include_dirs=['src'], + language="c++", + libraries=libraries) + config.add_extension('dist_metrics', sources=['dist_metrics.pyx'], include_dirs=[numpy.get_include(), diff --git a/sklearn/neighbors/src/nth_element.cpp b/sklearn/neighbors/src/nth_element.cpp new file mode 100644 index 0000000000000..630036eec22f2 --- /dev/null +++ b/sklearn/neighbors/src/nth_element.cpp @@ -0,0 +1,31 @@ +#include "nth_element.h" +#include + +class IndexComparator { + double *data; + Py_intptr_t split_dim, n_features; + +public: + IndexComparator(double *data, Py_intptr_t split_dim, Py_intptr_t n_features): + data(data), split_dim(split_dim), n_features(n_features) {} + + bool operator()(Py_intptr_t a, Py_intptr_t b) { + return data[a * n_features + split_dim] + < data[b * n_features + split_dim]; + } +}; + +int partition_node_indices(double *data, + Py_intptr_t *node_indices, + Py_intptr_t split_dim, + Py_intptr_t split_index, + Py_intptr_t n_features, + Py_intptr_t n_points) { + IndexComparator index_comparator(data, split_dim, n_features); + std::nth_element(node_indices, + node_indices + split_index, + node_indices + n_points, + index_comparator); + return 0; +} + diff --git a/sklearn/neighbors/src/nth_element.h b/sklearn/neighbors/src/nth_element.h new file mode 100644 index 0000000000000..cf2403271a65d --- /dev/null +++ b/sklearn/neighbors/src/nth_element.h @@ -0,0 +1,9 @@ +#include + +int partition_node_indices(double *data, + Py_intptr_t *node_indices, + Py_intptr_t split_dim, + Py_intptr_t split_index, + Py_intptr_t n_features, + Py_intptr_t n_points); + From 076069a39489e160f2d485697d2b13c55749612b Mon Sep 17 00:00:00 2001 From: jiefangxuanyan <505745416@qq.com> Date: Fri, 8 Jun 2018 15:07:47 +0800 Subject: [PATCH 02/14] Add a benchmark to expose time complexity degeneration problem of the original partition algorithm. --- benchmarks/bench_binary_tree_partition.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) create mode 100644 benchmarks/bench_binary_tree_partition.py diff --git a/benchmarks/bench_binary_tree_partition.py b/benchmarks/bench_binary_tree_partition.py new file mode 100644 index 0000000000000..ee0aa6188ef9b --- /dev/null +++ b/benchmarks/bench_binary_tree_partition.py @@ -0,0 +1,22 @@ +from sklearn.neighbors import KDTree +import numpy as np +from time import time + + +def main(): + test_cases = [ + ("ordered", np.arange(50000, dtype=float)), + ("reverse ordered", np.arange(50000, 0, -1, dtype=float)), + ("duplicated", np.zeros([50000], dtype=float)) + ] + for name, case in test_cases: + expanded_case = np.expand_dims(case, -1) + begin = time() + tree = KDTree(expanded_case, leaf_size=1) + end = time() + del tree + print("{name}: {time}s".format(name=name, time=end - begin)) + + +if __name__ == "__main__": + main() From 3c34c7c8a47391b3ccf695744c905bc8c6181c53 Mon Sep 17 00:00:00 2001 From: jiefangxuanyan <505745416@qq.com> Date: Fri, 8 Jun 2018 15:12:43 +0800 Subject: [PATCH 03/14] Added a general case for binary tree partition algorithm benchmark. --- benchmarks/bench_binary_tree_partition.py | 1 + 1 file changed, 1 insertion(+) diff --git a/benchmarks/bench_binary_tree_partition.py b/benchmarks/bench_binary_tree_partition.py index ee0aa6188ef9b..b10e9b65c5657 100644 --- a/benchmarks/bench_binary_tree_partition.py +++ b/benchmarks/bench_binary_tree_partition.py @@ -5,6 +5,7 @@ def main(): test_cases = [ + ("random", np.random.rand(50000)), ("ordered", np.arange(50000, dtype=float)), ("reverse ordered", np.arange(50000, 0, -1, dtype=float)), ("duplicated", np.zeros([50000], dtype=float)) From cd112a89f0fdfa47a6ea797443addd8543004786 Mon Sep 17 00:00:00 2001 From: jiefangxuanyan <505745416@qq.com> Date: Sat, 11 Apr 2020 20:25:25 +0800 Subject: [PATCH 04/14] Multiple changes on the nth_element extension: 1. Add "_" in the file name; 2. Make the comparator stable; 3. Use a template implementation of partition_node_indices, the original one has hard-coded types(DTYPE_t -> double, ITYPE_t -> Py_intptr_t) which would be easily broken when these type changes. --- sklearn/neighbors/_binary_tree.pxi | 4 ++-- sklearn/neighbors/_nth_element.pxd | 9 +++++++ sklearn/neighbors/_nth_element.pyx | 25 +++++++++++++++++++ sklearn/neighbors/_nth_element_inner.h | 33 ++++++++++++++++++++++++++ sklearn/neighbors/nth_element.pxd | 8 ------- sklearn/neighbors/nth_element.pyx | 9 ------- sklearn/neighbors/setup.py | 7 +++--- sklearn/neighbors/src/nth_element.cpp | 31 ------------------------ sklearn/neighbors/src/nth_element.h | 9 ------- 9 files changed, 72 insertions(+), 63 deletions(-) create mode 100644 sklearn/neighbors/_nth_element.pxd create mode 100644 sklearn/neighbors/_nth_element.pyx create mode 100644 sklearn/neighbors/_nth_element_inner.h delete mode 100644 sklearn/neighbors/nth_element.pxd delete mode 100644 sklearn/neighbors/nth_element.pyx delete mode 100644 sklearn/neighbors/src/nth_element.cpp delete mode 100644 sklearn/neighbors/src/nth_element.h diff --git a/sklearn/neighbors/_binary_tree.pxi b/sklearn/neighbors/_binary_tree.pxi index 4e757534d4eab..0ff0f94dfa563 100755 --- a/sklearn/neighbors/_binary_tree.pxi +++ b/sklearn/neighbors/_binary_tree.pxi @@ -159,11 +159,11 @@ from ._typedefs import DTYPE, ITYPE from ._dist_metrics cimport (DistanceMetric, euclidean_dist, euclidean_rdist, euclidean_dist_to_rdist, euclidean_rdist_to_dist) +from ._nth_element cimport partition_node_indices + cdef extern from "numpy/arrayobject.h": void PyArray_ENABLEFLAGS(np.ndarray arr, int flags) -from nth_element cimport partition_node_indices - np.import_array() # some handy constants diff --git a/sklearn/neighbors/_nth_element.pxd b/sklearn/neighbors/_nth_element.pxd new file mode 100644 index 0000000000000..522e826632824 --- /dev/null +++ b/sklearn/neighbors/_nth_element.pxd @@ -0,0 +1,9 @@ +from ._typedefs cimport DTYPE_t, ITYPE_t + +cdef int partition_node_indices( + DTYPE_t *data, + ITYPE_t *node_indices, + ITYPE_t split_dim, + ITYPE_t split_index, + ITYPE_t n_features, + ITYPE_t n_points) except -1 diff --git a/sklearn/neighbors/_nth_element.pyx b/sklearn/neighbors/_nth_element.pyx new file mode 100644 index 0000000000000..58cf9563c515c --- /dev/null +++ b/sklearn/neighbors/_nth_element.pyx @@ -0,0 +1,25 @@ +cdef extern from "_nth_element_inner.h": + void partition_node_indices_inner[D, I]( + D *data, + I *node_indices, + I split_dim, + I split_index, + I n_features, + I n_points) except + + + +cdef int partition_node_indices( + DTYPE_t *data, + ITYPE_t *node_indices, + ITYPE_t split_dim, + ITYPE_t split_index, + ITYPE_t n_features, + ITYPE_t n_points) except -1: + partition_node_indices_inner( + data, + node_indices, + split_dim, + split_index, + n_features, + n_points) + return 0 diff --git a/sklearn/neighbors/_nth_element_inner.h b/sklearn/neighbors/_nth_element_inner.h new file mode 100644 index 0000000000000..816addcf6b3db --- /dev/null +++ b/sklearn/neighbors/_nth_element_inner.h @@ -0,0 +1,33 @@ +#include + +template +class IndexComparator { +private: + const D *data; + I split_dim, n_features; +public: + IndexComparator(const D *data, const I &split_dim, const I &n_features): + data(data), split_dim(split_dim), n_features(n_features) {} + + bool operator()(const I &a, const I &b) const { + D a_value = data[a * n_features + split_dim]; + D b_value = data[b * n_features + split_dim]; + return a_value == b_value ? a < b : a_value < b_value; + } +}; + +template +void partition_node_indices_inner( + const D *data, + I *node_indices, + const I &split_dim, + const I &split_index, + const I &n_features, + const I &n_points) { + IndexComparator index_comparator(data, split_dim, n_features); + std::nth_element( + node_indices, + node_indices + split_index, + node_indices + n_points, + index_comparator); +} diff --git a/sklearn/neighbors/nth_element.pxd b/sklearn/neighbors/nth_element.pxd deleted file mode 100644 index 89683c112bcbb..0000000000000 --- a/sklearn/neighbors/nth_element.pxd +++ /dev/null @@ -1,8 +0,0 @@ -from typedefs cimport DTYPE_t, ITYPE_t - -cdef int partition_node_indices(DTYPE_t *data, - ITYPE_t *node_indices, - ITYPE_t split_dim, - ITYPE_t split_index, - ITYPE_t n_features, - ITYPE_t n_points) diff --git a/sklearn/neighbors/nth_element.pyx b/sklearn/neighbors/nth_element.pyx deleted file mode 100644 index a897de30ede35..0000000000000 --- a/sklearn/neighbors/nth_element.pyx +++ /dev/null @@ -1,9 +0,0 @@ -from typedefs cimport DTYPE_t, ITYPE_t - -cdef extern from "nth_element.h": - int partition_node_indices(DTYPE_t*data, - ITYPE_t*node_indices, - ITYPE_t split_dim, - ITYPE_t split_index, - ITYPE_t n_features, - ITYPE_t n_points) diff --git a/sklearn/neighbors/setup.py b/sklearn/neighbors/setup.py index e6d1f2d0eb602..fac1d63a75eb1 100644 --- a/sklearn/neighbors/setup.py +++ b/sklearn/neighbors/setup.py @@ -20,10 +20,9 @@ def configuration(parent_package='', top_path=None): include_dirs=[numpy.get_include()], libraries=libraries) - config.add_extension('nth_element', - sources=['nth_element.pyx', - os.path.join('src', 'nth_element.cpp')], - include_dirs=['src'], + config.add_extension('_nth_element', + sources=['_nth_element.pyx'], + include_dirs=[numpy.get_include()], language="c++", libraries=libraries) diff --git a/sklearn/neighbors/src/nth_element.cpp b/sklearn/neighbors/src/nth_element.cpp deleted file mode 100644 index 630036eec22f2..0000000000000 --- a/sklearn/neighbors/src/nth_element.cpp +++ /dev/null @@ -1,31 +0,0 @@ -#include "nth_element.h" -#include - -class IndexComparator { - double *data; - Py_intptr_t split_dim, n_features; - -public: - IndexComparator(double *data, Py_intptr_t split_dim, Py_intptr_t n_features): - data(data), split_dim(split_dim), n_features(n_features) {} - - bool operator()(Py_intptr_t a, Py_intptr_t b) { - return data[a * n_features + split_dim] - < data[b * n_features + split_dim]; - } -}; - -int partition_node_indices(double *data, - Py_intptr_t *node_indices, - Py_intptr_t split_dim, - Py_intptr_t split_index, - Py_intptr_t n_features, - Py_intptr_t n_points) { - IndexComparator index_comparator(data, split_dim, n_features); - std::nth_element(node_indices, - node_indices + split_index, - node_indices + n_points, - index_comparator); - return 0; -} - diff --git a/sklearn/neighbors/src/nth_element.h b/sklearn/neighbors/src/nth_element.h deleted file mode 100644 index cf2403271a65d..0000000000000 --- a/sklearn/neighbors/src/nth_element.h +++ /dev/null @@ -1,9 +0,0 @@ -#include - -int partition_node_indices(double *data, - Py_intptr_t *node_indices, - Py_intptr_t split_dim, - Py_intptr_t split_index, - Py_intptr_t n_features, - Py_intptr_t n_points); - From 719e6c8fdd6b078f20c5e6cf834ec14038684de2 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Tue, 16 Feb 2021 09:45:24 +0100 Subject: [PATCH 05/14] Inline C++ comparator and interface --- sklearn/neighbors/_nth_element.pyx | 51 ++++++++++++++++++++++---- sklearn/neighbors/_nth_element_inner.h | 33 ----------------- 2 files changed, 44 insertions(+), 40 deletions(-) delete mode 100644 sklearn/neighbors/_nth_element_inner.h diff --git a/sklearn/neighbors/_nth_element.pyx b/sklearn/neighbors/_nth_element.pyx index 58cf9563c515c..06f331c44089a 100644 --- a/sklearn/neighbors/_nth_element.pyx +++ b/sklearn/neighbors/_nth_element.pyx @@ -1,11 +1,48 @@ -cdef extern from "_nth_element_inner.h": +# distutils : language = c++ + +cdef extern from *: + """ + #include + + template + class IndexComparator { + private: + const D *data; + I split_dim, n_features; + public: + IndexComparator(const D *data, const I &split_dim, const I &n_features): + data(data), split_dim(split_dim), n_features(n_features) {} + + bool operator()(const I &a, const I &b) const { + D a_value = data[a * n_features + split_dim]; + D b_value = data[b * n_features + split_dim]; + return a_value == b_value ? a < b : a_value < b_value; + } + }; + + template + void partition_node_indices_inner( + const D *data, + I *node_indices, + const I &split_dim, + const I &split_index, + const I &n_features, + const I &n_points) { + IndexComparator index_comparator(data, split_dim, n_features); + std::nth_element( + node_indices, + node_indices + split_index, + node_indices + n_points, + index_comparator); + } + """ void partition_node_indices_inner[D, I]( - D *data, - I *node_indices, - I split_dim, - I split_index, - I n_features, - I n_points) except + + D *data, + I *node_indices, + I split_dim, + I split_index, + I n_features, + I n_points) except + cdef int partition_node_indices( diff --git a/sklearn/neighbors/_nth_element_inner.h b/sklearn/neighbors/_nth_element_inner.h deleted file mode 100644 index 816addcf6b3db..0000000000000 --- a/sklearn/neighbors/_nth_element_inner.h +++ /dev/null @@ -1,33 +0,0 @@ -#include - -template -class IndexComparator { -private: - const D *data; - I split_dim, n_features; -public: - IndexComparator(const D *data, const I &split_dim, const I &n_features): - data(data), split_dim(split_dim), n_features(n_features) {} - - bool operator()(const I &a, const I &b) const { - D a_value = data[a * n_features + split_dim]; - D b_value = data[b * n_features + split_dim]; - return a_value == b_value ? a < b : a_value < b_value; - } -}; - -template -void partition_node_indices_inner( - const D *data, - I *node_indices, - const I &split_dim, - const I &split_index, - const I &n_features, - const I &n_points) { - IndexComparator index_comparator(data, split_dim, n_features); - std::nth_element( - node_indices, - node_indices + split_index, - node_indices + n_points, - index_comparator); -} From 711ea353d06ba201fc1ea9f5f05362407861bc3f Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Wed, 17 Feb 2021 00:02:41 +0100 Subject: [PATCH 06/14] Remove script for benchmark The latest asv benchmark script is available online here: https://gist.github.com/jjerphan/6c1f3e4c80908b862ccce6682835d36a --- benchmarks/bench_binary_tree_partition.py | 23 ----------------------- 1 file changed, 23 deletions(-) delete mode 100644 benchmarks/bench_binary_tree_partition.py diff --git a/benchmarks/bench_binary_tree_partition.py b/benchmarks/bench_binary_tree_partition.py deleted file mode 100644 index b10e9b65c5657..0000000000000 --- a/benchmarks/bench_binary_tree_partition.py +++ /dev/null @@ -1,23 +0,0 @@ -from sklearn.neighbors import KDTree -import numpy as np -from time import time - - -def main(): - test_cases = [ - ("random", np.random.rand(50000)), - ("ordered", np.arange(50000, dtype=float)), - ("reverse ordered", np.arange(50000, 0, -1, dtype=float)), - ("duplicated", np.zeros([50000], dtype=float)) - ] - for name, case in test_cases: - expanded_case = np.expand_dims(case, -1) - begin = time() - tree = KDTree(expanded_case, leaf_size=1) - end = time() - del tree - print("{name}: {time}s".format(name=name, time=end - begin)) - - -if __name__ == "__main__": - main() From e301099889716ad92e88c067009f5d6ac266c400 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Wed, 3 Mar 2021 21:28:42 +0100 Subject: [PATCH 07/14] Add whats_new entry for #19473 --- doc/whats_new/v1.0.rst | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/doc/whats_new/v1.0.rst b/doc/whats_new/v1.0.rst index 024fefe3fd825..b3af8b21c14a6 100644 --- a/doc/whats_new/v1.0.rst +++ b/doc/whats_new/v1.0.rst @@ -164,6 +164,15 @@ Changelog Use ``var_`` instead. :pr:`18842` by :user:`Hong Shao Yang `. +:mod:`sklearn.neighbors` +.......................... + +- |Enhancement| Improve :class:`neighbors.KDTree` worst-case time complexity + from :math:`\mathcal{O}(n^2)` to :math:`\mathcal{O}(n \log n)` on sorted and + on duplicate inputs. + :pr:`19473` by :user:`jiefangxuanyan ` and + :user:`Julien Jerphanion `. + :mod:`sklearn.preprocessing` ............................ From d69f06de32f2b4731d330a648b2ef7d0b28b2291 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Thu, 4 Mar 2021 07:57:13 +0100 Subject: [PATCH 08/14] fixup! Add whats_new entry for #19473 --- doc/whats_new/v1.0.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/whats_new/v1.0.rst b/doc/whats_new/v1.0.rst index b3af8b21c14a6..fb4ec4350740a 100644 --- a/doc/whats_new/v1.0.rst +++ b/doc/whats_new/v1.0.rst @@ -168,7 +168,7 @@ Changelog .......................... - |Enhancement| Improve :class:`neighbors.KDTree` worst-case time complexity - from :math:`\mathcal{O}(n^2)` to :math:`\mathcal{O}(n \log n)` on sorted and + from :math:`\mathcal{O}(n^2)` to :math:`\mathcal{O}(n)` on sorted and on duplicate inputs. :pr:`19473` by :user:`jiefangxuanyan ` and :user:`Julien Jerphanion `. From 788e7278c13904c7afd7a6d567e0cb3444e139a1 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Wed, 7 Apr 2021 13:23:23 +0200 Subject: [PATCH 09/14] fixup! Use std::nth_element from C++ standard library. Reintroduce the docstring of partition_node_indices. Co-authored-by: "Thomas J. Fan" --- sklearn/neighbors/_nth_element.pyx | 37 ++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/sklearn/neighbors/_nth_element.pyx b/sklearn/neighbors/_nth_element.pyx index 06f331c44089a..43a1cdb8c3f9c 100644 --- a/sklearn/neighbors/_nth_element.pyx +++ b/sklearn/neighbors/_nth_element.pyx @@ -52,6 +52,43 @@ cdef int partition_node_indices( ITYPE_t split_index, ITYPE_t n_features, ITYPE_t n_points) except -1: + """Partition points in the node into two equal-sized groups. + + Upon return, the values in node_indices will be rearranged such that + (assuming numpy-style indexing): + + data[node_indices[0:split_index], split_dim] + <= data[node_indices[split_index], split_dim] + + and + + data[node_indices[split_index], split_dim] + <= data[node_indices[split_index:n_points], split_dim] + + The algorithm is essentially a partial in-place quicksort around a + set pivot. + + Parameters + ---------- + data : double pointer + Pointer to a 2D array of the training data, of shape [N, n_features]. + N must be greater than any of the values in node_indices. + node_indices : int pointer + Pointer to a 1D array of length n_points. This lists the indices of + each of the points within the current node. This will be modified + in-place. + split_dim : int + the dimension on which to split. This will usually be computed via + the routine ``find_node_split_dim`` + split_index : int + the index within node_indices around which to split the points. + + Returns + ------- + status : int + integer exit status. On return, the contents of node_indices are + modified as noted above. + """ partition_node_indices_inner( data, node_indices, From 9d62ce7d3abad08819e3190438e74748f91c4f57 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Wed, 7 Apr 2021 13:53:21 +0200 Subject: [PATCH 10/14] Use a clearer name for KDTrees' nodes' paritioning submodule Co-authored-by: "Thomas J. Fan" --- sklearn/neighbors/{_nth_element.pxd => _partition_nodes.pxd} | 0 sklearn/neighbors/{_nth_element.pyx => _partition_nodes.pyx} | 0 sklearn/neighbors/setup.py | 4 ++-- 3 files changed, 2 insertions(+), 2 deletions(-) rename sklearn/neighbors/{_nth_element.pxd => _partition_nodes.pxd} (100%) rename sklearn/neighbors/{_nth_element.pyx => _partition_nodes.pyx} (100%) diff --git a/sklearn/neighbors/_nth_element.pxd b/sklearn/neighbors/_partition_nodes.pxd similarity index 100% rename from sklearn/neighbors/_nth_element.pxd rename to sklearn/neighbors/_partition_nodes.pxd diff --git a/sklearn/neighbors/_nth_element.pyx b/sklearn/neighbors/_partition_nodes.pyx similarity index 100% rename from sklearn/neighbors/_nth_element.pyx rename to sklearn/neighbors/_partition_nodes.pyx diff --git a/sklearn/neighbors/setup.py b/sklearn/neighbors/setup.py index fac1d63a75eb1..996b855d2d45a 100644 --- a/sklearn/neighbors/setup.py +++ b/sklearn/neighbors/setup.py @@ -20,8 +20,8 @@ def configuration(parent_package='', top_path=None): include_dirs=[numpy.get_include()], libraries=libraries) - config.add_extension('_nth_element', - sources=['_nth_element.pyx'], + config.add_extension('_partition_nodes', + sources=['_partition_nodes.pyx'], include_dirs=[numpy.get_include()], language="c++", libraries=libraries) From d93e5b67cfc4bdfbecdbd233ed8014313d38845c Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Wed, 7 Apr 2021 13:56:44 +0200 Subject: [PATCH 11/14] Add comment motivating the use of C++ --- sklearn/neighbors/_partition_nodes.pyx | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/sklearn/neighbors/_partition_nodes.pyx b/sklearn/neighbors/_partition_nodes.pyx index 43a1cdb8c3f9c..771343b9d1586 100644 --- a/sklearn/neighbors/_partition_nodes.pyx +++ b/sklearn/neighbors/_partition_nodes.pyx @@ -1,5 +1,23 @@ # distutils : language = c++ +# KDTrees rely on partial sorts to partition their nodes. +# +# The C++ std library exposes nth_element, an efficient partial sort for this +# situation which has a linear time complexity as well as the best performances. +# +# To use std::algorithm::nth_element, a few fixture are defined using Cython: +# - partition_node_indices, a cython function used within KDTree, that calls +# - partition_node_indices_inner, a C++ function that wraps nth_element and uses +# - an IndexComparator to state how to compare KDTrees' indices +# +# IndexComparator has been defined so that partial sorts are stable with +# respect to the nodes initial indices. +# +# See for reference: +# - https://en.cppreference.com/w/cpp/algorithm/nth_element. +# - https://github.com/scikit-learn/scikit-learn/pull/11103 +# - https://github.com/scikit-learn/scikit-learn/pull/19473 + cdef extern from *: """ #include From 809aa8be377e8123d34aa705173884b023ce8b17 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Wed, 7 Apr 2021 14:12:21 +0200 Subject: [PATCH 12/14] fixup! Use a clearer name for KDTrees' nodes' paritioning submodule --- sklearn/neighbors/_binary_tree.pxi | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/neighbors/_binary_tree.pxi b/sklearn/neighbors/_binary_tree.pxi index 3d9be84f38a40..cabad951c4975 100755 --- a/sklearn/neighbors/_binary_tree.pxi +++ b/sklearn/neighbors/_binary_tree.pxi @@ -159,7 +159,7 @@ from ._typedefs import DTYPE, ITYPE from ._dist_metrics cimport (DistanceMetric, euclidean_dist, euclidean_rdist, euclidean_dist_to_rdist, euclidean_rdist_to_dist) -from ._nth_element cimport partition_node_indices +from ._partition_nodes cimport partition_node_indices cdef extern from "numpy/arrayobject.h": void PyArray_ENABLEFLAGS(np.ndarray arr, int flags) From 4e3a217ac93cdcef5b46f01bd1b26e651a99e5f4 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Wed, 7 Apr 2021 15:17:50 +0200 Subject: [PATCH 13/14] Adapt the documentation to mention BallTree As the changes relate to both BinaryTree's subclasses. --- doc/whats_new/v1.0.rst | 6 +++--- sklearn/neighbors/_partition_nodes.pyx | 5 +++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/doc/whats_new/v1.0.rst b/doc/whats_new/v1.0.rst index eadf0eee024bb..ce7da3139d140 100644 --- a/doc/whats_new/v1.0.rst +++ b/doc/whats_new/v1.0.rst @@ -282,9 +282,9 @@ Changelog :mod:`sklearn.neighbors` .......................... -- |Enhancement| Improve :class:`neighbors.KDTree` worst-case time complexity - from :math:`\mathcal{O}(n^2)` to :math:`\mathcal{O}(n)` on sorted and - on duplicate inputs. +- |Enhancement| The creation of :class:`neighbors.KDTree` and + :class:`neighbors.BallTree` has been improved for their worst-cases time + complexity from :math:`\mathcal{O}(n^2)` to :math:`\mathcal{O}(n)`. :pr:`19473` by :user:`jiefangxuanyan ` and :user:`Julien Jerphanion `. diff --git a/sklearn/neighbors/_partition_nodes.pyx b/sklearn/neighbors/_partition_nodes.pyx index 771343b9d1586..0073db81372db 100644 --- a/sklearn/neighbors/_partition_nodes.pyx +++ b/sklearn/neighbors/_partition_nodes.pyx @@ -1,12 +1,13 @@ # distutils : language = c++ -# KDTrees rely on partial sorts to partition their nodes. +# BinaryTrees rely on partial sorts to partition their nodes during their +# initialisation. # # The C++ std library exposes nth_element, an efficient partial sort for this # situation which has a linear time complexity as well as the best performances. # # To use std::algorithm::nth_element, a few fixture are defined using Cython: -# - partition_node_indices, a cython function used within KDTree, that calls +# - partition_node_indices, a Cython function used in BinaryTrees, that calls # - partition_node_indices_inner, a C++ function that wraps nth_element and uses # - an IndexComparator to state how to compare KDTrees' indices # From e8d4488abebf4ae57093870381b4c9d4e18d0de2 Mon Sep 17 00:00:00 2001 From: Julien Jerphanion Date: Thu, 8 Apr 2021 09:29:35 +0200 Subject: [PATCH 14/14] Complete partition_node_indices' docstring --- sklearn/neighbors/_partition_nodes.pyx | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/sklearn/neighbors/_partition_nodes.pyx b/sklearn/neighbors/_partition_nodes.pyx index 0073db81372db..508e9560ae8c2 100644 --- a/sklearn/neighbors/_partition_nodes.pyx +++ b/sklearn/neighbors/_partition_nodes.pyx @@ -98,10 +98,14 @@ cdef int partition_node_indices( in-place. split_dim : int the dimension on which to split. This will usually be computed via - the routine ``find_node_split_dim`` + the routine ``find_node_split_dim``. split_index : int the index within node_indices around which to split the points. - + n_features: int + the number of features (i.e columns) in the 2D array pointed by data. + n_points : int + the length of node_indices. This is also the number of points in + the original dataset. Returns ------- status : int