Skip to content

ENH Improve the creation of KDTree and BallTree on their worst-case time complexity #19473

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Apr 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions doc/whats_new/v1.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,16 @@ Changelog
Use ``var_`` instead.
:pr:`18842` by :user:`Hong Shao Yang <hongshaoyang>`.


:mod:`sklearn.neighbors`
..........................

- |Enhancement| The creation of :class:`neighbors.KDTree` and
:class:`neighbors.BallTree` has been improved for their worst-cases time
complexity from :math:`\mathcal{O}(n^2)` to :math:`\mathcal{O}(n)`.
:pr:`19473` by :user:`jiefangxuanyan <jiefangxuanyan>` and
:user:`Julien Jerphanion <jjerphan>`.

:mod:`sklearn.pipeline`
.......................

Expand Down
69 changes: 2 additions & 67 deletions sklearn/neighbors/_binary_tree.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,8 @@ from ._typedefs import DTYPE, ITYPE
from ._dist_metrics cimport (DistanceMetric, euclidean_dist, euclidean_rdist,
euclidean_dist_to_rdist, euclidean_rdist_to_dist)

from ._partition_nodes cimport partition_node_indices

cdef extern from "numpy/arrayobject.h":
void PyArray_ENABLEFLAGS(np.ndarray arr, int flags)

Expand Down Expand Up @@ -776,73 +778,6 @@ cdef ITYPE_t find_node_split_dim(DTYPE_t* data,
return j_max


cdef int partition_node_indices(DTYPE_t* data,
ITYPE_t* node_indices,
ITYPE_t split_dim,
ITYPE_t split_index,
ITYPE_t n_features,
ITYPE_t n_points) except -1:
"""Partition points in the node into two equal-sized groups.

Upon return, the values in node_indices will be rearranged such that
(assuming numpy-style indexing):

data[node_indices[0:split_index], split_dim]
<= data[node_indices[split_index], split_dim]

and

data[node_indices[split_index], split_dim]
<= data[node_indices[split_index:n_points], split_dim]

The algorithm is essentially a partial in-place quicksort around a
set pivot.

Parameters
----------
data : double pointer
Pointer to a 2D array of the training data, of shape [N, n_features].
N must be greater than any of the values in node_indices.
node_indices : int pointer
Pointer to a 1D array of length n_points. This lists the indices of
each of the points within the current node. This will be modified
in-place.
split_dim : int
the dimension on which to split. This will usually be computed via
the routine ``find_node_split_dim``
split_index : int
the index within node_indices around which to split the points.

Returns
-------
status : int
integer exit status. On return, the contents of node_indices are
modified as noted above.
"""
cdef ITYPE_t left, right, midindex, i
cdef DTYPE_t d1, d2
left = 0
right = n_points - 1

while True:
midindex = left
for i in range(left, right):
d1 = data[node_indices[i] * n_features + split_dim]
d2 = data[node_indices[right] * n_features + split_dim]
if d1 < d2:
swap(node_indices, i, midindex)
midindex += 1
swap(node_indices, midindex, right)
if midindex == split_index:
break
elif midindex < split_index:
left = midindex + 1
else:
right = midindex - 1

return 0


######################################################################
# NodeHeap : min-heap used to keep track of nodes during
# breadth-first query
Expand Down
9 changes: 9 additions & 0 deletions sklearn/neighbors/_partition_nodes.pxd
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from ._typedefs cimport DTYPE_t, ITYPE_t

cdef int partition_node_indices(
DTYPE_t *data,
ITYPE_t *node_indices,
ITYPE_t split_dim,
ITYPE_t split_index,
ITYPE_t n_features,
ITYPE_t n_points) except -1
122 changes: 122 additions & 0 deletions sklearn/neighbors/_partition_nodes.pyx
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# distutils : language = c++

# BinaryTrees rely on partial sorts to partition their nodes during their
# initialisation.
#
# The C++ std library exposes nth_element, an efficient partial sort for this
# situation which has a linear time complexity as well as the best performances.
#
# To use std::algorithm::nth_element, a few fixture are defined using Cython:
# - partition_node_indices, a Cython function used in BinaryTrees, that calls
# - partition_node_indices_inner, a C++ function that wraps nth_element and uses
# - an IndexComparator to state how to compare KDTrees' indices
#
# IndexComparator has been defined so that partial sorts are stable with
# respect to the nodes initial indices.
#
# See for reference:
# - https://en.cppreference.com/w/cpp/algorithm/nth_element.
# - https://github.com/scikit-learn/scikit-learn/pull/11103
# - https://github.com/scikit-learn/scikit-learn/pull/19473

cdef extern from *:
"""
#include <algorithm>

template<class D, class I>
class IndexComparator {
private:
const D *data;
I split_dim, n_features;
public:
IndexComparator(const D *data, const I &split_dim, const I &n_features):
data(data), split_dim(split_dim), n_features(n_features) {}

bool operator()(const I &a, const I &b) const {
D a_value = data[a * n_features + split_dim];
D b_value = data[b * n_features + split_dim];
return a_value == b_value ? a < b : a_value < b_value;
}
};

template<class D, class I>
void partition_node_indices_inner(
const D *data,
I *node_indices,
const I &split_dim,
const I &split_index,
const I &n_features,
const I &n_points) {
IndexComparator<D, I> index_comparator(data, split_dim, n_features);
std::nth_element(
node_indices,
node_indices + split_index,
node_indices + n_points,
index_comparator);
}
"""
void partition_node_indices_inner[D, I](
D *data,
I *node_indices,
I split_dim,
I split_index,
I n_features,
I n_points) except +


cdef int partition_node_indices(
DTYPE_t *data,
ITYPE_t *node_indices,
ITYPE_t split_dim,
ITYPE_t split_index,
ITYPE_t n_features,
ITYPE_t n_points) except -1:
"""Partition points in the node into two equal-sized groups.

Upon return, the values in node_indices will be rearranged such that
(assuming numpy-style indexing):

data[node_indices[0:split_index], split_dim]
<= data[node_indices[split_index], split_dim]

and

data[node_indices[split_index], split_dim]
<= data[node_indices[split_index:n_points], split_dim]

The algorithm is essentially a partial in-place quicksort around a
set pivot.

Parameters
----------
data : double pointer
Pointer to a 2D array of the training data, of shape [N, n_features].
N must be greater than any of the values in node_indices.
node_indices : int pointer
Pointer to a 1D array of length n_points. This lists the indices of
each of the points within the current node. This will be modified
in-place.
split_dim : int
the dimension on which to split. This will usually be computed via
the routine ``find_node_split_dim``.
split_index : int
the index within node_indices around which to split the points.
n_features: int
the number of features (i.e columns) in the 2D array pointed by data.
n_points : int
the length of node_indices. This is also the number of points in
the original dataset.
Returns
-------
status : int
integer exit status. On return, the contents of node_indices are
modified as noted above.
"""
partition_node_indices_inner(
data,
node_indices,
split_dim,
split_index,
n_features,
n_points)
return 0
6 changes: 6 additions & 0 deletions sklearn/neighbors/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@ def configuration(parent_package='', top_path=None):
include_dirs=[numpy.get_include()],
libraries=libraries)

config.add_extension('_partition_nodes',
sources=['_partition_nodes.pyx'],
include_dirs=[numpy.get_include()],
language="c++",
libraries=libraries)

config.add_extension('_dist_metrics',
sources=['_dist_metrics.pyx'],
include_dirs=[numpy.get_include(),
Expand Down