Skip to content

Commit 80673ed

Browse files
author
Fabian Pedregosa
committed
Rename strategy --> algorithm in Neighbors*.
1 parent 62bcf5b commit 80673ed

File tree

3 files changed

+24
-23
lines changed

3 files changed

+24
-23
lines changed

doc/modules/neighbors.rst

+2-2
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ The :class:`NeighborsClassifier` implements the nearest-neighbors
2020
classification method using a vote heuristic: the class most present
2121
in the k nearest neighbors of a point is assigned to this point.
2222

23-
It is possible to use different nearest neighbor search strategies by
24-
using the keyword ``strategy``. Possible values are ``'auto'``,
23+
It is possible to use different nearest neighbor search algorithms by
24+
using the keyword ``algorithm``. Possible values are ``'auto'``,
2525
``'ball_tree'``, ``'brute'`` and ``'brute_inplace'``. ``'ball_tree'``
2626
will create an instance of :class:`BallTree` to conduct the search,
2727
which is usually very efficient in low-dimensional spaces. In higher

scikits/learn/neighbors.py

+16-15
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ class NeighborsClassifier(BaseEstimator, ClassifierMixin):
2222
window_size : int, optional
2323
Window size passed to BallTree
2424
25-
strategy : {'auto', 'ball_tree', 'brute', 'brute_inplace'}, optional
26-
Strategy used to compute the nearest neighbors. 'ball_tree'
25+
algorithm : {'auto', 'ball_tree', 'brute', 'brute_inplace'}, optional
26+
Algorithm used to compute the nearest neighbors. 'ball_tree'
2727
will construct a BallTree, 'brute' and 'brute_inplace' will
2828
perform brute-force search.'auto' will guess the most
2929
appropriate based on current dataset.
@@ -35,7 +35,7 @@ class NeighborsClassifier(BaseEstimator, ClassifierMixin):
3535
>>> from scikits.learn.neighbors import NeighborsClassifier
3636
>>> neigh = NeighborsClassifier(n_neighbors=1)
3737
>>> neigh.fit(samples, labels)
38-
NeighborsClassifier(n_neighbors=1, window_size=1, strategy='auto')
38+
NeighborsClassifier(n_neighbors=1, window_size=1, algorithm='auto')
3939
>>> print neigh.predict([[0,0,0]])
4040
[1]
4141
@@ -48,10 +48,10 @@ class NeighborsClassifier(BaseEstimator, ClassifierMixin):
4848
http://en.wikipedia.org/wiki/K-nearest_neighbor_algorithm
4949
"""
5050

51-
def __init__(self, n_neighbors=5, strategy='auto', window_size=1):
51+
def __init__(self, n_neighbors=5, algorithm='auto', window_size=1):
5252
self.n_neighbors = n_neighbors
5353
self.window_size = window_size
54-
self.strategy = strategy
54+
self.algorithm = algorithm
5555

5656

5757
def fit(self, X, Y, **params):
@@ -73,8 +73,8 @@ def fit(self, X, Y, **params):
7373
self._y = np.asanyarray(Y)
7474
self._set_params(**params)
7575

76-
if self.strategy == 'ball_tree' or \
77-
(self.strategy == 'auto' and X.shape[1] < 20):
76+
if self.algorithm == 'ball_tree' or \
77+
(self.algorithm == 'auto' and X.shape[1] < 20):
7878
self.ball_tree = BallTree(X, self.window_size)
7979
else:
8080
self.ball_tree = None
@@ -119,7 +119,7 @@ class from an array representing our data set and ask who's
119119
>>> from scikits.learn.neighbors import NeighborsClassifier
120120
>>> neigh = NeighborsClassifier(n_neighbors=1)
121121
>>> neigh.fit(samples, labels)
122-
NeighborsClassifier(n_neighbors=1, window_size=1, strategy='auto')
122+
NeighborsClassifier(n_neighbors=1, window_size=1, algorithm='auto')
123123
>>> print neigh.kneighbors([1., 1., 1.])
124124
(array([ 0.5]), array([2]))
125125
@@ -160,7 +160,7 @@ def predict(self, X, **params):
160160

161161
# .. get neighbors ..
162162
if self.ball_tree is None:
163-
if self.strategy == 'brute_inplace':
163+
if self.algorithm == 'brute_inplace':
164164
neigh_ind = knn_brute(self._fit_X, X, self.n_neighbors)
165165
else:
166166
from .metrics import euclidean_distances
@@ -203,8 +203,8 @@ class NeighborsRegressor(NeighborsClassifier, RegressorMixin):
203203
mode : {'mean', 'barycenter'}, optional
204204
Weights to apply to labels.
205205
206-
strategy : {'auto', 'ball_tree', 'brute', 'brute_inplace'}, optional
207-
Strategy used to compute the nearest neighbors. 'ball_tree'
206+
algorithm : {'auto', 'ball_tree', 'brute', 'brute_inplace'}, optional
207+
Algorithm used to compute the nearest neighbors. 'ball_tree'
208208
will construct a BallTree, 'brute' and 'brute_inplace' will
209209
perform brute-force search.'auto' will guess the most
210210
appropriate based on current dataset.
@@ -216,7 +216,8 @@ class NeighborsRegressor(NeighborsClassifier, RegressorMixin):
216216
>>> from scikits.learn.neighbors import NeighborsRegressor
217217
>>> neigh = NeighborsRegressor(n_neighbors=2)
218218
>>> neigh.fit(X, y)
219-
NeighborsRegressor(n_neighbors=2, window_size=1, mode='mean', strategy='auto')
219+
NeighborsRegressor(n_neighbors=2, window_size=1, mode='mean',
220+
algorithm='auto')
220221
>>> print neigh.predict([[1.5]])
221222
[ 0.5]
222223
@@ -226,12 +227,12 @@ class NeighborsRegressor(NeighborsClassifier, RegressorMixin):
226227
"""
227228

228229

229-
def __init__(self, n_neighbors=5, mode='mean', strategy='auto',
230+
def __init__(self, n_neighbors=5, mode='mean', algorithm='auto',
230231
window_size=1):
231232
self.n_neighbors = n_neighbors
232233
self.window_size = window_size
233234
self.mode = mode
234-
self.strategy = strategy
235+
self.algorithm = algorithm
235236

236237

237238
def predict(self, X, **params):
@@ -256,7 +257,7 @@ def predict(self, X, **params):
256257

257258
# .. get neighbors ..
258259
if self.ball_tree is None:
259-
if self.strategy == 'brute_inplace':
260+
if self.algorithm == 'brute_inplace':
260261
neigh_ind = knn_brute(self._fit_X, X, self.n_neighbors)
261262
else:
262263
from .metrics.pairwise import euclidean_distances

scikits/learn/tests/test_neighbors.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -24,19 +24,19 @@ def test_neighbors_1D():
2424

2525
for s in ('auto', 'ball_tree', 'brute', 'inplace'):
2626
# n_neighbors = 1
27-
knn = neighbors.NeighborsClassifier(n_neighbors=1, strategy=s)
27+
knn = neighbors.NeighborsClassifier(n_neighbors=1, algorithm=s)
2828
knn.fit(X, Y)
2929
test = [[i + 0.01] for i in range(0, n/2)] + \
3030
[[i - 0.01] for i in range(n/2, n)]
3131
assert_array_equal(knn.predict(test), [0]*3 + [1]*3)
3232

3333
# n_neighbors = 2
34-
knn = neighbors.NeighborsClassifier(n_neighbors=2, strategy=s)
34+
knn = neighbors.NeighborsClassifier(n_neighbors=2, algorithm=s)
3535
knn.fit(X, Y)
3636
assert_array_equal(knn.predict(test), [0]*4 + [1]*2)
3737

3838
# n_neighbors = 3
39-
knn = neighbors.NeighborsClassifier(n_neighbors=3, strategy=s)
39+
knn = neighbors.NeighborsClassifier(n_neighbors=3, algorithm=s)
4040
knn.fit(X, Y)
4141
assert_array_equal(knn.predict([[i +0.01] for i in range(0, n/2)]),
4242
[0 for i in range(n/2)])
@@ -54,15 +54,15 @@ def test_neighbors_iris():
5454

5555
for s in ('auto', 'ball_tree', 'brute', 'inplace'):
5656
clf = neighbors.NeighborsClassifier()
57-
clf.fit(iris.data, iris.target, n_neighbors=1, strategy=s)
57+
clf.fit(iris.data, iris.target, n_neighbors=1, algorithm=s)
5858
assert_array_equal(clf.predict(iris.data), iris.target)
5959

60-
clf.fit(iris.data, iris.target, n_neighbors=9, strategy=s)
60+
clf.fit(iris.data, iris.target, n_neighbors=9, algorithm=s)
6161
assert_(np.mean(clf.predict(iris.data)== iris.target) > 0.95)
6262

6363
for m in ('barycenter', 'mean'):
6464
rgs = neighbors.NeighborsRegressor()
65-
rgs.fit(iris.data, iris.target, mode=m, strategy=s)
65+
rgs.fit(iris.data, iris.target, mode=m, algorithm=s)
6666
assert_(np.mean(
6767
rgs.predict(iris.data).round() == iris.target) > 0.95)
6868

0 commit comments

Comments
 (0)