Closed
Description
Description
Fitting a NearestNeighbors
model fails when a) the distance metric
used is a callable
and b) the input to the NearestNeighbors
model is sparse.
Steps/Code to Reproduce
from scipy import sparse
from sklearn.neighbors import NearestNeighbors
def sparse_metric(x, y): # Some metric accepting sparse input
return x.count_nonzero() / y.count_nonzero()
A = sparse.random(10, 5, density=0.3, format='csr')
nn = NearestNeighbors(algorithm='brute', metric=sparse_metric).fit(A)
Expected Results
No error is thrown when passing a callable as metric with sparse input
Actual Results
ValueError Traceback (most recent call last)
<ipython-input-2-a9d2fd7f843b> in <module>()
7 A = sparse.random(10, 5, density=0.3, format='csr')
8
----> 9 nn = NearestNeighbors(algorithm='brute', metric=sparse_metric).fit(A)
/Volumes/LocalDataHD/thk22/.virtualenvs/nlpy3/lib/python3.5/site-packages/sklearn/neighbors/base.py in fit(self, X, y)
797 or [n_samples, n_samples] if metric='precomputed'.
798 """
--> 799 return self._fit(X)
/Volumes/LocalDataHD/thk22/.virtualenvs/nlpy3/lib/python3.5/site-packages/sklearn/neighbors/base.py in _fit(self, X)
213 if self.effective_metric_ not in VALID_METRICS_SPARSE['brute']:
214 raise ValueError("metric '%s' not valid for sparse input"
--> 215 % self.effective_metric_)
216 self._fit_X = X.copy()
217 self._tree = None
ValueError: metric '<function sparse_metric at 0x1097d0378>' not valid for sparse input
Some Analysis/Wild Speculation
The problem seems to come from the fact that in the case of sparse input, it is only checked whether the given metric is in the list of metrics accepting sparse input, but no check is made whether the given metric is a string or a callable: https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/neighbors/base.py#L210
Versions
Darwin-15.6.0-x86_64-i386-64bit
Python 3.5.1 (default, Dec 8 2015, 06:00:08)
[GCC 4.2.1 Compatible Apple LLVM 6.0 (clang-600.0.57)]
NumPy 1.12.1
SciPy 0.19.0
Scikit-Learn 0.18.2
Metadata
Metadata
Assignees
Labels
No labels