Skip to content

Commit 93fa00c

Browse files
Fix error in euclidean_distances when X is float64 and X_norm_squared is float32 (scikit-learn#27624)
Co-authored-by: Julien Jerphanion <git@jjerphan.xyz>
1 parent 5c85b58 commit 93fa00c

File tree

3 files changed

+33
-17
lines changed

3 files changed

+33
-17
lines changed

doc/whats_new/v1.4.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,11 @@ Changelog
342342
:mod:`sklearn.metrics`
343343
......................
344344

345+
- |Fix| computing pairwise distances with :func:`euclidean_distances` no longer
346+
raises an exception when `X` is provided as a `float64` array and
347+
`X_norm_squared` as a `float32` array. :pr:`27624` by
348+
:user:`Jérôme Dockès <jeromedockes>`.
349+
345350
- |Efficiency| Computing pairwise distances via :class:`metrics.DistanceMetric`
346351
for CSR × CSR, Dense × CSR, and CSR × Dense datasets is now 1.5x faster.
347352
:pr:`26765` by :user:`Meekail Zain <micky774>`

sklearn/metrics/pairwise.py

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -356,30 +356,24 @@ def _euclidean_distances(X, Y, X_norm_squared=None, Y_norm_squared=None, squared
356356
float32, norms needs to be recomputed on upcast chunks.
357357
TODO: use a float64 accumulator in row_norms to avoid the latter.
358358
"""
359-
if X_norm_squared is not None:
360-
if X_norm_squared.dtype == np.float32:
361-
XX = None
362-
else:
363-
XX = X_norm_squared.reshape(-1, 1)
364-
elif X.dtype == np.float32:
365-
XX = None
366-
else:
359+
if X_norm_squared is not None and X_norm_squared.dtype != np.float32:
360+
XX = X_norm_squared.reshape(-1, 1)
361+
elif X.dtype != np.float32:
367362
XX = row_norms(X, squared=True)[:, np.newaxis]
363+
else:
364+
XX = None
368365

369366
if Y is X:
370367
YY = None if XX is None else XX.T
371368
else:
372-
if Y_norm_squared is not None:
373-
if Y_norm_squared.dtype == np.float32:
374-
YY = None
375-
else:
376-
YY = Y_norm_squared.reshape(1, -1)
377-
elif Y.dtype == np.float32:
378-
YY = None
379-
else:
369+
if Y_norm_squared is not None and Y_norm_squared.dtype != np.float32:
370+
YY = Y_norm_squared.reshape(1, -1)
371+
elif Y.dtype != np.float32:
380372
YY = row_norms(Y, squared=True)[np.newaxis, :]
373+
else:
374+
YY = None
381375

382-
if X.dtype == np.float32:
376+
if X.dtype == np.float32 or Y.dtype == np.float32:
383377
# To minimize precision issues with float32, we compute the distance
384378
# matrix on chunks of X and Y upcast to float64
385379
distances = _euclidean_distances_upcast(X, XX, Y, YY)

sklearn/metrics/tests/test_pairwise.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -848,6 +848,23 @@ def test_euclidean_distances_with_norms(global_dtype, y_array_constr):
848848
assert_allclose(wrong_D, D1)
849849

850850

851+
@pytest.mark.parametrize("symmetric", [True, False])
852+
def test_euclidean_distances_float32_norms(global_random_seed, symmetric):
853+
# Non-regression test for #27621
854+
rng = np.random.RandomState(global_random_seed)
855+
X = rng.random_sample((10, 10))
856+
Y = X if symmetric else rng.random_sample((20, 10))
857+
X_norm_sq = (X.astype(np.float32) ** 2).sum(axis=1).reshape(1, -1)
858+
Y_norm_sq = (Y.astype(np.float32) ** 2).sum(axis=1).reshape(1, -1)
859+
D1 = euclidean_distances(X, Y)
860+
D2 = euclidean_distances(X, Y, X_norm_squared=X_norm_sq)
861+
D3 = euclidean_distances(X, Y, Y_norm_squared=Y_norm_sq)
862+
D4 = euclidean_distances(X, Y, X_norm_squared=X_norm_sq, Y_norm_squared=Y_norm_sq)
863+
assert_allclose(D2, D1)
864+
assert_allclose(D3, D1)
865+
assert_allclose(D4, D1)
866+
867+
851868
def test_euclidean_distances_norm_shapes():
852869
# Check all accepted shapes for the norms or appropriate error messages.
853870
rng = np.random.RandomState(0)

0 commit comments

Comments
 (0)