Skip to content

FIX Revert {Ball,KD}Tree.valid_metrics to public class attributes #26754

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions doc/modules/neighbors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,9 @@ including specification of query strategies, distance metrics, etc. For a list
of valid metrics use :meth:`KDTree.valid_metrics` and :meth:`BallTree.valid_metrics`:

>>> from sklearn.neighbors import KDTree, BallTree
>>> KDTree.valid_metrics()
>>> KDTree.valid_metrics
['euclidean', 'l2', 'minkowski', 'p', 'manhattan', 'cityblock', 'l1', 'chebyshev', 'infinity']
>>> BallTree.valid_metrics()
>>> BallTree.valid_metrics
['euclidean', 'l2', 'minkowski', 'p', 'manhattan', 'cityblock', 'l1', 'chebyshev', 'infinity', 'seuclidean', 'mahalanobis', 'hamming', 'canberra', 'braycurtis', 'jaccard', 'dice', 'rogerstanimoto', 'russellrao', 'sokalmichener', 'sokalsneath', 'haversine', 'pyfunc']

.. _classification:
Expand Down Expand Up @@ -480,7 +480,7 @@ A list of valid metrics for any of the above algorithms can be obtained by using
``valid_metric`` attribute. For example, valid metrics for ``KDTree`` can be generated by:

>>> from sklearn.neighbors import KDTree
>>> print(sorted(KDTree.valid_metrics()))
>>> print(sorted(KDTree.valid_metrics))
['chebyshev', 'cityblock', 'euclidean', 'infinity', 'l1', 'l2', 'manhattan', 'minkowski', 'p']


Expand Down
17 changes: 17 additions & 0 deletions doc/whats_new/v1.3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,23 @@

.. currentmodule:: sklearn

.. _changes_1_3_1:

Version 1.3.1
=============

**TODO: set date**

Changelog
---------

:mod:`sklearn.neighbors`
........................

- |Fix| Reintroduce :attr:`sklearn.neighbors.BallTree.valid_metrics` and
:attr:`sklearn.neighbors.KDTree.valid_metrics` as public class attributes.
:pr:`26754` by :user:`Julien Jerphanion <jjerphan>`.

.. _changes_1_3:

Version 1.3.0
Expand Down
10 changes: 4 additions & 6 deletions sklearn/cluster/_hdbscan/hdbscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
from ._reachability import mutual_reachability_graph
from ._tree import HIERARCHY_dtype, labelling_at_cut, tree_to_labels

FAST_METRICS = set(KDTree.valid_metrics() + BallTree.valid_metrics())
FAST_METRICS = set(KDTree.valid_metrics + BallTree.valid_metrics)

# Encodings are arbitrary but must be strictly negative.
# The current encodings are chosen as extensions to the -1 noise label.
Expand Down Expand Up @@ -768,14 +768,12 @@ def fit(self, X, y=None):
n_jobs=self.n_jobs,
**self._metric_params,
)
if self.algorithm == "kdtree" and self.metric not in KDTree.valid_metrics():
if self.algorithm == "kdtree" and self.metric not in KDTree.valid_metrics:
raise ValueError(
f"{self.metric} is not a valid metric for a KDTree-based algorithm."
" Please select a different metric."
)
elif (
self.algorithm == "balltree" and self.metric not in BallTree.valid_metrics()
):
elif self.algorithm == "balltree" and self.metric not in BallTree.valid_metrics:
raise ValueError(
f"{self.metric} is not a valid metric for a BallTree-based algorithm."
" Please select a different metric."
Expand Down Expand Up @@ -805,7 +803,7 @@ def fit(self, X, y=None):
# We can't do much with sparse matrices ...
mst_func = _hdbscan_brute
kwargs["copy"] = self.copy
elif self.metric in KDTree.valid_metrics():
elif self.metric in KDTree.valid_metrics:
# TODO: Benchmark KD vs Ball Tree efficiency
mst_func = _hdbscan_prims
kwargs["algo"] = "kd_tree"
Expand Down
4 changes: 2 additions & 2 deletions sklearn/cluster/tests/test_hdbscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def test_hdbscan_algorithms(algo, metric):
metric_params=metric_params,
)

if metric not in ALGOS_TREES[algo].valid_metrics():
if metric not in ALGOS_TREES[algo].valid_metrics:
with pytest.raises(ValueError):
hdb.fit(X)
elif metric == "wminkowski":
Expand Down Expand Up @@ -424,7 +424,7 @@ def test_hdbscan_tree_invalid_metric():

# The set of valid metrics for KDTree at the time of writing this test is a
# strict subset of those supported in BallTree
metrics_not_kd = list(set(BallTree.valid_metrics()) - set(KDTree.valid_metrics()))
metrics_not_kd = list(set(BallTree.valid_metrics) - set(KDTree.valid_metrics))
if len(metrics_not_kd) > 0:
with pytest.raises(ValueError, match=msg):
HDBSCAN(algorithm="kdtree", metric=metrics_not_kd[0]).fit(X)
Expand Down
4 changes: 2 additions & 2 deletions sklearn/neighbors/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@
SCIPY_METRICS += ["matching"]

VALID_METRICS = dict(
ball_tree=BallTree._valid_metrics,
kd_tree=KDTree._valid_metrics,
ball_tree=BallTree.valid_metrics,
kd_tree=KDTree.valid_metrics,
# The following list comes from the
# sklearn.metrics.pairwise doc string
brute=sorted(set(PAIRWISE_DISTANCE_FUNCTIONS).union(SCIPY_METRICS)),
Expand Down
22 changes: 6 additions & 16 deletions sklearn/neighbors/_binary_tree.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -236,9 +236,10 @@ metric : str or DistanceMetric64 object, default='minkowski'
Metric to use for distance computation. Default is "minkowski", which
results in the standard Euclidean distance when p = 2.
A list of valid metrics for {BinaryTree} is given by
:meth:`{BinaryTree}.valid_metrics`.
:attr:`{BinaryTree}.valid_metrics`.
See the documentation of `scipy.spatial.distance
<https://docs.scipy.org/doc/scipy/reference/spatial.distance.html>`_ and the metrics listed in :class:`~sklearn.metrics.pairwise.distance_metrics` for
<https://docs.scipy.org/doc/scipy/reference/spatial.distance.html>`_ and
the metrics listed in :class:`~sklearn.metrics.pairwise.distance_metrics` for
more information on any distance metric.

Additional keywords are passed to the distance metric class.
Expand All @@ -249,6 +250,8 @@ Attributes
----------
data : memory view
The training data
valid_metrics: list of str
List of valid distance metrics.

Examples
--------
Expand Down Expand Up @@ -792,7 +795,7 @@ cdef class BinaryTree:
cdef int n_splits
cdef int n_calls

_valid_metrics = VALID_METRIC_IDS
valid_metrics = VALID_METRIC_IDS

# Use cinit to initialize all arrays to empty: this will prevent memory
# errors and seg-faults in rare cases where __init__ is not called
Expand Down Expand Up @@ -979,19 +982,6 @@ cdef class BinaryTree:
self.node_bounds.base,
)

@classmethod
def valid_metrics(cls):
"""Get list of valid distance metrics.

.. versionadded:: 1.3

Returns
-------
valid_metrics: list of str
List of valid distance metrics.
"""
return cls._valid_metrics

cdef inline float64_t dist(self, float64_t* x1, float64_t* x2,
intp_t size) except -1 nogil:
"""Compute the distance between arrays x1 and x2"""
Expand Down
6 changes: 3 additions & 3 deletions sklearn/neighbors/_kde.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,12 +173,12 @@ def _choose_algorithm(self, algorithm, metric):
# algorithm to compute the result.
if algorithm == "auto":
# use KD Tree if possible
if metric in KDTree.valid_metrics():
if metric in KDTree.valid_metrics:
return "kd_tree"
elif metric in BallTree.valid_metrics():
elif metric in BallTree.valid_metrics:
return "ball_tree"
else: # kd_tree or ball_tree
if metric not in TREE_DICT[algorithm].valid_metrics():
if metric not in TREE_DICT[algorithm].valid_metrics:
raise ValueError(
"invalid metric for {0}: '{1}'".format(TREE_DICT[algorithm], metric)
)
Expand Down
4 changes: 2 additions & 2 deletions sklearn/neighbors/tests/test_kde.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def test_kde_algorithm_metric_choice(algorithm, metric):

kde = KernelDensity(algorithm=algorithm, metric=metric)

if algorithm == "kd_tree" and metric not in KDTree.valid_metrics():
if algorithm == "kd_tree" and metric not in KDTree.valid_metrics:
with pytest.raises(ValueError, match="invalid metric"):
kde.fit(X)
else:
Expand Down Expand Up @@ -164,7 +164,7 @@ def test_kde_sample_weights():
test_points = rng.rand(n_samples_test, d)
for algorithm in ["auto", "ball_tree", "kd_tree"]:
for metric in ["euclidean", "minkowski", "manhattan", "chebyshev"]:
if algorithm != "kd_tree" or metric in KDTree.valid_metrics():
if algorithm != "kd_tree" or metric in KDTree.valid_metrics:
kde = KernelDensity(algorithm=algorithm, metric=metric)

# Test that adding a constant sample weight has no effect
Expand Down