Skip to content

Commit 6e02fee

Browse files
rththomasjpfan
authored andcommitted
TST Common tests between KDTree and BallTree (scikit-learn#15148)
1 parent 377462f commit 6e02fee

File tree

3 files changed

+99
-107
lines changed

3 files changed

+99
-107
lines changed

sklearn/neighbors/tests/test_ball_tree.py

Lines changed: 0 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import pickle
21
import itertools
32

43
import numpy as np
@@ -45,27 +44,6 @@ def brute_force_neighbors(X, Y, k, metric, **kwargs):
4544
return dist, ind
4645

4746

48-
@pytest.mark.parametrize('metric', METRICS)
49-
@pytest.mark.parametrize('k', (1, 3, 5))
50-
@pytest.mark.parametrize('dualtree', (True, False))
51-
@pytest.mark.parametrize('breadth_first', (True, False))
52-
def test_ball_tree_query(metric, k, dualtree, breadth_first):
53-
rng = check_random_state(0)
54-
X = rng.random_sample((40, DIMENSION))
55-
Y = rng.random_sample((10, DIMENSION))
56-
57-
kwargs = METRICS[metric]
58-
59-
bt = BallTree(X, leaf_size=1, metric=metric, **kwargs)
60-
dist1, ind1 = bt.query(Y, k, dualtree=dualtree,
61-
breadth_first=breadth_first)
62-
dist2, ind2 = brute_force_neighbors(X, Y, k, metric, **kwargs)
63-
64-
# don't check indices here: if there are any duplicate distances,
65-
# the indices may not match. Distances should not have this problem.
66-
assert_array_almost_equal(dist1, dist2)
67-
68-
6947
@pytest.mark.parametrize('metric',
7048
itertools.chain(BOOLEAN_METRICS, DISCRETE_METRICS))
7149
def test_ball_tree_query_metrics(metric):
@@ -201,37 +179,6 @@ def check_two_point(r, dualtree):
201179
check_two_point(r, dualtree)
202180

203181

204-
def test_ball_tree_pickle():
205-
rng = check_random_state(0)
206-
X = rng.random_sample((10, 3))
207-
208-
bt1 = BallTree(X, leaf_size=1)
209-
# Test if BallTree with callable metric is picklable
210-
bt1_pyfunc = BallTree(X, metric=dist_func, leaf_size=1, p=2)
211-
212-
ind1, dist1 = bt1.query(X)
213-
ind1_pyfunc, dist1_pyfunc = bt1_pyfunc.query(X)
214-
215-
def check_pickle_protocol(protocol):
216-
s = pickle.dumps(bt1, protocol=protocol)
217-
bt2 = pickle.loads(s)
218-
219-
s_pyfunc = pickle.dumps(bt1_pyfunc, protocol=protocol)
220-
bt2_pyfunc = pickle.loads(s_pyfunc)
221-
222-
ind2, dist2 = bt2.query(X)
223-
ind2_pyfunc, dist2_pyfunc = bt2_pyfunc.query(X)
224-
225-
assert_array_almost_equal(ind1, ind2)
226-
assert_array_almost_equal(dist1, dist2)
227-
228-
assert_array_almost_equal(ind1_pyfunc, ind2_pyfunc)
229-
assert_array_almost_equal(dist1_pyfunc, dist2_pyfunc)
230-
231-
assert isinstance(bt2, BallTree)
232-
233-
for protocol in (0, 1, 2):
234-
check_pickle_protocol(protocol)
235182

236183

237184
def test_neighbors_heap(n_pts=5, n_nbrs=10):

sklearn/neighbors/tests/test_kd_tree.py

Lines changed: 3 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,6 @@
1010
from sklearn.utils import check_random_state
1111
from sklearn.utils.testing import assert_allclose
1212

13-
rng = np.random.RandomState(42)
14-
V = rng.random_sample((3, 3))
15-
V = np.dot(V, V.T)
16-
1713
DIMENSION = 3
1814

1915
METRICS = {'euclidean': {},
@@ -22,37 +18,6 @@
2218
'minkowski': dict(p=3)}
2319

2420

25-
def brute_force_neighbors(X, Y, k, metric, **kwargs):
26-
D = DistanceMetric.get_metric(metric, **kwargs).pairwise(Y, X)
27-
ind = np.argsort(D, axis=1)[:, :k]
28-
dist = D[np.arange(Y.shape[0])[:, None], ind]
29-
return dist, ind
30-
31-
32-
def check_neighbors(dualtree, breadth_first, k, metric, X, Y, kwargs):
33-
kdt = KDTree(X, leaf_size=1, metric=metric, **kwargs)
34-
dist1, ind1 = kdt.query(Y, k, dualtree=dualtree,
35-
breadth_first=breadth_first)
36-
dist2, ind2 = brute_force_neighbors(X, Y, k, metric, **kwargs)
37-
38-
# don't check indices here: if there are any duplicate distances,
39-
# the indices may not match. Distances should not have this problem.
40-
assert_array_almost_equal(dist1, dist2)
41-
42-
43-
@pytest.mark.parametrize('metric', METRICS)
44-
@pytest.mark.parametrize('k', (1, 3, 5))
45-
@pytest.mark.parametrize('dualtree', (True, False))
46-
@pytest.mark.parametrize('breadth_first', (True, False))
47-
def test_kd_tree_query(metric, k, dualtree, breadth_first):
48-
rng = check_random_state(0)
49-
X = rng.random_sample((40, DIMENSION))
50-
Y = rng.random_sample((10, DIMENSION))
51-
52-
kwargs = METRICS[metric]
53-
check_neighbors(dualtree, breadth_first, k, metric, X, Y, kwargs)
54-
55-
5621
def test_kd_tree_query_radius(n_samples=100, n_features=10):
5722
rng = check_random_state(0)
5823
X = 2 * rng.random_sample(size=(n_samples, n_features)) - 1
@@ -173,27 +138,9 @@ def test_kd_tree_two_point(dualtree):
173138
assert_array_almost_equal(counts, counts_true)
174139

175140

176-
@pytest.mark.parametrize('protocol', (0, 1, 2))
177-
def test_kd_tree_pickle(protocol):
178-
import pickle
179-
rng = check_random_state(0)
180-
X = rng.random_sample((10, 3))
181-
kdt1 = KDTree(X, leaf_size=1)
182-
ind1, dist1 = kdt1.query(X)
183-
184-
def check_pickle_protocol(protocol):
185-
s = pickle.dumps(kdt1, protocol=protocol)
186-
kdt2 = pickle.loads(s)
187-
ind2, dist2 = kdt2.query(X)
188-
assert_array_almost_equal(ind1, ind2)
189-
assert_array_almost_equal(dist1, dist2)
190-
assert isinstance(kdt2, KDTree)
191-
192-
check_pickle_protocol(protocol)
193-
194-
195141
def test_neighbors_heap(n_pts=5, n_nbrs=10):
196142
heap = NeighborsHeap(n_pts, n_nbrs)
143+
rng = np.random.RandomState(42)
197144

198145
for row in range(n_pts):
199146
d_in = rng.random_sample(2 * n_nbrs).astype(DTYPE, copy=False)
@@ -212,6 +159,7 @@ def test_neighbors_heap(n_pts=5, n_nbrs=10):
212159

213160

214161
def test_node_heap(n_nodes=50):
162+
rng = np.random.RandomState(42)
215163
vals = rng.random_sample(n_nodes).astype(DTYPE, copy=False)
216164

217165
i1 = np.argsort(vals)
@@ -222,6 +170,7 @@ def test_node_heap(n_nodes=50):
222170

223171

224172
def test_simultaneous_sort(n_rows=10, n_pts=201):
173+
rng = np.random.RandomState(42)
225174
dist = rng.random_sample((n_rows, n_pts)).astype(DTYPE, copy=False)
226175
ind = (np.arange(n_pts) + np.zeros((n_rows, 1))).astype(ITYPE, copy=False)
227176

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# License: BSD 3 clause
2+
3+
import pickle
4+
import itertools
5+
6+
import numpy as np
7+
import pytest
8+
9+
from sklearn.neighbors.dist_metrics import DistanceMetric
10+
from sklearn.neighbors.ball_tree import BallTree
11+
from sklearn.neighbors.kd_tree import KDTree
12+
13+
from sklearn.utils import check_random_state
14+
from numpy.testing import assert_array_almost_equal
15+
16+
rng = np.random.RandomState(42)
17+
V_mahalanobis = rng.rand(3, 3)
18+
V_mahalanobis = np.dot(V_mahalanobis, V_mahalanobis.T)
19+
20+
DIMENSION = 3
21+
22+
METRICS = {'euclidean': {},
23+
'manhattan': {},
24+
'minkowski': dict(p=3),
25+
'chebyshev': {},
26+
'seuclidean': dict(V=rng.random_sample(DIMENSION)),
27+
'wminkowski': dict(p=3, w=rng.random_sample(DIMENSION)),
28+
'mahalanobis': dict(V=V_mahalanobis)}
29+
30+
KD_TREE_METRICS = ['euclidean', 'manhattan', 'chebyshev', 'minkowski']
31+
BALL_TREE_METRICS = list(METRICS)
32+
33+
34+
def dist_func(x1, x2, p):
35+
return np.sum((x1 - x2) ** p) ** (1. / p)
36+
37+
38+
def brute_force_neighbors(X, Y, k, metric, **kwargs):
39+
D = DistanceMetric.get_metric(metric, **kwargs).pairwise(Y, X)
40+
ind = np.argsort(D, axis=1)[:, :k]
41+
dist = D[np.arange(Y.shape[0])[:, None], ind]
42+
return dist, ind
43+
44+
45+
@pytest.mark.parametrize(
46+
'Cls, metric',
47+
itertools.chain(
48+
[(KDTree, metric) for metric in KD_TREE_METRICS],
49+
[(BallTree, metric) for metric in BALL_TREE_METRICS]))
50+
@pytest.mark.parametrize('k', (1, 3, 5))
51+
@pytest.mark.parametrize('dualtree', (True, False))
52+
@pytest.mark.parametrize('breadth_first', (True, False))
53+
def test_nn_tree_query(Cls, metric, k, dualtree, breadth_first):
54+
rng = check_random_state(0)
55+
X = rng.random_sample((40, DIMENSION))
56+
Y = rng.random_sample((10, DIMENSION))
57+
58+
kwargs = METRICS[metric]
59+
60+
kdt = Cls(X, leaf_size=1, metric=metric, **kwargs)
61+
dist1, ind1 = kdt.query(Y, k, dualtree=dualtree,
62+
breadth_first=breadth_first)
63+
dist2, ind2 = brute_force_neighbors(X, Y, k, metric, **kwargs)
64+
65+
# don't check indices here: if there are any duplicate distances,
66+
# the indices may not match. Distances should not have this problem.
67+
assert_array_almost_equal(dist1, dist2)
68+
69+
70+
@pytest.mark.parametrize(
71+
"Cls, metric",
72+
[(KDTree, 'euclidean'), (BallTree, 'euclidean'),
73+
(BallTree, dist_func)])
74+
@pytest.mark.parametrize('protocol', (0, 1, 2))
75+
def test_pickle(Cls, metric, protocol):
76+
rng = check_random_state(0)
77+
X = rng.random_sample((10, 3))
78+
79+
if hasattr(metric, '__call__'):
80+
kwargs = {'p': 2}
81+
else:
82+
kwargs = {}
83+
84+
tree1 = Cls(X, leaf_size=1, metric=metric, **kwargs)
85+
86+
ind1, dist1 = tree1.query(X)
87+
88+
s = pickle.dumps(tree1, protocol=protocol)
89+
tree2 = pickle.loads(s)
90+
91+
ind2, dist2 = tree2.query(X)
92+
93+
assert_array_almost_equal(ind1, ind2)
94+
assert_array_almost_equal(dist1, dist2)
95+
96+
assert isinstance(tree2, Cls)

0 commit comments

Comments
 (0)