Skip to content

Commit d8d5265

Browse files
Shreesha3112jeremiedbbMicky774
authored andcommitted
FIX Allow 0<p<1 for Minkowski metric regardless of X's dtype (scikit-learn#26760)
Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com> Co-authored-by: Meekail Zain <34613774+Micky774@users.noreply.github.com>
1 parent 55eb419 commit d8d5265

File tree

4 files changed

+61
-8
lines changed

4 files changed

+61
-8
lines changed

doc/whats_new/v1.4.rst

+7
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,13 @@ Changelog
132132
object in the parameter grid if it's an estimator. :pr:`26786` by `Adrin
133133
Jalali`_.
134134

135+
:mod:`sklearn.neighbors`
136+
........................
137+
138+
- |Fix| Neighbors based estimators now correctly work when `metric="minkowski"` and the
139+
metric parameter `p` is in the range `0 < p < 1`, regardless of the `dtype` of `X`.
140+
:pr:`26760` by :user:`Shreesha Kumar Bhat <Shreesha3112>`.
141+
135142
:mod:`sklearn.tree`
136143
...................
137144

sklearn/metrics/_dist_metrics.pyx.tp

+13-5
Original file line numberDiff line numberDiff line change
@@ -1396,19 +1396,27 @@ cdef class MinkowskiDistance{{name_suffix}}(DistanceMetric{{name_suffix}}):
13961396

13971397
Parameters
13981398
----------
1399-
p : int
1399+
p : float
14001400
The order of the p-norm of the difference (see above).
1401+
1402+
.. versionchanged:: 1.4.0
1403+
Minkowski distance allows `p` to be `0<p<1`.
1404+
1405+
14011406
w : (N,) array-like (optional)
14021407
The weight vector.
14031408

1404-
Minkowski Distance requires p >= 1 and finite. For p = infinity,
1405-
use ChebyshevDistance.
1409+
Minkowski Distance requires p > 0 and finite.
1410+
When :math:`p \in (0,1)`, it isn't a true metric but is permissible when
1411+
the triangular inequality isn't necessary.
1412+
For p = infinity, use ChebyshevDistance.
14061413
Note that for p=1, ManhattanDistance is more efficient, and for
14071414
p=2, EuclideanDistance is more efficient.
1415+
14081416
"""
14091417
def __init__(self, p, w=None):
1410-
if p < 1:
1411-
raise ValueError("p must be greater than 1")
1418+
if p <= 0:
1419+
raise ValueError("p must be greater than 0")
14121420
elif np.isinf(p):
14131421
raise ValueError("MinkowskiDistance requires finite p. "
14141422
"For p=inf, use ChebyshevDistance.")

sklearn/metrics/tests/test_dist_metrics.py

+22-3
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
)
1616
from sklearn.utils import check_random_state
1717
from sklearn.utils._testing import assert_allclose, create_memmap_backed_data
18+
from sklearn.utils.fixes import parse_version, sp_version
1819

1920

2021
def dist_func(x1, x2, p):
@@ -42,18 +43,17 @@ def dist_func(x1, x2, p):
4243
V = rng.random_sample((d, d))
4344
VI = np.dot(V, V.T)
4445

45-
4646
METRICS_DEFAULT_PARAMS = [
4747
("euclidean", {}),
4848
("cityblock", {}),
49-
("minkowski", dict(p=(1, 1.5, 2, 3))),
49+
("minkowski", dict(p=(0.5, 1, 1.5, 2, 3))),
5050
("chebyshev", {}),
5151
("seuclidean", dict(V=(rng.random_sample(d),))),
5252
("mahalanobis", dict(VI=(VI,))),
5353
("hamming", {}),
5454
("canberra", {}),
5555
("braycurtis", {}),
56-
("minkowski", dict(p=(1, 1.5, 3), w=(rng.random_sample(d),))),
56+
("minkowski", dict(p=(0.5, 1, 1.5, 3), w=(rng.random_sample(d),))),
5757
]
5858

5959

@@ -76,6 +76,13 @@ def test_cdist(metric_param_grid, X, Y):
7676
# with scipy
7777
rtol_dict = {"rtol": 1e-6}
7878

79+
# TODO: Remove when scipy minimum version >= 1.7.0
80+
# scipy supports 0<p<1 for minkowski metric >= 1.7.0
81+
if metric == "minkowski":
82+
p = kwargs["p"]
83+
if sp_version < parse_version("1.7.0") and p < 1:
84+
pytest.skip("scipy does not support 0<p<1 for minkowski metric < 1.7.0")
85+
7986
D_scipy_cdist = cdist(X, Y, metric, **kwargs)
8087

8188
dm = DistanceMetric.get_metric(metric, X.dtype, **kwargs)
@@ -150,6 +157,12 @@ def test_pdist(metric_param_grid, X):
150157
# with scipy
151158
rtol_dict = {"rtol": 1e-6}
152159

160+
# TODO: Remove when scipy minimum version >= 1.7.0
161+
# scipy supports 0<p<1 for minkowski metric >= 1.7.0
162+
if metric == "minkowski":
163+
p = kwargs["p"]
164+
if sp_version < parse_version("1.7.0") and p < 1:
165+
pytest.skip("scipy does not support 0<p<1 for minkowski metric < 1.7.0")
153166
D_scipy_pdist = cdist(X, X, metric, **kwargs)
154167

155168
dm = DistanceMetric.get_metric(metric, X.dtype, **kwargs)
@@ -397,3 +410,9 @@ def test_get_metric_bad_dtype():
397410
msg = r"Unexpected dtype .* provided. Please select a dtype from"
398411
with pytest.raises(ValueError, match=msg):
399412
DistanceMetric.get_metric("manhattan", dtype)
413+
414+
415+
def test_minkowski_metric_validate_bad_p_parameter():
416+
msg = "p must be greater than 0"
417+
with pytest.raises(ValueError, match=msg):
418+
DistanceMetric.get_metric("minkowski", p=0)

sklearn/neighbors/tests/test_neighbors.py

+19
Original file line numberDiff line numberDiff line change
@@ -2207,3 +2207,22 @@ def test_predict_dataframe():
22072207

22082208
knn = neighbors.KNeighborsClassifier(n_neighbors=2).fit(X, y)
22092209
knn.predict(X)
2210+
2211+
2212+
def test_nearest_neighbours_works_with_p_less_than_1():
2213+
"""Check that NearestNeighbors works with :math:`p \\in (0,1)` when `algorithm`
2214+
is `"auto"` or `"brute"` regardless of the dtype of X.
2215+
2216+
Non-regression test for issue #26548
2217+
"""
2218+
X = np.array([[1.0, 0.0], [0.0, 0.0], [0.0, 1.0]])
2219+
neigh = neighbors.NearestNeighbors(
2220+
n_neighbors=3, algorithm="brute", metric_params={"p": 0.5}
2221+
)
2222+
neigh.fit(X)
2223+
2224+
y = neigh.radius_neighbors(X[0].reshape(1, -1), radius=4, return_distance=False)
2225+
assert_allclose(y[0], [0, 1, 2])
2226+
2227+
y = neigh.kneighbors(X[0].reshape(1, -1), return_distance=False)
2228+
assert_allclose(y[0], [0, 1, 2])

0 commit comments

Comments
 (0)