From a6639aa102c23ca79f947a28532592c2965b65b2 Mon Sep 17 00:00:00 2001 From: DessyVV Date: Sat, 23 Oct 2021 20:03:36 +0300 Subject: [PATCH 1/5] Accelarate KDE test. --- sklearn/neighbors/_kde.py | 17 +++++++---------- sklearn/neighbors/tests/test_kde.py | 5 +++-- sklearn/tests/test_common.py | 1 - 3 files changed, 10 insertions(+), 13 deletions(-) diff --git a/sklearn/neighbors/_kde.py b/sklearn/neighbors/_kde.py index 9ae25e611463f..a785fcd86939f 100644 --- a/sklearn/neighbors/_kde.py +++ b/sklearn/neighbors/_kde.py @@ -135,16 +135,6 @@ def __init__( self.leaf_size = leaf_size self.metric_params = metric_params - # run the choose algorithm code so that exceptions will happen here - # we're using clone() in the GenerativeBayes classifier, - # so we can't do this kind of logic in __init__ - self._choose_algorithm(self.algorithm, self.metric) - - if bandwidth <= 0: - raise ValueError("bandwidth must be positive") - if kernel not in VALID_KERNELS: - raise ValueError("invalid kernel: '{0}'".format(kernel)) - def _choose_algorithm(self, algorithm, metric): # given the algorithm string + metric string, choose the optimal # algorithm to compute the result. @@ -188,7 +178,14 @@ def fit(self, X, y=None, sample_weight=None): self : object Returns the instance itself. """ + algorithm = self._choose_algorithm(self.algorithm, self.metric) + + if self.bandwidth <= 0: + raise ValueError("bandwidth must be positive") + if self.kernel not in VALID_KERNELS: + raise ValueError("invalid kernel: '{0}'".format(self.kernel)) + X = self._validate_data(X, order="C", dtype=DTYPE) if sample_weight is not None: diff --git a/sklearn/neighbors/tests/test_kde.py b/sklearn/neighbors/tests/test_kde.py index d4fb775c44826..f29ba27d47cde 100644 --- a/sklearn/neighbors/tests/test_kde.py +++ b/sklearn/neighbors/tests/test_kde.py @@ -107,11 +107,12 @@ def test_kde_algorithm_metric_choice(algorithm, metric): X = rng.randn(10, 2) # 2 features required for haversine dist. Y = rng.randn(10, 2) + kde = KernelDensity(algorithm=algorithm, metric=metric) + if algorithm == "kd_tree" and metric not in KDTree.valid_metrics: with pytest.raises(ValueError): - KernelDensity(algorithm=algorithm, metric=metric) + kde.fit(X) else: - kde = KernelDensity(algorithm=algorithm, metric=metric) kde.fit(X) y_dens = kde.score_samples(Y) assert y_dens.shape == Y.shape[:1] diff --git a/sklearn/tests/test_common.py b/sklearn/tests/test_common.py index 11dc95cc19f3a..084ccadeeef5a 100644 --- a/sklearn/tests/test_common.py +++ b/sklearn/tests/test_common.py @@ -413,7 +413,6 @@ def test_transformers_get_feature_names_out(transformer): "FeatureUnion", "GridSearchCV", "HalvingGridSearchCV", - "KernelDensity", "KernelPCA", "LabelBinarizer", "NuSVC", From c486eb8d808a9c7b52e68a1ac21e8bbe549204ec Mon Sep 17 00:00:00 2001 From: DessyVV Date: Sat, 23 Oct 2021 20:54:07 +0300 Subject: [PATCH 2/5] Change test_kde_badargs to test the fit method. --- sklearn/neighbors/tests/test_kde.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/sklearn/neighbors/tests/test_kde.py b/sklearn/neighbors/tests/test_kde.py index f29ba27d47cde..b2e33019c20e2 100644 --- a/sklearn/neighbors/tests/test_kde.py +++ b/sklearn/neighbors/tests/test_kde.py @@ -127,16 +127,17 @@ def test_kde_score(n_samples=100, n_features=3): def test_kde_badargs(): + X = np.random.random((200, 10)) with pytest.raises(ValueError): - KernelDensity(algorithm="blah") + KernelDensity(algorithm="blah").fit(X) with pytest.raises(ValueError): - KernelDensity(bandwidth=0) + KernelDensity(bandwidth=0).fit(X) with pytest.raises(ValueError): - KernelDensity(kernel="blah") + KernelDensity(kernel="blah").fit(X) with pytest.raises(ValueError): - KernelDensity(metric="blah") + KernelDensity(metric="blah").fit(X) with pytest.raises(ValueError): - KernelDensity(algorithm="kd_tree", metric="blah") + KernelDensity(algorithm="kd_tree", metric="blah").fit(X) kde = KernelDensity() with pytest.raises(ValueError): kde.fit(np.random.random((200, 10)), sample_weight=np.random.random((200, 10))) From 474a445970596322408211fe0efcba7ee5517bbd Mon Sep 17 00:00:00 2001 From: DessyVV Date: Sun, 24 Oct 2021 09:21:05 +0300 Subject: [PATCH 3/5] Give co-authorship to @LucyJimenez MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-Authored-By: Lucy Jiménez lucy.jimenez.chem@gmail.com From 33477bc1d6422b11827589d717088f18d609fd12 Mon Sep 17 00:00:00 2001 From: DessyVV Date: Sun, 24 Oct 2021 09:39:20 +0300 Subject: [PATCH 4/5] Adding entry to the changelog for KDE validation changes. --- doc/whats_new/v1.1.rst | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/doc/whats_new/v1.1.rst b/doc/whats_new/v1.1.rst index e6e309cd5fac4..3f59b6a73c9ef 100644 --- a/doc/whats_new/v1.1.rst +++ b/doc/whats_new/v1.1.rst @@ -88,6 +88,13 @@ Changelog - |Enhancement| :func:`utils.estimator_html_repr` shows a more helpful error message when running in a jupyter notebook that is not trusted. :pr:`21316` by `Thomas Fan`_. +....................... + +- |Fix| :class:`neighbors.KernelDensity` now validates input parameters in `fit` + instead of `__init__`. :pr:`21430` by :user:`Desislava Vasileva ` and + :user:`Lucy Jimenez `. + +:mod:`sklearn.neighbors` Code and Documentation Contributors ----------------------------------- From bf35d4dff1245157838311a5166f62c07441517b Mon Sep 17 00:00:00 2001 From: DessyVV Date: Sun, 24 Oct 2021 10:01:57 +0300 Subject: [PATCH 5/5] Fix formatting for new KDE validation changelog entry. --- doc/whats_new/v1.1.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/doc/whats_new/v1.1.rst b/doc/whats_new/v1.1.rst index 3f59b6a73c9ef..363a97199d294 100644 --- a/doc/whats_new/v1.1.rst +++ b/doc/whats_new/v1.1.rst @@ -88,14 +88,14 @@ Changelog - |Enhancement| :func:`utils.estimator_html_repr` shows a more helpful error message when running in a jupyter notebook that is not trusted. :pr:`21316` by `Thomas Fan`_. -....................... + +:mod:`sklearn.neighbors` +........................ - |Fix| :class:`neighbors.KernelDensity` now validates input parameters in `fit` instead of `__init__`. :pr:`21430` by :user:`Desislava Vasileva ` and :user:`Lucy Jimenez `. -:mod:`sklearn.neighbors` - Code and Documentation Contributors -----------------------------------