From e5f2848d7eb68d079c0db02985668d15c5189afc Mon Sep 17 00:00:00 2001 From: zeeshan Date: Tue, 28 Feb 2023 01:39:08 +0530 Subject: [PATCH 1/2] Added parameter validation for chi2 and tested it --- sklearn/feature_selection/_univariate_selection.py | 8 +++++++- sklearn/tests/test_public_functions.py | 1 + 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/sklearn/feature_selection/_univariate_selection.py b/sklearn/feature_selection/_univariate_selection.py index a2a53b6c116dd..11e6bc8e092c7 100644 --- a/sklearn/feature_selection/_univariate_selection.py +++ b/sklearn/feature_selection/_univariate_selection.py @@ -17,7 +17,7 @@ from ..utils import as_float_array, check_array, check_X_y, safe_sqr, safe_mask from ..utils.extmath import safe_sparse_dot, row_norms from ..utils.validation import check_is_fitted -from ..utils._param_validation import Interval, StrOptions +from ..utils._param_validation import Interval, StrOptions, validate_params from ._base import SelectorMixin @@ -167,6 +167,12 @@ def _chisquare(f_obs, f_exp): return chisq, special.chdtrc(k - 1, chisq) +@validate_params( + { + "X":["array-like","sparse matrix"], + "y":["array-like"], + } +) def chi2(X, y): """Compute chi-squared stats between each non-negative feature and class. diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index dae1fdb2e6164..5e2afca424a84 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -144,6 +144,7 @@ def _check_function_param_validation( "sklearn.model_selection.train_test_split", "sklearn.random_projection.johnson_lindenstrauss_min_dim", "sklearn.svm.l1_min_c", + "sklearn.feature_selection.chi2", ] From b1c514bede64627c719bcd309783825ebaac73ec Mon Sep 17 00:00:00 2001 From: jeremiedbb Date: Tue, 28 Feb 2023 11:50:31 +0100 Subject: [PATCH 2/2] lint --- sklearn/feature_selection/_univariate_selection.py | 8 ++++---- sklearn/tests/test_public_functions.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/sklearn/feature_selection/_univariate_selection.py b/sklearn/feature_selection/_univariate_selection.py index 11e6bc8e092c7..5521a62649c81 100644 --- a/sklearn/feature_selection/_univariate_selection.py +++ b/sklearn/feature_selection/_univariate_selection.py @@ -168,10 +168,10 @@ def _chisquare(f_obs, f_exp): @validate_params( - { - "X":["array-like","sparse matrix"], - "y":["array-like"], - } + { + "X": ["array-like", "sparse matrix"], + "y": ["array-like"], + } ) def chi2(X, y): """Compute chi-squared stats between each non-negative feature and class. diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index 5e2afca424a84..4fe337e9a95ea 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -112,6 +112,7 @@ def _check_function_param_validation( "sklearn.feature_extraction.img_to_graph", "sklearn.feature_extraction.image.extract_patches_2d", "sklearn.feature_extraction.image.reconstruct_from_patches_2d", + "sklearn.feature_selection.chi2", "sklearn.metrics.accuracy_score", "sklearn.metrics.auc", "sklearn.metrics.average_precision_score", @@ -144,7 +145,6 @@ def _check_function_param_validation( "sklearn.model_selection.train_test_split", "sklearn.random_projection.johnson_lindenstrauss_min_dim", "sklearn.svm.l1_min_c", - "sklearn.feature_selection.chi2", ]