Skip to content

Fitting a NearestNeighbors model fails with sparse input and a callable as metric #9199

Closed
@tttthomasssss

Description

@tttthomasssss

Description

Fitting a NearestNeighbors model fails when a) the distance metric used is a callable and b) the input to the NearestNeighbors model is sparse.

Steps/Code to Reproduce

from scipy import sparse
from sklearn.neighbors import NearestNeighbors

def sparse_metric(x, y): # Some metric accepting sparse input
    return x.count_nonzero() / y.count_nonzero()

A = sparse.random(10, 5, density=0.3, format='csr')

nn = NearestNeighbors(algorithm='brute', metric=sparse_metric).fit(A)

Expected Results

No error is thrown when passing a callable as metric with sparse input

Actual Results

ValueError                                Traceback (most recent call last)
<ipython-input-2-a9d2fd7f843b> in <module>()
      7 A = sparse.random(10, 5, density=0.3, format='csr')
      8 
----> 9 nn = NearestNeighbors(algorithm='brute', metric=sparse_metric).fit(A)

/Volumes/LocalDataHD/thk22/.virtualenvs/nlpy3/lib/python3.5/site-packages/sklearn/neighbors/base.py in fit(self, X, y)
    797             or [n_samples, n_samples] if metric='precomputed'.
    798         """
--> 799         return self._fit(X)

/Volumes/LocalDataHD/thk22/.virtualenvs/nlpy3/lib/python3.5/site-packages/sklearn/neighbors/base.py in _fit(self, X)
    213             if self.effective_metric_ not in VALID_METRICS_SPARSE['brute']:
    214                 raise ValueError("metric '%s' not valid for sparse input"
--> 215                                  % self.effective_metric_)
    216             self._fit_X = X.copy()
    217             self._tree = None

ValueError: metric '<function sparse_metric at 0x1097d0378>' not valid for sparse input

Some Analysis/Wild Speculation

The problem seems to come from the fact that in the case of sparse input, it is only checked whether the given metric is in the list of metrics accepting sparse input, but no check is made whether the given metric is a string or a callable: https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/neighbors/base.py#L210

Versions

Darwin-15.6.0-x86_64-i386-64bit
Python 3.5.1 (default, Dec  8 2015, 06:00:08) 
[GCC 4.2.1 Compatible Apple LLVM 6.0 (clang-600.0.57)]
NumPy 1.12.1
SciPy 0.19.0
Scikit-Learn 0.18.2

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions