Skip to content
Merged
26 changes: 25 additions & 1 deletion sklearn/neighbors/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,21 @@ def _fit(self, X, y=None):
):
self._fit_method = "brute"
else:
if self.effective_metric_ in VALID_METRICS["kd_tree"]:
if (
self.effective_metric_ == "minkowski"
and self.effective_metric_params_.get("w") is not None
):
# Be consistent with scipy 1.8 conventions: in scipy 1.8,
# 'wminkowski' was removed in favor of passing a
# weight vector directly to 'minkowski'.
#
# 'wminkowski' is not part of valid metrics for KDTree but
# the 'minkowski' without weights is.
#
# Hence, we detect this case and choose BallTree
# which supports 'wminkowski'.
self._fit_method = "ball_tree"
elif self.effective_metric_ in VALID_METRICS["kd_tree"]:
self._fit_method = "kd_tree"
elif (
callable(self.effective_metric_)
Expand All @@ -553,6 +567,16 @@ def _fit(self, X, y=None):
**self.effective_metric_params_,
)
elif self._fit_method == "kd_tree":
if (
self.effective_metric_ == "minkowski"
and self.effective_metric_params_.get("w") is not None
):
raise ValueError(
"algorithm='kd_tree' is not valid for "
"metric='minkowski' with a weight parameter 'w': "
"try algorithm='ball_tree' "
"or algorithm='brute' instead."
)
self._tree = KDTree(
X,
self.leaf_size,
Expand Down
124 changes: 77 additions & 47 deletions sklearn/neighbors/tests/test_neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from sklearn.metrics.pairwise import pairwise_distances
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import train_test_split
from sklearn.neighbors import VALID_METRICS_SPARSE, VALID_METRICS
from sklearn.neighbors import VALID_METRICS_SPARSE
from sklearn.neighbors._base import _is_sorted_by_data, _check_precomputed
from sklearn.pipeline import make_pipeline
from sklearn.utils._testing import assert_array_almost_equal
Expand Down Expand Up @@ -58,6 +58,46 @@
neighbors.radius_neighbors_graph = ignore_warnings(neighbors.radius_neighbors_graph)


def _generate_test_params_for(metric: str, n_features: int):
"""Return list of dummy DistanceMetric kwargs for tests."""

# Distinguishing on cases not to compute unneeded datastructures.
rng = np.random.RandomState(1)
weights = rng.random_sample(n_features)

if metric == "minkowski":
minkowski_kwargs = [dict(p=1.5), dict(p=2), dict(p=3), dict(p=np.inf)]
if sp_version >= parse_version("1.8.0.dev0"):
# TODO: remove the test once we no longer support scipy < 1.8.0.
# Recent scipy versions accept weights in the Minkowski metric directly:
# type: ignore
minkowski_kwargs.append(dict(p=3, w=rng.rand(n_features)))
return minkowski_kwargs

# TODO: remove this case for "wminkowski" once we no longer support scipy < 1.8.0.
if metric == "wminkowski":
weights /= weights.sum()
wminkowski_kwargs = [dict(p=1.5, w=weights)]
if sp_version < parse_version("1.8.0.dev0"):
# wminkowski was removed in scipy 1.8.0 but should work for previous
# versions.
wminkowski_kwargs.append(dict(p=3, w=rng.rand(n_features)))
return wminkowski_kwargs

if metric == "seuclidean":
return [dict(V=rng.rand(n_features))]

if metric == "mahalanobis":
A = rng.rand(n_features, n_features)
# Make the matrix symmetric positive definite
VI = A + A.T + 3 * np.eye(n_features)
return [dict(VI=VI)]

# Case of: "euclidean", "manhattan", "chebyshev", "haversine" or any other metric.
# In those cases, no kwargs are needed.
return [{}]


def _weight_func(dist):
"""Weight function to replace lambda d: d ** -2.
The lambda function is not valid because:
Expand Down Expand Up @@ -1385,58 +1425,48 @@ def custom_metric(x1, x2):

# TODO: Remove filterwarnings in 1.3 when wminkowski is removed
@pytest.mark.filterwarnings("ignore:WMinkowskiDistance:FutureWarning:sklearn")
def test_valid_brute_metric_for_auto_algorithm():
X = rng.rand(12, 12)
@pytest.mark.parametrize("metric", neighbors.VALID_METRICS["brute"])
def test_valid_brute_metric_for_auto_algorithm(metric, n_samples=20, n_features=12):
# Any valid metric for algorithm="brute" must be a valid for algorithm="auto".
# It's the responsibility of the estimator to select which algorithm is likely
# to be the most efficient from the subset of the algorithm compatible with
# that metric (and params). Worst case is to fallback to algorithm="brute".
X = rng.rand(n_samples, n_features)
Xcsr = csr_matrix(X)

# check that there is a metric that is valid for brute
# but not ball_tree (so we actually test something)
assert "cosine" in VALID_METRICS["brute"]
assert "cosine" not in VALID_METRICS["ball_tree"]
metric_params_list = _generate_test_params_for(metric, n_features)

if metric == "precomputed":
X_precomputed = rng.random_sample((10, 4))
Y_precomputed = rng.random_sample((3, 4))
DXX = metrics.pairwise_distances(X_precomputed, metric="euclidean")
DYX = metrics.pairwise_distances(
Y_precomputed, X_precomputed, metric="euclidean"
)
nb_p = neighbors.NearestNeighbors(n_neighbors=3, metric="precomputed")
nb_p.fit(DXX)
nb_p.kneighbors(DYX)

# Metric which don't required any additional parameter
require_params = ["mahalanobis", "wminkowski", "seuclidean"]
for metric in VALID_METRICS["brute"]:
if metric != "precomputed" and metric not in require_params:
else:
for metric_params in metric_params_list:
nn = neighbors.NearestNeighbors(
n_neighbors=3, algorithm="auto", metric=metric
)
if metric != "haversine":
nn.fit(X)
nn.kneighbors(X)
else:
nn.fit(X[:, :2])
nn.kneighbors(X[:, :2])
elif metric == "precomputed":
X_precomputed = rng.random_sample((10, 4))
Y_precomputed = rng.random_sample((3, 4))
DXX = metrics.pairwise_distances(X_precomputed, metric="euclidean")
DYX = metrics.pairwise_distances(
Y_precomputed, X_precomputed, metric="euclidean"
n_neighbors=3,
algorithm="auto",
metric=metric,
metric_params=metric_params,
)
nb_p = neighbors.NearestNeighbors(n_neighbors=3)
nb_p.fit(DXX)
nb_p.kneighbors(DYX)
# Haversine distance only accepts 2D data
if metric == "haversine":
X = np.ascontiguousarray(X[:, :2])

for metric in VALID_METRICS_SPARSE["brute"]:
if metric != "precomputed" and metric not in require_params:
nn = neighbors.NearestNeighbors(
n_neighbors=3, algorithm="auto", metric=metric
).fit(Xcsr)
nn.kneighbors(Xcsr)

# Metric with parameter
VI = np.dot(X, X.T)
list_metrics = [
("seuclidean", dict(V=rng.rand(12))),
("wminkowski", dict(w=rng.rand(12))),
("mahalanobis", dict(VI=VI)),
]
for metric, params in list_metrics:
nn = neighbors.NearestNeighbors(
n_neighbors=3, algorithm="auto", metric=metric, metric_params=params
).fit(X)
nn.kneighbors(X)
nn.fit(X)
nn.kneighbors(X)

if metric in VALID_METRICS_SPARSE["brute"]:
nn = neighbors.NearestNeighbors(
n_neighbors=3, algorithm="auto", metric=metric
).fit(Xcsr)
nn.kneighbors(Xcsr)


def test_metric_params_interface():
Expand Down