@@ -231,31 +231,6 @@ def test_pairwise_precomputed_non_negative():
231
231
metric = 'precomputed' )
232
232
233
233
234
- def check_pairwise_parallel (func , metric , kwds ):
235
- rng = np .random .RandomState (0 )
236
- for make_data in (np .array , csr_matrix ):
237
- X = make_data (rng .random_sample ((5 , 4 )))
238
- Y = make_data (rng .random_sample ((3 , 4 )))
239
-
240
- try :
241
- S = func (X , metric = metric , n_jobs = 1 , ** kwds )
242
- except (TypeError , ValueError ) as exc :
243
- # Not all metrics support sparse input
244
- # ValueError may be triggered by bad callable
245
- if make_data is csr_matrix :
246
- assert_raises (type (exc ), func , X , metric = metric ,
247
- n_jobs = 2 , ** kwds )
248
- continue
249
- else :
250
- raise
251
- S2 = func (X , metric = metric , n_jobs = 2 , ** kwds )
252
- assert_array_almost_equal (S , S2 )
253
-
254
- S = func (X , Y , metric = metric , n_jobs = 1 , ** kwds )
255
- S2 = func (X , Y , metric = metric , n_jobs = 2 , ** kwds )
256
- assert_array_almost_equal (S , S2 )
257
-
258
-
259
234
_wminkowski_kwds = {'w' : np .arange (1 , 5 ).astype ('double' , copy = False ), 'p' : 1 }
260
235
261
236
@@ -272,8 +247,30 @@ def callable_rbf_kernel(x, y, **kwds):
272
247
(pairwise_distances , 'wminkowski' , _wminkowski_kwds ),
273
248
(pairwise_kernels , 'polynomial' , {'degree' : 1 }),
274
249
(pairwise_kernels , callable_rbf_kernel , {'gamma' : .1 })])
275
- def test_pairwise_parallel (func , metric , kwds ):
276
- check_pairwise_parallel (func , metric , kwds )
250
+ @pytest .mark .parametrize ('array_constr' , [np .array , csr_matrix ])
251
+ @pytest .mark .parametrize ('dtype' , [np .float64 , int ])
252
+ def test_pairwise_parallel (func , metric , kwds , array_constr , dtype ):
253
+ rng = np .random .RandomState (0 )
254
+ X = array_constr (5 * rng .random_sample ((5 , 4 )), dtype = dtype )
255
+ Y = array_constr (5 * rng .random_sample ((3 , 4 )), dtype = dtype )
256
+
257
+ try :
258
+ S = func (X , metric = metric , n_jobs = 1 , ** kwds )
259
+ except (TypeError , ValueError ) as exc :
260
+ # Not all metrics support sparse input
261
+ # ValueError may be triggered by bad callable
262
+ if array_constr is csr_matrix :
263
+ with pytest .raises (type (exc )):
264
+ func (X , metric = metric , n_jobs = 2 , ** kwds )
265
+ return
266
+ else :
267
+ raise
268
+ S2 = func (X , metric = metric , n_jobs = 2 , ** kwds )
269
+ assert_allclose (S , S2 )
270
+
271
+ S = func (X , Y , metric = metric , n_jobs = 1 , ** kwds )
272
+ S2 = func (X , Y , metric = metric , n_jobs = 2 , ** kwds )
273
+ assert_allclose (S , S2 )
277
274
278
275
279
276
def test_pairwise_callable_nonstrict_metric ():
@@ -551,6 +548,16 @@ def test_pairwise_distances_chunked_diagonal(metric):
551
548
assert_array_almost_equal (np .diag (np .vstack (chunks )), 0 , decimal = 10 )
552
549
553
550
551
+ @pytest .mark .parametrize (
552
+ 'metric' ,
553
+ ('euclidean' , 'l2' , 'sqeuclidean' ))
554
+ def test_parallel_pairwise_distances_diagonal (metric ):
555
+ rng = np .random .RandomState (0 )
556
+ X = rng .normal (size = (1000 , 10 ), scale = 1e10 )
557
+ distances = pairwise_distances (X , metric = metric , n_jobs = 2 )
558
+ assert_allclose (np .diag (distances ), 0 , atol = 1e-10 )
559
+
560
+
554
561
@ignore_warnings
555
562
def test_pairwise_distances_chunked ():
556
563
# Test the pairwise_distance helper function.
0 commit comments