From 83a538c146c00a6861c4afe5229abc91ba2027a7 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Fri, 9 Oct 2020 00:02:50 -0400 Subject: [PATCH 1/5] ENH Adds n_features_in_ checks to linear module --- sklearn/ensemble/tests/test_stacking.py | 5 ++--- sklearn/linear_model/_base.py | 5 +++-- sklearn/linear_model/_glm/glm.py | 13 ++++++------- sklearn/linear_model/_logistic.py | 2 +- sklearn/linear_model/_ransac.py | 2 +- sklearn/linear_model/_stochastic_gradient.py | 19 ++++++++++--------- sklearn/svm/_base.py | 9 +++------ sklearn/tests/test_common.py | 2 -- 8 files changed, 26 insertions(+), 31 deletions(-) diff --git a/sklearn/ensemble/tests/test_stacking.py b/sklearn/ensemble/tests/test_stacking.py index d6b4c385b9073..815fdd44558ec 100644 --- a/sklearn/ensemble/tests/test_stacking.py +++ b/sklearn/ensemble/tests/test_stacking.py @@ -483,9 +483,8 @@ def test_stacking_without_n_features_in(make_dataset, Stacking, Estimator): class MyEstimator(Estimator): """Estimator without n_features_in_""" - def fit(self, X, y): - super().fit(X, y) - del self.n_features_in_ + def _check_n_features(self, X, reset): + pass X, y = make_dataset(random_state=0, n_samples=100) stacker = Stacking(estimators=[('lr', MyEstimator())]) diff --git a/sklearn/linear_model/_base.py b/sklearn/linear_model/_base.py index 2399e1216238f..be0e7653634e0 100644 --- a/sklearn/linear_model/_base.py +++ b/sklearn/linear_model/_base.py @@ -217,7 +217,8 @@ def fit(self, X, y): def _decision_function(self, X): check_is_fitted(self) - X = check_array(X, accept_sparse=['csr', 'csc', 'coo']) + X = self._validate_data(X, accept_sparse=['csr', 'csc', 'coo'], + reset=False) return safe_sparse_dot(X, self.coef_.T, dense_output=True) + self.intercept_ @@ -281,7 +282,7 @@ class would be predicted. """ check_is_fitted(self) - X = check_array(X, accept_sparse='csr') + X = self._validate_data(X, accept_sparse='csr', reset=False) n_features = self.coef_.shape[1] if X.shape[1] != n_features: diff --git a/sklearn/linear_model/_glm/glm.py b/sklearn/linear_model/_glm/glm.py index 8559ef306b3a4..88088be5ae997 100644 --- a/sklearn/linear_model/_glm/glm.py +++ b/sklearn/linear_model/_glm/glm.py @@ -12,7 +12,6 @@ import scipy.optimize from ...base import BaseEstimator, RegressorMixin -from ...utils import check_array, check_X_y from ...utils.optimize import _check_optimize_result from ...utils.validation import check_is_fitted, _check_sample_weight from ..._loss.glm_distribution import ( @@ -221,9 +220,9 @@ def fit(self, X, y, sample_weight=None): family = self._family_instance link = self._link_instance - X, y = check_X_y(X, y, accept_sparse=['csc', 'csr'], - dtype=[np.float64, np.float32], - y_numeric=True, multi_output=False) + X, y = self._validate_data(X, y, accept_sparse=['csc', 'csr'], + dtype=[np.float64, np.float32], + y_numeric=True, multi_output=False) weights = _check_sample_weight(sample_weight, X) @@ -311,9 +310,9 @@ def _linear_predictor(self, X): Returns predicted values of linear predictor. """ check_is_fitted(self) - X = check_array(X, accept_sparse=['csr', 'csc', 'coo'], - dtype=[np.float64, np.float32], ensure_2d=True, - allow_nd=False) + X = self._validate_data(X, accept_sparse=['csr', 'csc', 'coo'], + dtype=[np.float64, np.float32], ensure_2d=True, + allow_nd=False, reset=False) return X @ self.coef_ + self.intercept_ def predict(self, X): diff --git a/sklearn/linear_model/_logistic.py b/sklearn/linear_model/_logistic.py index 1afa06637b04a..b40f723899d93 100644 --- a/sklearn/linear_model/_logistic.py +++ b/sklearn/linear_model/_logistic.py @@ -973,6 +973,7 @@ def _log_reg_scoring_path(X, y, train, test, pos_class=None, Cs=10, max_squared_sum=max_squared_sum, sample_weight=sample_weight) log_reg = LogisticRegression(solver=solver, multi_class=multi_class) + log_reg.n_features_in_ = X.shape[1] # The score method of Logistic Regression has a classes_ attribute. if multi_class == 'ovr': @@ -2084,7 +2085,6 @@ def score(self, X, y, sample_weight=None): """ scoring = self.scoring or 'accuracy' scoring = get_scorer(scoring) - return scoring(self, X, y, sample_weight=sample_weight) def _more_tags(self): diff --git a/sklearn/linear_model/_ransac.py b/sklearn/linear_model/_ransac.py index c9246c121c387..390b247e719f5 100644 --- a/sklearn/linear_model/_ransac.py +++ b/sklearn/linear_model/_ransac.py @@ -478,7 +478,7 @@ def predict(self, X): Returns predicted values. """ check_is_fitted(self) - + X = self._validate_data(X, accept_sparse='csr', reset=False) return self.estimator_.predict(X) def score(self, X, y): diff --git a/sklearn/linear_model/_stochastic_gradient.py b/sklearn/linear_model/_stochastic_gradient.py index e99116ca4f3e3..498bd9a5caa95 100644 --- a/sklearn/linear_model/_stochastic_gradient.py +++ b/sklearn/linear_model/_stochastic_gradient.py @@ -15,7 +15,7 @@ from ._base import LinearClassifierMixin, SparseCoefMixin from ._base import make_dataset from ..base import BaseEstimator, RegressorMixin -from ..utils import check_array, check_random_state, check_X_y +from ..utils import check_random_state from ..utils.extmath import safe_sparse_dot from ..utils.multiclass import _check_partial_fit_first_call from ..utils.validation import check_is_fitted, _check_sample_weight @@ -55,6 +55,7 @@ def __init__(self, estimator, X_val, y_val, sample_weight_val, classes=None): self.estimator = clone(estimator) self.estimator.t_ = 1 # to pass check_is_fitted + self.estimator.n_features_in_ = X_val.shape[1] if classes is not None: self.estimator.classes_ = classes self.X_val = X_val @@ -488,8 +489,10 @@ def _partial_fit(self, X, y, alpha, C, loss, learning_rate, max_iter, classes, sample_weight, coef_init, intercept_init): - X, y = check_X_y(X, y, accept_sparse='csr', dtype=np.float64, - order="C", accept_large_sparse=False) + first_call = not hasattr(self, "classes_") + X, y = self._validate_data(X, y, accept_sparse='csr', dtype=np.float64, + order="C", accept_large_sparse=False, + reset=first_call) n_samples, n_features = X.shape @@ -1135,9 +1138,10 @@ def __init__(self, loss="squared_loss", *, penalty="l2", alpha=0.0001, def _partial_fit(self, X, y, alpha, C, loss, learning_rate, max_iter, sample_weight, coef_init, intercept_init): + first_call = getattr(self, "coef_", None) is None X, y = self._validate_data(X, y, accept_sparse="csr", copy=False, order='C', dtype=np.float64, - accept_large_sparse=False) + accept_large_sparse=False, reset=first_call) y = y.astype(np.float64, copy=False) n_samples, n_features = X.shape @@ -1145,12 +1149,9 @@ def _partial_fit(self, X, y, alpha, C, loss, learning_rate, sample_weight = _check_sample_weight(sample_weight, X) # Allocate datastructures from input arguments - if getattr(self, "coef_", None) is None: + if first_call: self._allocate_parameter_mem(1, n_features, coef_init, intercept_init) - elif n_features != self.coef_.shape[-1]: - raise ValueError("Number of features %d does not match previous " - "data %d." % (n_features, self.coef_.shape[-1])) if self.average > 0 and getattr(self, "_average_coef", None) is None: self._average_coef = np.zeros(n_features, dtype=np.float64, @@ -1266,7 +1267,7 @@ def _decision_function(self, X): """ check_is_fitted(self) - X = check_array(X, accept_sparse='csr') + X = self._validate_data(X, accept_sparse='csr', reset=False) scores = safe_sparse_dot(X, self.coef_.T, dense_output=True) + self.intercept_ diff --git a/sklearn/svm/_base.py b/sklearn/svm/_base.py index c5196a5801607..bf57f057182cf 100644 --- a/sklearn/svm/_base.py +++ b/sklearn/svm/_base.py @@ -471,8 +471,9 @@ def _validate_for_predict(self, X): check_is_fitted(self) if not callable(self.kernel): - X = check_array(X, accept_sparse='csr', dtype=np.float64, - order="C", accept_large_sparse=False) + X = self._validate_data(X, accept_sparse='csr', dtype=np.float64, + order="C", accept_large_sparse=False, + reset=False) if self._sparse and not sp.isspmatrix(X): X = sp.csr_matrix(X) @@ -489,10 +490,6 @@ def _validate_for_predict(self, X): raise ValueError("X.shape[1] = %d should be equal to %d, " "the number of samples at training time" % (X.shape[1], self.shape_fit_[0])) - elif not callable(self.kernel) and X.shape[1] != self.shape_fit_[1]: - raise ValueError("X.shape[1] = %d should be equal to %d, " - "the number of features at training time" % - (X.shape[1], self.shape_fit_[1])) return X @property diff --git a/sklearn/tests/test_common.py b/sklearn/tests/test_common.py index b84b66d1fb919..a1ad846b48cca 100644 --- a/sklearn/tests/test_common.py +++ b/sklearn/tests/test_common.py @@ -298,7 +298,6 @@ def test_strict_mode_parametrize_with_checks(estimator, check): 'isotonic', 'kernel_approximation', 'kernel_ridge', - 'linear_model', 'manifold', 'mixture', 'model_selection', @@ -310,7 +309,6 @@ def test_strict_mode_parametrize_with_checks(estimator, check): 'preprocessing', 'random_projection', 'semi_supervised', - 'svm', 'tree', } From 61ba6d5b2a80f2fbbc8c34d093e36ee6dad6abd5 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Fri, 9 Oct 2020 00:04:59 -0400 Subject: [PATCH 2/5] REV Reduces diff --- sklearn/linear_model/_logistic.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sklearn/linear_model/_logistic.py b/sklearn/linear_model/_logistic.py index b40f723899d93..d0083faa1f246 100644 --- a/sklearn/linear_model/_logistic.py +++ b/sklearn/linear_model/_logistic.py @@ -2085,6 +2085,7 @@ def score(self, X, y, sample_weight=None): """ scoring = self.scoring or 'accuracy' scoring = get_scorer(scoring) + return scoring(self, X, y, sample_weight=sample_weight) def _more_tags(self): From 6b4481135e35db667c8abfab0f67b7e34e42a758 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Tue, 13 Oct 2020 16:49:51 -0400 Subject: [PATCH 3/5] CLN Remove unreachable code --- sklearn/linear_model/_base.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/sklearn/linear_model/_base.py b/sklearn/linear_model/_base.py index be0e7653634e0..56fc3e9f71edf 100644 --- a/sklearn/linear_model/_base.py +++ b/sklearn/linear_model/_base.py @@ -283,12 +283,6 @@ class would be predicted. check_is_fitted(self) X = self._validate_data(X, accept_sparse='csr', reset=False) - - n_features = self.coef_.shape[1] - if X.shape[1] != n_features: - raise ValueError("X has %d features per sample; expecting %d" - % (X.shape[1], n_features)) - scores = safe_sparse_dot(X, self.coef_.T, dense_output=True) + self.intercept_ return scores.ravel() if scores.shape[1] == 1 else scores From 7201313bba799e104b583196f84a06c66129fd22 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Tue, 13 Oct 2020 16:52:06 -0400 Subject: [PATCH 4/5] REV Reduces diff --- sklearn/ensemble/tests/test_stacking.py | 5 +++-- sklearn/linear_model/_logistic.py | 1 - sklearn/linear_model/_ransac.py | 2 +- sklearn/linear_model/_stochastic_gradient.py | 1 - 4 files changed, 4 insertions(+), 5 deletions(-) diff --git a/sklearn/ensemble/tests/test_stacking.py b/sklearn/ensemble/tests/test_stacking.py index 815fdd44558ec..d6b4c385b9073 100644 --- a/sklearn/ensemble/tests/test_stacking.py +++ b/sklearn/ensemble/tests/test_stacking.py @@ -483,8 +483,9 @@ def test_stacking_without_n_features_in(make_dataset, Stacking, Estimator): class MyEstimator(Estimator): """Estimator without n_features_in_""" - def _check_n_features(self, X, reset): - pass + def fit(self, X, y): + super().fit(X, y) + del self.n_features_in_ X, y = make_dataset(random_state=0, n_samples=100) stacker = Stacking(estimators=[('lr', MyEstimator())]) diff --git a/sklearn/linear_model/_logistic.py b/sklearn/linear_model/_logistic.py index d0083faa1f246..1afa06637b04a 100644 --- a/sklearn/linear_model/_logistic.py +++ b/sklearn/linear_model/_logistic.py @@ -973,7 +973,6 @@ def _log_reg_scoring_path(X, y, train, test, pos_class=None, Cs=10, max_squared_sum=max_squared_sum, sample_weight=sample_weight) log_reg = LogisticRegression(solver=solver, multi_class=multi_class) - log_reg.n_features_in_ = X.shape[1] # The score method of Logistic Regression has a classes_ attribute. if multi_class == 'ovr': diff --git a/sklearn/linear_model/_ransac.py b/sklearn/linear_model/_ransac.py index 390b247e719f5..c9246c121c387 100644 --- a/sklearn/linear_model/_ransac.py +++ b/sklearn/linear_model/_ransac.py @@ -478,7 +478,7 @@ def predict(self, X): Returns predicted values. """ check_is_fitted(self) - X = self._validate_data(X, accept_sparse='csr', reset=False) + return self.estimator_.predict(X) def score(self, X, y): diff --git a/sklearn/linear_model/_stochastic_gradient.py b/sklearn/linear_model/_stochastic_gradient.py index 498bd9a5caa95..ed96d066370ff 100644 --- a/sklearn/linear_model/_stochastic_gradient.py +++ b/sklearn/linear_model/_stochastic_gradient.py @@ -55,7 +55,6 @@ def __init__(self, estimator, X_val, y_val, sample_weight_val, classes=None): self.estimator = clone(estimator) self.estimator.t_ = 1 # to pass check_is_fitted - self.estimator.n_features_in_ = X_val.shape[1] if classes is not None: self.estimator.classes_ = classes self.X_val = X_val From 81b3fc748eff3fadb118145ad8b23f983a9b6092 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Wed, 14 Oct 2020 15:52:55 +0200 Subject: [PATCH 5/5] Expect base estimator names --- sklearn/utils/estimator_checks.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index cb1df0ad95ff3..4d178eb6402d1 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -3143,6 +3143,14 @@ def check_requires_y_none(name, estimator_orig, strict_mode=True): warnings.warn(warning_msg, FutureWarning) +def _accumulate_estimator_names(estimator, names): + for attribute_name in ["estimator_", "base_estimator_"]: + base_estimator = getattr(estimator, attribute_name, None) + if base_estimator is not None: + names.append(base_estimator.__class__.__name__) + _accumulate_estimator_names(base_estimator, names) + + def check_n_features_in_after_fitting(name, estimator_orig, strict_mode=True): # Make sure that n_features_in are checked after fitting tags = estimator_orig._get_tags() @@ -3173,8 +3181,10 @@ def check_n_features_in_after_fitting(name, estimator_orig, strict_mode=True): check_methods = ["predict", "transform", "decision_function", "predict_proba"] X_bad = X[:, [1]] - - msg = (f"X has 1 features, but {name} is expecting {X.shape[1]} " + names = [name] + _accumulate_estimator_names(estimator, names) + names = "|".join(names) + msg = (f"X has 1 features, but ({names}) is expecting {X.shape[1]} " "features as input") for method in check_methods: if not hasattr(estimator, method):