Skip to content

Commit 0c9b984

Browse files
jeremiedbbjnothman
authored andcommitted
Some fixes for parallel pairwise distances when n_jobs > 1 (scikit-learn#13877)
1 parent 2993b7f commit 0c9b984

File tree

3 files changed

+47
-28
lines changed

3 files changed

+47
-28
lines changed

doc/whats_new/v0.21.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,12 @@ Changelog
2323
``Y == None``.
2424
:issue:`13864` by :user:`Paresh Mathur <rick2047>`.
2525

26+
- |Fix| Fixed two bugs in :class:`metrics.pairwise_distances` when
27+
``n_jobs > 1``. First it used to return a distance matrix with same dtype as
28+
input, even for integer dtype. Then the diagonal was not zeros for euclidean
29+
metric when ``Y`` is ``X``. :issue:`13877` by
30+
:user:`Jérémie du Boisberranger <jeremiedbb>`.
31+
2632
:mod:`sklearn.neighbors`
2733
......................
2834

sklearn/metrics/pairwise.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1198,17 +1198,23 @@ def _parallel_pairwise(X, Y, func, n_jobs, **kwds):
11981198

11991199
if Y is None:
12001200
Y = X
1201+
X, Y, dtype = _return_float_dtype(X, Y)
12011202

12021203
if effective_n_jobs(n_jobs) == 1:
12031204
return func(X, Y, **kwds)
12041205

12051206
# enforce a threading backend to prevent data communication overhead
12061207
fd = delayed(_dist_wrapper)
1207-
ret = np.empty((X.shape[0], Y.shape[0]), dtype=X.dtype, order='F')
1208+
ret = np.empty((X.shape[0], Y.shape[0]), dtype=dtype, order='F')
12081209
Parallel(backend="threading", n_jobs=n_jobs)(
12091210
fd(func, ret, s, X, Y[s], **kwds)
12101211
for s in gen_even_slices(_num_samples(Y), effective_n_jobs(n_jobs)))
12111212

1213+
if (X is Y or Y is None) and func is euclidean_distances:
1214+
# zeroing diagonal for euclidean norm.
1215+
# TODO: do it also for other norms.
1216+
np.fill_diagonal(ret, 0)
1217+
12121218
return ret
12131219

12141220

sklearn/metrics/tests/test_pairwise.py

Lines changed: 34 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -231,31 +231,6 @@ def test_pairwise_precomputed_non_negative():
231231
metric='precomputed')
232232

233233

234-
def check_pairwise_parallel(func, metric, kwds):
235-
rng = np.random.RandomState(0)
236-
for make_data in (np.array, csr_matrix):
237-
X = make_data(rng.random_sample((5, 4)))
238-
Y = make_data(rng.random_sample((3, 4)))
239-
240-
try:
241-
S = func(X, metric=metric, n_jobs=1, **kwds)
242-
except (TypeError, ValueError) as exc:
243-
# Not all metrics support sparse input
244-
# ValueError may be triggered by bad callable
245-
if make_data is csr_matrix:
246-
assert_raises(type(exc), func, X, metric=metric,
247-
n_jobs=2, **kwds)
248-
continue
249-
else:
250-
raise
251-
S2 = func(X, metric=metric, n_jobs=2, **kwds)
252-
assert_array_almost_equal(S, S2)
253-
254-
S = func(X, Y, metric=metric, n_jobs=1, **kwds)
255-
S2 = func(X, Y, metric=metric, n_jobs=2, **kwds)
256-
assert_array_almost_equal(S, S2)
257-
258-
259234
_wminkowski_kwds = {'w': np.arange(1, 5).astype('double', copy=False), 'p': 1}
260235

261236

@@ -272,8 +247,30 @@ def callable_rbf_kernel(x, y, **kwds):
272247
(pairwise_distances, 'wminkowski', _wminkowski_kwds),
273248
(pairwise_kernels, 'polynomial', {'degree': 1}),
274249
(pairwise_kernels, callable_rbf_kernel, {'gamma': .1})])
275-
def test_pairwise_parallel(func, metric, kwds):
276-
check_pairwise_parallel(func, metric, kwds)
250+
@pytest.mark.parametrize('array_constr', [np.array, csr_matrix])
251+
@pytest.mark.parametrize('dtype', [np.float64, int])
252+
def test_pairwise_parallel(func, metric, kwds, array_constr, dtype):
253+
rng = np.random.RandomState(0)
254+
X = array_constr(5 * rng.random_sample((5, 4)), dtype=dtype)
255+
Y = array_constr(5 * rng.random_sample((3, 4)), dtype=dtype)
256+
257+
try:
258+
S = func(X, metric=metric, n_jobs=1, **kwds)
259+
except (TypeError, ValueError) as exc:
260+
# Not all metrics support sparse input
261+
# ValueError may be triggered by bad callable
262+
if array_constr is csr_matrix:
263+
with pytest.raises(type(exc)):
264+
func(X, metric=metric, n_jobs=2, **kwds)
265+
return
266+
else:
267+
raise
268+
S2 = func(X, metric=metric, n_jobs=2, **kwds)
269+
assert_allclose(S, S2)
270+
271+
S = func(X, Y, metric=metric, n_jobs=1, **kwds)
272+
S2 = func(X, Y, metric=metric, n_jobs=2, **kwds)
273+
assert_allclose(S, S2)
277274

278275

279276
def test_pairwise_callable_nonstrict_metric():
@@ -551,6 +548,16 @@ def test_pairwise_distances_chunked_diagonal(metric):
551548
assert_array_almost_equal(np.diag(np.vstack(chunks)), 0, decimal=10)
552549

553550

551+
@pytest.mark.parametrize(
552+
'metric',
553+
('euclidean', 'l2', 'sqeuclidean'))
554+
def test_parallel_pairwise_distances_diagonal(metric):
555+
rng = np.random.RandomState(0)
556+
X = rng.normal(size=(1000, 10), scale=1e10)
557+
distances = pairwise_distances(X, metric=metric, n_jobs=2)
558+
assert_allclose(np.diag(distances), 0, atol=1e-10)
559+
560+
554561
@ignore_warnings
555562
def test_pairwise_distances_chunked():
556563
# Test the pairwise_distance helper function.

0 commit comments

Comments
 (0)