Skip to content

Commit f9beef1

Browse files
jjerphanjeremiedbb
andcommitted
MAINT Use _VALID_METRICS for PairwiseDistancesReductions
Co-authored-by: Jérémie du Boisberranger <jeremiedbb@users.noreply.github.com>
1 parent c9df9ec commit f9beef1

File tree

3 files changed

+35
-33
lines changed

3 files changed

+35
-33
lines changed

sklearn/metrics/_pairwise_distances_reduction/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -91,11 +91,13 @@
9191
ArgKmin,
9292
RadiusNeighbors,
9393
sqeuclidean_row_norms,
94+
_VALID_METRICS,
9495
)
9596

9697
__all__ = [
9798
"BaseDistancesReductionDispatcher",
9899
"ArgKmin",
99100
"RadiusNeighbors",
100101
"sqeuclidean_row_norms",
102+
"_VALID_METRICS",
101103
]

sklearn/metrics/_pairwise_distances_reduction/_dispatcher.py

+32-2
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from scipy.sparse import isspmatrix_csr
88

9-
from .._dist_metrics import BOOL_METRICS, METRIC_MAPPING
9+
from .._dist_metrics import BOOL_METRICS
1010

1111
from ._base import (
1212
_sqeuclidean_row_norms64,
@@ -23,6 +23,36 @@
2323

2424
from ... import get_config
2525

26+
_VALID_METRICS = [
27+
"euclidean",
28+
"l2",
29+
"l1",
30+
"manhattan",
31+
"cityblock",
32+
"braycurtis",
33+
"canberra",
34+
"chebyshev",
35+
"correlation",
36+
"cosine",
37+
"dice",
38+
"hamming",
39+
"jaccard",
40+
"kulsinski",
41+
"mahalanobis",
42+
"matching",
43+
"minkowski",
44+
"rogerstanimoto",
45+
"russellrao",
46+
"seuclidean",
47+
"sokalmichener",
48+
"sokalsneath",
49+
"sqeuclidean",
50+
"yule",
51+
"wminkowski",
52+
"nan_euclidean",
53+
"haversine",
54+
]
55+
2656

2757
def sqeuclidean_row_norms(X, num_threads):
2858
"""Compute the squared euclidean norm of the rows of X in parallel.
@@ -73,7 +103,7 @@ def valid_metrics(cls) -> List[str]:
73103
"hamming",
74104
*BOOL_METRICS,
75105
}
76-
return sorted(({"sqeuclidean"} | set(METRIC_MAPPING.keys())) - excluded)
106+
return sorted(set(_VALID_METRICS) - excluded)
77107

78108
@classmethod
79109
def is_usable_for(cls, X, Y, metric) -> bool:

sklearn/metrics/pairwise.py

+1-31
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from ..utils.fixes import delayed
3131
from ..utils.fixes import sp_version, parse_version
3232

33-
from ._pairwise_distances_reduction import ArgKmin
33+
from ._pairwise_distances_reduction import ArgKmin, _VALID_METRICS
3434
from ._pairwise_fast import _chi2_kernel_fast, _sparse_manhattan
3535
from ..exceptions import DataConversionWarning
3636

@@ -1619,36 +1619,6 @@ def _pairwise_callable(X, Y, metric, force_all_finite=True, **kwds):
16191619
return out
16201620

16211621

1622-
_VALID_METRICS = [
1623-
"euclidean",
1624-
"l2",
1625-
"l1",
1626-
"manhattan",
1627-
"cityblock",
1628-
"braycurtis",
1629-
"canberra",
1630-
"chebyshev",
1631-
"correlation",
1632-
"cosine",
1633-
"dice",
1634-
"hamming",
1635-
"jaccard",
1636-
"kulsinski",
1637-
"mahalanobis",
1638-
"matching",
1639-
"minkowski",
1640-
"rogerstanimoto",
1641-
"russellrao",
1642-
"seuclidean",
1643-
"sokalmichener",
1644-
"sokalsneath",
1645-
"sqeuclidean",
1646-
"yule",
1647-
"wminkowski",
1648-
"nan_euclidean",
1649-
"haversine",
1650-
]
1651-
16521622
_NAN_METRICS = ["nan_euclidean"]
16531623

16541624

0 commit comments

Comments
 (0)