From 43d860cf95daf100563c63554aa087f291b79130 Mon Sep 17 00:00:00 2001 From: Omar Salman Date: Fri, 14 Jun 2024 13:02:23 +0500 Subject: [PATCH] Fix array api integration for additive_chi2_kernel with torch mps --- sklearn/metrics/pairwise.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sklearn/metrics/pairwise.py b/sklearn/metrics/pairwise.py index b0e4b6f2a9738..afc046e588708 100644 --- a/sklearn/metrics/pairwise.py +++ b/sklearn/metrics/pairwise.py @@ -30,6 +30,7 @@ from ..utils._array_api import ( _find_matching_floating_dtype, _is_numpy_namespace, + device, get_namespace, ) from ..utils._chunking import get_chunk_n_rows @@ -1763,12 +1764,13 @@ def additive_chi2_kernel(X, Y=None): return result else: dtype = _find_matching_floating_dtype(X, Y, xp=xp) + device_ = device(X, Y) xb = X[:, None, :] yb = Y[None, :, :] nom = -((xb - yb) ** 2) denom = xb + yb - nom = xp.where(denom == 0, xp.asarray(0, dtype=dtype), nom) - denom = xp.where(denom == 0, xp.asarray(1, dtype=dtype), denom) + nom = xp.where(denom == 0, xp.asarray(0, dtype=dtype, device=device_), nom) + denom = xp.where(denom == 0, xp.asarray(1, dtype=dtype, device=device_), denom) return xp.sum(nom / denom, axis=2)