Skip to content

MAINT Remove ReadonlyArrayWrapper from DistanceMetric #25553

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions sklearn/metrics/_dist_metrics.pxd.tp
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ cdef class DistanceMetric{{name_suffix}}:
# Because we don't expect to instantiate a lot of these objects, the
# extra memory overhead of this setup should not be an issue.
cdef DTYPE_t p
cdef DTYPE_t[::1] vec
cdef DTYPE_t[:, ::1] mat
cdef const DTYPE_t[::1] vec
cdef const DTYPE_t[:, ::1] mat
cdef ITYPE_t size
cdef object func
cdef object kwargs
Expand Down
46 changes: 26 additions & 20 deletions sklearn/metrics/_dist_metrics.pyx.tp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ from libc.math cimport fabs, sqrt, exp, pow, cos, sin, asin
from scipy.sparse import csr_matrix, issparse
from ..utils._typedefs cimport DTYPE_t, ITYPE_t, DTYPECODE
from ..utils._typedefs import DTYPE, ITYPE
from ..utils._readonly_array_wrapper import ReadonlyArrayWrapper
from ..utils import check_array
from ..utils.fixes import parse_version, sp_base_version

Expand Down Expand Up @@ -269,8 +268,8 @@ cdef class DistanceMetric{{name_suffix}}:
set state for pickling
"""
self.p = state[0]
self.vec = ReadonlyArrayWrapper(state[1])
self.mat = ReadonlyArrayWrapper(state[2])
self.vec = state[1]
self.mat = state[2]
if self.__class__.__name__ == "PyFuncDistance{{name_suffix}}":
self.func = state[3]
self.kwargs = state[4]
Expand Down Expand Up @@ -979,7 +978,7 @@ cdef class SEuclideanDistance{{name_suffix}}(DistanceMetric{{name_suffix}}):
D(x, y) = \sqrt{ \sum_i \frac{ (x_i - y_i) ^ 2}{V_i} }
"""
def __init__(self, V):
self.vec = ReadonlyArrayWrapper(np.asarray(V, dtype=DTYPE))
self.vec = np.asarray(V, dtype=DTYPE)
self.size = self.vec.shape[0]
self.p = 2

Expand Down Expand Up @@ -1294,10 +1293,10 @@ cdef class MinkowskiDistance{{name_suffix}}(DistanceMetric{{name_suffix}}):
)
if (w_array < 0).any():
raise ValueError("w cannot contain negative weights")
self.vec = ReadonlyArrayWrapper(w_array)
self.vec = w_array
self.size = self.vec.shape[0]
else:
self.vec = ReadonlyArrayWrapper(np.asarray([], dtype=DTYPE))
self.vec = np.asarray([], dtype=DTYPE)
self.size = 0

def _validate_data(self, X):
Expand Down Expand Up @@ -1486,7 +1485,7 @@ cdef class WMinkowskiDistance{{name_suffix}}(DistanceMetric{{name_suffix}}):
raise ValueError("WMinkowskiDistance requires finite p. "
"For p=inf, use ChebyshevDistance.")
self.p = p
self.vec = ReadonlyArrayWrapper(np.asarray(w, dtype=DTYPE))
self.vec = np.asarray(w, dtype=DTYPE)
self.size = self.vec.shape[0]

def _validate_data(self, X):
Expand Down Expand Up @@ -1622,6 +1621,8 @@ cdef class MahalanobisDistance{{name_suffix}}(DistanceMetric{{name_suffix}}):
optionally specify the inverse directly. If VI is passed,
then V is not referenced.
"""
cdef DTYPE_t[::1] buffer

def __init__(self, V=None, VI=None):
if VI is None:
if V is None:
Expand All @@ -1631,12 +1632,17 @@ cdef class MahalanobisDistance{{name_suffix}}(DistanceMetric{{name_suffix}}):
if VI.ndim != 2 or VI.shape[0] != VI.shape[1]:
raise ValueError("V/VI must be square")

self.mat = ReadonlyArrayWrapper(np.asarray(VI, dtype=DTYPE, order='C'))
self.mat = np.asarray(VI, dtype=DTYPE, order='C')

self.size = self.mat.shape[0]

# we need vec as a work buffer
self.vec = np.zeros(self.size, dtype=DTYPE)
# We need to create a buffer to store the vectors' coordinates' differences
self.buffer = np.zeros(self.size, dtype=DTYPE)

def __setstate__(self, state):
super().__setstate__(state)
self.size = self.mat.shape[0]
self.buffer = np.zeros(self.size, dtype=DTYPE)

def _validate_data(self, X):
if X.shape[1] != self.size:
Expand All @@ -1653,13 +1659,13 @@ cdef class MahalanobisDistance{{name_suffix}}(DistanceMetric{{name_suffix}}):

# compute (x1 - x2).T * VI * (x1 - x2)
for i in range(size):
self.vec[i] = x1[i] - x2[i]
self.buffer[i] = x1[i] - x2[i]

for i in range(size):
tmp = 0
for j in range(size):
tmp += self.mat[i, j] * self.vec[j]
d += tmp * self.vec[i]
tmp += self.mat[i, j] * self.buffer[j]
d += tmp * self.buffer[i]
return d

cdef inline DTYPE_t dist(
Expand Down Expand Up @@ -1707,32 +1713,32 @@ cdef class MahalanobisDistance{{name_suffix}}(DistanceMetric{{name_suffix}}):
ix2 = x2_indices[i2]

if ix1 == ix2:
self.vec[ix1] = x1_data[i1] - x2_data[i2]
self.buffer[ix1] = x1_data[i1] - x2_data[i2]
i1 = i1 + 1
i2 = i2 + 1
elif ix1 < ix2:
self.vec[ix1] = x1_data[i1]
self.buffer[ix1] = x1_data[i1]
i1 = i1 + 1
else:
self.vec[ix2] = - x2_data[i2]
self.buffer[ix2] = - x2_data[i2]
i2 = i2 + 1

if i1 == x1_end:
while i2 < x2_end:
ix2 = x2_indices[i2]
self.vec[ix2] = - x2_data[i2]
self.buffer[ix2] = - x2_data[i2]
i2 = i2 + 1
else:
while i1 < x1_end:
ix1 = x1_indices[i1]
self.vec[ix1] = x1_data[i1]
self.buffer[ix1] = x1_data[i1]
i1 = i1 + 1

for i in range(size):
tmp = 0
for j in range(size):
tmp += self.mat[i, j] * self.vec[j]
d += tmp * self.vec[i]
tmp += self.mat[i, j] * self.buffer[j]
d += tmp * self.buffer[i]

return d

Expand Down