From 8dec90617f751fee53e24199f1736dd8a0790e09 Mon Sep 17 00:00:00 2001 From: Charlie-XIAO Date: Fri, 31 Mar 2023 22:31:53 +0800 Subject: [PATCH 1/3] MAINT Parameters validation for preprocessing.scale --- sklearn/preprocessing/_data.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/sklearn/preprocessing/_data.py b/sklearn/preprocessing/_data.py index d72b0294fa4f4..fadb3637fc9f5 100644 --- a/sklearn/preprocessing/_data.py +++ b/sklearn/preprocessing/_data.py @@ -120,6 +120,15 @@ def _handle_zeros_in_scale(scale, copy=True, constant_mask=None): return scale +@validate_params( + { + "X": ["array-like", "sparse matrix"], + "axis": [Options(int, {0, 1})], + "with_mean": ["boolean"], + "with_std": ["boolean"], + "copy": ["boolean"], + } +) def scale(X, *, axis=0, with_mean=True, with_std=True, copy=True): """Standardize a dataset along any axis. From 44ec47eba43437b4236de2143a16c60485237ecc Mon Sep 17 00:00:00 2001 From: Charlie-XIAO Date: Fri, 31 Mar 2023 22:36:15 +0800 Subject: [PATCH 2/3] updated validation and docstring --- sklearn/preprocessing/_data.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/preprocessing/_data.py b/sklearn/preprocessing/_data.py index fadb3637fc9f5..03a999ffba49e 100644 --- a/sklearn/preprocessing/_data.py +++ b/sklearn/preprocessing/_data.py @@ -123,7 +123,7 @@ def _handle_zeros_in_scale(scale, copy=True, constant_mask=None): @validate_params( { "X": ["array-like", "sparse matrix"], - "axis": [Options(int, {0, 1})], + "axis": [Options(Integral, {0, 1})], "with_mean": ["boolean"], "with_std": ["boolean"], "copy": ["boolean"], @@ -141,7 +141,7 @@ def scale(X, *, axis=0, with_mean=True, with_std=True, copy=True): X : {array-like, sparse matrix} of shape (n_samples, n_features) The data to center and scale. - axis : int, default=0 + axis : {0, 1}, default=0 Axis used to compute the means and standard deviations along. If 0, independently standardize each feature, otherwise (if 1) standardize each sample. From 841e72354115eb7fb513d6e44010dd17dac56ff9 Mon Sep 17 00:00:00 2001 From: Charlie-XIAO Date: Fri, 31 Mar 2023 22:37:53 +0800 Subject: [PATCH 3/3] updated common param validation test list --- sklearn/tests/test_public_functions.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index e127369072828..9581de7630788 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -210,6 +210,7 @@ def _check_function_param_validation( "sklearn.metrics.top_k_accuracy_score", "sklearn.metrics.zero_one_loss", "sklearn.model_selection.train_test_split", + "sklearn.preprocessing.scale", "sklearn.random_projection.johnson_lindenstrauss_min_dim", "sklearn.svm.l1_min_c", "sklearn.tree.export_text",