-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
[MRG+2] algorithm='auto' should always work for nearest neighbors (continuation) #9145
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
if metric not in _metrics: | ||
nn = neighbors.NearestNeighbors(n_neighbors=3, algorithm='auto', | ||
metric=metric).fit(Xcsr) | ||
if metric != "precomputed": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is that normal that I can't call kneighbors
method on sparse matrix when metric='precomputed'
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That case is not implemented: it would need to be treated a bit differently to the precomputed dense array case; and it doesn't really make sense. An array containing distances which is mostly 0s is pretty useless for nearest neighbors. (Unless we're admitting negative distances, or having a different interpretation of 0s, both of which may be useful.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But I don't understand why precomputed should be in VALID_METRICS_SPARSE
and hence why it's relevant tot his code path.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
otherwise LGTM
if metric not in _metrics: | ||
nn = neighbors.NearestNeighbors(n_neighbors=3, algorithm='auto', | ||
metric=metric).fit(Xcsr) | ||
if metric != "precomputed": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That case is not implemented: it would need to be treated a bit differently to the precomputed dense array case; and it doesn't really make sense. An array containing distances which is mostly 0s is pretty useless for nearest neighbors. (Unless we're admitting negative distances, or having a different interpretation of 0s, both of which may be useful.)
if metric not in _metrics: | ||
nn = neighbors.NearestNeighbors(n_neighbors=3, algorithm='auto', | ||
metric=metric).fit(Xcsr) | ||
if metric != "precomputed": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But I don't understand why precomputed should be in VALID_METRICS_SPARSE
and hence why it's relevant tot his code path.
@@ -988,6 +990,38 @@ def custom_metric(x1, x2): | |||
assert_array_almost_equal(dist1, dist2) | |||
|
|||
|
|||
def test_algo_auto_metrics(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a better name or comment would be appreciated
doc/modules/neighbors.rst
Outdated
``'effective_metric_'`` is in the ``'VALID_METRICS'`` list of | ||
``'ball_tree'``. It selects ``'brute'`` if :math:`k < N/2` and the | ||
``'effective_metric_'`` is not in the ``'VALID_METRICS'`` list of | ||
``'kd_tree'`` or ``'ball_tree'``. It selects ``'brute'`` if :math:`k >= N/2`. This choice is based on the assumption that the number of query points is at least the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you please keep this under 80 chars
Thanks for the review. The change has been done! |
@@ -988,6 +990,46 @@ def custom_metric(x1, x2): | |||
assert_array_almost_equal(dist1, dist2) | |||
|
|||
|
|||
def test_unsupported_metric_for_auto(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know what you mean by unsupported. It's clearly not unsupported if you're testing it works.
X = rng.rand(12, 12) | ||
Xcsr = csr_matrix(X) | ||
|
||
# Metric which don't required any additional parameter |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I find the comment a bit misleading. I would rename the variable and say only test metrics that don't require additional arguments
.
And maybe assert that some strange metric is in VALID_METRICS['brute']
.
Checking that the test is non-empty, and more didactic variable name
LGTM, I made minor minor changes to fix my nitpicks, merge on green. |
Thank you! |
…ntinuation) (scikit-learn#9145) * Merge neighbors.rst * issue scikit-learn#4931 * Improve test implementation * Update base.py * Remove unused import * Customize test for precomputed metric * Change test function name * rename _metrics to require_params, add set assert Checking that the test is non-empty, and more didactic variable name * Remove blank line
…ntinuation) (scikit-learn#9145) * Merge neighbors.rst * issue scikit-learn#4931 * Improve test implementation * Update base.py * Remove unused import * Customize test for precomputed metric * Change test function name * rename _metrics to require_params, add set assert Checking that the test is non-empty, and more didactic variable name * Remove blank line
…ntinuation) (scikit-learn#9145) * Merge neighbors.rst * issue scikit-learn#4931 * Improve test implementation * Update base.py * Remove unused import * Customize test for precomputed metric * Change test function name * rename _metrics to require_params, add set assert Checking that the test is non-empty, and more didactic variable name * Remove blank line
…ntinuation) (scikit-learn#9145) * Merge neighbors.rst * issue scikit-learn#4931 * Improve test implementation * Update base.py * Remove unused import * Customize test for precomputed metric * Change test function name * rename _metrics to require_params, add set assert Checking that the test is non-empty, and more didactic variable name * Remove blank line
…ntinuation) (scikit-learn#9145) * Merge neighbors.rst * issue scikit-learn#4931 * Improve test implementation * Update base.py * Remove unused import * Customize test for precomputed metric * Change test function name * rename _metrics to require_params, add set assert Checking that the test is non-empty, and more didactic variable name * Remove blank line
…ntinuation) (scikit-learn#9145) * Merge neighbors.rst * issue scikit-learn#4931 * Improve test implementation * Update base.py * Remove unused import * Customize test for precomputed metric * Change test function name * rename _metrics to require_params, add set assert Checking that the test is non-empty, and more didactic variable name * Remove blank line
…ntinuation) (scikit-learn#9145) * Merge neighbors.rst * issue scikit-learn#4931 * Improve test implementation * Update base.py * Remove unused import * Customize test for precomputed metric * Change test function name * rename _metrics to require_params, add set assert Checking that the test is non-empty, and more didactic variable name * Remove blank line
Reference Issue
Fixes #4931, continuation of #5596
What does this implement/fix? Explain your changes.
Implement test for metric in
['mahalanobis', 'wminkowski', 'seuclidean']
Any other comments?
Should we warn the user when the algorithm is set into
brute
? (instead ofauto
)