Skip to content

Commit 70cab6d

Browse files
committed
ENH: support arbitrary dtype in kNN classifiers
Fixes scikit-learn#1224
1 parent f2c52f1 commit 70cab6d

File tree

2 files changed

+16
-7
lines changed

2 files changed

+16
-7
lines changed

sklearn/neighbors/classification.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def predict(self, X):
138138
else:
139139
mode, _ = weighted_mode(pred_labels, weights, axis=1)
140140

141-
return mode.flatten().astype(np.int)
141+
return mode.flatten().astype(self._y.dtype)
142142

143143
def predict_proba(self, X):
144144
"""Return probability estimates for the test data X.
@@ -312,11 +312,11 @@ def predict(self, X):
312312
weights = _get_weights(neigh_dist, self.weights)
313313

314314
if weights is None:
315-
mode = np.asarray([stats.mode(pl)[0] for pl in pred_labels],
316-
dtype=np.int)
315+
mode = np.array([stats.mode(pl)[0] for pl in pred_labels],
316+
dtype=self._y.dtype)
317317
else:
318-
mode = np.asarray([weighted_mode(pl, w)[0]
319-
for (pl, w) in zip(pred_labels, weights)],
320-
dtype=np.int)
318+
mode = np.array([weighted_mode(pl, w)[0]
319+
for (pl, w) in zip(pred_labels, weights)],
320+
dtype=self._y.dtype)
321321

322-
return mode.flatten().astype(np.int)
322+
return mode.flatten()

sklearn/neighbors/tests/test_neighbors.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,7 @@ def test_kneighbors_classifier(n_samples=40,
175175
rng = np.random.RandomState(random_state)
176176
X = 2 * rng.rand(n_samples, n_features) - 1
177177
y = ((X ** 2).sum(axis=1) < .5).astype(np.int)
178+
y_str = y.astype(str)
178179

179180
weight_func = _weight_func
180181

@@ -187,6 +188,10 @@ def test_kneighbors_classifier(n_samples=40,
187188
epsilon = 1e-5 * (2 * rng.rand(1, n_features) - 1)
188189
y_pred = knn.predict(X[:n_test_pts] + epsilon)
189190
assert_array_equal(y_pred, y[:n_test_pts])
191+
# Test prediction with y_str
192+
knn.fit(X, y_str)
193+
y_pred = knn.predict(X[:n_test_pts] + epsilon)
194+
assert_array_equal(y_pred, y_str[:n_test_pts])
190195

191196

192197
def test_kneighbors_classifier_predict_proba():
@@ -219,6 +224,7 @@ def test_radius_neighbors_classifier(n_samples=40,
219224
rng = np.random.RandomState(random_state)
220225
X = 2 * rng.rand(n_samples, n_features) - 1
221226
y = ((X ** 2).sum(axis=1) < .5).astype(np.int)
227+
y_str = y.astype(str)
222228

223229
weight_func = _weight_func
224230

@@ -231,6 +237,9 @@ def test_radius_neighbors_classifier(n_samples=40,
231237
epsilon = 1e-5 * (2 * rng.rand(1, n_features) - 1)
232238
y_pred = neigh.predict(X[:n_test_pts] + epsilon)
233239
assert_array_equal(y_pred, y[:n_test_pts])
240+
neigh.fit(X, y_str)
241+
y_pred = neigh.predict(X[:n_test_pts] + epsilon)
242+
assert_array_equal(y_pred, y_str[:n_test_pts])
234243

235244

236245
def test_radius_neighbors_classifier_when_no_neighbors():

0 commit comments

Comments
 (0)