Skip to content

Commit b969bf6

Browse files
committed
issue #4931
1 parent 46ede61 commit b969bf6

File tree

3 files changed

+25
-8
lines changed

3 files changed

+25
-8
lines changed

doc/modules/neighbors.rst

+9-7
Original file line numberDiff line numberDiff line change
@@ -419,12 +419,14 @@ depends on a number of factors:
419419
a significant fraction of the total cost. If very few query points
420420
will be required, brute force is better than a tree-based method.
421421

422-
Currently, ``algorithm = 'auto'`` selects ``'kd_tree'`` if :math:`k < N/2`
423-
and the ``'effective_metric_'`` is in the ``'VALID_METRICS'`` list of
424-
``'kd_tree'``. It selects ``'ball_tree'`` if :math:`k < N/2` and the
425-
``'effective_metric_'`` is not in the ``'VALID_METRICS'`` list of
426-
``'kd_tree'``. It selects ``'brute'`` if :math:`k >= N/2`. This choice is based on the assumption that the number of query points is at least the
427-
same order as the number of training points, and that ``leaf_size`` is
422+
Currently, ``algorithm = 'auto'`` selects ``'kd_tree'`` if :math:`k < N/2`
423+
and the ``'effective_metric_'`` is in the ``'VALID_METRICS'`` list of
424+
``'kd_tree'``. It selects ``'ball_tree'`` if :math:`k < N/2` and the
425+
``'effective_metric_'`` is in the ``'VALID_METRICS'`` list of
426+
``'ball_tree'``. It selects ``'brute'`` if :math:`k < N/2` and the
427+
``'effective_metric_'`` is not in the ``'VALID_METRICS'`` list of
428+
``'kd_tree'`` or ``'ball_tree'``. It selects ``'brute'`` if :math:`k >= N/2`. This choice is based on the assumption that the number of query points is at least the
429+
same order as the number of training points, and that ``leaf_size`` is
428430
close to its default value of ``30``.
429431

430432
Effect of ``leaf_size``
@@ -666,7 +668,7 @@ the :math:`m` nearest neighbors of a point :math:`q`. First, a top-down
666668
traversal is performed using a binary search to identify the leaf having the
667669
longest prefix match (maximum depth) with :math:`q`'s label after subjecting
668670
:math:`q` to the same hash functions. :math:`M >> m` points (total candidates)
669-
are extracted from the forest, moving up from the previously found maximum
671+
are extracted from the forest, moving up from the previously found maximum
670672
depth towards the root synchronously across all trees in the bottom-up
671673
traversal. `M` is set to :math:`cl` where :math:`c`, the number of candidates
672674
extracted from each tree, is a constant. Finally, the similarity of each of

sklearn/neighbors/base.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,8 @@ def _init_params(self, n_neighbors=None, radius=None,
145145
alg_check = 'brute'
146146
else:
147147
alg_check = 'ball_tree'
148+
if metric not in VALID_METRICS[alg_check]:
149+
alg_check = 'brute'
148150
else:
149151
alg_check = algorithm
150152

@@ -246,8 +248,10 @@ def _fit(self, X):
246248
self.metric != 'precomputed'):
247249
if self.effective_metric_ in VALID_METRICS['kd_tree']:
248250
self._fit_method = 'kd_tree'
249-
else:
251+
elif self.effective_metric_ in VALID_METRICS['ball_tree']:
250252
self._fit_method = 'ball_tree'
253+
else:
254+
self._fit_method = 'brute'
251255
else:
252256
self._fit_method = 'brute'
253257

sklearn/neighbors/tests/test_neighbors.py

+11
Original file line numberDiff line numberDiff line change
@@ -975,6 +975,17 @@ def test_callable_metric():
975975
assert_array_almost_equal(dist1, dist2)
976976

977977

978+
def test_algo_auto_metrics():
979+
X = rng.rand(12, 3)
980+
Xcsr = csr_matrix(X)
981+
nn = neighbors.NearestNeighbors(n_neighbors=3, algorithm='auto',
982+
metric='cosine').fit(X)
983+
assert_true(nn._fit_method, 'brute')
984+
nn = neighbors.NearestNeighbors(n_neighbors=3, algorithm='auto',
985+
metric='cosine').fit(Xcsr)
986+
assert_true(nn._fit_method, 'brute')
987+
988+
978989
def test_metric_params_interface():
979990
assert_warns(DeprecationWarning, neighbors.KNeighborsClassifier,
980991
metric='wminkowski', w=np.ones(10))

0 commit comments

Comments
 (0)