Skip to content

[WIP] algorithm='auto' should always work for nearest neighbors #7669

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions doc/modules/neighbors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -419,12 +419,14 @@ depends on a number of factors:
a significant fraction of the total cost. If very few query points
will be required, brute force is better than a tree-based method.

Currently, ``algorithm = 'auto'`` selects ``'kd_tree'`` if :math:`k < N/2`
and the ``'effective_metric_'`` is in the ``'VALID_METRICS'`` list of
``'kd_tree'``. It selects ``'ball_tree'`` if :math:`k < N/2` and the
``'effective_metric_'`` is not in the ``'VALID_METRICS'`` list of
``'kd_tree'``. It selects ``'brute'`` if :math:`k >= N/2`. This choice is based on the assumption that the number of query points is at least the
same order as the number of training points, and that ``leaf_size`` is
Currently, ``algorithm = 'auto'`` selects ``'kd_tree'`` if :math:`k < N/2`
and the ``'effective_metric_'`` is in the ``'VALID_METRICS'`` list of
``'kd_tree'``. It selects ``'ball_tree'`` if :math:`k < N/2` and the
``'effective_metric_'`` is in the ``'VALID_METRICS'`` list of
``'ball_tree'``. It selects ``'brute'`` if :math:`k < N/2` and the
``'effective_metric_'`` is not in the ``'VALID_METRICS'`` list of
``'kd_tree'`` or ``'ball_tree'``. It selects ``'brute'`` if :math:`k >= N/2`. This choice is based on the assumption that the number of query points is at least the
same order as the number of training points, and that ``leaf_size`` is
close to its default value of ``30``.

Effect of ``leaf_size``
Expand Down Expand Up @@ -666,7 +668,7 @@ the :math:`m` nearest neighbors of a point :math:`q`. First, a top-down
traversal is performed using a binary search to identify the leaf having the
longest prefix match (maximum depth) with :math:`q`'s label after subjecting
:math:`q` to the same hash functions. :math:`M >> m` points (total candidates)
are extracted from the forest, moving up from the previously found maximum
are extracted from the forest, moving up from the previously found maximum
depth towards the root synchronously across all trees in the bottom-up
traversal. `M` is set to :math:`cl` where :math:`c`, the number of candidates
extracted from each tree, is a constant. Finally, the similarity of each of
Expand Down
9 changes: 7 additions & 2 deletions sklearn/neighbors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,10 @@ def _init_params(self, n_neighbors=None, radius=None,
if algorithm == 'auto':
if metric == 'precomputed':
alg_check = 'brute'
else:
elif callable(metric) or metric in VALID_METRICS['ball_tree']:
alg_check = 'ball_tree'
else:
alg_check = 'brute'
else:
alg_check = algorithm

Expand Down Expand Up @@ -229,8 +231,11 @@ def _fit(self, X):
self.metric != 'precomputed'):
if self.effective_metric_ in VALID_METRICS['kd_tree']:
self._fit_method = 'kd_tree'
else:
elif (callable(self.effective_metric_) or
self.effective_metric_ in VALID_METRICS['ball_tree']):
self._fit_method = 'ball_tree'
else:
self._fit_method = 'brute'
else:
self._fit_method = 'brute'

Expand Down
24 changes: 22 additions & 2 deletions sklearn/neighbors/tests/test_neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from sklearn.utils.testing import assert_greater
from sklearn.utils.validation import check_random_state
from sklearn.metrics.pairwise import pairwise_distances
from sklearn.neighbors.base import (VALID_METRICS,
VALID_METRICS_SPARSE)
from sklearn import neighbors, datasets
from sklearn.exceptions import DataConversionWarning

Expand Down Expand Up @@ -149,7 +151,7 @@ def test_precomputed(random_state=42):
neighbors.RadiusNeighborsClassifier,
neighbors.KNeighborsRegressor,
neighbors.RadiusNeighborsRegressor):
print(Est)

est = Est(metric='euclidean')
est.radius = est.n_neighbors = 1
pred_X = est.fit(X, target).predict(Y)
Expand Down Expand Up @@ -984,10 +986,28 @@ def custom_metric(x1, x2):

dist1, ind1 = nbrs1.kneighbors(X)
dist2, ind2 = nbrs2.kneighbors(X)


assert_true(nbrs1._fit_method, 'ball_tree')
assert_true(nbrs2._fit_method, 'ball_tree')

assert_array_almost_equal(dist1, dist2)


def test_algo_auto_metrics():
X = rng.rand(12, 3)
Xcsr = csr_matrix(X)

for metric in VALID_METRICS['brute']:
nn = neighbors.NearestNeighbors(n_neighbors=3, algorithm='auto',
metric=metric).fit(X)
assert_true(nn._fit_method, 'brute')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't it be enough to make sure no error is raised? I don't understand this test. The other metrics might also support this metric.


for metric in VALID_METRICS_SPARSE['brute']:
nn = neighbors.NearestNeighbors(n_neighbors=3, algorithm='auto',
metric=metric).fit(Xcsr)
assert_true(nn._fit_method, 'brute')


def test_metric_params_interface():
assert_warns(SyntaxWarning, neighbors.KNeighborsClassifier,
metric_params={'p': 3})
Expand Down