Skip to content

FIX Make DistanceMetrics support readonly buffers attributes #21694

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Nov 19, 2021
9 changes: 9 additions & 0 deletions doc/whats_new/v1.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,15 @@ Changelog
and :class:`decomposition.MiniBatchSparsePCA` to be convex and match the referenced
article. :pr:`19210` by :user:`Jérémie du Boisberranger <jeremiedbb>`.

:mod:`sklearn.metrics`
......................

- |Fix| All :class:`sklearn.metrics.DistanceMetric` subclasses now correctly support
read-only buffer attributes.
This fixes a regression introduced in 1.0.0 with respect to 0.24.2.
:pr:`21694` by :user:`Julien Jerphanion <jjerphan>`.


:mod:`sklearn.preprocessing`
............................

Expand Down
12 changes: 6 additions & 6 deletions sklearn/metrics/_dist_metrics.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ cdef DTYPE_t INF = np.inf

from ..utils._typedefs cimport DTYPE_t, ITYPE_t, DITYPE_t, DTYPECODE
from ..utils._typedefs import DTYPE, ITYPE

from ..utils._readonly_array_wrapper import ReadonlyArrayWrapper

######################################################################
# newObj function
Expand Down Expand Up @@ -214,8 +214,8 @@ cdef class DistanceMetric:
set state for pickling
"""
self.p = state[0]
self.vec = state[1]
self.mat = state[2]
self.vec = ReadonlyArrayWrapper(state[1])
self.mat = ReadonlyArrayWrapper(state[2])
if self.__class__.__name__ == "PyFuncDistance":
self.func = state[3]
self.kwargs = state[4]
Expand Down Expand Up @@ -444,7 +444,7 @@ cdef class SEuclideanDistance(DistanceMetric):
D(x, y) = \sqrt{ \sum_i \frac{ (x_i - y_i) ^ 2}{V_i} }
"""
def __init__(self, V):
self.vec = np.asarray(V, dtype=DTYPE)
self.vec = ReadonlyArrayWrapper(np.asarray(V, dtype=DTYPE))
self.size = self.vec.shape[0]
self.p = 2

Expand Down Expand Up @@ -605,7 +605,7 @@ cdef class WMinkowskiDistance(DistanceMetric):
raise ValueError("WMinkowskiDistance requires finite p. "
"For p=inf, use ChebyshevDistance.")
self.p = p
self.vec = np.asarray(w, dtype=DTYPE)
self.vec = ReadonlyArrayWrapper(np.asarray(w, dtype=DTYPE))
self.size = self.vec.shape[0]

def _validate_data(self, X):
Expand Down Expand Up @@ -665,7 +665,7 @@ cdef class MahalanobisDistance(DistanceMetric):
if VI.ndim != 2 or VI.shape[0] != VI.shape[1]:
raise ValueError("V/VI must be square")

self.mat = np.asarray(VI, dtype=float, order='C')
self.mat = ReadonlyArrayWrapper(np.asarray(VI, dtype=float, order='C'))

self.size = self.mat.shape[0]

Expand Down
24 changes: 23 additions & 1 deletion sklearn/metrics/tests/test_dist_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,11 +158,16 @@ def check_pdist_bool(metric, D_true):
assert_array_almost_equal(D12, D_true)


@pytest.mark.parametrize("use_read_only_kwargs", [True, False])
@pytest.mark.parametrize("metric", METRICS_DEFAULT_PARAMS)
def test_pickle(metric):
def test_pickle(use_read_only_kwargs, metric):
argdict = METRICS_DEFAULT_PARAMS[metric]
keys = argdict.keys()
for vals in itertools.product(*argdict.values()):
if use_read_only_kwargs:
for val in vals:
if isinstance(val, np.ndarray):
val.setflags(write=False)
kwargs = dict(zip(keys, vals))
check_pickle(metric, kwargs)

Expand Down Expand Up @@ -242,3 +247,20 @@ def custom_metric(x, y):
pyfunc = DistanceMetric.get_metric("pyfunc", func=custom_metric)
eucl = DistanceMetric.get_metric("euclidean")
assert_array_almost_equal(pyfunc.pairwise(X), eucl.pairwise(X) ** 2)


def test_readonly_kwargs():
# Non-regression test for:
# https://github.com/scikit-learn/scikit-learn/issues/21685

rng = check_random_state(0)

weights = rng.rand(100)
VI = rng.rand(10, 10)
weights.setflags(write=False)
VI.setflags(write=False)

# Those distances metrics have to support readonly buffers.
DistanceMetric.get_metric("seuclidean", V=weights)
DistanceMetric.get_metric("wminkowski", p=1, w=weights)
DistanceMetric.get_metric("mahalanobis", VI=VI)