Skip to content

ENH: Update KDTree, and example documentation #25482

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Feb 17, 2023
12 changes: 8 additions & 4 deletions doc/modules/neighbors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,13 @@ have the same interface; we'll show an example of using the KD Tree here:
Refer to the :class:`KDTree` and :class:`BallTree` class documentation
for more information on the options available for nearest neighbors searches,
including specification of query strategies, distance metrics, etc. For a list
of available metrics, see the documentation of the :class:`DistanceMetric` class
and the metrics listed in `sklearn.metrics.pairwise.PAIRWISE_DISTANCE_FUNCTIONS`.
Note that the "cosine" metric uses :func:`~sklearn.metrics.pairwise.cosine_distances`.
of valid metrics use :meth:`KDTree.valid_metrics` and :meth:`BallTree.valid_metrics`:

>>> from sklearn.neighbors import KDTree, BallTree
>>> KDTree.valid_metrics()
['euclidean', 'l2', 'minkowski', 'p', 'manhattan', 'cityblock', 'l1', 'chebyshev', 'infinity']
>>> BallTree.valid_metrics()
['euclidean', 'l2', 'minkowski', 'p', 'manhattan', 'cityblock', 'l1', 'chebyshev', 'infinity', 'seuclidean', 'mahalanobis', 'wminkowski', 'hamming', 'canberra', 'braycurtis', 'matching', 'jaccard', 'dice', 'rogerstanimoto', 'russellrao', 'sokalmichener', 'sokalsneath', 'haversine', 'pyfunc']

.. _classification:

Expand Down Expand Up @@ -476,7 +480,7 @@ A list of valid metrics for any of the above algorithms can be obtained by using
``valid_metric`` attribute. For example, valid metrics for ``KDTree`` can be generated by:

>>> from sklearn.neighbors import KDTree
>>> print(sorted(KDTree.valid_metrics))
>>> print(sorted(KDTree.valid_metrics()))
['chebyshev', 'cityblock', 'euclidean', 'infinity', 'l1', 'l2', 'manhattan', 'minkowski', 'p']


Expand Down
4 changes: 2 additions & 2 deletions sklearn/neighbors/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@
SCIPY_METRICS += ["kulsinski"]

VALID_METRICS = dict(
ball_tree=BallTree.valid_metrics,
kd_tree=KDTree.valid_metrics,
ball_tree=BallTree._valid_metrics,
kd_tree=KDTree._valid_metrics,
# The following list comes from the
# sklearn.metrics.pairwise doc string
brute=sorted(set(PAIRWISE_DISTANCE_FUNCTIONS).union(SCIPY_METRICS)),
Expand Down
25 changes: 19 additions & 6 deletions sklearn/neighbors/_binary_tree.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -234,11 +234,11 @@ leaf_size : positive int, default=40
metric : str or DistanceMetric object, default='minkowski'
Metric to use for distance computation. Default is "minkowski", which
results in the standard Euclidean distance when p = 2.
{binary_tree}.valid_metrics gives a list of the metrics which are valid for
{BinaryTree}. See the documentation of `scipy.spatial.distance
<https://docs.scipy.org/doc/scipy/reference/spatial.distance.html>`_ and the
metrics listed in :class:`~sklearn.metrics.pairwise.distance_metrics` for
more information.
A list of valid metrics for {BinaryTree} is given by
:meth:`{BinaryTree}.valid_metrics`.
See the documentation of `scipy.spatial.distance
<https://docs.scipy.org/doc/scipy/reference/spatial.distance.html>`_ and the metrics listed in :class:`~sklearn.metrics.pairwise.distance_metrics` for
more information on any distance metric.

Additional keywords are passed to the distance metric class.
Note: Callable functions in the metric parameter are NOT supported for KDTree
Expand Down Expand Up @@ -791,7 +791,7 @@ cdef class BinaryTree:
cdef int n_splits
cdef int n_calls

valid_metrics = VALID_METRIC_IDS
_valid_metrics = VALID_METRIC_IDS

# Use cinit to initialize all arrays to empty: this will prevent memory
# errors and seg-faults in rare cases where __init__ is not called
Expand Down Expand Up @@ -979,6 +979,19 @@ cdef class BinaryTree:
self.node_bounds.base,
)

@classmethod
def valid_metrics(cls):
"""Get list of valid distance metrics.

.. versionadded:: 1.3

Returns
-------
valid_metrics: list of str
List of valid distance metrics.
"""
return cls._valid_metrics

cdef inline DTYPE_t dist(self, DTYPE_t* x1, DTYPE_t* x2,
ITYPE_t size) nogil except -1:
"""Compute the distance between arrays x1 and x2"""
Expand Down
6 changes: 3 additions & 3 deletions sklearn/neighbors/_kde.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,12 +174,12 @@ def _choose_algorithm(self, algorithm, metric):
# algorithm to compute the result.
if algorithm == "auto":
# use KD Tree if possible
if metric in KDTree.valid_metrics:
if metric in KDTree.valid_metrics():
return "kd_tree"
elif metric in BallTree.valid_metrics:
elif metric in BallTree.valid_metrics():
return "ball_tree"
else: # kd_tree or ball_tree
if metric not in TREE_DICT[algorithm].valid_metrics:
if metric not in TREE_DICT[algorithm].valid_metrics():
raise ValueError(
"invalid metric for {0}: '{1}'".format(TREE_DICT[algorithm], metric)
)
Expand Down
4 changes: 2 additions & 2 deletions sklearn/neighbors/tests/test_kde.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def test_kde_algorithm_metric_choice(algorithm, metric):

kde = KernelDensity(algorithm=algorithm, metric=metric)

if algorithm == "kd_tree" and metric not in KDTree.valid_metrics:
if algorithm == "kd_tree" and metric not in KDTree.valid_metrics():
with pytest.raises(ValueError, match="invalid metric"):
kde.fit(X)
else:
Expand Down Expand Up @@ -165,7 +165,7 @@ def test_kde_sample_weights():
test_points = rng.rand(n_samples_test, d)
for algorithm in ["auto", "ball_tree", "kd_tree"]:
for metric in ["euclidean", "minkowski", "manhattan", "chebyshev"]:
if algorithm != "kd_tree" or metric in KDTree.valid_metrics:
if algorithm != "kd_tree" or metric in KDTree.valid_metrics():
kde = KernelDensity(algorithm=algorithm, metric=metric)

# Test that adding a constant sample weight has no effect
Expand Down