From e8adbe75ec029f00021476cff0eb00cf5f652e3a Mon Sep 17 00:00:00 2001 From: ROMEEZHOU Date: Mon, 3 Apr 2023 09:20:05 +0800 Subject: [PATCH] MAINT Parameters validation for sklearn.preprocessing.add_dummy_feature --- sklearn/preprocessing/_data.py | 8 +++++++- sklearn/tests/test_public_functions.py | 1 + 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/sklearn/preprocessing/_data.py b/sklearn/preprocessing/_data.py index a363db0bbc8c2..a08fe82bb3fe1 100644 --- a/sklearn/preprocessing/_data.py +++ b/sklearn/preprocessing/_data.py @@ -24,7 +24,7 @@ ClassNamePrefixFeaturesOutMixin, ) from ..utils import check_array -from ..utils._param_validation import Interval, StrOptions +from ..utils._param_validation import validate_params, Interval, StrOptions from ..utils.extmath import _incremental_mean_and_var, row_norms from ..utils.sparsefuncs_fast import ( inplace_csr_row_normalize_l1, @@ -2322,6 +2322,12 @@ def _more_tags(self): return {"pairwise": True} +@validate_params( + { + "X": ["array-like", "sparse matrix"], + "value": [Interval(Real, None, None, closed="neither")], + } +) def add_dummy_feature(X, value=1.0): """Augment dataset with an additional dummy feature. diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index 9e1f95c4d057a..fd7d674535818 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -197,6 +197,7 @@ def _check_function_param_validation( "sklearn.metrics.top_k_accuracy_score", "sklearn.metrics.zero_one_loss", "sklearn.model_selection.train_test_split", + "sklearn.preprocessing.add_dummy_feature", "sklearn.random_projection.johnson_lindenstrauss_min_dim", "sklearn.svm.l1_min_c", "sklearn.tree.export_text",