Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions doc/whats_new/v0.24.rst
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,11 @@ Changelog
the data intrinsic dimensionality is too high for tree-based methods.
:pr:`17148` by :user:`Geoffrey Bolmier <gbolmier>`.

- |Fix| :class:`neighbors.BinaryTree`
will raise a `ValueError` when fitting on data array having points with
different dimensions.
:pr:`18691` by :user:`Chiara Marmo <cmarmo>`.

- |Fix| :class:`neighbors.NearestCentroid` with a numerical `shrink_threshold`
will raise a `ValueError` when fitting on data with all constant features.
:pr:`18370` by :user:`Trevor Waite <trewaite>`.
Expand Down
14 changes: 7 additions & 7 deletions sklearn/neighbors/_binary_tree.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -1049,17 +1049,17 @@ cdef class BinaryTree:
def __init__(self, data,
leaf_size=40, metric='minkowski', sample_weight=None, **kwargs):
# validate data
if data.size == 0:
self.data_arr = check_array(data, dtype=DTYPE, order='C')
if self.data_arr.size == 0:
raise ValueError("X is an empty array")

n_samples = self.data_arr.shape[0]
n_features = self.data_arr.shape[1]

if leaf_size < 1:
raise ValueError("leaf_size must be greater than or equal to 1")

n_samples = data.shape[0]
n_features = data.shape[1]

self.data_arr = np.asarray(data, dtype=DTYPE, order='C')
self.leaf_size = leaf_size

self.dist_metric = DistanceMetric.get_metric(metric, **kwargs)
self.euclidean = (self.dist_metric.__class__.__name__
== 'EuclideanDistance')
Expand All @@ -1069,7 +1069,7 @@ cdef class BinaryTree:
raise ValueError('metric {metric} is not valid for '
'{BinaryTree}'.format(metric=metric,
**DOC_DICT))
self.dist_metric._validate_data(data)
self.dist_metric._validate_data(self.data_arr)

# determine number of levels in the tree, and from this
# the number of nodes in the tree. This results in leaf nodes
Expand Down
24 changes: 21 additions & 3 deletions sklearn/neighbors/tests/test_ball_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from sklearn.neighbors._ball_tree import BallTree
from sklearn.neighbors import DistanceMetric
from sklearn.utils import check_random_state
from sklearn.utils.validation import check_array
from sklearn.utils._testing import _convert_container

rng = np.random.RandomState(10)
V_mahalanobis = rng.rand(3, 3)
Expand All @@ -31,22 +33,28 @@


def brute_force_neighbors(X, Y, k, metric, **kwargs):
X, Y = check_array(X), check_array(Y)
D = DistanceMetric.get_metric(metric, **kwargs).pairwise(Y, X)
ind = np.argsort(D, axis=1)[:, :k]
dist = D[np.arange(Y.shape[0])[:, None], ind]
return dist, ind


@pytest.mark.parametrize('metric',
itertools.chain(BOOLEAN_METRICS, DISCRETE_METRICS))
def test_ball_tree_query_metrics(metric):
@pytest.mark.parametrize(
'metric',
itertools.chain(BOOLEAN_METRICS, DISCRETE_METRICS)
)
@pytest.mark.parametrize("array_type", ["list", "array"])
def test_ball_tree_query_metrics(metric, array_type):
rng = check_random_state(0)
if metric in BOOLEAN_METRICS:
X = rng.random_sample((40, 10)).round(0)
Y = rng.random_sample((10, 10)).round(0)
elif metric in DISCRETE_METRICS:
X = (4 * rng.random_sample((40, 10))).round(0)
Y = (4 * rng.random_sample((10, 10))).round(0)
X = _convert_container(X, array_type)
Y = _convert_container(Y, array_type)

k = 5

Expand All @@ -65,3 +73,13 @@ def test_query_haversine():

assert_array_almost_equal(dist1, dist2)
assert_array_almost_equal(ind1, ind2)


def test_array_object_type():
"""Check that we do not accept object dtype array."""
X = np.array([(1, 2, 3), (2, 5), (5, 5, 1, 2)], dtype=object)
with pytest.raises(
ValueError,
match="setting an array element with a sequence"
):
BallTree(X)
15 changes: 15 additions & 0 deletions sklearn/neighbors/tests/test_kd_tree.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,21 @@
import numpy as np
import pytest

from sklearn.neighbors._kd_tree import KDTree

DIMENSION = 3

METRICS = {'euclidean': {},
'manhattan': {},
'chebyshev': {},
'minkowski': dict(p=3)}


def test_array_object_type():
"""Check that we do not accept object dtype array."""
X = np.array([(1, 2, 3), (2, 5), (5, 5, 1, 2)], dtype=object)
with pytest.raises(
ValueError,
match="setting an array element with a sequence"
):
KDTree(X)