Skip to content

Commit cccf7b4

Browse files
EmilyXinyilucyleeowogrisel
authored
Array API support for pairwise kernels (scikit-learn#29822)
Co-authored-by: Lucy Liu <jliu176@gmail.com> Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
1 parent f27a26d commit cccf7b4

File tree

6 files changed

+171
-29
lines changed

6 files changed

+171
-29
lines changed

doc/modules/array_api.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ Metrics
158158
- :func:`sklearn.metrics.pairwise.linear_kernel`
159159
- :func:`sklearn.metrics.pairwise.paired_cosine_distances`
160160
- :func:`sklearn.metrics.pairwise.paired_euclidean_distances`
161+
- :func:`sklearn.metrics.pairwise.pairwise_kernels` (supports all metrics except :func:`sklearn.metrics.pairwise.laplacian_kernel`)
161162
- :func:`sklearn.metrics.pairwise.polynomial_kernel`
162163
- :func:`sklearn.metrics.pairwise.rbf_kernel` (see :ref:`device_support_for_float64`)
163164
- :func:`sklearn.metrics.pairwise.sigmoid_kernel`
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
- :func:`metrics.pairwise.pairwise_kernels` now supports Array API
2+
compatible inputs, when the underling `metric` does (the only metric NOT currently
3+
supported is :func:`sklearn.metrics.pairwise.laplacian_kernel`).
4+
By :user:`Emily Chen <EmilyXinyi>` and :user:`Lucy Liu <lucyleeow>`.

sklearn/metrics/pairwise.py

Lines changed: 49 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,15 @@ def _return_float_dtype(X, Y):
7272
return X, Y, dtype
7373

7474

75+
def _find_floating_dtype_allow_sparse(X, Y, xp=None):
76+
"""Find matching floating type, allowing for sparse input."""
77+
if any([issparse(X), issparse(Y)]) or _is_numpy_namespace(xp):
78+
X, Y, dtype_float = _return_float_dtype(X, Y)
79+
else:
80+
dtype_float = _find_matching_floating_dtype(X, Y, xp=xp)
81+
return X, Y, dtype_float
82+
83+
7584
def check_pairwise_arrays(
7685
X,
7786
Y,
@@ -177,10 +186,7 @@ def check_pairwise_arrays(
177186
ensure_all_finite = _deprecate_force_all_finite(force_all_finite, ensure_all_finite)
178187

179188
xp, _ = get_namespace(X, Y)
180-
if any([issparse(X), issparse(Y)]) or _is_numpy_namespace(xp):
181-
X, Y, dtype_float = _return_float_dtype(X, Y)
182-
else:
183-
dtype_float = _find_matching_floating_dtype(X, Y, xp=xp)
189+
X, Y, dtype_float = _find_floating_dtype_allow_sparse(X, Y, xp=xp)
184190

185191
estimator = "check_pairwise_arrays"
186192
if dtype == "infer_float":
@@ -433,7 +439,7 @@ def _euclidean_distances(X, Y, X_norm_squared=None, Y_norm_squared=None, squared
433439
# Ensure that distances between vectors and themselves are set to 0.0.
434440
# This may not be the case due to floating point rounding errors.
435441
if X is Y:
436-
_fill_or_add_to_diagonal(distances, 0, xp=xp, add_value=False)
442+
distances = _fill_or_add_to_diagonal(distances, 0, xp=xp, add_value=False)
437443

438444
if squared:
439445
return distances
@@ -1171,7 +1177,7 @@ def cosine_distances(X, Y=None):
11711177
if X is Y or Y is None:
11721178
# Ensure that distances between vectors and themselves are set to 0.0.
11731179
# This may not be the case due to floating point rounding errors.
1174-
_fill_or_add_to_diagonal(S, 0.0, xp, add_value=False)
1180+
S = _fill_or_add_to_diagonal(S, 0.0, xp, add_value=False)
11751181
return S
11761182

11771183

@@ -1943,40 +1949,48 @@ def distance_metrics():
19431949
return PAIRWISE_DISTANCE_FUNCTIONS
19441950

19451951

1946-
def _dist_wrapper(dist_func, dist_matrix, slice_, *args, **kwargs):
1952+
def _transposed_dist_wrapper(dist_func, dist_matrix, slice_, *args, **kwargs):
19471953
"""Write in-place to a slice of a distance matrix."""
1948-
dist_matrix[:, slice_] = dist_func(*args, **kwargs)
1954+
dist_matrix[slice_, ...] = dist_func(*args, **kwargs).T
19491955

19501956

19511957
def _parallel_pairwise(X, Y, func, n_jobs, **kwds):
19521958
"""Break the pairwise matrix in n_jobs even slices
19531959
and compute them using multithreading."""
1960+
xp, _, device = get_namespace_and_device(X, Y)
1961+
X, Y, dtype_float = _find_floating_dtype_allow_sparse(X, Y, xp=xp)
19541962

19551963
if Y is None:
19561964
Y = X
1957-
X, Y, dtype = _return_float_dtype(X, Y)
19581965

19591966
if effective_n_jobs(n_jobs) == 1:
19601967
return func(X, Y, **kwds)
19611968

19621969
# enforce a threading backend to prevent data communication overhead
1963-
fd = delayed(_dist_wrapper)
1964-
ret = np.empty((X.shape[0], Y.shape[0]), dtype=dtype, order="F")
1970+
fd = delayed(_transposed_dist_wrapper)
1971+
# Transpose `ret` such that a given thread writes its ouput to a contiguous chunk.
1972+
# Note `order` (i.e. F/C-contiguous) is not included in array API standard, see
1973+
# https://github.com/data-apis/array-api/issues/571 for details.
1974+
# We assume that currently (April 2025) all array API compatible namespaces
1975+
# allocate 2D arrays using the C-contiguity convention by default.
1976+
ret = xp.empty((X.shape[0], Y.shape[0]), device=device, dtype=dtype_float).T
19651977
Parallel(backend="threading", n_jobs=n_jobs)(
1966-
fd(func, ret, s, X, Y[s], **kwds)
1978+
fd(func, ret, s, X, Y[s, ...], **kwds)
19671979
for s in gen_even_slices(_num_samples(Y), effective_n_jobs(n_jobs))
19681980
)
19691981

19701982
if (X is Y or Y is None) and func is euclidean_distances:
19711983
# zeroing diagonal for euclidean norm.
19721984
# TODO: do it also for other norms.
1973-
np.fill_diagonal(ret, 0)
1985+
ret = _fill_or_add_to_diagonal(ret, 0, xp=xp, add_value=False)
19741986

1975-
return ret
1987+
# Transform output back
1988+
return ret.T
19761989

19771990

19781991
def _pairwise_callable(X, Y, metric, ensure_all_finite=True, **kwds):
19791992
"""Handle the callable case for pairwise_{distances,kernels}."""
1993+
xp, _, device = get_namespace_and_device(X)
19801994
X, Y = check_pairwise_arrays(
19811995
X,
19821996
Y,
@@ -1985,16 +1999,28 @@ def _pairwise_callable(X, Y, metric, ensure_all_finite=True, **kwds):
19851999
# No input dimension checking done for custom metrics (left to user)
19862000
ensure_2d=False,
19872001
)
2002+
_, _, dtype_float = _find_floating_dtype_allow_sparse(X, Y, xp=xp)
2003+
2004+
def _get_slice(array, index):
2005+
# TODO: below 2 lines can be removed once min scipy >= 1.14. Support for
2006+
# 1D shapes in scipy sparse arrays (COO, DOK and CSR formats) only
2007+
# added in 1.14. We must return 2D array until min scipy 1.14.
2008+
if issparse(array):
2009+
return array[[index], :]
2010+
# When `metric` is a callable, 1D input arrays allowed, in which case
2011+
# scalar should be returned.
2012+
if array.ndim == 1:
2013+
return array[index]
2014+
else:
2015+
return array[index, ...]
19882016

19892017
if X is Y:
19902018
# Only calculate metric for upper triangle
1991-
out = np.zeros((X.shape[0], Y.shape[0]), dtype="float")
2019+
out = xp.zeros((X.shape[0], Y.shape[0]), dtype=dtype_float, device=device)
19922020
iterator = itertools.combinations(range(X.shape[0]), 2)
19932021
for i, j in iterator:
1994-
# scipy has not yet implemented 1D sparse slices; once implemented this can
1995-
# be removed and `arr[ind]` can be simply used.
1996-
x = X[[i], :] if issparse(X) else X[i]
1997-
y = Y[[j], :] if issparse(Y) else Y[j]
2022+
x = _get_slice(X, i)
2023+
y = _get_slice(Y, j)
19982024
out[i, j] = metric(x, y, **kwds)
19992025

20002026
# Make symmetric
@@ -2004,20 +2030,16 @@ def _pairwise_callable(X, Y, metric, ensure_all_finite=True, **kwds):
20042030
# Calculate diagonal
20052031
# NB: nonzero diagonals are allowed for both metrics and kernels
20062032
for i in range(X.shape[0]):
2007-
# scipy has not yet implemented 1D sparse slices; once implemented this can
2008-
# be removed and `arr[ind]` can be simply used.
2009-
x = X[[i], :] if issparse(X) else X[i]
2033+
x = _get_slice(X, i)
20102034
out[i, i] = metric(x, x, **kwds)
20112035

20122036
else:
20132037
# Calculate all cells
2014-
out = np.empty((X.shape[0], Y.shape[0]), dtype="float")
2038+
out = xp.empty((X.shape[0], Y.shape[0]), dtype=dtype_float)
20152039
iterator = itertools.product(range(X.shape[0]), range(Y.shape[0]))
20162040
for i, j in iterator:
2017-
# scipy has not yet implemented 1D sparse slices; once implemented this can
2018-
# be removed and `arr[ind]` can be simply used.
2019-
x = X[[i], :] if issparse(X) else X[i]
2020-
y = Y[[j], :] if issparse(Y) else Y[j]
2041+
x = _get_slice(X, i)
2042+
y = _get_slice(Y, j)
20212043
out[i, j] = metric(x, y, **kwds)
20222044

20232045
return out

sklearn/metrics/tests/test_common.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
linear_kernel,
6666
paired_cosine_distances,
6767
paired_euclidean_distances,
68+
pairwise_kernels,
6869
polynomial_kernel,
6970
rbf_kernel,
7071
sigmoid_kernel,
@@ -2277,6 +2278,7 @@ def check_array_api_metric_pairwise(metric, array_namespace, device, dtype_name)
22772278
check_array_api_regression_metric_multioutput,
22782279
],
22792280
sigmoid_kernel: [check_array_api_metric_pairwise],
2281+
pairwise_kernels: [check_array_api_metric_pairwise],
22802282
roc_curve: [
22812283
check_array_api_binary_classification_metric,
22822284
],

sklearn/metrics/tests/test_pairwise.py

Lines changed: 111 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,15 @@
4848
sigmoid_kernel,
4949
)
5050
from sklearn.preprocessing import normalize
51+
from sklearn.utils._array_api import (
52+
_convert_to_numpy,
53+
_get_namespace_device_dtype_ids,
54+
get_namespace,
55+
xpx,
56+
yield_namespace_device_dtype_combinations,
57+
)
5158
from sklearn.utils._testing import (
59+
_array_api_for_tests,
5260
assert_allclose,
5361
assert_almost_equal,
5462
assert_array_equal,
@@ -295,10 +303,18 @@ def test_pairwise_precomputed_non_negative():
295303

296304

297305
def callable_rbf_kernel(x, y, **kwds):
306+
xp, _ = get_namespace(x, y)
298307
# Callable version of pairwise.rbf_kernel.
299-
K = rbf_kernel(np.atleast_2d(x), np.atleast_2d(y), **kwds)
308+
K = rbf_kernel(
309+
xpx.atleast_nd(x, ndim=2, xp=xp), xpx.atleast_nd(y, ndim=2, xp=xp), **kwds
310+
)
300311
# unpack the output since this is a scalar packed in a 0-dim array
301-
return K.item()
312+
# Note below is array API version of numpys `item()`
313+
if K.ndim > 0:
314+
K_flat = xp.reshape(K, (-1,))
315+
if K_flat.shape == (1,):
316+
return float(K_flat[0])
317+
raise ValueError("can only convert an array of size 1 to a Python scalar")
302318

303319

304320
@pytest.mark.parametrize(
@@ -334,6 +350,53 @@ def test_pairwise_parallel(func, metric, kwds, dtype):
334350
assert_allclose(S, S2)
335351

336352

353+
@pytest.mark.parametrize(
354+
"array_namespace, device, dtype_name",
355+
yield_namespace_device_dtype_combinations(),
356+
ids=_get_namespace_device_dtype_ids,
357+
)
358+
@pytest.mark.parametrize(
359+
"func, metric, kwds",
360+
[
361+
(pairwise_distances, "euclidean", {}),
362+
(pairwise_kernels, "polynomial", {"degree": 1}),
363+
(pairwise_kernels, callable_rbf_kernel, {"gamma": 0.1}),
364+
],
365+
)
366+
def test_pairwise_parallel_array_api(
367+
func, metric, kwds, array_namespace, device, dtype_name
368+
):
369+
xp = _array_api_for_tests(array_namespace, device)
370+
rng = np.random.RandomState(0)
371+
# Why 5 and not more? this seems to still result in a lot of 0 vaules?
372+
X_np = np.array(5 * rng.random_sample((5, 4)), dtype=dtype_name)
373+
Y_np = np.array(5 * rng.random_sample((3, 4)), dtype=dtype_name)
374+
X_xp = xp.asarray(X_np, device=device)
375+
Y_xp = xp.asarray(Y_np, device=device)
376+
377+
with config_context(array_api_dispatch=True):
378+
for y_val in (None, "not none"):
379+
Y_xp = None if y_val is None else Y_xp
380+
Y_np = None if y_val is None else Y_np
381+
382+
n_job1_xp = func(X_xp, Y_xp, metric=metric, n_jobs=1, **kwds)
383+
n_job1_xp_np = _convert_to_numpy(n_job1_xp, xp=xp)
384+
assert get_namespace(n_job1_xp)[0].__name__ == xp.__name__
385+
assert n_job1_xp.device == X_xp.device
386+
assert n_job1_xp.dtype == X_xp.dtype
387+
388+
n_job2_xp = func(X_xp, Y_xp, metric=metric, n_jobs=2, **kwds)
389+
n_job2_xp_np = _convert_to_numpy(n_job2_xp, xp=xp)
390+
assert get_namespace(n_job2_xp)[0].__name__ == xp.__name__
391+
assert n_job2_xp.device == X_xp.device
392+
assert n_job2_xp.dtype == X_xp.dtype
393+
394+
n_job2_np = func(X_np, metric=metric, n_jobs=2, **kwds)
395+
396+
assert_allclose(n_job1_xp_np, n_job2_xp_np)
397+
assert_allclose(n_job2_xp_np, n_job2_np)
398+
399+
337400
def test_pairwise_callable_nonstrict_metric():
338401
# paired_distances should allow callable metric where metric(x, x) != 0
339402
# Knowing that the callable is a strict metric would allow the diagonal to
@@ -378,6 +441,52 @@ def test_pairwise_kernels(metric, csr_container):
378441
assert_allclose(K1, K2)
379442

380443

444+
@pytest.mark.parametrize(
445+
"array_namespace, device, dtype_name",
446+
yield_namespace_device_dtype_combinations(),
447+
ids=_get_namespace_device_dtype_ids,
448+
)
449+
@pytest.mark.parametrize(
450+
"metric",
451+
["rbf", "sigmoid", "polynomial", "linear", "chi2", "additive_chi2"],
452+
)
453+
@pytest.mark.parametrize("csr_container", CSR_CONTAINERS)
454+
def test_pairwise_kernels_array_api(
455+
metric, csr_container, array_namespace, device, dtype_name
456+
):
457+
# Test array API support in pairwise_kernels.
458+
xp = _array_api_for_tests(array_namespace, device)
459+
460+
rng = np.random.RandomState(0)
461+
X_np = 10 * rng.random_sample((5, 4))
462+
X_np = X_np.astype(dtype_name, copy=False)
463+
Y_np = 10 * rng.random_sample((2, 4))
464+
Y_np = Y_np.astype(dtype_name, copy=False)
465+
X_xp = xp.asarray(X_np, device=device)
466+
Y_xp = xp.asarray(Y_np, device=device)
467+
468+
with config_context(array_api_dispatch=True):
469+
# Test with Y=None
470+
K_xp = pairwise_kernels(X_xp, metric=metric)
471+
K_xp_np = _convert_to_numpy(K_xp, xp=xp)
472+
assert get_namespace(K_xp)[0].__name__ == xp.__name__
473+
assert K_xp.device == X_xp.device
474+
assert K_xp.dtype == X_xp.dtype
475+
476+
K_np = pairwise_kernels(X_np, metric=metric)
477+
assert_allclose(K_xp_np, K_np)
478+
479+
# Test with Y=Y_np/Y_xp
480+
K_xp = pairwise_kernels(X_xp, Y=Y_xp, metric=metric)
481+
K_xp_np = _convert_to_numpy(K_xp, xp=xp)
482+
assert get_namespace(K_xp)[0].__name__ == xp.__name__
483+
assert K_xp.device == X_xp.device
484+
assert K_xp.dtype == X_xp.dtype
485+
486+
K_np = pairwise_kernels(X_np, Y=Y_np, metric=metric)
487+
assert_allclose(K_xp_np, K_np)
488+
489+
381490
def test_pairwise_kernels_callable():
382491
# Test the pairwise_kernels helper function
383492
# with a callable function, with given keywords.

sklearn/utils/_array_api.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -557,6 +557,10 @@ def _fill_or_add_to_diagonal(array, value, xp, add_value=True, wrap=False):
557557
array_flat[:end:step] += value
558558
else:
559559
array_flat[:end:step] = value
560+
# `array_flat` is not always a view on `array` (e.g. for certain array types that
561+
# were filled via parallel processing i.e., in `_parallel_pairwise`), thus we need
562+
# to return reshaped `array_flat`.
563+
return xp.reshape(array_flat, array.shape)
560564

561565

562566
def _is_xp_namespace(xp, name):

0 commit comments

Comments
 (0)