From 247b42b81973f21a806be205f6eedfc8660a9610 Mon Sep 17 00:00:00 2001 From: OmarManzoor Date: Wed, 9 Nov 2022 12:47:04 +0500 Subject: [PATCH 1/4] MAINT Parameters validation for cluster.estimate_bandwidth --- sklearn/cluster/_mean_shift.py | 11 ++++++++++- sklearn/tests/test_public_functions.py | 1 + 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/sklearn/cluster/_mean_shift.py b/sklearn/cluster/_mean_shift.py index 9dd4c6cc7920b..1f59625a51146 100644 --- a/sklearn/cluster/_mean_shift.py +++ b/sklearn/cluster/_mean_shift.py @@ -20,7 +20,7 @@ from numbers import Integral, Real from collections import defaultdict -from ..utils._param_validation import Interval +from ..utils._param_validation import Interval, validate_params from ..utils.validation import check_is_fitted from ..utils.fixes import delayed from ..utils import check_random_state, gen_batches, check_array @@ -30,6 +30,15 @@ from .._config import config_context +@validate_params( + { + "X": ["array-like"], + "quantile": [Interval(Real, 0, 1, closed="both")], + "n_samples": [Interval(Integral, 1, None, closed="left"), None], + "random_state": ["random_state"], + "n_jobs": [Integral, None], + } +) def estimate_bandwidth(X, *, quantile=0.3, n_samples=None, random_state=0, n_jobs=None): """Estimate the bandwidth to use with the mean-shift algorithm. diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index d4e645c052dab..c3ff9f579298c 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -10,6 +10,7 @@ PARAM_VALIDATION_FUNCTION_LIST = [ "sklearn.cluster.kmeans_plusplus", + "sklearn.cluster.estimate_bandwidth", ] From f42caf5ac5b409146d652d1aa288081584105bbb Mon Sep 17 00:00:00 2001 From: OmarManzoor Date: Wed, 9 Nov 2022 15:23:12 +0500 Subject: [PATCH 2/4] MAINT Parameters validation for decomposition.dict_learning --- sklearn/decomposition/_dict_learning.py | 22 +++++++++++++++++++++- sklearn/tests/test_public_functions.py | 1 + 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/sklearn/decomposition/_dict_learning.py b/sklearn/decomposition/_dict_learning.py index 1957d5290c4cd..8483295abffba 100644 --- a/sklearn/decomposition/_dict_learning.py +++ b/sklearn/decomposition/_dict_learning.py @@ -18,7 +18,7 @@ from ..base import BaseEstimator, TransformerMixin, ClassNamePrefixFeaturesOutMixin from ..utils import check_array, check_random_state, gen_even_slices, gen_batches from ..utils import deprecated -from ..utils._param_validation import Hidden, Interval, StrOptions +from ..utils._param_validation import Hidden, Interval, StrOptions, validate_params from ..utils.extmath import randomized_svd, row_norms, svd_flip from ..utils.validation import check_is_fitted from ..utils.fixes import delayed @@ -490,6 +490,26 @@ def _update_dict( print(f"{n_unused} unused atoms resampled.") +@validate_params( + { + "X": ["array-like"], + "n_components": [Interval(Integral, 1, None, closed="left")], + "alpha": [Interval(Integral, 0, None, closed="left")], + "max_iter": [Interval(Integral, 0, None, closed="left")], + "tol": [Interval(Real, 0, None, closed="left")], + "method": [StrOptions({"lars", "cd"})], + "n_jobs": [Integral, None], + "dict_init": [np.ndarray, None], + "code_init": [np.ndarray, None], + "callback": [None, callable], + "verbose": ["verbose"], + "random_state": ["random_state", None], + "return_n_iter": ["boolean"], + "positive_dict": ["boolean"], + "positive_code": ["boolean"], + "method_max_iter": [Interval(Integral, 0, None, closed="left")], + } +) def dict_learning( X, n_components, diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index c3ff9f579298c..7e3a5b5464461 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -11,6 +11,7 @@ PARAM_VALIDATION_FUNCTION_LIST = [ "sklearn.cluster.kmeans_plusplus", "sklearn.cluster.estimate_bandwidth", + "sklearn.decomposition.dict_learning", ] From ae2cb6e3be0ec05ca581e5b427d5ef5a29499773 Mon Sep 17 00:00:00 2001 From: OmarManzoor Date: Wed, 9 Nov 2022 15:28:27 +0500 Subject: [PATCH 3/4] Revert "MAINT Parameters validation for decomposition.dict_learning" This reverts commit f42caf5ac5b409146d652d1aa288081584105bbb. --- sklearn/decomposition/_dict_learning.py | 22 +--------------------- sklearn/tests/test_public_functions.py | 1 - 2 files changed, 1 insertion(+), 22 deletions(-) diff --git a/sklearn/decomposition/_dict_learning.py b/sklearn/decomposition/_dict_learning.py index 8483295abffba..1957d5290c4cd 100644 --- a/sklearn/decomposition/_dict_learning.py +++ b/sklearn/decomposition/_dict_learning.py @@ -18,7 +18,7 @@ from ..base import BaseEstimator, TransformerMixin, ClassNamePrefixFeaturesOutMixin from ..utils import check_array, check_random_state, gen_even_slices, gen_batches from ..utils import deprecated -from ..utils._param_validation import Hidden, Interval, StrOptions, validate_params +from ..utils._param_validation import Hidden, Interval, StrOptions from ..utils.extmath import randomized_svd, row_norms, svd_flip from ..utils.validation import check_is_fitted from ..utils.fixes import delayed @@ -490,26 +490,6 @@ def _update_dict( print(f"{n_unused} unused atoms resampled.") -@validate_params( - { - "X": ["array-like"], - "n_components": [Interval(Integral, 1, None, closed="left")], - "alpha": [Interval(Integral, 0, None, closed="left")], - "max_iter": [Interval(Integral, 0, None, closed="left")], - "tol": [Interval(Real, 0, None, closed="left")], - "method": [StrOptions({"lars", "cd"})], - "n_jobs": [Integral, None], - "dict_init": [np.ndarray, None], - "code_init": [np.ndarray, None], - "callback": [None, callable], - "verbose": ["verbose"], - "random_state": ["random_state", None], - "return_n_iter": ["boolean"], - "positive_dict": ["boolean"], - "positive_code": ["boolean"], - "method_max_iter": [Interval(Integral, 0, None, closed="left")], - } -) def dict_learning( X, n_components, diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index 7e3a5b5464461..c3ff9f579298c 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -11,7 +11,6 @@ PARAM_VALIDATION_FUNCTION_LIST = [ "sklearn.cluster.kmeans_plusplus", "sklearn.cluster.estimate_bandwidth", - "sklearn.decomposition.dict_learning", ] From eed1f641ff86a5fa5057d90acfbbe12b4a752c8a Mon Sep 17 00:00:00 2001 From: OmarManzoor Date: Thu, 10 Nov 2022 18:53:59 +0500 Subject: [PATCH 4/4] Addressed PR suggestions --- sklearn/cluster/_mean_shift.py | 6 ++++-- sklearn/cluster/tests/test_mean_shift.py | 10 ---------- 2 files changed, 4 insertions(+), 12 deletions(-) diff --git a/sklearn/cluster/_mean_shift.py b/sklearn/cluster/_mean_shift.py index 1f59625a51146..b8bf585b0067d 100644 --- a/sklearn/cluster/_mean_shift.py +++ b/sklearn/cluster/_mean_shift.py @@ -42,8 +42,10 @@ def estimate_bandwidth(X, *, quantile=0.3, n_samples=None, random_state=0, n_jobs=None): """Estimate the bandwidth to use with the mean-shift algorithm. - That this function takes time at least quadratic in n_samples. For large - datasets, it's wise to set that parameter to a small value. + This function takes time at least quadratic in `n_samples`. For large + datasets, it is wise to subsample by setting `n_samples`. Alternatively, + the parameter `bandwidth` can be set to a small value without estimating + it. Parameters ---------- diff --git a/sklearn/cluster/tests/test_mean_shift.py b/sklearn/cluster/tests/test_mean_shift.py index 0f4d1c68d2f6e..db13e4d18650f 100644 --- a/sklearn/cluster/tests/test_mean_shift.py +++ b/sklearn/cluster/tests/test_mean_shift.py @@ -7,8 +7,6 @@ import warnings import pytest -from scipy import sparse - from sklearn.utils._testing import assert_array_equal from sklearn.utils._testing import assert_allclose @@ -76,14 +74,6 @@ def test_mean_shift( assert cluster_centers.dtype == global_dtype -def test_estimate_bandwidth_with_sparse_matrix(): - # Test estimate_bandwidth with sparse matrix - X = sparse.lil_matrix((1000, 1000)) - msg = "A sparse matrix was passed, but dense data is required." - with pytest.raises(TypeError, match=msg): - estimate_bandwidth(X) - - def test_parallel(global_dtype): centers = np.array([[1, 1], [-1, -1], [1, -1]]) + 10 X, _ = make_blobs(