@@ -72,6 +72,15 @@ def _return_float_dtype(X, Y):
72
72
return X , Y , dtype
73
73
74
74
75
+ def _find_floating_dtype_allow_sparse (X , Y , xp = None ):
76
+ """Find matching floating type, allowing for sparse input."""
77
+ if any ([issparse (X ), issparse (Y )]) or _is_numpy_namespace (xp ):
78
+ X , Y , dtype_float = _return_float_dtype (X , Y )
79
+ else :
80
+ dtype_float = _find_matching_floating_dtype (X , Y , xp = xp )
81
+ return X , Y , dtype_float
82
+
83
+
75
84
def check_pairwise_arrays (
76
85
X ,
77
86
Y ,
@@ -177,10 +186,7 @@ def check_pairwise_arrays(
177
186
ensure_all_finite = _deprecate_force_all_finite (force_all_finite , ensure_all_finite )
178
187
179
188
xp , _ = get_namespace (X , Y )
180
- if any ([issparse (X ), issparse (Y )]) or _is_numpy_namespace (xp ):
181
- X , Y , dtype_float = _return_float_dtype (X , Y )
182
- else :
183
- dtype_float = _find_matching_floating_dtype (X , Y , xp = xp )
189
+ X , Y , dtype_float = _find_floating_dtype_allow_sparse (X , Y , xp = xp )
184
190
185
191
estimator = "check_pairwise_arrays"
186
192
if dtype == "infer_float" :
@@ -433,7 +439,7 @@ def _euclidean_distances(X, Y, X_norm_squared=None, Y_norm_squared=None, squared
433
439
# Ensure that distances between vectors and themselves are set to 0.0.
434
440
# This may not be the case due to floating point rounding errors.
435
441
if X is Y :
436
- _fill_or_add_to_diagonal (distances , 0 , xp = xp , add_value = False )
442
+ distances = _fill_or_add_to_diagonal (distances , 0 , xp = xp , add_value = False )
437
443
438
444
if squared :
439
445
return distances
@@ -1171,7 +1177,7 @@ def cosine_distances(X, Y=None):
1171
1177
if X is Y or Y is None :
1172
1178
# Ensure that distances between vectors and themselves are set to 0.0.
1173
1179
# This may not be the case due to floating point rounding errors.
1174
- _fill_or_add_to_diagonal (S , 0.0 , xp , add_value = False )
1180
+ S = _fill_or_add_to_diagonal (S , 0.0 , xp , add_value = False )
1175
1181
return S
1176
1182
1177
1183
@@ -1943,40 +1949,48 @@ def distance_metrics():
1943
1949
return PAIRWISE_DISTANCE_FUNCTIONS
1944
1950
1945
1951
1946
- def _dist_wrapper (dist_func , dist_matrix , slice_ , * args , ** kwargs ):
1952
+ def _transposed_dist_wrapper (dist_func , dist_matrix , slice_ , * args , ** kwargs ):
1947
1953
"""Write in-place to a slice of a distance matrix."""
1948
- dist_matrix [:, slice_ ] = dist_func (* args , ** kwargs )
1954
+ dist_matrix [slice_ , ... ] = dist_func (* args , ** kwargs ). T
1949
1955
1950
1956
1951
1957
def _parallel_pairwise (X , Y , func , n_jobs , ** kwds ):
1952
1958
"""Break the pairwise matrix in n_jobs even slices
1953
1959
and compute them using multithreading."""
1960
+ xp , _ , device = get_namespace_and_device (X , Y )
1961
+ X , Y , dtype_float = _find_floating_dtype_allow_sparse (X , Y , xp = xp )
1954
1962
1955
1963
if Y is None :
1956
1964
Y = X
1957
- X , Y , dtype = _return_float_dtype (X , Y )
1958
1965
1959
1966
if effective_n_jobs (n_jobs ) == 1 :
1960
1967
return func (X , Y , ** kwds )
1961
1968
1962
1969
# enforce a threading backend to prevent data communication overhead
1963
- fd = delayed (_dist_wrapper )
1964
- ret = np .empty ((X .shape [0 ], Y .shape [0 ]), dtype = dtype , order = "F" )
1970
+ fd = delayed (_transposed_dist_wrapper )
1971
+ # Transpose `ret` such that a given thread writes its ouput to a contiguous chunk.
1972
+ # Note `order` (i.e. F/C-contiguous) is not included in array API standard, see
1973
+ # https://github.com/data-apis/array-api/issues/571 for details.
1974
+ # We assume that currently (April 2025) all array API compatible namespaces
1975
+ # allocate 2D arrays using the C-contiguity convention by default.
1976
+ ret = xp .empty ((X .shape [0 ], Y .shape [0 ]), device = device , dtype = dtype_float ).T
1965
1977
Parallel (backend = "threading" , n_jobs = n_jobs )(
1966
- fd (func , ret , s , X , Y [s ], ** kwds )
1978
+ fd (func , ret , s , X , Y [s , ... ], ** kwds )
1967
1979
for s in gen_even_slices (_num_samples (Y ), effective_n_jobs (n_jobs ))
1968
1980
)
1969
1981
1970
1982
if (X is Y or Y is None ) and func is euclidean_distances :
1971
1983
# zeroing diagonal for euclidean norm.
1972
1984
# TODO: do it also for other norms.
1973
- np . fill_diagonal (ret , 0 )
1985
+ ret = _fill_or_add_to_diagonal (ret , 0 , xp = xp , add_value = False )
1974
1986
1975
- return ret
1987
+ # Transform output back
1988
+ return ret .T
1976
1989
1977
1990
1978
1991
def _pairwise_callable (X , Y , metric , ensure_all_finite = True , ** kwds ):
1979
1992
"""Handle the callable case for pairwise_{distances,kernels}."""
1993
+ xp , _ , device = get_namespace_and_device (X )
1980
1994
X , Y = check_pairwise_arrays (
1981
1995
X ,
1982
1996
Y ,
@@ -1985,16 +1999,28 @@ def _pairwise_callable(X, Y, metric, ensure_all_finite=True, **kwds):
1985
1999
# No input dimension checking done for custom metrics (left to user)
1986
2000
ensure_2d = False ,
1987
2001
)
2002
+ _ , _ , dtype_float = _find_floating_dtype_allow_sparse (X , Y , xp = xp )
2003
+
2004
+ def _get_slice (array , index ):
2005
+ # TODO: below 2 lines can be removed once min scipy >= 1.14. Support for
2006
+ # 1D shapes in scipy sparse arrays (COO, DOK and CSR formats) only
2007
+ # added in 1.14. We must return 2D array until min scipy 1.14.
2008
+ if issparse (array ):
2009
+ return array [[index ], :]
2010
+ # When `metric` is a callable, 1D input arrays allowed, in which case
2011
+ # scalar should be returned.
2012
+ if array .ndim == 1 :
2013
+ return array [index ]
2014
+ else :
2015
+ return array [index , ...]
1988
2016
1989
2017
if X is Y :
1990
2018
# Only calculate metric for upper triangle
1991
- out = np .zeros ((X .shape [0 ], Y .shape [0 ]), dtype = "float" )
2019
+ out = xp .zeros ((X .shape [0 ], Y .shape [0 ]), dtype = dtype_float , device = device )
1992
2020
iterator = itertools .combinations (range (X .shape [0 ]), 2 )
1993
2021
for i , j in iterator :
1994
- # scipy has not yet implemented 1D sparse slices; once implemented this can
1995
- # be removed and `arr[ind]` can be simply used.
1996
- x = X [[i ], :] if issparse (X ) else X [i ]
1997
- y = Y [[j ], :] if issparse (Y ) else Y [j ]
2022
+ x = _get_slice (X , i )
2023
+ y = _get_slice (Y , j )
1998
2024
out [i , j ] = metric (x , y , ** kwds )
1999
2025
2000
2026
# Make symmetric
@@ -2004,20 +2030,16 @@ def _pairwise_callable(X, Y, metric, ensure_all_finite=True, **kwds):
2004
2030
# Calculate diagonal
2005
2031
# NB: nonzero diagonals are allowed for both metrics and kernels
2006
2032
for i in range (X .shape [0 ]):
2007
- # scipy has not yet implemented 1D sparse slices; once implemented this can
2008
- # be removed and `arr[ind]` can be simply used.
2009
- x = X [[i ], :] if issparse (X ) else X [i ]
2033
+ x = _get_slice (X , i )
2010
2034
out [i , i ] = metric (x , x , ** kwds )
2011
2035
2012
2036
else :
2013
2037
# Calculate all cells
2014
- out = np .empty ((X .shape [0 ], Y .shape [0 ]), dtype = "float" )
2038
+ out = xp .empty ((X .shape [0 ], Y .shape [0 ]), dtype = dtype_float )
2015
2039
iterator = itertools .product (range (X .shape [0 ]), range (Y .shape [0 ]))
2016
2040
for i , j in iterator :
2017
- # scipy has not yet implemented 1D sparse slices; once implemented this can
2018
- # be removed and `arr[ind]` can be simply used.
2019
- x = X [[i ], :] if issparse (X ) else X [i ]
2020
- y = Y [[j ], :] if issparse (Y ) else Y [j ]
2041
+ x = _get_slice (X , i )
2042
+ y = _get_slice (Y , j )
2021
2043
out [i , j ] = metric (x , y , ** kwds )
2022
2044
2023
2045
return out
0 commit comments