Skip to content

[WIP] Pre-compute norms to speed-up NearestNeighbors.kneighbors algorith='brute' metric='euclidean' #10212

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 15 additions & 10 deletions sklearn/neighbors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,16 @@
from .kd_tree import KDTree
from ..base import BaseEstimator
from ..metrics import pairwise_distances
from ..metrics.pairwise import PAIRWISE_DISTANCE_FUNCTIONS
from ..metrics.pairwise import PAIRWISE_DISTANCE_FUNCTIONS, euclidean_distances
from ..utils import check_X_y, check_array, _get_n_jobs, gen_even_slices
from ..utils.multiclass import check_classification_targets
from ..utils.validation import check_is_fitted
from ..utils.extmath import row_norms
from ..externals import six
from ..externals.joblib import Parallel, delayed
from ..exceptions import NotFittedError
from ..exceptions import DataConversionWarning
from .. import config_context

VALID_METRICS = dict(ball_tree=BallTree.valid_metrics,
kd_tree=KDTree.valid_metrics,
Expand Down Expand Up @@ -106,7 +108,6 @@ class NeighborsBase(six.with_metaclass(ABCMeta, BaseEstimator)):
def __init__(self, n_neighbors=None, radius=None,
algorithm='auto', leaf_size=30, metric='minkowski',
p=2, metric_params=None, n_jobs=1):

self.n_neighbors = n_neighbors
self.radius = radius
self.algorithm = algorithm
Expand Down Expand Up @@ -185,6 +186,8 @@ def _fit(self, X):
self._fit_X = X._fit_X
self._tree = X._tree
self._fit_method = X._fit_method
if hasattr(X, '_fit_X_norms_squared'):
self._fit_X_norms_squared = X._fit_X_norms_squared
return self

elif isinstance(X, BallTree):
Expand All @@ -199,12 +202,15 @@ def _fit(self, X):
self._fit_method = 'kd_tree'
return self

X = check_array(X, accept_sparse='csr')
X = check_array(X, accept_sparse='csr', dtype=[np.float64, np.float32])

n_samples = X.shape[0]
if n_samples == 0:
raise ValueError("n_samples must be greater than 0")

self._fit_method = self.algorithm
self._fit_X = X

if issparse(X):
if self.algorithm not in ('auto', 'brute'):
warnings.warn("cannot use tree with sparse input: "
Expand All @@ -215,12 +221,7 @@ def _fit(self, X):
raise ValueError("metric '%s' not valid for sparse input"
% self.effective_metric_)
self._fit_X = X.copy()
self._tree = None
self._fit_method = 'brute'
return self

self._fit_method = self.algorithm
self._fit_X = X

if self._fit_method == 'auto':
# A tree approach is better for small number of neighbors,
Expand Down Expand Up @@ -248,6 +249,8 @@ def _fit(self, X):
**self.effective_metric_params_)
elif self._fit_method == 'brute':
self._tree = None
if self.effective_metric_ == 'euclidean':
self._fit_X_norms_squared = row_norms(X, squared=True)
else:
raise ValueError("algorithm = '%s' not recognized"
% self.algorithm)
Expand Down Expand Up @@ -352,8 +355,10 @@ class from an array representing our data set and ask who's
if self._fit_method == 'brute':
# for efficiency, use squared euclidean distances
if self.effective_metric_ == 'euclidean':
dist = pairwise_distances(X, self._fit_X, 'euclidean',
n_jobs=n_jobs, squared=True)
Y_norm_squared = self._fit_X_norms_squared.reshape(1, -1)
with config_context(assume_finite=True):
dist = euclidean_distances(X, self._fit_X, squared=True,
Y_norm_squared=Y_norm_squared)
else:
dist = pairwise_distances(
X, self._fit_X, self.effective_metric_, n_jobs=n_jobs,
Expand Down