Skip to content

Commit 33c39e5

Browse files
authored
FIX Make DistanceMetrics support readonly buffers attributes (#21694)
1 parent 432ae47 commit 33c39e5

File tree

3 files changed

+38
-7
lines changed

3 files changed

+38
-7
lines changed

doc/whats_new/v1.0.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,15 @@ Changelog
2828
and :class:`decomposition.MiniBatchSparsePCA` to be convex and match the referenced
2929
article. :pr:`19210` by :user:`Jérémie du Boisberranger <jeremiedbb>`.
3030

31+
:mod:`sklearn.metrics`
32+
......................
33+
34+
- |Fix| All :class:`sklearn.metrics.DistanceMetric` subclasses now correctly support
35+
read-only buffer attributes.
36+
This fixes a regression introduced in 1.0.0 with respect to 0.24.2.
37+
:pr:`21694` by :user:`Julien Jerphanion <jjerphan>`.
38+
39+
3140
:mod:`sklearn.preprocessing`
3241
............................
3342

sklearn/metrics/_dist_metrics.pyx

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ cdef DTYPE_t INF = np.inf
2929

3030
from ..utils._typedefs cimport DTYPE_t, ITYPE_t, DITYPE_t, DTYPECODE
3131
from ..utils._typedefs import DTYPE, ITYPE
32-
32+
from ..utils._readonly_array_wrapper import ReadonlyArrayWrapper
3333

3434
######################################################################
3535
# newObj function
@@ -214,8 +214,8 @@ cdef class DistanceMetric:
214214
set state for pickling
215215
"""
216216
self.p = state[0]
217-
self.vec = state[1]
218-
self.mat = state[2]
217+
self.vec = ReadonlyArrayWrapper(state[1])
218+
self.mat = ReadonlyArrayWrapper(state[2])
219219
if self.__class__.__name__ == "PyFuncDistance":
220220
self.func = state[3]
221221
self.kwargs = state[4]
@@ -444,7 +444,7 @@ cdef class SEuclideanDistance(DistanceMetric):
444444
D(x, y) = \sqrt{ \sum_i \frac{ (x_i - y_i) ^ 2}{V_i} }
445445
"""
446446
def __init__(self, V):
447-
self.vec = np.asarray(V, dtype=DTYPE)
447+
self.vec = ReadonlyArrayWrapper(np.asarray(V, dtype=DTYPE))
448448
self.size = self.vec.shape[0]
449449
self.p = 2
450450

@@ -605,7 +605,7 @@ cdef class WMinkowskiDistance(DistanceMetric):
605605
raise ValueError("WMinkowskiDistance requires finite p. "
606606
"For p=inf, use ChebyshevDistance.")
607607
self.p = p
608-
self.vec = np.asarray(w, dtype=DTYPE)
608+
self.vec = ReadonlyArrayWrapper(np.asarray(w, dtype=DTYPE))
609609
self.size = self.vec.shape[0]
610610

611611
def _validate_data(self, X):
@@ -665,7 +665,7 @@ cdef class MahalanobisDistance(DistanceMetric):
665665
if VI.ndim != 2 or VI.shape[0] != VI.shape[1]:
666666
raise ValueError("V/VI must be square")
667667

668-
self.mat = np.asarray(VI, dtype=float, order='C')
668+
self.mat = ReadonlyArrayWrapper(np.asarray(VI, dtype=float, order='C'))
669669

670670
self.size = self.mat.shape[0]
671671

sklearn/metrics/tests/test_dist_metrics.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,11 +158,16 @@ def check_pdist_bool(metric, D_true):
158158
assert_array_almost_equal(D12, D_true)
159159

160160

161+
@pytest.mark.parametrize("use_read_only_kwargs", [True, False])
161162
@pytest.mark.parametrize("metric", METRICS_DEFAULT_PARAMS)
162-
def test_pickle(metric):
163+
def test_pickle(use_read_only_kwargs, metric):
163164
argdict = METRICS_DEFAULT_PARAMS[metric]
164165
keys = argdict.keys()
165166
for vals in itertools.product(*argdict.values()):
167+
if use_read_only_kwargs:
168+
for val in vals:
169+
if isinstance(val, np.ndarray):
170+
val.setflags(write=False)
166171
kwargs = dict(zip(keys, vals))
167172
check_pickle(metric, kwargs)
168173

@@ -242,3 +247,20 @@ def custom_metric(x, y):
242247
pyfunc = DistanceMetric.get_metric("pyfunc", func=custom_metric)
243248
eucl = DistanceMetric.get_metric("euclidean")
244249
assert_array_almost_equal(pyfunc.pairwise(X), eucl.pairwise(X) ** 2)
250+
251+
252+
def test_readonly_kwargs():
253+
# Non-regression test for:
254+
# https://github.com/scikit-learn/scikit-learn/issues/21685
255+
256+
rng = check_random_state(0)
257+
258+
weights = rng.rand(100)
259+
VI = rng.rand(10, 10)
260+
weights.setflags(write=False)
261+
VI.setflags(write=False)
262+
263+
# Those distances metrics have to support readonly buffers.
264+
DistanceMetric.get_metric("seuclidean", V=weights)
265+
DistanceMetric.get_metric("wminkowski", p=1, w=weights)
266+
DistanceMetric.get_metric("mahalanobis", VI=VI)

0 commit comments

Comments
 (0)