From 89ae918c0b68bcbfd6177a9c3046e4e8114fb516 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Thu, 1 Oct 2020 19:43:34 -0400 Subject: [PATCH 01/19] ENH Enables validate_data for non-fit methods --- sklearn/base.py | 2 +- sklearn/impute/_base.py | 10 +--- sklearn/impute/_iterative.py | 11 +++- sklearn/impute/_knn.py | 11 ++-- .../neural_network/_multilayer_perceptron.py | 22 ++++--- sklearn/neural_network/_rbm.py | 7 ++- sklearn/tests/test_common.py | 49 ++++++++++++++- sklearn/utils/estimator_checks.py | 59 +++++++++++++++++++ 8 files changed, 139 insertions(+), 32 deletions(-) diff --git a/sklearn/base.py b/sklearn/base.py index 42bda7dba0913..348c6a7127419 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -410,7 +410,7 @@ def _validate_data(self, X, y=None, reset=True, """ if y is None: - if self._get_tags()['requires_y']: + if reset and self._get_tags()['requires_y']: raise ValueError( f"This {self.__class__.__name__} estimator " f"requires y to be passed, but the target y is None." diff --git a/sklearn/impute/_base.py b/sklearn/impute/_base.py index 20b22224d53c7..ddfbc9b8f6285 100644 --- a/sklearn/impute/_base.py +++ b/sklearn/impute/_base.py @@ -428,10 +428,6 @@ def transform(self, X): X = self._validate_input(X, in_fit=False) statistics = self.statistics_ - if X.shape[1] != statistics.shape[0]: - raise ValueError("X has %d features per sample, expected %d" - % (X.shape[1], self.statistics_.shape[0])) - # compute mask before eliminating invalid features missing_mask = _get_mask(X, self.missing_values) @@ -793,16 +789,12 @@ def transform(self, X): # Need not validate X again as it would have already been validated # in the Imputer calling MissingIndicator if not self._precomputed: - X = self._validate_input(X, in_fit=True) + X = self._validate_input(X, in_fit=False) else: if not (hasattr(X, 'dtype') and X.dtype.kind == 'b'): raise ValueError("precomputed is True but the input data is " "not a mask") - if X.shape[1] != self._n_features: - raise ValueError("X has a different number of features " - "than during fitting.") - imputer_mask, features = self._get_missing_features_info(X) if self.features == "missing-only": diff --git a/sklearn/impute/_iterative.py b/sklearn/impute/_iterative.py index 325a484244143..e5a2be328e7ce 100644 --- a/sklearn/impute/_iterative.py +++ b/sklearn/impute/_iterative.py @@ -468,7 +468,7 @@ def _get_abs_corr_mat(self, X_filled, tolerance=1e-6): abs_corr_mat = normalize(abs_corr_mat, norm='l1', axis=0, copy=False) return abs_corr_mat - def _initial_imputation(self, X): + def _initial_imputation(self, X, in_fit=True): """Perform initial imputation for input X. Parameters @@ -477,6 +477,9 @@ def _initial_imputation(self, X): Input data, where "n_samples" is the number of samples and "n_features" is the number of features. + in_fit : bool, default=True + Whether the imputation is done in fit. + Returns ------- Xt : ndarray, shape (n_samples, n_features) @@ -501,7 +504,8 @@ def _initial_imputation(self, X): force_all_finite = True X = self._validate_data(X, dtype=FLOAT_DTYPES, order="F", - force_all_finite=force_all_finite) + force_all_finite=force_all_finite, + reset=in_fit) _check_inputs_dtype(X, self.missing_values) X_missing_mask = _get_mask(X, self.missing_values) @@ -695,7 +699,8 @@ def transform(self, X): """ check_is_fitted(self) - X, Xt, mask_missing_values, complete_mask = self._initial_imputation(X) + X, Xt, mask_missing_values, complete_mask = ( + self._initial_imputation(X, in_fit=False)) X_indicator = super()._transform_indicator(complete_mask) diff --git a/sklearn/impute/_knn.py b/sklearn/impute/_knn.py index df66e4a20aff6..dfd5bc8852182 100644 --- a/sklearn/impute/_knn.py +++ b/sklearn/impute/_knn.py @@ -10,7 +10,6 @@ from ..metrics.pairwise import _NAN_METRICS from ..neighbors._base import _get_weights from ..neighbors._base import _check_weights -from ..utils import check_array from ..utils import is_scalar_nan from ..utils._mask import _get_mask from ..utils.validation import check_is_fitted @@ -213,12 +212,10 @@ def transform(self, X): force_all_finite = True else: force_all_finite = "allow-nan" - X = check_array(X, accept_sparse=False, dtype=FLOAT_DTYPES, - force_all_finite=force_all_finite, copy=self.copy) - - if X.shape[1] != self._fit_X.shape[1]: - raise ValueError("Incompatible dimension between the fitted " - "dataset and the one to be transformed") + X = self._validate_data( + X, accept_sparse=False, dtype=FLOAT_DTYPES, + force_all_finite=force_all_finite, copy=self.copy, + reset=False) mask = _get_mask(X, self.missing_values) mask_fit_X = self._mask_fit_X diff --git a/sklearn/neural_network/_multilayer_perceptron.py b/sklearn/neural_network/_multilayer_perceptron.py index 2937e59e3f0ec..505b20067f38a 100644 --- a/sklearn/neural_network/_multilayer_perceptron.py +++ b/sklearn/neural_network/_multilayer_perceptron.py @@ -22,7 +22,7 @@ from ..utils import gen_batches, check_random_state from ..utils import shuffle from ..utils import _safe_indexing -from ..utils import check_array, column_or_1d +from ..utils import column_or_1d from ..exceptions import ConvergenceWarning from ..utils.extmath import safe_sparse_dot from ..utils.validation import check_is_fitted, _deprecate_positional_args @@ -131,7 +131,8 @@ def _forward_pass_fast(self, X): y_pred : ndarray of shape (n_samples,) or (n_samples, n_outputs) The decision function of the samples for each class in the model. """ - X = check_array(X, accept_sparse=['csr', 'csc']) + X = self._validate_data(X, accept_sparse=['csr', 'csc'], + reset=False) # Initialize first layer activation = X @@ -358,8 +359,10 @@ def _fit(self, X, y, incremental=False): if np.any(np.array(hidden_layer_sizes) <= 0): raise ValueError("hidden_layer_sizes must be > 0, got %s." % hidden_layer_sizes) + first_pass = (not hasattr(self, 'coefs_') or + (not self.warm_start and not incremental)) - X, y = self._validate_input(X, y, incremental) + X, y = self._validate_input(X, y, incremental, reset=first_pass) n_samples, n_features = X.shape @@ -375,8 +378,7 @@ def _fit(self, X, y, incremental=False): # check random state self._random_state = check_random_state(self.random_state) - if not hasattr(self, 'coefs_') or (not self.warm_start and not - incremental): + if first_pass: # First time training the model self._initialize(y, layer_units, X.dtype) @@ -970,10 +972,11 @@ def __init__(self, hidden_layer_sizes=(100,), activation="relu", *, beta_1=beta_1, beta_2=beta_2, epsilon=epsilon, n_iter_no_change=n_iter_no_change, max_fun=max_fun) - def _validate_input(self, X, y, incremental): + def _validate_input(self, X, y, incremental, reset): X, y = self._validate_data(X, y, accept_sparse=['csr', 'csc'], multi_output=True, - dtype=(np.float64, np.float32)) + dtype=(np.float64, np.float32), + reset=reset) if y.ndim == 2 and y.shape[1] == 1: y = column_or_1d(y, warn=True) @@ -1416,10 +1419,11 @@ def predict(self, X): return y_pred.ravel() return y_pred - def _validate_input(self, X, y, incremental): + def _validate_input(self, X, y, incremental, reset): X, y = self._validate_data(X, y, accept_sparse=['csr', 'csc'], multi_output=True, y_numeric=True, - dtype=(np.float64, np.float32)) + dtype=(np.float64, np.float32), + reset=reset) if y.ndim == 2 and y.shape[1] == 1: y = column_or_1d(y, warn=True) return X, y diff --git a/sklearn/neural_network/_rbm.py b/sklearn/neural_network/_rbm.py index d1028911f4185..97cdbb0132ebf 100644 --- a/sklearn/neural_network/_rbm.py +++ b/sklearn/neural_network/_rbm.py @@ -131,7 +131,8 @@ def transform(self, X): """ check_is_fitted(self) - X = check_array(X, accept_sparse='csr', dtype=(np.float64, np.float32)) + X = self._validate_data(X, accept_sparse='csr', reset=False, + dtype=(np.float64, np.float32)) return self._mean_hiddens(X) def _mean_hiddens(self, v): @@ -243,7 +244,9 @@ def partial_fit(self, X, y=None): self : BernoulliRBM The fitted model. """ - X = check_array(X, accept_sparse='csr', dtype=np.float64) + first_pass = not hasattr(self, 'components_') + X = self._validate_data(X, accept_sparse='csr', dtype=np.float64, + reset=first_pass) if not hasattr(self, 'random_state_'): self.random_state_ = check_random_state(self.random_state) if not hasattr(self, 'components_'): diff --git a/sklearn/tests/test_common.py b/sklearn/tests/test_common.py index d63484c10f2a5..86b11846bb8cc 100644 --- a/sklearn/tests/test_common.py +++ b/sklearn/tests/test_common.py @@ -37,7 +37,8 @@ _set_checking_parameters, _get_check_estimator_ids, check_class_weight_balanced_linear_classifier, - parametrize_with_checks) + parametrize_with_checks, + check_n_features_in_after_fitting) def test_all_estimator_no_base_class(): @@ -270,3 +271,49 @@ def test_strict_mode_check_estimator(): def test_strict_mode_parametrize_with_checks(estimator, check): # Ideally we should assert that the strict checks are Xfailed... check(estimator) + + +# TODO: When more modules get added, we can remove it from this list to make +# sure it gets tested. After we finish each module we can move the checks +# into check_estimator +N_FEATURES_IN_AFTER_FIT_MODULES_TO_IGNORE = { + 'calibration', + 'cluster', + 'compose', + 'covariance', + 'cross_decomposition', + 'decomposition', + 'discriminant_analysis', + 'ensemble', + 'feature_extraction', + 'feature_selection', + 'gaussian_process', + 'isotonic', + 'kernel_approximation', + 'kernel_ridge', + 'linear_model', + 'manifold', + 'mixture', + 'model_selection', + 'multiclass', + 'multioutput', + 'naive_bayes', + 'neighbors', + 'pipeline', + 'preprocessing', + 'random_projection', + 'semi_supervised', + 'svm', + 'tree', +} + +N_FEATURES_IN_AFTER_FIT_ESTIMATORS = [ + est for est in _tested_estimators() if est.__module__.split('.')[1] not in + N_FEATURES_IN_AFTER_FIT_MODULES_TO_IGNORE +] + + +@pytest.mark.parametrize("estimator", N_FEATURES_IN_AFTER_FIT_ESTIMATORS, + ids=_get_check_estimator_ids) +def test_check_n_features_in_after_fitting(estimator): + check_n_features_in_after_fitting(estimator.__class__.__name__, estimator) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 5b99e8e56c420..6515efc4ee586 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -3121,6 +3121,65 @@ def check_requires_y_none(name, estimator_orig, strict_mode=True): warnings.warn(warning_msg, FutureWarning) +def check_n_features_in_after_fitting(name, estimator_orig, strict_mode=True): + # Make sure that n_features_in are checked after fitting + tags = estimator_orig._get_tags() + + if "2darray" not in tags["X_types"] or tags["no_validation"]: + return + + rng = np.random.RandomState(0) + + estimator = clone(estimator_orig) + set_random_state(estimator) + if 'warm_start' in estimator.get_params(): + estimator.set_params(warm_start=False) + + n_samples = 100 + X = rng.normal(loc=100, size=(n_samples, 2)) + X = _pairwise_estimator_convert_X(X, estimator) + if is_regressor(estimator_orig): + y = rng.normal(size=n_samples) + else: + y = rng.randint(low=0, high=2, size=n_samples) + y = _enforce_estimator_tags_y(estimator, y) + + estimator.fit(X, y) + assert estimator.n_features_in_ == X.shape[1] + + # check methods will check n_features_in_ + check_methods = ["predict", "transform", "decision_function", + "predict_proba"] + X_bad = X[:, [1]] + for method in check_methods: + if not hasattr(estimator, method): + continue + + msg = f"X has 1 features, but {name} is expecting 2 features as input" + with raises(ValueError, match=msg): + getattr(estimator, method)(X_bad) + + # partial_fit will check in the second call + if not hasattr(estimator, "partial_fit"): + return + + estimator = clone(estimator_orig) + + has_classes = 'classes' in signature(estimator.partial_fit).parameters + if has_classes: + estimator.partial_fit(X, y, classes=np.unique(y)) + else: + estimator.partial_fit(X, y) + assert estimator.n_features_in_ == X.shape[1] + + msg = f"X has 1 features, but {name} is expecting 2 features as input" + with raises(ValueError, match=msg): + if has_classes: + estimator.partial_fit(X_bad, y, classes=np.unique(y)) + else: + estimator.partial_fit(X_bad, y) + + # set of checks that are completely strict, i.e. they have no non-strict part _FULLY_STRICT_CHECKS = set([ 'check_n_features_in', From ba22deefce680b1f29658ac5175327b33926ab64 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Thu, 1 Oct 2020 19:50:09 -0400 Subject: [PATCH 02/19] REV Less diffs --- sklearn/impute/_base.py | 10 +++++++++- sklearn/impute/_iterative.py | 11 +++-------- sklearn/impute/_knn.py | 11 +++++++---- sklearn/tests/test_common.py | 1 + 4 files changed, 20 insertions(+), 13 deletions(-) diff --git a/sklearn/impute/_base.py b/sklearn/impute/_base.py index ddfbc9b8f6285..20b22224d53c7 100644 --- a/sklearn/impute/_base.py +++ b/sklearn/impute/_base.py @@ -428,6 +428,10 @@ def transform(self, X): X = self._validate_input(X, in_fit=False) statistics = self.statistics_ + if X.shape[1] != statistics.shape[0]: + raise ValueError("X has %d features per sample, expected %d" + % (X.shape[1], self.statistics_.shape[0])) + # compute mask before eliminating invalid features missing_mask = _get_mask(X, self.missing_values) @@ -789,12 +793,16 @@ def transform(self, X): # Need not validate X again as it would have already been validated # in the Imputer calling MissingIndicator if not self._precomputed: - X = self._validate_input(X, in_fit=False) + X = self._validate_input(X, in_fit=True) else: if not (hasattr(X, 'dtype') and X.dtype.kind == 'b'): raise ValueError("precomputed is True but the input data is " "not a mask") + if X.shape[1] != self._n_features: + raise ValueError("X has a different number of features " + "than during fitting.") + imputer_mask, features = self._get_missing_features_info(X) if self.features == "missing-only": diff --git a/sklearn/impute/_iterative.py b/sklearn/impute/_iterative.py index e5a2be328e7ce..325a484244143 100644 --- a/sklearn/impute/_iterative.py +++ b/sklearn/impute/_iterative.py @@ -468,7 +468,7 @@ def _get_abs_corr_mat(self, X_filled, tolerance=1e-6): abs_corr_mat = normalize(abs_corr_mat, norm='l1', axis=0, copy=False) return abs_corr_mat - def _initial_imputation(self, X, in_fit=True): + def _initial_imputation(self, X): """Perform initial imputation for input X. Parameters @@ -477,9 +477,6 @@ def _initial_imputation(self, X, in_fit=True): Input data, where "n_samples" is the number of samples and "n_features" is the number of features. - in_fit : bool, default=True - Whether the imputation is done in fit. - Returns ------- Xt : ndarray, shape (n_samples, n_features) @@ -504,8 +501,7 @@ def _initial_imputation(self, X, in_fit=True): force_all_finite = True X = self._validate_data(X, dtype=FLOAT_DTYPES, order="F", - force_all_finite=force_all_finite, - reset=in_fit) + force_all_finite=force_all_finite) _check_inputs_dtype(X, self.missing_values) X_missing_mask = _get_mask(X, self.missing_values) @@ -699,8 +695,7 @@ def transform(self, X): """ check_is_fitted(self) - X, Xt, mask_missing_values, complete_mask = ( - self._initial_imputation(X, in_fit=False)) + X, Xt, mask_missing_values, complete_mask = self._initial_imputation(X) X_indicator = super()._transform_indicator(complete_mask) diff --git a/sklearn/impute/_knn.py b/sklearn/impute/_knn.py index dfd5bc8852182..df66e4a20aff6 100644 --- a/sklearn/impute/_knn.py +++ b/sklearn/impute/_knn.py @@ -10,6 +10,7 @@ from ..metrics.pairwise import _NAN_METRICS from ..neighbors._base import _get_weights from ..neighbors._base import _check_weights +from ..utils import check_array from ..utils import is_scalar_nan from ..utils._mask import _get_mask from ..utils.validation import check_is_fitted @@ -212,10 +213,12 @@ def transform(self, X): force_all_finite = True else: force_all_finite = "allow-nan" - X = self._validate_data( - X, accept_sparse=False, dtype=FLOAT_DTYPES, - force_all_finite=force_all_finite, copy=self.copy, - reset=False) + X = check_array(X, accept_sparse=False, dtype=FLOAT_DTYPES, + force_all_finite=force_all_finite, copy=self.copy) + + if X.shape[1] != self._fit_X.shape[1]: + raise ValueError("Incompatible dimension between the fitted " + "dataset and the one to be transformed") mask = _get_mask(X, self.missing_values) mask_fit_X = self._mask_fit_X diff --git a/sklearn/tests/test_common.py b/sklearn/tests/test_common.py index 86b11846bb8cc..06b06ffb3a440 100644 --- a/sklearn/tests/test_common.py +++ b/sklearn/tests/test_common.py @@ -288,6 +288,7 @@ def test_strict_mode_parametrize_with_checks(estimator, check): 'feature_extraction', 'feature_selection', 'gaussian_process', + 'impute', 'isotonic', 'kernel_approximation', 'kernel_ridge', From e180abd64234251393074ea2601bc74bcac71236 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Thu, 1 Oct 2020 19:56:49 -0400 Subject: [PATCH 03/19] TST Improves test --- sklearn/utils/estimator_checks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 6515efc4ee586..18ca57a4b9c53 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -3155,7 +3155,8 @@ def check_n_features_in_after_fitting(name, estimator_orig, strict_mode=True): if not hasattr(estimator, method): continue - msg = f"X has 1 features, but {name} is expecting 2 features as input" + msg = (f"X has 1 features, but {name} is expecting {X.shape[1]} " + "features as input") with raises(ValueError, match=msg): getattr(estimator, method)(X_bad) @@ -3172,7 +3173,6 @@ def check_n_features_in_after_fitting(name, estimator_orig, strict_mode=True): estimator.partial_fit(X, y) assert estimator.n_features_in_ == X.shape[1] - msg = f"X has 1 features, but {name} is expecting 2 features as input" with raises(ValueError, match=msg): if has_classes: estimator.partial_fit(X_bad, y, classes=np.unique(y)) From fd03030b8a3ca34f4efd1887d511e39fcff02cb6 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Thu, 1 Oct 2020 19:57:46 -0400 Subject: [PATCH 04/19] TST Improves test --- sklearn/utils/estimator_checks.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 18ca57a4b9c53..6402f022fd75c 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -3151,14 +3151,15 @@ def check_n_features_in_after_fitting(name, estimator_orig, strict_mode=True): check_methods = ["predict", "transform", "decision_function", "predict_proba"] X_bad = X[:, [1]] + + msg = (f"X has 1 features, but {name} is expecting {X.shape[1]} " + "features as input") for method in check_methods: - if not hasattr(estimator, method): + func = getattr(estimator, method, None) + if func is None: continue - - msg = (f"X has 1 features, but {name} is expecting {X.shape[1]} " - "features as input") with raises(ValueError, match=msg): - getattr(estimator, method)(X_bad) + func(X_bad) # partial_fit will check in the second call if not hasattr(estimator, "partial_fit"): From c40d37ebc5d6e3996bcedbd94389d5a68ef42bee Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Thu, 1 Oct 2020 20:05:38 -0400 Subject: [PATCH 05/19] DOC Adds docs --- sklearn/base.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/sklearn/base.py b/sklearn/base.py index 348c6a7127419..34c3990831958 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -360,6 +360,10 @@ def _check_n_features(self, X, reset): If True, the `n_features_in_` attribute is set to `X.shape[1]`. Else, the attribute must already exist and the function checks that it is equal to `X.shape[1]`. + .. note:: + It is recommended to call reset=True in `fit` and in the first + call to `partial_fit`. All other methods that validates `X` + should set `reset=False`. """ n_features = X.shape[1] @@ -394,6 +398,10 @@ def _validate_data(self, X, y=None, reset=True, Whether to reset the `n_features_in_` attribute. If False, the input will be checked for consistency with data provided when reset was last True. + .. note:: + It is recommended to call reset=True in `fit` and in the first + call to `partial_fit`. All other methods that validates `X` + should set `reset=False`. validate_separately : False or tuple of dicts, default=False Only used if y is not None. If False, call validate_X_y(). Else, it must be a tuple of kwargs From f8ffd8812054a66dc7b5f6ab840b7e923439ee6f Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Fri, 2 Oct 2020 12:05:49 -0400 Subject: [PATCH 06/19] TST Update with more feature setting --- sklearn/neural_network/tests/test_mlp.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sklearn/neural_network/tests/test_mlp.py b/sklearn/neural_network/tests/test_mlp.py index 72f347966355d..bdadf37c39902 100644 --- a/sklearn/neural_network/tests/test_mlp.py +++ b/sklearn/neural_network/tests/test_mlp.py @@ -94,6 +94,7 @@ def test_fit(): mlp.intercepts_[1] = np.array([1.0]) mlp._coef_grads = [] * 2 mlp._intercept_grads = [] * 2 + mlp.n_features_in_ = 3 # Initialize parameters mlp.n_iter_ = 0 From 18454c1bbcb63ec245f01289fd001d8f52698a2f Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Fri, 2 Oct 2020 12:20:00 -0400 Subject: [PATCH 07/19] TST Fixes tests --- sklearn/base.py | 4 ++-- sklearn/tests/test_common.py | 2 +- sklearn/utils/estimator_checks.py | 11 +++-------- 3 files changed, 6 insertions(+), 11 deletions(-) diff --git a/sklearn/base.py b/sklearn/base.py index 34c3990831958..0e6dcd64b04b0 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -362,7 +362,7 @@ def _check_n_features(self, X, reset): that it is equal to `X.shape[1]`. .. note:: It is recommended to call reset=True in `fit` and in the first - call to `partial_fit`. All other methods that validates `X` + call to `partial_fit`. All other methods that validate `X` should set `reset=False`. """ n_features = X.shape[1] @@ -400,7 +400,7 @@ def _validate_data(self, X, y=None, reset=True, provided when reset was last True. .. note:: It is recommended to call reset=True in `fit` and in the first - call to `partial_fit`. All other methods that validates `X` + call to `partial_fit`. All other methods that validate `X` should set `reset=False`. validate_separately : False or tuple of dicts, default=False Only used if y is not None. diff --git a/sklearn/tests/test_common.py b/sklearn/tests/test_common.py index 06b06ffb3a440..9bcecbe8dc0c9 100644 --- a/sklearn/tests/test_common.py +++ b/sklearn/tests/test_common.py @@ -275,7 +275,7 @@ def test_strict_mode_parametrize_with_checks(estimator, check): # TODO: When more modules get added, we can remove it from this list to make # sure it gets tested. After we finish each module we can move the checks -# into check_estimator +# into sklearn.utils.estimator_checks.check_n_features_in N_FEATURES_IN_AFTER_FIT_MODULES_TO_IGNORE = { 'calibration', 'cluster', diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 6402f022fd75c..af3b8496bb687 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -3138,7 +3138,7 @@ def check_n_features_in_after_fitting(name, estimator_orig, strict_mode=True): n_samples = 100 X = rng.normal(loc=100, size=(n_samples, 2)) X = _pairwise_estimator_convert_X(X, estimator) - if is_regressor(estimator_orig): + if is_regressor(estimator): y = rng.normal(size=n_samples) else: y = rng.randint(low=0, high=2, size=n_samples) @@ -3166,19 +3166,14 @@ def check_n_features_in_after_fitting(name, estimator_orig, strict_mode=True): return estimator = clone(estimator_orig) - - has_classes = 'classes' in signature(estimator.partial_fit).parameters - if has_classes: + if is_classifier(estimator): estimator.partial_fit(X, y, classes=np.unique(y)) else: estimator.partial_fit(X, y) assert estimator.n_features_in_ == X.shape[1] with raises(ValueError, match=msg): - if has_classes: - estimator.partial_fit(X_bad, y, classes=np.unique(y)) - else: - estimator.partial_fit(X_bad, y) + estimator.partial_fit(X_bad, y) # set of checks that are completely strict, i.e. they have no non-strict part From 0c2ea46b1181a5e2e24ab17d60d73694b0025a52 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Fri, 2 Oct 2020 12:20:15 -0400 Subject: [PATCH 08/19] DOC Adds comment --- sklearn/tests/test_common.py | 6 +++++- sklearn/utils/estimator_checks.py | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/sklearn/tests/test_common.py b/sklearn/tests/test_common.py index 9bcecbe8dc0c9..56d2cde3e5b97 100644 --- a/sklearn/tests/test_common.py +++ b/sklearn/tests/test_common.py @@ -275,7 +275,11 @@ def test_strict_mode_parametrize_with_checks(estimator, check): # TODO: When more modules get added, we can remove it from this list to make # sure it gets tested. After we finish each module we can move the checks -# into sklearn.utils.estimator_checks.check_n_features_in +# into sklearn.utils.estimator_checks.check_n_features_in. +# sklearn.utils.estimator_checks.check_estimators_partial_fit_n_features +# can either be removed or updated with the two more assertions: +# 1. `n_features_in_` is set during the first call to `partial_fit`. +# 2. More strict when it comes to the error message. N_FEATURES_IN_AFTER_FIT_MODULES_TO_IGNORE = { 'calibration', 'cluster', diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index af3b8496bb687..e5d7fbb43d836 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -3153,7 +3153,7 @@ def check_n_features_in_after_fitting(name, estimator_orig, strict_mode=True): X_bad = X[:, [1]] msg = (f"X has 1 features, but {name} is expecting {X.shape[1]} " - "features as input") + "features as input") for method in check_methods: func = getattr(estimator, method, None) if func is None: From 907dce9c483f554ccae92170a028508331628c1f Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Fri, 2 Oct 2020 12:21:06 -0400 Subject: [PATCH 09/19] CLN Uses hasattr --- sklearn/utils/estimator_checks.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index e5d7fbb43d836..2187c55a2cbc1 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -3155,11 +3155,10 @@ def check_n_features_in_after_fitting(name, estimator_orig, strict_mode=True): msg = (f"X has 1 features, but {name} is expecting {X.shape[1]} " "features as input") for method in check_methods: - func = getattr(estimator, method, None) - if func is None: + if not hasattr(estimator, method): continue with raises(ValueError, match=msg): - func(X_bad) + getattr(estimator, method)(X_bad) # partial_fit will check in the second call if not hasattr(estimator, "partial_fit"): From 8e8d11d8da43de0b294029dd2ee0a4b07e891fc4 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Fri, 2 Oct 2020 13:00:49 -0400 Subject: [PATCH 10/19] CLN Adds requires_y kwargs --- sklearn/base.py | 14 ++++++++++++-- sklearn/neural_network/_multilayer_perceptron.py | 2 +- sklearn/neural_network/_rbm.py | 3 ++- 3 files changed, 15 insertions(+), 4 deletions(-) diff --git a/sklearn/base.py b/sklearn/base.py index 0e6dcd64b04b0..1fbca54651277 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -383,7 +383,8 @@ def _check_n_features(self, X, reset): ) def _validate_data(self, X, y=None, reset=True, - validate_separately=False, **check_params): + validate_separately=False, + requires_y='auto', **check_params): """Validate input data and set or check the `n_features_in_` attribute. Parameters @@ -406,6 +407,14 @@ def _validate_data(self, X, y=None, reset=True, Only used if y is not None. If False, call validate_X_y(). Else, it must be a tuple of kwargs to be used for calling check_array() on X and y respectively. + requires_y : bool or 'auto', default='auto' + If 'auto', the the 'requires_y' tag will be used to decide if `y` + is required. If bool, then the caller decides if `y` is required. + .. note:: + It is recommended to leave requires_y='auto' in `fit and + in the first call to `partial-fit. All other methods that + validate `X` and does not require `y` should set + `requires_y=False`. **check_params : kwargs Parameters passed to :func:`sklearn.utils.check_array` or :func:`sklearn.utils.check_X_y`. Ignored if validate_separately @@ -418,7 +427,8 @@ def _validate_data(self, X, y=None, reset=True, """ if y is None: - if reset and self._get_tags()['requires_y']: + if ((requires_y == 'auto' and self._get_tags()['requires_y']) or + (isinstance(requires_y, bool) and requires_y)): raise ValueError( f"This {self.__class__.__name__} estimator " f"requires y to be passed, but the target y is None." diff --git a/sklearn/neural_network/_multilayer_perceptron.py b/sklearn/neural_network/_multilayer_perceptron.py index 505b20067f38a..80baaa4d278a1 100644 --- a/sklearn/neural_network/_multilayer_perceptron.py +++ b/sklearn/neural_network/_multilayer_perceptron.py @@ -132,7 +132,7 @@ def _forward_pass_fast(self, X): The decision function of the samples for each class in the model. """ X = self._validate_data(X, accept_sparse=['csr', 'csc'], - reset=False) + reset=False, requires_y=False) # Initialize first layer activation = X diff --git a/sklearn/neural_network/_rbm.py b/sklearn/neural_network/_rbm.py index 97cdbb0132ebf..90b066360d999 100644 --- a/sklearn/neural_network/_rbm.py +++ b/sklearn/neural_network/_rbm.py @@ -132,7 +132,8 @@ def transform(self, X): check_is_fitted(self) X = self._validate_data(X, accept_sparse='csr', reset=False, - dtype=(np.float64, np.float32)) + dtype=(np.float64, np.float32), + requires_y=False) return self._mean_hiddens(X) def _mean_hiddens(self, v): From bd150f632c48ffafcad8ab2fd454bf419ddc92a3 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Fri, 2 Oct 2020 13:04:41 -0400 Subject: [PATCH 11/19] CLN Uses requires_y='use_tag' --- sklearn/base.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/sklearn/base.py b/sklearn/base.py index 1fbca54651277..01a28a7ff21f4 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -384,7 +384,7 @@ def _check_n_features(self, X, reset): def _validate_data(self, X, y=None, reset=True, validate_separately=False, - requires_y='auto', **check_params): + requires_y='use_tag', **check_params): """Validate input data and set or check the `n_features_in_` attribute. Parameters @@ -407,11 +407,11 @@ def _validate_data(self, X, y=None, reset=True, Only used if y is not None. If False, call validate_X_y(). Else, it must be a tuple of kwargs to be used for calling check_array() on X and y respectively. - requires_y : bool or 'auto', default='auto' - If 'auto', the the 'requires_y' tag will be used to decide if `y` + requires_y : bool or 'use_tag', default='use_tag' + If 'use_tag', the 'requires_y' tag will be used to decide if `y` is required. If bool, then the caller decides if `y` is required. .. note:: - It is recommended to leave requires_y='auto' in `fit and + It is recommended to leave requires_y='use_tag' in `fit and in the first call to `partial-fit. All other methods that validate `X` and does not require `y` should set `requires_y=False`. @@ -427,7 +427,7 @@ def _validate_data(self, X, y=None, reset=True, """ if y is None: - if ((requires_y == 'auto' and self._get_tags()['requires_y']) or + if ((requires_y == 'use_tag' and self._get_tags()['requires_y']) or (isinstance(requires_y, bool) and requires_y)): raise ValueError( f"This {self.__class__.__name__} estimator " From 1ed5421ca2883d5817aa6fa285ae6e7eff95e1cb Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Fri, 2 Oct 2020 13:11:34 -0400 Subject: [PATCH 12/19] DOC add check_classifiers_train to docs --- sklearn/tests/test_common.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/sklearn/tests/test_common.py b/sklearn/tests/test_common.py index 56d2cde3e5b97..50e5ed814103c 100644 --- a/sklearn/tests/test_common.py +++ b/sklearn/tests/test_common.py @@ -276,10 +276,13 @@ def test_strict_mode_parametrize_with_checks(estimator, check): # TODO: When more modules get added, we can remove it from this list to make # sure it gets tested. After we finish each module we can move the checks # into sklearn.utils.estimator_checks.check_n_features_in. -# sklearn.utils.estimator_checks.check_estimators_partial_fit_n_features -# can either be removed or updated with the two more assertions: +# +# check_estimators_partial_fit_n_features can either be removed or updated +# with the two more assertions: # 1. `n_features_in_` is set during the first call to `partial_fit`. # 2. More strict when it comes to the error message. +# +# check_classifiers_train would need to be updated with the error message N_FEATURES_IN_AFTER_FIT_MODULES_TO_IGNORE = { 'calibration', 'cluster', From 3735ba07322e8c03486636898e7e3276dd4d27ea Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Mon, 5 Oct 2020 15:33:40 -0400 Subject: [PATCH 13/19] CLN Change signature of --- sklearn/base.py | 37 +++++++++++++++++++++---------------- 1 file changed, 21 insertions(+), 16 deletions(-) diff --git a/sklearn/base.py b/sklearn/base.py index 01a28a7ff21f4..fa6b0cd5f7f26 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -156,6 +156,8 @@ class BaseEstimator: at the class level in their ``__init__`` as explicit keyword arguments (no ``*args`` or ``**kwargs``). """ + # used by _validate_data when `y` is not validated + __NO_Y = '__NO_Y' @classmethod def _get_param_names(cls): @@ -382,9 +384,8 @@ def _check_n_features(self, X, reset): self.n_features_in_) ) - def _validate_data(self, X, y=None, reset=True, - validate_separately=False, - requires_y='use_tag', **check_params): + def _validate_data(self, X, y=__NO_Y, reset=True, + validate_separately=False, **check_params): """Validate input data and set or check the `n_features_in_` attribute. Parameters @@ -392,9 +393,19 @@ def _validate_data(self, X, y=None, reset=True, X : {array-like, sparse matrix, dataframe} of shape \ (n_samples, n_features) The input samples. - y : array-like of shape (n_samples,), default=None - The targets. If None, `check_array` is called on `X` and - `check_X_y` is called otherwise. + y : array-like of shape (n_samples,), default=__NO_Y + The targets. + + - If `None`, `check_array` is called on `X`. If the estimator's + requires_y tag is True, then an error will be raised. + - If `__NO_Y`, `check_array` is called on `X` and the estimator's + requires_y tag is ignored. + - Otherwise, both `X` and `y` are checked with either `check_array` + or `check_X_y`. + + .. note:: + Be sure to set `y` to `None` in `fit`. + reset : bool, default=True Whether to reset the `n_features_in_` attribute. If False, the input will be checked for consistency with data @@ -407,14 +418,6 @@ def _validate_data(self, X, y=None, reset=True, Only used if y is not None. If False, call validate_X_y(). Else, it must be a tuple of kwargs to be used for calling check_array() on X and y respectively. - requires_y : bool or 'use_tag', default='use_tag' - If 'use_tag', the 'requires_y' tag will be used to decide if `y` - is required. If bool, then the caller decides if `y` is required. - .. note:: - It is recommended to leave requires_y='use_tag' in `fit and - in the first call to `partial-fit. All other methods that - validate `X` and does not require `y` should set - `requires_y=False`. **check_params : kwargs Parameters passed to :func:`sklearn.utils.check_array` or :func:`sklearn.utils.check_X_y`. Ignored if validate_separately @@ -427,14 +430,16 @@ def _validate_data(self, X, y=None, reset=True, """ if y is None: - if ((requires_y == 'use_tag' and self._get_tags()['requires_y']) or - (isinstance(requires_y, bool) and requires_y)): + if self._get_tags()['requires_y']: raise ValueError( f"This {self.__class__.__name__} estimator " f"requires y to be passed, but the target y is None." ) X = check_array(X, **check_params) out = X + elif y is self.__NO_Y: + X = check_array(X, **check_params) + out = X else: if validate_separately: # We need this because some estimators validate X and y From a865265c09281fcb746580b587930736b9d83669 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Mon, 5 Oct 2020 15:45:00 -0400 Subject: [PATCH 14/19] DOC Removes note --- sklearn/base.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/sklearn/base.py b/sklearn/base.py index fa6b0cd5f7f26..6994f381ab41f 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -403,9 +403,6 @@ def _validate_data(self, X, y=__NO_Y, reset=True, - Otherwise, both `X` and `y` are checked with either `check_array` or `check_X_y`. - .. note:: - Be sure to set `y` to `None` in `fit`. - reset : bool, default=True Whether to reset the `n_features_in_` attribute. If False, the input will be checked for consistency with data From 4ec2b6245cab83de4c515328ea7d1d70d083af0e Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Mon, 5 Oct 2020 16:07:00 -0400 Subject: [PATCH 15/19] CLN Fully removes requires_y --- sklearn/neural_network/_multilayer_perceptron.py | 3 +-- sklearn/neural_network/_rbm.py | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/sklearn/neural_network/_multilayer_perceptron.py b/sklearn/neural_network/_multilayer_perceptron.py index 148e88d3789c9..ae06502d3ce1a 100644 --- a/sklearn/neural_network/_multilayer_perceptron.py +++ b/sklearn/neural_network/_multilayer_perceptron.py @@ -131,8 +131,7 @@ def _forward_pass_fast(self, X): y_pred : ndarray of shape (n_samples,) or (n_samples, n_outputs) The decision function of the samples for each class in the model. """ - X = self._validate_data(X, accept_sparse=['csr', 'csc'], - reset=False, requires_y=False) + X = self._validate_data(X, accept_sparse=['csr', 'csc'], reset=False) # Initialize first layer activation = X diff --git a/sklearn/neural_network/_rbm.py b/sklearn/neural_network/_rbm.py index 90b066360d999..97cdbb0132ebf 100644 --- a/sklearn/neural_network/_rbm.py +++ b/sklearn/neural_network/_rbm.py @@ -132,8 +132,7 @@ def transform(self, X): check_is_fitted(self) X = self._validate_data(X, accept_sparse='csr', reset=False, - dtype=(np.float64, np.float32), - requires_y=False) + dtype=(np.float64, np.float32)) return self._mean_hiddens(X) def _mean_hiddens(self, v): From 24636d46d63e2eebc045dc0f55aafc91da392fc8 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Mon, 5 Oct 2020 16:10:16 -0400 Subject: [PATCH 16/19] DOC Improves docs --- sklearn/base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sklearn/base.py b/sklearn/base.py index 6994f381ab41f..4efdcd7e4dcaa 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -399,7 +399,8 @@ def _validate_data(self, X, y=__NO_Y, reset=True, - If `None`, `check_array` is called on `X`. If the estimator's requires_y tag is True, then an error will be raised. - If `__NO_Y`, `check_array` is called on `X` and the estimator's - requires_y tag is ignored. + requires_y tag is ignored. This is a default placeholder and is + never meant to be explicitly set. - Otherwise, both `X` and `y` are checked with either `check_array` or `check_X_y`. From e397a74bfbda1185cace3a6f4a0bfea2f3096a26 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Tue, 6 Oct 2020 12:45:35 -0400 Subject: [PATCH 17/19] CLN uses 'no_validation' --- sklearn/base.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/sklearn/base.py b/sklearn/base.py index 4efdcd7e4dcaa..486b946ee00f2 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -156,8 +156,6 @@ class BaseEstimator: at the class level in their ``__init__`` as explicit keyword arguments (no ``*args`` or ``**kwargs``). """ - # used by _validate_data when `y` is not validated - __NO_Y = '__NO_Y' @classmethod def _get_param_names(cls): @@ -384,7 +382,7 @@ def _check_n_features(self, X, reset): self.n_features_in_) ) - def _validate_data(self, X, y=__NO_Y, reset=True, + def _validate_data(self, X, y='no_validation', reset=True, validate_separately=False, **check_params): """Validate input data and set or check the `n_features_in_` attribute. @@ -398,11 +396,11 @@ def _validate_data(self, X, y=__NO_Y, reset=True, - If `None`, `check_array` is called on `X`. If the estimator's requires_y tag is True, then an error will be raised. - - If `__NO_Y`, `check_array` is called on `X` and the estimator's - requires_y tag is ignored. This is a default placeholder and is - never meant to be explicitly set. + - If `'no_validation'`, `check_array` is called on `X` and the + estimator's requires_y tag is ignored. This is a default + placeholder and is never meant to be explicitly set. - Otherwise, both `X` and `y` are checked with either `check_array` - or `check_X_y`. + or `check_X_y` depending on `validate_separately`. reset : bool, default=True Whether to reset the `n_features_in_` attribute. @@ -435,7 +433,7 @@ def _validate_data(self, X, y=__NO_Y, reset=True, ) X = check_array(X, **check_params) out = X - elif y is self.__NO_Y: + elif y == 'no_validation': X = check_array(X, **check_params) out = X else: From 64b78fceab83aa2df964dea74b52814ba6205408 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Tue, 6 Oct 2020 12:50:59 -0400 Subject: [PATCH 18/19] DOC Update docstring to no_validation --- sklearn/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/base.py b/sklearn/base.py index 486b946ee00f2..05f47988189f0 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -391,7 +391,7 @@ def _validate_data(self, X, y='no_validation', reset=True, X : {array-like, sparse matrix, dataframe} of shape \ (n_samples, n_features) The input samples. - y : array-like of shape (n_samples,), default=__NO_Y + y : array-like of shape (n_samples,), default='no_validation' The targets. - If `None`, `check_array` is called on `X`. If the estimator's From c21e02da095e3aa732880736d8107a8e5d717e8d Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Tue, 6 Oct 2020 17:28:15 -0400 Subject: [PATCH 19/19] FIX Check for string first --- sklearn/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/base.py b/sklearn/base.py index 05f47988189f0..1e86d812436c4 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -433,7 +433,7 @@ def _validate_data(self, X, y='no_validation', reset=True, ) X = check_array(X, **check_params) out = X - elif y == 'no_validation': + elif isinstance(y, str) and y == 'no_validation': X = check_array(X, **check_params) out = X else: