-
-
Notifications
You must be signed in to change notification settings - Fork 26k
MNT: Remove duplicated data validation done in internally used BinaryTrees #19418
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
011c8d2
59b4a8e
f6a1357
bddc33f
e7de9a0
02e738f
99a4ce1
966498d
df291d9
1828e00
e6959e3
d9c900c
d88970c
aae2ad3
32a5cff
bba6465
ba93b54
ea516c4
6fe0619
91d18c1
8e733cb
8dd54e4
7cec6bd
22a6a2f
156eff7
dcc9995
ff32212
903125f
c882688
9cb0d6b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,6 +19,7 @@ | |
|
||
from ._ball_tree import BallTree | ||
from ._kd_tree import KDTree | ||
from .. import config_context | ||
from ..base import BaseEstimator, MultiOutputMixin | ||
from ..base import is_classifier | ||
from ..metrics import pairwise_distances_chunked | ||
|
@@ -542,24 +543,28 @@ def _fit(self, X, y=None): | |
else: | ||
self._fit_method = "brute" | ||
|
||
if self._fit_method == "ball_tree": | ||
self._tree = BallTree( | ||
X, | ||
self.leaf_size, | ||
metric=self.effective_metric_, | ||
**self.effective_metric_params_, | ||
) | ||
elif self._fit_method == "kd_tree": | ||
self._tree = KDTree( | ||
X, | ||
self.leaf_size, | ||
metric=self.effective_metric_, | ||
**self.effective_metric_params_, | ||
) | ||
elif self._fit_method == "brute": | ||
self._tree = None | ||
else: | ||
raise ValueError("algorithm = '%s' not recognized" % self.algorithm) | ||
with config_context(assume_finite=True): | ||
# In the following cases, we remove the validation of X done at | ||
# the beginning of the BinaryTree's constructors as X already got | ||
# validated when calling this method, NeighborsBase._fit. | ||
Comment on lines
+546
to
+549
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Checks running time (represented by the small spike at the beginning) is negligible here. Script# profile.py
from sklearn.datasets import make_regression
from sklearn.neighbors import NearestNeighbors
def main(args=None):
X_train, _ = make_regression(n_samples=100_000, n_features=10)
X_test, _ = make_regression(n_samples=100_000, n_features=10)
nn = NearestNeighbors(algorithm='kd_tree').fit(X_train)
nn.kneighbors(X_test, n_neighbors=2)
if __name__ == "__main__":
main() giltracer --state-detect profile.py |
||
if self._fit_method == "ball_tree": | ||
self._tree = BallTree( | ||
X, | ||
self.leaf_size, | ||
metric=self.effective_metric_, | ||
**self.effective_metric_params_, | ||
) | ||
elif self._fit_method == "kd_tree": | ||
self._tree = KDTree( | ||
X, | ||
self.leaf_size, | ||
metric=self.effective_metric_, | ||
**self.effective_metric_params_, | ||
) | ||
elif self._fit_method == "brute": | ||
self._tree = None | ||
else: | ||
raise ValueError("algorithm = '%s' not recognized" % self.algorithm) | ||
|
||
if self.n_neighbors is not None: | ||
if self.n_neighbors <= 0: | ||
|
@@ -770,12 +775,18 @@ class from an array representing our data set and ask who's | |
parallel_kwargs = {"backend": "threading"} | ||
else: | ||
parallel_kwargs = {"prefer": "threads"} | ||
chunked_results = Parallel(n_jobs, **parallel_kwargs)( | ||
delayed(_tree_query_parallel_helper)( | ||
self._tree, X[s], n_neighbors, return_distance | ||
|
||
with config_context(assume_finite=True): | ||
# We remove the validation of the query points | ||
# (in *parallel_kwargs) done at the beginning of | ||
# BinaryTree.query as those points already got | ||
# validated in the caller. | ||
Comment on lines
+779
to
+783
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Checks running time (represented by the small spike at the beginning) here is negligible. Script# profile.py
from sklearn.datasets import make_regression
from sklearn.neighbors import NearestNeighbors
def main(args=None):
X_train, _ = make_regression(n_samples=100_000, n_features=10)
X_test, _ = make_regression(n_samples=100_000, n_features=10)
nn = NearestNeighbors(algorithm='kd_tree').fit(X_train)
nn.kneighbors(X_test, n_neighbors=2)
if __name__ == "__main__":
main() giltracer --state-detect profile.py |
||
chunked_results = Parallel(n_jobs, **parallel_kwargs)( | ||
delayed(_tree_query_parallel_helper)( | ||
self._tree, X[s], n_neighbors, return_distance | ||
) | ||
for s in gen_even_slices(X.shape[0], n_jobs) | ||
) | ||
for s in gen_even_slices(X.shape[0], n_jobs) | ||
) | ||
else: | ||
raise ValueError("internal: _fit_method not recognized") | ||
|
||
|
@@ -1108,12 +1119,17 @@ class from an array representing our data set and ask who's | |
else: | ||
parallel_kwargs = {"prefer": "threads"} | ||
|
||
chunked_results = Parallel(n_jobs, **parallel_kwargs)( | ||
delayed_query( | ||
self._tree, X[s], radius, return_distance, sort_results=sort_results | ||
with config_context(assume_finite=True): | ||
# We remove the validation of the query points | ||
# (in *parallel_kwargs) done at the beginning of | ||
# BinaryTree.query_radius as those points already | ||
# got validated in the caller. | ||
Comment on lines
+1122
to
+1126
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Checks running time here is negligible here similarly. Script# profile.py
from sklearn.datasets import make_regression
from sklearn.neighbors import NearestNeighbors
def main(args=None):
X_train, _ = make_regression(n_samples=100_000, n_features=10)
X_test, _ = make_regression(n_samples=100_000, n_features=10)
nn = NearestNeighbors(algorithm='kd_tree').fit(X_train)
nn.kneighbors(X_test, n_neighbors=2)
if __name__ == "__main__":
main() giltracer --state-detect profile.py |
||
chunked_results = Parallel(n_jobs, **parallel_kwargs)( | ||
delayed_query( | ||
self._tree, X[s], radius, return_distance, sort_results=sort_results | ||
) | ||
for s in gen_even_slices(X.shape[0], n_jobs) | ||
) | ||
for s in gen_even_slices(X.shape[0], n_jobs) | ||
) | ||
if return_distance: | ||
neigh_ind, neigh_dist = tuple(zip(*chunked_results)) | ||
results = np.hstack(neigh_dist), np.hstack(neigh_ind) | ||
|
Uh oh!
There was an error while loading. Please reload this page.