diff --git a/doc/whats_new/v1.1.rst b/doc/whats_new/v1.1.rst index e6e309cd5fac4..363a97199d294 100644 --- a/doc/whats_new/v1.1.rst +++ b/doc/whats_new/v1.1.rst @@ -89,6 +89,13 @@ Changelog message when running in a jupyter notebook that is not trusted. :pr:`21316` by `Thomas Fan`_. +:mod:`sklearn.neighbors` +........................ + +- |Fix| :class:`neighbors.KernelDensity` now validates input parameters in `fit` + instead of `__init__`. :pr:`21430` by :user:`Desislava Vasileva ` and + :user:`Lucy Jimenez `. + Code and Documentation Contributors ----------------------------------- diff --git a/sklearn/neighbors/_kde.py b/sklearn/neighbors/_kde.py index 9ae25e611463f..a785fcd86939f 100644 --- a/sklearn/neighbors/_kde.py +++ b/sklearn/neighbors/_kde.py @@ -135,16 +135,6 @@ def __init__( self.leaf_size = leaf_size self.metric_params = metric_params - # run the choose algorithm code so that exceptions will happen here - # we're using clone() in the GenerativeBayes classifier, - # so we can't do this kind of logic in __init__ - self._choose_algorithm(self.algorithm, self.metric) - - if bandwidth <= 0: - raise ValueError("bandwidth must be positive") - if kernel not in VALID_KERNELS: - raise ValueError("invalid kernel: '{0}'".format(kernel)) - def _choose_algorithm(self, algorithm, metric): # given the algorithm string + metric string, choose the optimal # algorithm to compute the result. @@ -188,7 +178,14 @@ def fit(self, X, y=None, sample_weight=None): self : object Returns the instance itself. """ + algorithm = self._choose_algorithm(self.algorithm, self.metric) + + if self.bandwidth <= 0: + raise ValueError("bandwidth must be positive") + if self.kernel not in VALID_KERNELS: + raise ValueError("invalid kernel: '{0}'".format(self.kernel)) + X = self._validate_data(X, order="C", dtype=DTYPE) if sample_weight is not None: diff --git a/sklearn/neighbors/tests/test_kde.py b/sklearn/neighbors/tests/test_kde.py index d4fb775c44826..b2e33019c20e2 100644 --- a/sklearn/neighbors/tests/test_kde.py +++ b/sklearn/neighbors/tests/test_kde.py @@ -107,11 +107,12 @@ def test_kde_algorithm_metric_choice(algorithm, metric): X = rng.randn(10, 2) # 2 features required for haversine dist. Y = rng.randn(10, 2) + kde = KernelDensity(algorithm=algorithm, metric=metric) + if algorithm == "kd_tree" and metric not in KDTree.valid_metrics: with pytest.raises(ValueError): - KernelDensity(algorithm=algorithm, metric=metric) + kde.fit(X) else: - kde = KernelDensity(algorithm=algorithm, metric=metric) kde.fit(X) y_dens = kde.score_samples(Y) assert y_dens.shape == Y.shape[:1] @@ -126,16 +127,17 @@ def test_kde_score(n_samples=100, n_features=3): def test_kde_badargs(): + X = np.random.random((200, 10)) with pytest.raises(ValueError): - KernelDensity(algorithm="blah") + KernelDensity(algorithm="blah").fit(X) with pytest.raises(ValueError): - KernelDensity(bandwidth=0) + KernelDensity(bandwidth=0).fit(X) with pytest.raises(ValueError): - KernelDensity(kernel="blah") + KernelDensity(kernel="blah").fit(X) with pytest.raises(ValueError): - KernelDensity(metric="blah") + KernelDensity(metric="blah").fit(X) with pytest.raises(ValueError): - KernelDensity(algorithm="kd_tree", metric="blah") + KernelDensity(algorithm="kd_tree", metric="blah").fit(X) kde = KernelDensity() with pytest.raises(ValueError): kde.fit(np.random.random((200, 10)), sample_weight=np.random.random((200, 10))) diff --git a/sklearn/tests/test_common.py b/sklearn/tests/test_common.py index 11dc95cc19f3a..084ccadeeef5a 100644 --- a/sklearn/tests/test_common.py +++ b/sklearn/tests/test_common.py @@ -413,7 +413,6 @@ def test_transformers_get_feature_names_out(transformer): "FeatureUnion", "GridSearchCV", "HalvingGridSearchCV", - "KernelDensity", "KernelPCA", "LabelBinarizer", "NuSVC",