diff --git a/sklearn/base.py b/sklearn/base.py index ca957898c42ff..36a849d2e2c1e 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -14,6 +14,8 @@ from . import __version__ from .utils import _IS_32BIT +from .utils.validation import check_X_y +from .utils.validation import check_array _DEFAULT_TAGS = { 'non_deterministic': False, @@ -323,6 +325,31 @@ def _get_tags(self): collected_tags.update(more_tags) return collected_tags + def _validate_n_features(self, X, check_n_features): + if check_n_features: + if not hasattr(self, 'n_features_in_'): + raise RuntimeError( + "check_n_features is True but there is no n_features_in_ " + "attribute." + ) + if X.shape[1] != self.n_features_in_: + raise ValueError( + 'X has {} features, but this {} is expecting {} features ' + 'as input.'.format(X.shape[1], self.__class__.__name__, + self.n_features_in_) + ) + else: + self.n_features_in_ = X.shape[1] + + def _validate_X(self, X, check_n_features=False, **check_array_params): + X = check_array(X, **check_array_params) + self._validate_n_features(X, check_n_features) + return X + + def _validate_X_y(self, X, y, check_n_features=False, **check_X_y_params): + X, y = check_X_y(X, y, **check_X_y_params) + self._validate_n_features(X, check_n_features) + return X, y class ClassifierMixin: """Mixin class for all classifiers in scikit-learn.""" diff --git a/sklearn/calibration.py b/sklearn/calibration.py index d19a0d8ead5a5..71d1b94680ac0 100644 --- a/sklearn/calibration.py +++ b/sklearn/calibration.py @@ -130,8 +130,8 @@ def fit(self, X, y, sample_weight=None): self : object Returns an instance of self. """ - X, y = check_X_y(X, y, accept_sparse=['csc', 'csr', 'coo'], - force_all_finite=False, allow_nd=True) + X, y = self._validate_X_y(X, y, accept_sparse=['csc', 'csr', 'coo'], + force_all_finite=False, allow_nd=True) X, y = indexable(X, y) le = LabelBinarizer().fit(y) self.classes_ = le.classes_ diff --git a/sklearn/cluster/affinity_propagation_.py b/sklearn/cluster/affinity_propagation_.py index 4806afee90d1b..b7293d3e0c2b8 100644 --- a/sklearn/cluster/affinity_propagation_.py +++ b/sklearn/cluster/affinity_propagation_.py @@ -372,7 +372,7 @@ def fit(self, X, y=None): accept_sparse = False else: accept_sparse = 'csr' - X = check_array(X, accept_sparse=accept_sparse) + X = self._validate_X(X, accept_sparse=accept_sparse) if self.affinity == "precomputed": self.affinity_matrix_ = X elif self.affinity == "euclidean": diff --git a/sklearn/cluster/bicluster.py b/sklearn/cluster/bicluster.py index 5bfd335549012..3b54df43fe295 100644 --- a/sklearn/cluster/bicluster.py +++ b/sklearn/cluster/bicluster.py @@ -115,7 +115,7 @@ def fit(self, X, y=None): y : Ignored """ - X = check_array(X, accept_sparse='csr', dtype=np.float64) + X = self._validate_X(X, accept_sparse='csr', dtype=np.float64) self._check_parameters() self._fit(X) return self diff --git a/sklearn/cluster/birch.py b/sklearn/cluster/birch.py index 2593d2cfcc3a5..b2cf15466ed3f 100644 --- a/sklearn/cluster/birch.py +++ b/sklearn/cluster/birch.py @@ -445,7 +445,7 @@ def fit(self, X, y=None): return self._fit(X) def _fit(self, X): - X = check_array(X, accept_sparse='csr', copy=self.copy) + X = self._validate_X(X, accept_sparse='csr', copy=self.copy) threshold = self.threshold branching_factor = self.branching_factor diff --git a/sklearn/cluster/dbscan_.py b/sklearn/cluster/dbscan_.py index 4d40d36627100..9b6467d170e70 100644 --- a/sklearn/cluster/dbscan_.py +++ b/sklearn/cluster/dbscan_.py @@ -306,7 +306,7 @@ def fit(self, X, y=None, sample_weight=None): self """ - X = check_array(X, accept_sparse='csr') + X = self._validate_X(X, accept_sparse='csr') if not self.eps > 0.0: raise ValueError("eps must be positive.") diff --git a/sklearn/cluster/hierarchical.py b/sklearn/cluster/hierarchical.py index 36ccf95253e96..6413c7c11b6a5 100644 --- a/sklearn/cluster/hierarchical.py +++ b/sklearn/cluster/hierarchical.py @@ -790,7 +790,7 @@ def fit(self, X, y=None): ------- self """ - X = check_array(X, ensure_min_samples=2, estimator=self) + X = self._validate_X(X, ensure_min_samples=2, estimator=self) memory = check_memory(self.memory) if self.n_clusters is not None and self.n_clusters <= 0: @@ -1034,9 +1034,14 @@ def fit(self, X, y=None, **params): ------- self """ - X = check_array(X, accept_sparse=['csr', 'csc', 'coo'], - ensure_min_features=2, estimator=self) - return AgglomerativeClustering.fit(self, X.T, **params) + X = self._validate_X(X, accept_sparse=['csr', 'csc', 'coo'], + ensure_min_features=2, estimator=self) + n_features_in_ = self.n_features_in_ + AgglomerativeClustering.fit(self, X.T, **params) + # Need to restore n_features_in_ attribute that was overridden in + # AgglomerativeClustering since we passed it X.T. + self.n_features_in_ = n_features_in_ + return self @property def fit_predict(self): diff --git a/sklearn/cluster/k_means_.py b/sklearn/cluster/k_means_.py index 8af8cc6873011..0f15a6f6dc698 100644 --- a/sklearn/cluster/k_means_.py +++ b/sklearn/cluster/k_means_.py @@ -852,8 +852,9 @@ def fit(self, X, y=None, sample_weight=None): # avoid forcing order when copy_x=False order = "C" if self.copy_x else None - X = check_array(X, accept_sparse='csr', dtype=[np.float64, np.float32], - order=order, copy=self.copy_x) + X = self._validate_X(X, accept_sparse='csr', + dtype=[np.float64, np.float32], + order=order, copy=self.copy_x) # verify that the number of samples given is larger than k if _num_samples(X) < self.n_clusters: raise ValueError("n_samples=%d should be >= n_clusters=%d" % ( @@ -1497,8 +1498,8 @@ def fit(self, X, y=None, sample_weight=None): """ random_state = check_random_state(self.random_state) - X = check_array(X, accept_sparse="csr", order='C', - dtype=[np.float64, np.float32]) + X = self._validate_X(X, accept_sparse="csr", order='C', + dtype=[np.float64, np.float32]) n_samples, n_features = X.shape if n_samples < self.n_clusters: raise ValueError("n_samples=%d should be >= n_clusters=%d" diff --git a/sklearn/cluster/mean_shift_.py b/sklearn/cluster/mean_shift_.py index 6cccff6bddf18..9c8c95a1c53f6 100644 --- a/sklearn/cluster/mean_shift_.py +++ b/sklearn/cluster/mean_shift_.py @@ -414,7 +414,7 @@ def fit(self, X, y=None): y : Ignored """ - X = check_array(X) + X = self._validate_X(X) self.cluster_centers_, self.labels_ = \ mean_shift(X, bandwidth=self.bandwidth, seeds=self.seeds, min_bin_freq=self.min_bin_freq, diff --git a/sklearn/cluster/optics_.py b/sklearn/cluster/optics_.py index 46df91683863d..8211e50d4cb1d 100755 --- a/sklearn/cluster/optics_.py +++ b/sklearn/cluster/optics_.py @@ -233,7 +233,7 @@ def fit(self, X, y=None): self : instance of OPTICS The instance. """ - X = check_array(X, dtype=np.float) + X = self._validate_X(X, dtype=np.float) if self.cluster_method not in ['dbscan', 'xi']: raise ValueError("cluster_method should be one of" diff --git a/sklearn/cluster/spectral.py b/sklearn/cluster/spectral.py index b89f2bc37d65d..588742613938d 100644 --- a/sklearn/cluster/spectral.py +++ b/sklearn/cluster/spectral.py @@ -474,8 +474,8 @@ def fit(self, X, y=None): self """ - X = check_array(X, accept_sparse=['csr', 'csc', 'coo'], - dtype=np.float64, ensure_min_samples=2) + X = self._validate_X(X, accept_sparse=['csr', 'csc', 'coo'], + dtype=np.float64, ensure_min_samples=2) allow_squared = self.affinity in ["precomputed", "precomputed_nearest_neighbors"] if X.shape[0] == X.shape[1] and not allow_squared: diff --git a/sklearn/cluster/tests/test_bicluster.py b/sklearn/cluster/tests/test_bicluster.py index 1d88769f238aa..5057480572a6b 100644 --- a/sklearn/cluster/tests/test_bicluster.py +++ b/sklearn/cluster/tests/test_bicluster.py @@ -256,3 +256,14 @@ def test_wrong_shape(): data = np.arange(27).reshape((3, 3, 3)) with pytest.raises(ValueError): model.fit(data) + + +@pytest.mark.parametrize('est', + (SpectralBiclustering(), SpectralCoclustering())) +def test_n_features_in_(est): + + X, _, _ = make_biclusters((3, 3), 3, random_state=0) + + assert not hasattr(est, 'n_features_in_') + est.fit(X) + assert est.n_features_in_ == 3 diff --git a/sklearn/compose/_column_transformer.py b/sklearn/compose/_column_transformer.py index 6335fd7a4b20d..4fd3dcab780f8 100644 --- a/sklearn/compose/_column_transformer.py +++ b/sklearn/compose/_column_transformer.py @@ -506,6 +506,8 @@ def fit_transform(self, X, y=None): else: self._feature_names_in = None X = _check_X(X) + # set n_features_in_ attribute + self._validate_n_features(X, check_n_features=False) self._validate_transformers() self._validate_column_callables(X) self._validate_remainder(X) @@ -579,6 +581,7 @@ def transform(self, X): 'and for transform when using the ' 'remainder keyword') + # TODO: also call _validate_n_features(check_n_features=True) in 0.24 self._validate_features(X.shape[1], X_feature_names) Xs = self._fit_transform(X, None, _transform_one, fitted=True) self._validate_output(Xs) diff --git a/sklearn/compose/_target.py b/sklearn/compose/_target.py index 8fc02462257c0..46f5598e215c7 100644 --- a/sklearn/compose/_target.py +++ b/sklearn/compose/_target.py @@ -10,6 +10,7 @@ from ..utils.validation import check_is_fitted from ..utils import check_array, safe_indexing from ..preprocessing import FunctionTransformer +from ..exceptions import NotFittedError __all__ = ['TransformedTargetRegressor'] @@ -234,3 +235,17 @@ def predict(self, X): def _more_tags(self): return {'poor_score': True, 'no_validation': True} + + @property + def n_features_in_(self): + # For consistency with other estimators we raise a AttributeError so + # that hasattr() fails if the estimator isn't fitted. + try: + check_is_fitted(self) + except NotFittedError as nfe: + raise AttributeError( + "{} object has no n_features_in_ attribute." + .format(self.__class__.__name__) + ) from nfe + + return self.regressor_.n_features_in_ diff --git a/sklearn/compose/tests/test_column_transformer.py b/sklearn/compose/tests/test_column_transformer.py index 094b2769de369..0f2d9fbcd30fe 100644 --- a/sklearn/compose/tests/test_column_transformer.py +++ b/sklearn/compose/tests/test_column_transformer.py @@ -1180,3 +1180,15 @@ def test_column_transformer_mask_indexing(array_type): ) X_trans = column_transformer.fit_transform(X) assert X_trans.shape == (3, 2) + + +def test_n_features_in(): + # make sure n_features_in is what is passed as input to the column + # transformer. + + X = [[1, 2], [3, 4], [5, 6]] + ct = ColumnTransformer([('a', DoubleTrans(), [0]), + ('b', DoubleTrans(), [1])]) + assert not hasattr(ct, 'n_features_in_') + ct.fit(X) + assert ct.n_features_in_ == 2 diff --git a/sklearn/covariance/empirical_covariance_.py b/sklearn/covariance/empirical_covariance_.py index 924f7edd7ffee..aa78d788142e2 100644 --- a/sklearn/covariance/empirical_covariance_.py +++ b/sklearn/covariance/empirical_covariance_.py @@ -191,7 +191,7 @@ def fit(self, X, y=None): self : object """ - X = check_array(X) + X = self._validate_X(X) if self.assume_centered: self.location_ = np.zeros(X.shape[1]) else: diff --git a/sklearn/covariance/graph_lasso_.py b/sklearn/covariance/graph_lasso_.py index e78950bd60421..874b5d5576c50 100644 --- a/sklearn/covariance/graph_lasso_.py +++ b/sklearn/covariance/graph_lasso_.py @@ -378,8 +378,8 @@ def fit(self, X, y=None): y : (ignored) """ # Covariance does not make sense for a single feature - X = check_array(X, ensure_min_features=2, ensure_min_samples=2, - estimator=self) + X = self._validate_X(X, ensure_min_features=2, ensure_min_samples=2, + estimator=self) if self.assume_centered: self.location_ = np.zeros(X.shape[1]) @@ -645,7 +645,7 @@ def fit(self, X, y=None): y : (ignored) """ # Covariance does not make sense for a single feature - X = check_array(X, ensure_min_features=2, estimator=self) + X = self._validate_X(X, ensure_min_features=2, estimator=self) if self.assume_centered: self.location_ = np.zeros(X.shape[1]) else: diff --git a/sklearn/covariance/robust_covariance.py b/sklearn/covariance/robust_covariance.py index 0c38a38e99bd1..34fc7ce481855 100644 --- a/sklearn/covariance/robust_covariance.py +++ b/sklearn/covariance/robust_covariance.py @@ -636,7 +636,7 @@ def fit(self, X, y=None): self : object """ - X = check_array(X, ensure_min_samples=2, estimator='MinCovDet') + X = self._validate_X(X, ensure_min_samples=2, estimator='MinCovDet') random_state = check_random_state(self.random_state) n_samples, n_features = X.shape # check that the empirical covariance is full rank diff --git a/sklearn/covariance/shrunk_covariance_.py b/sklearn/covariance/shrunk_covariance_.py index 6a0c80d2e4ff6..26b8ce237cbb5 100644 --- a/sklearn/covariance/shrunk_covariance_.py +++ b/sklearn/covariance/shrunk_covariance_.py @@ -143,7 +143,7 @@ def fit(self, X, y=None): self : object """ - X = check_array(X) + X = self._validate_X(X) # Not calling the parent object to fit, to avoid a potential # matrix inversion when setting the precision if self.assume_centered: @@ -419,7 +419,7 @@ def fit(self, X, y=None): """ # Not calling the parent object to fit, to avoid computing the # covariance matrix (and potentially the precision) - X = check_array(X) + X = self._validate_X(X) if self.assume_centered: self.location_ = np.zeros(X.shape[1]) else: @@ -572,7 +572,7 @@ def fit(self, X, y=None): self : object """ - X = check_array(X) + X = self._validate_X(X) # Not calling the parent object to fit, to avoid computing the # covariance matrix (and potentially the precision) if self.assume_centered: diff --git a/sklearn/cross_decomposition/pls_.py b/sklearn/cross_decomposition/pls_.py index c1eb72df11607..f570d93bc404a 100644 --- a/sklearn/cross_decomposition/pls_.py +++ b/sklearn/cross_decomposition/pls_.py @@ -252,8 +252,8 @@ def fit(self, X, Y): # copy since this will contains the residuals (deflated) matrices check_consistent_length(X, Y) - X = check_array(X, dtype=np.float64, copy=self.copy, - ensure_min_samples=2) + X = self._validate_X(X, dtype=np.float64, copy=self.copy, + ensure_min_samples=2) Y = check_array(Y, dtype=np.float64, copy=self.copy, ensure_2d=False) if Y.ndim == 1: Y = Y.reshape(-1, 1) @@ -828,8 +828,8 @@ def fit(self, X, Y): """ # copy since this will contains the centered data check_consistent_length(X, Y) - X = check_array(X, dtype=np.float64, copy=self.copy, - ensure_min_samples=2) + X = self._validate_X(X, dtype=np.float64, copy=self.copy, + ensure_min_samples=2) Y = check_array(Y, dtype=np.float64, copy=self.copy, ensure_2d=False) if Y.ndim == 1: Y = Y.reshape(-1, 1) diff --git a/sklearn/decomposition/dict_learning.py b/sklearn/decomposition/dict_learning.py index 05f06edc05934..501b259422533 100644 --- a/sklearn/decomposition/dict_learning.py +++ b/sklearn/decomposition/dict_learning.py @@ -1044,6 +1044,10 @@ def fit(self, X, y=None): """ return self + @property + def n_features_in_(self): + return self.components_.shape[1] + class DictionaryLearning(SparseCodingMixin, BaseEstimator): """Dictionary learning @@ -1217,7 +1221,7 @@ def fit(self, X, y=None): Returns the object itself """ random_state = check_random_state(self.random_state) - X = check_array(X) + X = self._validate_X(X) if self.n_components is None: n_components = X.shape[1] else: @@ -1423,7 +1427,7 @@ def fit(self, X, y=None): Returns the instance itself. """ random_state = check_random_state(self.random_state) - X = check_array(X) + X = self._validate_X(X) U, (A, B), self.n_iter_ = dict_learning_online( X, self.n_components, self.alpha, diff --git a/sklearn/decomposition/factor_analysis.py b/sklearn/decomposition/factor_analysis.py index 4fa48d5d0d88f..61d1f479f573e 100644 --- a/sklearn/decomposition/factor_analysis.py +++ b/sklearn/decomposition/factor_analysis.py @@ -167,7 +167,7 @@ def fit(self, X, y=None): ------- self """ - X = check_array(X, copy=self.copy, dtype=np.float64) + X = self._validate_X(X, copy=self.copy, dtype=np.float64) n_samples, n_features = X.shape n_components = self.n_components diff --git a/sklearn/decomposition/fastica_.py b/sklearn/decomposition/fastica_.py index dffce0dc0d8bc..7815ccd2b0ae8 100644 --- a/sklearn/decomposition/fastica_.py +++ b/sklearn/decomposition/fastica_.py @@ -501,6 +501,11 @@ def _fit(self, X, compute_sources=False): ------- X_new : array-like, shape (n_samples, n_components) """ + + # This validates twice but there is not clean way to avoid validation + # in fastica(). Please see issue 14897. + self._validate_X(X, copy=self.whiten, dtype=FLOAT_DTYPES, + ensure_min_samples=2).T fun_args = {} if self.fun_args is None else self.fun_args whitening, unmixing, sources, X_mean, self.n_iter_ = fastica( X=X, n_components=self.n_components, algorithm=self.algorithm, diff --git a/sklearn/decomposition/incremental_pca.py b/sklearn/decomposition/incremental_pca.py index c6d611dcd5fea..815a912f92f5d 100644 --- a/sklearn/decomposition/incremental_pca.py +++ b/sklearn/decomposition/incremental_pca.py @@ -192,8 +192,8 @@ def fit(self, X, y=None): self.singular_values_ = None self.noise_variance_ = None - X = check_array(X, accept_sparse=['csr', 'csc', 'lil'], - copy=self.copy, dtype=[np.float64, np.float32]) + X = self._validate_X(X, accept_sparse=['csr', 'csc', 'lil'], + copy=self.copy, dtype=[np.float64, np.float32]) n_samples, n_features = X.shape if self.batch_size is None: diff --git a/sklearn/decomposition/kernel_pca.py b/sklearn/decomposition/kernel_pca.py index 1429106495a6e..54e44f7f94131 100644 --- a/sklearn/decomposition/kernel_pca.py +++ b/sklearn/decomposition/kernel_pca.py @@ -271,7 +271,7 @@ def fit(self, X, y=None): self : object Returns the instance itself. """ - X = check_array(X, accept_sparse='csr', copy=self.copy_X) + X = self._validate_X(X, accept_sparse='csr', copy=self.copy_X) self._centerer = KernelCenterer() K = self._get_kernel(X) self._fit_transform(K) diff --git a/sklearn/decomposition/nmf.py b/sklearn/decomposition/nmf.py index 0cf663e123861..ab16e257e2345 100644 --- a/sklearn/decomposition/nmf.py +++ b/sklearn/decomposition/nmf.py @@ -1268,7 +1268,7 @@ def fit_transform(self, X, y=None, W=None, H=None): W : array, shape (n_samples, n_components) Transformed data. """ - X = check_array(X, accept_sparse=('csr', 'csc'), dtype=float) + X = self._validate_X(X, accept_sparse=('csr', 'csc'), dtype=float) W, H, n_iter_ = non_negative_factorization( X=X, W=W, H=H, n_components=self.n_components, init=self.init, diff --git a/sklearn/decomposition/online_lda.py b/sklearn/decomposition/online_lda.py index 862635c65500b..3c0bcb9372bd9 100644 --- a/sklearn/decomposition/online_lda.py +++ b/sklearn/decomposition/online_lda.py @@ -469,7 +469,7 @@ def _em_step(self, X, total_samples, batch_update, parallel=None): def _more_tags(self): return {'requires_positive_X': True} - def _check_non_neg_array(self, X, whom): + def _check_non_neg_array(self, X, check_n_features, whom): """check X format check X format and make sure no negative value in X. @@ -479,7 +479,8 @@ def _check_non_neg_array(self, X, whom): X : array-like or sparse matrix """ - X = check_array(X, accept_sparse='csr') + X = self._validate_X(X, check_n_features=check_n_features, + accept_sparse='csr') check_non_negative(X, whom) return X @@ -498,13 +499,20 @@ def partial_fit(self, X, y=None): self """ self._check_params() - X = self._check_non_neg_array(X, + first_time = not hasattr(self, 'components_') + # deactivating check for now (specific tests about error message would + # break) + # TODO: uncomment when addressing check_n_features in + # predict/transform/etc. + # check_n_features = not in_fit + check_n_features = False + X = self._check_non_neg_array(X, check_n_features, "LatentDirichletAllocation.partial_fit") n_samples, n_features = X.shape batch_size = self.batch_size # initialize parameters or check - if not hasattr(self, 'components_'): + if first_time: self._init_latent_vars(n_features) if n_features != self.components_.shape[1]: @@ -542,7 +550,8 @@ def fit(self, X, y=None): self """ self._check_params() - X = self._check_non_neg_array(X, "LatentDirichletAllocation.fit") + X = self._check_non_neg_array(X, check_n_features=False, + whom="LatentDirichletAllocation.fit") n_samples, n_features = X.shape max_iter = self.max_iter evaluate_every = self.evaluate_every @@ -611,7 +620,9 @@ def _unnormalized_transform(self, X): check_is_fitted(self) # make sure feature size is the same in fitted model and in X - X = self._check_non_neg_array(X, "LatentDirichletAllocation.transform") + X = self._check_non_neg_array( + X, check_n_features=False, + whom="LatentDirichletAllocation.transform") n_samples, n_features = X.shape if n_features != self.components_.shape[1]: raise ValueError( @@ -735,7 +746,8 @@ def score(self, X, y=None): score : float Use approximate bound as score. """ - X = self._check_non_neg_array(X, "LatentDirichletAllocation.score") + X = self._check_non_neg_array(X, check_n_features=False, + whom="LatentDirichletAllocation.score") doc_topic_distr = self._unnormalized_transform(X) score = self._approx_bound(X, doc_topic_distr, sub_sampling=False) @@ -764,8 +776,9 @@ def _perplexity_precomp_distr(self, X, doc_topic_distr=None, """ check_is_fitted(self) - X = self._check_non_neg_array(X, - "LatentDirichletAllocation.perplexity") + X = self._check_non_neg_array( + X, check_n_features=False, + whom="LatentDirichletAllocation.perplexity") if doc_topic_distr is None: doc_topic_distr = self._unnormalized_transform(X) diff --git a/sklearn/decomposition/pca.py b/sklearn/decomposition/pca.py index 1bf3d6e6b19e6..001bf5b0c3953 100644 --- a/sklearn/decomposition/pca.py +++ b/sklearn/decomposition/pca.py @@ -385,8 +385,8 @@ def _fit(self, X): raise TypeError('PCA does not support sparse input. See ' 'TruncatedSVD for a possible alternative.') - X = check_array(X, dtype=[np.float64, np.float32], ensure_2d=True, - copy=self.copy) + X = self._validate_X(X, dtype=[np.float64, np.float32], ensure_2d=True, + copy=self.copy) # Handle n_components==None if self.n_components is None: diff --git a/sklearn/decomposition/sparse_pca.py b/sklearn/decomposition/sparse_pca.py index 50f869fa4b1e8..3cef1005f1bc8 100644 --- a/sklearn/decomposition/sparse_pca.py +++ b/sklearn/decomposition/sparse_pca.py @@ -166,7 +166,7 @@ def fit(self, X, y=None): Returns the instance itself. """ random_state = check_random_state(self.random_state) - X = check_array(X) + X = self._validate_X(X) _check_normalize_components( self.normalize_components, self.__class__.__name__ @@ -364,7 +364,7 @@ def fit(self, X, y=None): Returns the instance itself. """ random_state = check_random_state(self.random_state) - X = check_array(X) + X = self._validate_X(X) _check_normalize_components( self.normalize_components, self.__class__.__name__ diff --git a/sklearn/decomposition/tests/test_dict_learning.py b/sklearn/decomposition/tests/test_dict_learning.py index 54c5ece561f18..af8a1869626f3 100644 --- a/sklearn/decomposition/tests/test_dict_learning.py +++ b/sklearn/decomposition/tests/test_dict_learning.py @@ -498,3 +498,9 @@ def test_sparse_coder_parallel_mmap(): sc = SparseCoder(init_dict, transform_algorithm='omp', n_jobs=2) sc.fit_transform(data) + + +def test_sparse_coder_n_features_in(): + d = np.array([[1, 2, 3], [1, 2, 3]]) + sc = SparseCoder(d) + assert sc.n_features_in_ == d.shape[1] diff --git a/sklearn/decomposition/truncated_svd.py b/sklearn/decomposition/truncated_svd.py index 13511cb7066b7..351d9421d1727 100644 --- a/sklearn/decomposition/truncated_svd.py +++ b/sklearn/decomposition/truncated_svd.py @@ -156,8 +156,8 @@ def fit_transform(self, X, y=None): X_new : array, shape (n_samples, n_components) Reduced version of X. This will always be a dense array. """ - X = check_array(X, accept_sparse=['csr', 'csc'], - ensure_min_features=2) + X = self._validate_X(X, accept_sparse=['csr', 'csc'], + ensure_min_features=2) random_state = check_random_state(self.random_state) if self.algorithm == "arpack": diff --git a/sklearn/discriminant_analysis.py b/sklearn/discriminant_analysis.py index f6d442fa91bdf..953c4f7303103 100644 --- a/sklearn/discriminant_analysis.py +++ b/sklearn/discriminant_analysis.py @@ -424,8 +424,8 @@ def fit(self, X, y): Target values. """ # FIXME: Future warning to be removed in 0.23 - X, y = check_X_y(X, y, ensure_min_samples=2, estimator=self, - dtype=[np.float64, np.float32]) + X, y = self._validate_X_y(X, y, ensure_min_samples=2, estimator=self, + dtype=[np.float64, np.float32]) self.classes_ = unique_labels(y) n_samples, _ = X.shape n_classes = len(self.classes_) @@ -656,7 +656,7 @@ def fit(self, X, y): y : array, shape = [n_samples] Target values (integers) """ - X, y = check_X_y(X, y) + X, y = self._validate_X_y(X, y) check_classification_targets(y) self.classes_, y = np.unique(y, return_inverse=True) n_samples, n_features = X.shape diff --git a/sklearn/dummy.py b/sklearn/dummy.py index 233dc27aec076..6dc524b778e45 100644 --- a/sklearn/dummy.py +++ b/sklearn/dummy.py @@ -127,6 +127,7 @@ def fit(self, X, y, sample_weight=None): self.n_outputs_ = y.shape[1] check_consistent_length(X, y, sample_weight) + self.n_features_in_ = None # No input validation is done for X if self.strategy == "constant": if self.constant is None: @@ -434,6 +435,7 @@ def fit(self, X, y, sample_weight=None): % (self.strategy, allowed_strategies)) y = check_array(y, ensure_2d=False) + self.n_features_in_ = None # No input validation is done for X if len(y) == 0: raise ValueError("y must not be empty.") diff --git a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py index b5c2d2b77f841..d5f90ef77371f 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py +++ b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py @@ -101,7 +101,8 @@ def fit(self, X, y): acc_compute_hist_time = 0. # time spent computing histograms # time spent predicting X for gradient and hessians update acc_prediction_time = 0. - X, y = check_X_y(X, y, dtype=[X_DTYPE], force_all_finite=False) + X, y = self._validate_X_y(X, y, dtype=[X_DTYPE], + force_all_finite=False) y = self._encode_y(y) # The rng state must be preserved if warm_start is True diff --git a/sklearn/ensemble/_stacking.py b/sklearn/ensemble/_stacking.py index c2a09c54b4622..dbf51fdcfbcc7 100644 --- a/sklearn/ensemble/_stacking.py +++ b/sklearn/ensemble/_stacking.py @@ -182,6 +182,7 @@ def fit(self, X, y, sample_weight=None): delayed(_parallel_fit_estimator)(clone(est), X, y, sample_weight) for est in all_estimators if est != 'drop' ) + self.n_features_in_ = self.estimators_[0].n_features_in_ self.named_estimators_ = Bunch() est_fitted_idx = 0 diff --git a/sklearn/ensemble/bagging.py b/sklearn/ensemble/bagging.py index 215caa5d4a334..33681ff42df1f 100644 --- a/sklearn/ensemble/bagging.py +++ b/sklearn/ensemble/bagging.py @@ -277,9 +277,9 @@ def _fit(self, X, y, max_samples=None, max_depth=None, sample_weight=None): random_state = check_random_state(self.random_state) # Convert data (X is required to be 2d and indexable) - X, y = check_X_y( - X, y, ['csr', 'csc'], dtype=None, force_all_finite=False, - multi_output=True + X, y = self._validate_X_y( + X, y, accept_sparse=['csr', 'csc'], dtype=None, + force_all_finite=False, multi_output=True ) if sample_weight is not None: sample_weight = check_array(sample_weight, ensure_2d=False) diff --git a/sklearn/ensemble/forest.py b/sklearn/ensemble/forest.py index 4726a0dabcedf..c509ae657aa52 100644 --- a/sklearn/ensemble/forest.py +++ b/sklearn/ensemble/forest.py @@ -240,7 +240,7 @@ def fit(self, X, y, sample_weight=None): self : object """ # Validate or convert input data - X = check_array(X, accept_sparse="csc", dtype=DTYPE) + X = self._validate_X(X, accept_sparse="csc", dtype=DTYPE) y = check_array(y, accept_sparse='csc', ensure_2d=False, dtype=None) if sample_weight is not None: sample_weight = check_array(sample_weight, ensure_2d=False) diff --git a/sklearn/ensemble/gradient_boosting.py b/sklearn/ensemble/gradient_boosting.py index 207090e64c18d..9ac8b5fe44ea9 100644 --- a/sklearn/ensemble/gradient_boosting.py +++ b/sklearn/ensemble/gradient_boosting.py @@ -1438,7 +1438,8 @@ def fit(self, X, y, sample_weight=None, monitor=None): # Check input # Since check_array converts both X and y to the same dtype, but the # trees use different types for X and y, checking them separately. - X = check_array(X, accept_sparse=['csr', 'csc', 'coo'], dtype=DTYPE) + X = self._validate_X(X, accept_sparse=['csr', 'csc', 'coo'], + dtype=DTYPE) n_samples, self.n_features_ = X.shape sample_weight_is_none = sample_weight is None diff --git a/sklearn/ensemble/tests/test_voting.py b/sklearn/ensemble/tests/test_voting.py index 84fafc4f74eb3..e2fce1eb2e918 100644 --- a/sklearn/ensemble/tests/test_voting.py +++ b/sklearn/ensemble/tests/test_voting.py @@ -528,3 +528,23 @@ def test_check_estimators_voting_estimator(estimator): # their testing parameters (for required parameters). check_estimator(estimator) check_no_attributes_set_in_init(estimator.__class__.__name__, estimator) + + +@pytest.mark.parametrize( + "est", + [VotingRegressor( + estimators=[('lr', LinearRegression()), + ('tree', DecisionTreeRegressor(random_state=0))]), + VotingClassifier( + estimators=[('lr', LogisticRegression(random_state=0)), + ('tree', DecisionTreeClassifier(random_state=0))])], + ids=['VotingRegressor', 'VotingClassifier'] +) +def test_n_features_in(est): + + X = [[1, 2], [3, 4], [5, 6]] + y = [0, 1, 2] + + assert not hasattr(est, 'n_features_in_') + est.fit(X, y) + assert est.n_features_in_ == 2 diff --git a/sklearn/ensemble/voting.py b/sklearn/ensemble/voting.py index e3d05c33de042..dbc8a2b7bff93 100644 --- a/sklearn/ensemble/voting.py +++ b/sklearn/ensemble/voting.py @@ -30,6 +30,7 @@ from ..utils.metaestimators import _BaseComposition from ..utils.multiclass import check_classification_targets from ..utils.validation import column_or_1d +from ..exceptions import NotFittedError class _BaseVoting(TransformerMixin, _BaseComposition): @@ -128,6 +129,20 @@ def get_params(self, deep=True): """ return self._get_params('estimators', deep=deep) + @property + def n_features_in_(self): + # For consistency with other estimators we raise a AttributeError so + # that hasattr() fails if the estimator isn't fitted. + try: + check_is_fitted(self) + except NotFittedError as nfe: + raise AttributeError( + "{} object has no n_features_in_ attribute." + .format(self.__class__.__name__) + ) from nfe + + return self.estimators_[0].n_features_in_ + class VotingClassifier(ClassifierMixin, _BaseVoting): """Soft Voting/Majority Rule classifier for unfitted estimators. diff --git a/sklearn/ensemble/weight_boosting.py b/sklearn/ensemble/weight_boosting.py index 6f95ace2a668d..77f970926553c 100644 --- a/sklearn/ensemble/weight_boosting.py +++ b/sklearn/ensemble/weight_boosting.py @@ -70,25 +70,9 @@ def __init__(self, self.learning_rate = learning_rate self.random_state = random_state - def _validate_data(self, X, y=None): - - # Accept or convert to these sparse matrix formats so we can - # use safe_indexing - accept_sparse = ['csr', 'csc'] - if y is None: - ret = check_array(X, - accept_sparse=accept_sparse, - ensure_2d=False, - allow_nd=True, - dtype=None) - else: - ret = check_X_y(X, y, - accept_sparse=accept_sparse, - ensure_2d=False, - allow_nd=True, - dtype=None, - y_numeric=is_regressor(self)) - return ret + def _validate_data(self, X): + return check_array(X, accept_sparse=['csr', 'csc'], ensure_2d=True, + allow_nd=True, dtype=None) def fit(self, X, y, sample_weight=None): """Build a boosted classifier/regressor from the training set (X, y). @@ -115,7 +99,12 @@ def fit(self, X, y, sample_weight=None): if self.learning_rate <= 0: raise ValueError("learning_rate must be greater than zero") - X, y = self._validate_data(X, y) + X, y = self._validate_X_y(X, y, + accept_sparse=['csr', 'csc'], + ensure_2d=True, + allow_nd=True, + dtype=None, + y_numeric=is_regressor(self)) if sample_weight is None: # Initialize weights to 1 / n_samples diff --git a/sklearn/feature_extraction/tests/test_dict_vectorizer.py b/sklearn/feature_extraction/tests/test_dict_vectorizer.py index 7e7481a369646..a65feb2d7590b 100644 --- a/sklearn/feature_extraction/tests/test_dict_vectorizer.py +++ b/sklearn/feature_extraction/tests/test_dict_vectorizer.py @@ -110,3 +110,13 @@ def test_deterministic_vocabulary(): v_2 = DictVectorizer().fit([d_shuffled]) assert v_1.vocabulary_ == v_2.vocabulary_ + + +def test_n_features_in(): + # For vectorizers, n_features_in_ does not make sense and it is always + # None + dv = DictVectorizer() + assert not hasattr(dv, 'n_features_in_') + d = [{'foo': 1, 'bar': 2}, {'foo': 3, 'baz': 1}] + dv.fit(d) + assert not hasattr(dv, 'n_features_in_') diff --git a/sklearn/feature_extraction/tests/test_text.py b/sklearn/feature_extraction/tests/test_text.py index 7b7697ff47fff..f775589fb6a8a 100644 --- a/sklearn/feature_extraction/tests/test_text.py +++ b/sklearn/feature_extraction/tests/test_text.py @@ -1343,3 +1343,15 @@ def test_unused_parameters_warn(Vectorizer, stop_words, ) with pytest.warns(UserWarning, match=msg): vect.fit(train_data) + + +@pytest.mark.parametrize('Vectorizer, X', ( + (HashingVectorizer, [{'foo': 1, 'bar': 2}, {'foo': 3, 'baz': 1}]), + (CountVectorizer, JUNK_FOOD_DOCS)) +) +def test_n_features_in(Vectorizer, X): + # For vectorizers, n_features_in_ does not make sense + vectorizer = Vectorizer() + assert not hasattr(vectorizer, 'n_features_in_') + vectorizer.fit(X) + assert not hasattr(vectorizer, 'n_features_in_') diff --git a/sklearn/feature_selection/from_model.py b/sklearn/feature_selection/from_model.py index 6d732d0e43dfd..86e7b2caea2f8 100644 --- a/sklearn/feature_selection/from_model.py +++ b/sklearn/feature_selection/from_model.py @@ -6,6 +6,7 @@ from .base import SelectorMixin from ..base import BaseEstimator, clone, MetaEstimatorMixin +from ..utils.validation import check_is_fitted from ..exceptions import NotFittedError from ..utils.metaestimators import if_delegate_has_method @@ -227,3 +228,17 @@ def partial_fit(self, X, y=None, **fit_params): self.estimator_ = clone(self.estimator) self.estimator_.partial_fit(X, y, **fit_params) return self + + @property + def n_features_in_(self): + # For consistency with other estimators we raise a AttributeError so + # that hasattr() fails if the estimator isn't fitted. + try: + check_is_fitted(self) + except NotFittedError as nfe: + raise AttributeError( + "{} object has no n_features_in_ attribute." + .format(self.__class__.__name__) + ) from nfe + + return self.estimator_.n_features_in_ diff --git a/sklearn/feature_selection/rfe.py b/sklearn/feature_selection/rfe.py index 0ebddc8e702b6..370a6d3e26458 100644 --- a/sklearn/feature_selection/rfe.py +++ b/sklearn/feature_selection/rfe.py @@ -150,7 +150,8 @@ def _fit(self, X, y, step_score=None): # and is used when implementing RFECV # self.scores_ will not be calculated when calling _fit through fit - X, y = check_X_y(X, y, "csc", ensure_min_features=2) + X, y = self._validate_X_y(X, y, accept_sparse="csc", + ensure_min_features=2) # Initialization n_features = X.shape[1] if self.n_features_to_select is None: @@ -479,7 +480,8 @@ def fit(self, X, y, groups=None): train/test set. Only used in conjunction with a "Group" :term:`cv` instance (e.g., :class:`~sklearn.model_selection.GroupKFold`). """ - X, y = check_X_y(X, y, "csr", ensure_min_features=2) + X, y = self._validate_X_y(X, y, accept_sparse="csr", + ensure_min_features=2) # Initialization cv = check_cv(self.cv, y, is_classifier(self.estimator)) diff --git a/sklearn/feature_selection/univariate_selection.py b/sklearn/feature_selection/univariate_selection.py index 5921e3494469b..6e9e9d6c90b99 100644 --- a/sklearn/feature_selection/univariate_selection.py +++ b/sklearn/feature_selection/univariate_selection.py @@ -338,7 +338,8 @@ def fit(self, X, y): ------- self : object """ - X, y = check_X_y(X, y, ['csr', 'csc'], multi_output=True) + X, y = self._validate_X_y(X, y, accept_sparse=['csr', 'csc'], + multi_output=True) if not callable(self.score_func): raise TypeError("The score function should be a callable, %s (%s) " diff --git a/sklearn/feature_selection/variance_threshold.py b/sklearn/feature_selection/variance_threshold.py index 62323f1ff2ec8..e25ae9a0d3a5e 100644 --- a/sklearn/feature_selection/variance_threshold.py +++ b/sklearn/feature_selection/variance_threshold.py @@ -61,7 +61,7 @@ def fit(self, X, y=None): ------- self """ - X = check_array(X, ('csr', 'csc'), dtype=np.float64) + X = self._validate_X(X, accept_sparse=('csr', 'csc'), dtype=np.float64) if hasattr(X, "toarray"): # sparse matrix _, self.variances_ = mean_variance_axis(X, axis=0) diff --git a/sklearn/gaussian_process/gpc.py b/sklearn/gaussian_process/gpc.py index 129f6c97aced9..00c255880a7da 100644 --- a/sklearn/gaussian_process/gpc.py +++ b/sklearn/gaussian_process/gpc.py @@ -612,7 +612,7 @@ def fit(self, X, y): ------- self : returns an instance of self. """ - X, y = check_X_y(X, y, multi_output=False) + X, y = self._validate_X_y(X, y, multi_output=False) self.base_estimator_ = _BinaryGaussianProcessClassifierLaplace( self.kernel, self.optimizer, self.n_restarts_optimizer, diff --git a/sklearn/gaussian_process/gpr.py b/sklearn/gaussian_process/gpr.py index 7d131c757bc78..1010061fa1289 100644 --- a/sklearn/gaussian_process/gpr.py +++ b/sklearn/gaussian_process/gpr.py @@ -182,7 +182,7 @@ def fit(self, X, y): self._rng = check_random_state(self.random_state) - X, y = check_X_y(X, y, multi_output=True, y_numeric=True) + X, y = self._validate_X_y(X, y, multi_output=True, y_numeric=True) # Normalize target value if self.normalize_y: diff --git a/sklearn/impute/_base.py b/sklearn/impute/_base.py index 8c8b83878bae3..50310c44fc3a6 100644 --- a/sklearn/impute/_base.py +++ b/sklearn/impute/_base.py @@ -160,7 +160,7 @@ def __init__(self, missing_values=np.nan, strategy="mean", self.copy = copy self.add_indicator = add_indicator - def _validate_input(self, X): + def _validate_input(self, X, in_fit): allowed_strategies = ["mean", "median", "most_frequent", "constant"] if self.strategy not in allowed_strategies: raise ValueError("Can only use these strategies: {0} " @@ -178,8 +178,11 @@ def _validate_input(self, X): force_all_finite = "allow-nan" try: - X = check_array(X, accept_sparse='csc', dtype=dtype, - force_all_finite=force_all_finite, copy=self.copy) + check_n_features = not in_fit + X = self._validate_X(X, check_n_features=check_n_features, + accept_sparse='csc', dtype=dtype, + force_all_finite=force_all_finite, + copy=self.copy) except ValueError as ve: if "could not convert" in str(ve): raise ValueError("Cannot use {0} strategy with non-numeric " @@ -212,7 +215,7 @@ def fit(self, X, y=None): ------- self : SimpleImputer """ - X = self._validate_input(X) + X = self._validate_input(X, in_fit=True) # default fill_value is 0 for numerical input and "missing_value" # otherwise @@ -357,7 +360,7 @@ def transform(self, X): """ check_is_fitted(self) - X = self._validate_input(X) + X = self._validate_input(X, in_fit=False) statistics = self.statistics_ @@ -543,13 +546,15 @@ def _get_missing_features_info(self, X): return imputer_mask, features_indices - def _validate_input(self, X): + def _validate_input(self, X, in_fit): if not is_scalar_nan(self.missing_values): force_all_finite = True else: force_all_finite = "allow-nan" - X = check_array(X, accept_sparse=('csc', 'csr'), dtype=None, - force_all_finite=force_all_finite) + check_n_features = not in_fit + X = self._validate_X(X, check_n_features=check_n_features, + accept_sparse=('csc', 'csr'), dtype=None, + force_all_finite=force_all_finite) _check_inputs_dtype(X, self.missing_values) if X.dtype.kind not in ("i", "u", "f", "O"): raise ValueError("MissingIndicator does not support data with " @@ -584,7 +589,7 @@ def _fit(self, X, y=None): The imputer mask of the original data. """ - X = self._validate_input(X) + X = self._validate_input(X, in_fit=True) self._n_features = X.shape[1] if self.features not in ('missing-only', 'all'): @@ -636,7 +641,7 @@ def transform(self, X): """ check_is_fitted(self) - X = self._validate_input(X) + X = self._validate_input(X, in_fit=False) if X.shape[1] != self._n_features: raise ValueError("X has a different number of features " diff --git a/sklearn/impute/_iterative.py b/sklearn/impute/_iterative.py index d870f6ca11f1c..668ea4c47ca24 100644 --- a/sklearn/impute/_iterative.py +++ b/sklearn/impute/_iterative.py @@ -484,8 +484,8 @@ def _initial_imputation(self, X): else: force_all_finite = True - X = check_array(X, dtype=FLOAT_DTYPES, order="F", - force_all_finite=force_all_finite) + X = self._validate_X(X, dtype=FLOAT_DTYPES, order="F", + force_all_finite=force_all_finite) _check_inputs_dtype(X, self.missing_values) mask_missing_values = _get_mask(X, self.missing_values) diff --git a/sklearn/impute/_knn.py b/sklearn/impute/_knn.py index 37c0b2d6cb754..6178ab39a8972 100644 --- a/sklearn/impute/_knn.py +++ b/sklearn/impute/_knn.py @@ -158,8 +158,9 @@ def fit(self, X, y=None): raise ValueError( "Expected n_neighbors > 0. Got {}".format(self.n_neighbors)) - X = check_array(X, accept_sparse=False, dtype=FLOAT_DTYPES, - force_all_finite=force_all_finite, copy=self.copy) + X = self._validate_X(X, accept_sparse=False, dtype=FLOAT_DTYPES, + force_all_finite=force_all_finite, + copy=self.copy) _check_weights(self.weights) self._fit_X = X diff --git a/sklearn/kernel_approximation.py b/sklearn/kernel_approximation.py index 248f9595c5b95..49d01453a0bdf 100644 --- a/sklearn/kernel_approximation.py +++ b/sklearn/kernel_approximation.py @@ -91,7 +91,7 @@ def fit(self, X, y=None): Returns the transformer. """ - X = check_array(X, accept_sparse='csr') + X = self._validate_X(X, accept_sparse='csr') random_state = check_random_state(self.random_state) n_features = X.shape[1] @@ -197,7 +197,7 @@ def fit(self, X, y=None): Returns the transformer. """ - X = check_array(X) + X = self._validate_X(X) random_state = check_random_state(self.random_state) n_features = X.shape[1] uniform = random_state.uniform(size=(n_features, self.n_components)) @@ -324,7 +324,7 @@ def fit(self, X, y=None): self : object Returns the transformer. """ - check_array(X, accept_sparse='csr') + self._validate_X(X, accept_sparse='csr') if self.sample_interval is None: # See reference, figure 2 c) if self.sample_steps == 1: @@ -540,7 +540,7 @@ def fit(self, X, y=None): X : array-like, shape=(n_samples, n_feature) Training data. """ - X = check_array(X, accept_sparse='csr') + X = self._validate_X(X, accept_sparse='csr') rnd = check_random_state(self.random_state) n_samples = X.shape[0] diff --git a/sklearn/kernel_ridge.py b/sklearn/kernel_ridge.py index fef571056c945..f6bbc6a2bec49 100644 --- a/sklearn/kernel_ridge.py +++ b/sklearn/kernel_ridge.py @@ -148,8 +148,8 @@ def fit(self, X, y=None, sample_weight=None): self : returns an instance of self. """ # Convert data - X, y = check_X_y(X, y, accept_sparse=("csr", "csc"), multi_output=True, - y_numeric=True) + X, y = self._validate_X_y(X, y, accept_sparse=("csr", "csc"), + multi_output=True, y_numeric=True) if sample_weight is not None and not isinstance(sample_weight, float): sample_weight = check_array(sample_weight, ensure_2d=False) diff --git a/sklearn/linear_model/base.py b/sklearn/linear_model/base.py index c554c8a921d9e..0134b30a34bd7 100644 --- a/sklearn/linear_model/base.py +++ b/sklearn/linear_model/base.py @@ -464,8 +464,8 @@ def fit(self, X, y, sample_weight=None): """ n_jobs_ = self.n_jobs - X, y = check_X_y(X, y, accept_sparse=['csr', 'csc', 'coo'], - y_numeric=True, multi_output=True) + X, y = self._validate_X_y(X, y, accept_sparse=['csr', 'csc', 'coo'], + y_numeric=True, multi_output=True) if sample_weight is not None and np.atleast_1d(sample_weight).ndim > 1: raise ValueError("Sample weights must be 1D array or scalar") diff --git a/sklearn/linear_model/bayes.py b/sklearn/linear_model/bayes.py index 88ba6fecee6a4..00686d543cc60 100644 --- a/sklearn/linear_model/bayes.py +++ b/sklearn/linear_model/bayes.py @@ -189,7 +189,7 @@ def fit(self, X, y, sample_weight=None): raise ValueError('n_iter should be greater than or equal to 1.' ' Got {!r}.'.format(self.n_iter)) - X, y = check_X_y(X, y, dtype=np.float64, y_numeric=True) + X, y = self._validate_X_y(X, y, dtype=np.float64, y_numeric=True) X, y, X_offset_, y_offset_, X_scale_ = self._preprocess_data( X, y, self.fit_intercept, self.normalize, self.copy_X, sample_weight=sample_weight) @@ -520,8 +520,8 @@ def fit(self, X, y): ------- self : returns an instance of self. """ - X, y = check_X_y(X, y, dtype=np.float64, y_numeric=True, - ensure_min_samples=2) + X, y = self._validate_X_y(X, y, dtype=np.float64, y_numeric=True, + ensure_min_samples=2) n_samples, n_features = X.shape coef_ = np.zeros(n_features) diff --git a/sklearn/linear_model/coordinate_descent.py b/sklearn/linear_model/coordinate_descent.py index 4ad8f759dd4be..07e08b7a28767 100644 --- a/sklearn/linear_model/coordinate_descent.py +++ b/sklearn/linear_model/coordinate_descent.py @@ -696,9 +696,11 @@ def fit(self, X, y, check_input=True): # when bypassing checks if check_input: X_copied = self.copy_X and self.fit_intercept - X, y = check_X_y(X, y, accept_sparse='csc', - order='F', dtype=[np.float64, np.float32], - copy=X_copied, multi_output=True, y_numeric=True) + X, y = self._validate_X_y(X, y, accept_sparse='csc', + order='F', + dtype=[np.float64, np.float32], + copy=X_copied, multi_output=True, + y_numeric=True) y = check_array(y, order='F', copy=False, dtype=X.dtype.type, ensure_2d=False) @@ -1112,7 +1114,7 @@ def fit(self, X, y): # Let us not impose fortran ordering so far: it is # not useful for the cross-validation loop and will be done # by the model fitting itself - X = check_array(X, 'csc', copy=False) + X = self._validate_X(X, accept_sparse='csc', copy=False) if sparse.isspmatrix(X): if (hasattr(reference_to_old_X, "data") and not np.may_share_memory(reference_to_old_X.data, X.data)): @@ -1123,8 +1125,9 @@ def fit(self, X, y): copy_X = False del reference_to_old_X else: - X = check_array(X, 'csc', dtype=[np.float64, np.float32], - order='F', copy=copy_X) + X = self._validate_X(X, accept_sparse='csc', + dtype=[np.float64, np.float32], order='F', + copy=copy_X) copy_X = False if X.shape[0] != y.shape[0]: @@ -1744,8 +1747,8 @@ def fit(self, X, y): To avoid memory re-allocation it is advised to allocate the initial data in memory directly using that format. """ - X = check_array(X, dtype=[np.float64, np.float32], order='F', - copy=self.copy_X and self.fit_intercept) + X = self._validate_X(X, dtype=[np.float64, np.float32], order='F', + copy=self.copy_X and self.fit_intercept) y = check_array(y, dtype=X.dtype.type, ensure_2d=False) if hasattr(self, 'l1_ratio'): diff --git a/sklearn/linear_model/huber.py b/sklearn/linear_model/huber.py index e518feae29b78..7ac5b3d0e19c4 100644 --- a/sklearn/linear_model/huber.py +++ b/sklearn/linear_model/huber.py @@ -252,7 +252,7 @@ def fit(self, X, y, sample_weight=None): ------- self : object """ - X, y = check_X_y( + X, y = self._validate_X_y( X, y, copy=False, accept_sparse=['csr'], y_numeric=True, dtype=[np.float64, np.float32]) diff --git a/sklearn/linear_model/least_angle.py b/sklearn/linear_model/least_angle.py index 6fa3ae3008a35..3affc443b1cea 100644 --- a/sklearn/linear_model/least_angle.py +++ b/sklearn/linear_model/least_angle.py @@ -954,7 +954,7 @@ def fit(self, X, y, Xy=None): self : object returns an instance of self. """ - X, y = check_X_y(X, y, y_numeric=True, multi_output=True) + X, y = self._validate_X_y(X, y, y_numeric=True, multi_output=True) alpha = getattr(self, 'alpha', 0.) if hasattr(self, 'n_nonzero_coefs'): @@ -1374,7 +1374,7 @@ def fit(self, X, y): self : object returns an instance of self. """ - X, y = check_X_y(X, y, y_numeric=True) + X, y = self._validate_X_y(X, y, y_numeric=True) X = as_float_array(X, copy=self.copy_X) y = as_float_array(y, copy=self.copy_X) @@ -1752,7 +1752,7 @@ def fit(self, X, y, copy_X=None): """ if copy_X is None: copy_X = self.copy_X - X, y = check_X_y(X, y, y_numeric=True) + X, y = self._validate_X_y(X, y, y_numeric=True) X, y, Xmean, ymean, Xstd = LinearModel._preprocess_data( X, y, self.fit_intercept, self.normalize, copy_X) diff --git a/sklearn/linear_model/logistic.py b/sklearn/linear_model/logistic.py index 9a1293ae9ab39..f45e4524d8c61 100644 --- a/sklearn/linear_model/logistic.py +++ b/sklearn/linear_model/logistic.py @@ -1511,8 +1511,9 @@ def fit(self, X, y, sample_weight=None): else: _dtype = [np.float64, np.float32] - X, y = check_X_y(X, y, accept_sparse='csr', dtype=_dtype, order="C", - accept_large_sparse=solver != 'liblinear') + X, y = self._validate_X_y(X, y, accept_sparse='csr', dtype=_dtype, + order="C", + accept_large_sparse=solver != 'liblinear') check_classification_targets(y) self.classes_ = np.unique(y) n_samples, n_features = X.shape @@ -1981,9 +1982,9 @@ def fit(self, X, y, sample_weight=None): "LogisticRegressionCV." ) - X, y = check_X_y(X, y, accept_sparse='csr', dtype=np.float64, - order="C", - accept_large_sparse=solver != 'liblinear') + X, y = self._validate_X_y(X, y, accept_sparse='csr', dtype=np.float64, + order="C", + accept_large_sparse=solver != 'liblinear') check_classification_targets(y) class_weight = self.class_weight diff --git a/sklearn/linear_model/omp.py b/sklearn/linear_model/omp.py index 3215b107aa9bf..fe9df77e2a61b 100644 --- a/sklearn/linear_model/omp.py +++ b/sklearn/linear_model/omp.py @@ -641,7 +641,7 @@ def fit(self, X, y): self : object returns an instance of self. """ - X, y = check_X_y(X, y, multi_output=True, y_numeric=True) + X, y = self._validate_X_y(X, y, multi_output=True, y_numeric=True) n_features = X.shape[1] X, y, X_offset, y_offset, X_scale, Gram, Xy = \ @@ -879,8 +879,8 @@ def fit(self, X, y): self : object returns an instance of self. """ - X, y = check_X_y(X, y, y_numeric=True, ensure_min_features=2, - estimator=self) + X, y = self._validate_X_y(X, y, y_numeric=True, ensure_min_features=2, + estimator=self) X = as_float_array(X, copy=False, force_all_finite=False) cv = check_cv(self.cv, classifier=False) max_iter = (min(max(int(0.1 * X.shape[1]), 5), X.shape[1]) diff --git a/sklearn/linear_model/ransac.py b/sklearn/linear_model/ransac.py index 3d390c5c67e61..c5326753468ec 100644 --- a/sklearn/linear_model/ransac.py +++ b/sklearn/linear_model/ransac.py @@ -251,7 +251,7 @@ def fit(self, X, y, sample_weight=None): `max_trials` randomly chosen sub-samples. """ - X = check_array(X, accept_sparse='csr') + X = self._validate_X(X, accept_sparse='csr') y = check_array(y, ensure_2d=False) check_consistent_length(X, y) diff --git a/sklearn/linear_model/ridge.py b/sklearn/linear_model/ridge.py index 64e2fc6bdfb41..06f7f3e77a6ba 100644 --- a/sklearn/linear_model/ridge.py +++ b/sklearn/linear_model/ridge.py @@ -541,10 +541,10 @@ def fit(self, X, y, sample_weight=None): _dtype = [np.float64, np.float32] _accept_sparse = _get_valid_accept_sparse(sparse.issparse(X), self.solver) - X, y = check_X_y(X, y, - accept_sparse=_accept_sparse, - dtype=_dtype, - multi_output=True, y_numeric=True) + X, y = self._validate_X_y(X, y, + accept_sparse=_accept_sparse, + dtype=_dtype, + multi_output=True, y_numeric=True) if sparse.issparse(X) and self.fit_intercept: if self.solver not in ['auto', 'sparse_cg', 'sag']: raise ValueError( @@ -921,7 +921,8 @@ def fit(self, X, y, sample_weight=None): """ _accept_sparse = _get_valid_accept_sparse(sparse.issparse(X), self.solver) - check_X_y(X, y, accept_sparse=_accept_sparse, multi_output=True) + self._validate_X_y(X, y, accept_sparse=_accept_sparse, + multi_output=True) self._label_binarizer = LabelBinarizer(pos_label=1, neg_label=-1) Y = self._label_binarizer.fit_transform(y) @@ -1418,9 +1419,9 @@ def fit(self, X, y, sample_weight=None): ------- self : object """ - X, y = check_X_y(X, y, ['csr', 'csc', 'coo'], - dtype=[np.float64], - multi_output=True, y_numeric=True) + X, y = self._validate_X_y(X, y, accept_sparse=['csr', 'csc', 'coo'], + dtype=[np.float64], + multi_output=True, y_numeric=True) if np.any(self.alphas <= 0): raise ValueError( @@ -1574,6 +1575,7 @@ def fit(self, X, y, sample_weight=None): self.coef_ = estimator.coef_ self.intercept_ = estimator.intercept_ + self.n_features_in_ = estimator.n_features_in_ return self @@ -1830,8 +1832,8 @@ def fit(self, X, y, sample_weight=None): ------- self : object """ - check_X_y(X, y, accept_sparse=['csr', 'csc', 'coo'], - multi_output=True) + self._validate_X_y(X, y, accept_sparse=['csr', 'csc', 'coo'], + multi_output=True) self._label_binarizer = LabelBinarizer(pos_label=1, neg_label=-1) Y = self._label_binarizer.fit_transform(y) diff --git a/sklearn/linear_model/stochastic_gradient.py b/sklearn/linear_model/stochastic_gradient.py index 6a11c4a97ee2f..6812a5d75eb84 100644 --- a/sklearn/linear_model/stochastic_gradient.py +++ b/sklearn/linear_model/stochastic_gradient.py @@ -509,8 +509,9 @@ def _fit(self, X, y, alpha, C, loss, learning_rate, coef_init=None, if hasattr(self, "classes_"): self.classes_ = None - X, y = check_X_y(X, y, 'csr', dtype=np.float64, order="C", - accept_large_sparse=False) + X, y = self._validate_X_y(X, y, accept_sparse='csr', + dtype=np.float64, order="C", + accept_large_sparse=False) # labels can be encoded as float, int, or string literals # np.unique sorts in asc order; largest class id is positive class @@ -1079,8 +1080,9 @@ def __init__(self, loss="squared_loss", penalty="l2", alpha=0.0001, def _partial_fit(self, X, y, alpha, C, loss, learning_rate, max_iter, sample_weight, coef_init, intercept_init): - X, y = check_X_y(X, y, "csr", copy=False, order='C', dtype=np.float64, - accept_large_sparse=False) + X, y = self._validate_X_y(X, y, accept_sparse="csr", copy=False, + order='C', dtype=np.float64, + accept_large_sparse=False) y = y.astype(np.float64, copy=False) n_samples, n_features = X.shape diff --git a/sklearn/linear_model/theil_sen.py b/sklearn/linear_model/theil_sen.py index 3468e904c3538..793a19ea79fcb 100644 --- a/sklearn/linear_model/theil_sen.py +++ b/sklearn/linear_model/theil_sen.py @@ -358,7 +358,7 @@ def fit(self, X, y): self : returns an instance of self. """ random_state = check_random_state(self.random_state) - X, y = check_X_y(X, y, y_numeric=True) + X, y = self._validate_X_y(X, y, y_numeric=True) n_samples, n_features = X.shape n_subsamples, self.n_subpopulation_ = self._check_subparams(n_samples, n_features) diff --git a/sklearn/manifold/isomap.py b/sklearn/manifold/isomap.py index a1fe5243c6ca2..93d7a17eca9db 100644 --- a/sklearn/manifold/isomap.py +++ b/sklearn/manifold/isomap.py @@ -140,13 +140,13 @@ def __init__(self, n_neighbors=5, n_components=2, eigen_solver='auto', self.metric_params = metric_params def _fit_transform(self, X): - self.nbrs_ = NearestNeighbors(n_neighbors=self.n_neighbors, algorithm=self.neighbors_algorithm, metric=self.metric, p=self.p, metric_params=self.metric_params, n_jobs=self.n_jobs) self.nbrs_.fit(X) + self.n_features_in_ = self.nbrs_.n_features_in_ self.kernel_pca_ = KernelPCA(n_components=self.n_components, kernel="precomputed", diff --git a/sklearn/manifold/locally_linear.py b/sklearn/manifold/locally_linear.py index 4b7140d6b5f23..dc77cff0f9da5 100644 --- a/sklearn/manifold/locally_linear.py +++ b/sklearn/manifold/locally_linear.py @@ -656,7 +656,7 @@ def _fit_transform(self, X): n_jobs=self.n_jobs) random_state = check_random_state(self.random_state) - X = check_array(X, dtype=float) + X = self._validate_X(X, dtype=float) self.nbrs_.fit(X) self.embedding_, self.reconstruction_error_ = \ locally_linear_embedding( diff --git a/sklearn/manifold/mds.py b/sklearn/manifold/mds.py index 5238c67e93dfd..0ddf8dda7f31c 100644 --- a/sklearn/manifold/mds.py +++ b/sklearn/manifold/mds.py @@ -414,7 +414,7 @@ def fit_transform(self, X, y=None, init=None): algorithm. By default, the algorithm is initialized with a randomly chosen array. """ - X = check_array(X) + X = self._validate_X(X) if X.shape[0] == X.shape[1] and self.dissimilarity != "precomputed": warnings.warn("The MDS API has changed. ``fit`` now constructs an" " dissimilarity matrix from data. To use a custom " diff --git a/sklearn/manifold/spectral_embedding_.py b/sklearn/manifold/spectral_embedding_.py index 9d52a9787425c..1052aeec9c955 100644 --- a/sklearn/manifold/spectral_embedding_.py +++ b/sklearn/manifold/spectral_embedding_.py @@ -535,8 +535,8 @@ def fit(self, X, y=None): Returns the instance itself. """ - X = check_array(X, accept_sparse='csr', ensure_min_samples=2, - estimator=self) + X = self._validate_X(X, accept_sparse='csr', ensure_min_samples=2, + estimator=self) random_state = check_random_state(self.random_state) if isinstance(self.affinity, str): diff --git a/sklearn/manifold/t_sne.py b/sklearn/manifold/t_sne.py index 5d031584d6c7f..598b820263776 100644 --- a/sklearn/manifold/t_sne.py +++ b/sklearn/manifold/t_sne.py @@ -644,11 +644,12 @@ def _fit(self, X, skip_num_points=0): if self.angle < 0.0 or self.angle > 1.0: raise ValueError("'angle' must be between 0.0 - 1.0") if self.method == 'barnes_hut': - X = check_array(X, accept_sparse=['csr'], ensure_min_samples=2, - dtype=[np.float32, np.float64]) + X = self._validate_X(X, accept_sparse=['csr'], + ensure_min_samples=2, + dtype=[np.float32, np.float64]) else: - X = check_array(X, accept_sparse=['csr', 'csc', 'coo'], - dtype=[np.float32, np.float64]) + X = self._validate_X(X, accept_sparse=['csr', 'csc', 'coo'], + dtype=[np.float32, np.float64]) if self.metric == "precomputed": if isinstance(self.init, str) and self.init == 'pca': raise ValueError("The parameter init=\"pca\" cannot be " diff --git a/sklearn/mixture/base.py b/sklearn/mixture/base.py index 4bb98a1d54e4a..56f3649f2b11c 100644 --- a/sklearn/mixture/base.py +++ b/sklearn/mixture/base.py @@ -217,6 +217,7 @@ def fit_predict(self, X, y=None): Component labels. """ X = _check_X(X, self.n_components, ensure_min_samples=2) + self._validate_n_features(X, check_n_features=False) self._check_initial_parameters(X) # if we enable warm_start, we will have a unique initialisation diff --git a/sklearn/model_selection/_search.py b/sklearn/model_selection/_search.py index 80e78e6b7f913..259fe89e712f7 100644 --- a/sklearn/model_selection/_search.py +++ b/sklearn/model_selection/_search.py @@ -563,6 +563,20 @@ def inverse_transform(self, Xt): self._check_is_fitted('inverse_transform') return self.best_estimator_.inverse_transform(Xt) + @property + def n_features_in_(self): + # For consistency with other estimators we raise a AttributeError so + # that hasattr() fails if the search estimator isn't fitted. + try: + check_is_fitted(self) + except NotFittedError as nfe: + raise AttributeError( + "{} object has no n_features_in_ attribute." + .format(self.__class__.__name__) + ) from nfe + + return self.best_estimator_.n_features_in_ + @property def classes_(self): self._check_is_fitted("classes_") diff --git a/sklearn/model_selection/tests/test_search.py b/sklearn/model_selection/tests/test_search.py index db69c66fe06dc..3ca0cf4f4cc5a 100644 --- a/sklearn/model_selection/tests/test_search.py +++ b/sklearn/model_selection/tests/test_search.py @@ -64,6 +64,8 @@ from sklearn.impute import SimpleImputer from sklearn.pipeline import Pipeline from sklearn.linear_model import Ridge, SGDClassifier +from sklearn.experimental import enable_hist_gradient_boosting # noqa +from sklearn.ensemble import HistGradientBoostingClassifier from sklearn.model_selection.tests.common import OneTimeSplitter @@ -1775,3 +1777,20 @@ def get_n_splits(self, *args, **kw): 'inconsistent results. Expected \\d+ ' 'splits, got \\d+'): ridge.fit(X[:train_size], y[:train_size]) + + +def test_n_features_in(): + # make sure grid search and random search delegate n_features_in to the + # best estimator + n_features = 4 + X, y = make_classification(n_features=n_features) + gbdt = HistGradientBoostingClassifier() + param_grid = {'max_iter': [3, 4]} + gs = GridSearchCV(gbdt, param_grid) + rs = RandomizedSearchCV(gbdt, param_grid, n_iter=1) + assert not hasattr(gs, 'n_features_in_') + assert not hasattr(rs, 'n_features_in_') + gs.fit(X, y) + rs.fit(X, y) + assert gs.n_features_in_ == n_features + assert rs.n_features_in_ == n_features diff --git a/sklearn/multiclass.py b/sklearn/multiclass.py index 9cee9661489b6..421e103e96de0 100644 --- a/sklearn/multiclass.py +++ b/sklearn/multiclass.py @@ -52,6 +52,7 @@ check_classification_targets, _ovr_decision_function) from .utils.metaestimators import _safe_split, if_delegate_has_method +from .exceptions import NotFittedError from joblib import Parallel, delayed @@ -415,6 +416,19 @@ def _pairwise(self): def _first_estimator(self): return self.estimators_[0] + @property + def n_features_in_(self): + # For consistency with other estimators we raise a AttributeError so + # that hasattr() fails if the OVR estimator isn't fitted. + try: + check_is_fitted(self) + except NotFittedError as nfe: + raise AttributeError( + "{} object has no n_features_in_ attribute." + .format(self.__class__.__name__) + ) from nfe + return self.estimators_[0].n_features_in_ + def _fit_ovo_binary(estimator, X, y, i, j): """Fit a single binary estimator (one-vs-one).""" @@ -503,7 +517,7 @@ def fit(self, X, y): ------- self """ - X, y = check_X_y(X, y, accept_sparse=['csr', 'csc']) + X, y = self._validate_X_y(X, y, accept_sparse=['csr', 'csc']) check_classification_targets(y) self.classes_ = np.unique(y) @@ -730,7 +744,7 @@ def fit(self, X, y): ------- self """ - X, y = check_X_y(X, y) + X, y = self._validate_X_y(X, y) if self.code_size <= 0: raise ValueError("code_size should be greater than 0, got {0}" "".format(self.code_size)) diff --git a/sklearn/multioutput.py b/sklearn/multioutput.py index 93eb87e81cb5e..5d485a1b689b0 100644 --- a/sklearn/multioutput.py +++ b/sklearn/multioutput.py @@ -148,9 +148,7 @@ def fit(self, X, y, sample_weight=None): raise ValueError("The base estimator should implement" " a fit method") - X, y = check_X_y(X, y, - multi_output=True, - accept_sparse=True) + X, y = self._validate_X_y(X, y, multi_output=True, accept_sparse=True) if is_classifier(self): check_classification_targets(y) @@ -432,7 +430,7 @@ def fit(self, X, Y): ------- self : object """ - X, Y = check_X_y(X, Y, multi_output=True, accept_sparse=True) + X, Y = self._validate_X_y(X, Y, multi_output=True, accept_sparse=True) random_state = check_random_state(self.random_state) check_array(X, accept_sparse=True) diff --git a/sklearn/naive_bayes.py b/sklearn/naive_bayes.py index d1bb360986c22..c2aaaf77e2070 100644 --- a/sklearn/naive_bayes.py +++ b/sklearn/naive_bayes.py @@ -192,7 +192,7 @@ def fit(self, X, y, sample_weight=None): ------- self : object """ - X, y = check_X_y(X, y) + X, y = self._validate_X_y(X, y) return self._partial_fit(X, y, np.unique(y), _refit=True, sample_weight=sample_weight) @@ -591,7 +591,7 @@ def fit(self, X, y, sample_weight=None): ------- self : object """ - X, y = check_X_y(X, y, 'csr') + X, y = self._validate_X_y(X, y, accept_sparse='csr') _, n_features = X.shape labelbin = LabelBinarizer() diff --git a/sklearn/neighbors/base.py b/sklearn/neighbors/base.py index ac117911fbb7c..28d0483ac9b5b 100644 --- a/sklearn/neighbors/base.py +++ b/sklearn/neighbors/base.py @@ -393,8 +393,9 @@ def _fit(self, X): if self.effective_metric_ == 'precomputed': X = _check_precomputed(X) + self.n_features_in_ = X.shape[1] else: - X = check_array(X, accept_sparse='csr') + X = self._validate_X(X, accept_sparse='csr') n_samples = X.shape[0] if n_samples == 0: diff --git a/sklearn/neighbors/kde.py b/sklearn/neighbors/kde.py index be5002e579423..baac14518954a 100644 --- a/sklearn/neighbors/kde.py +++ b/sklearn/neighbors/kde.py @@ -125,7 +125,7 @@ def fit(self, X, y=None, sample_weight=None): List of sample weights attached to the data X. """ algorithm = self._choose_algorithm(self.algorithm, self.metric) - X = check_array(X, order='C', dtype=DTYPE) + X = self._validate_X(X, order='C', dtype=DTYPE) if sample_weight is not None: sample_weight = check_array(sample_weight, order='C', dtype=DTYPE, diff --git a/sklearn/neighbors/nca.py b/sklearn/neighbors/nca.py index ae8e143ae0d1d..a4a89bd75b912 100644 --- a/sklearn/neighbors/nca.py +++ b/sklearn/neighbors/nca.py @@ -297,7 +297,7 @@ def _validate_params(self, X, y): """ # Validate the inputs X and y, and converts y to numerical classes. - X, y = check_X_y(X, y, ensure_min_samples=2) + X, y = self._validate_X_y(X, y, ensure_min_samples=2) check_classification_targets(y) y = LabelEncoder().fit_transform(y) diff --git a/sklearn/neighbors/nearest_centroid.py b/sklearn/neighbors/nearest_centroid.py index 3967e772bf1bb..2e10d84613ede 100644 --- a/sklearn/neighbors/nearest_centroid.py +++ b/sklearn/neighbors/nearest_centroid.py @@ -104,9 +104,9 @@ def fit(self, X, y): # If X is sparse and the metric is "manhattan", store it in a csc # format is easier to calculate the median. if self.metric == 'manhattan': - X, y = check_X_y(X, y, ['csc']) + X, y = self._validate_X_y(X, y, accept_sparse=['csc']) else: - X, y = check_X_y(X, y, ['csr', 'csc']) + X, y = self._validate_X_y(X, y, accept_sparse=['csr', 'csc']) is_X_sparse = sp.issparse(X) if is_X_sparse and self.shrink_threshold: raise ValueError("threshold shrinking not supported" diff --git a/sklearn/neural_network/_multilayer_perceptron.py b/sklearn/neural_network/_multilayer_perceptron.py index b6367d32e57a9..7819ff15d7ef7 100644 --- a/sklearn/neural_network/_multilayer_perceptron.py +++ b/sklearn/neural_network/_multilayer_perceptron.py @@ -928,8 +928,8 @@ def __init__(self, hidden_layer_sizes=(100,), activation="relu", n_iter_no_change=n_iter_no_change, max_fun=max_fun) def _validate_input(self, X, y, incremental): - X, y = check_X_y(X, y, accept_sparse=['csr', 'csc', 'coo'], - multi_output=True) + X, y = self._validate_X_y(X, y, accept_sparse=['csr', 'csc', 'coo'], + multi_output=True) if y.ndim == 2 and y.shape[1] == 1: y = column_or_1d(y, warn=True) @@ -1336,8 +1336,8 @@ def predict(self, X): return y_pred def _validate_input(self, X, y, incremental): - X, y = check_X_y(X, y, accept_sparse=['csr', 'csc', 'coo'], - multi_output=True, y_numeric=True) + X, y = self._validate_X_y(X, y, accept_sparse=['csr', 'csc', 'coo'], + multi_output=True, y_numeric=True) if y.ndim == 2 and y.shape[1] == 1: y = column_or_1d(y, warn=True) return X, y diff --git a/sklearn/neural_network/_rbm.py b/sklearn/neural_network/_rbm.py index efe3aeda951af..0fb1f1e6f6ea4 100644 --- a/sklearn/neural_network/_rbm.py +++ b/sklearn/neural_network/_rbm.py @@ -336,7 +336,7 @@ def fit(self, X, y=None): self : BernoulliRBM The fitted model. """ - X = check_array(X, accept_sparse='csr', dtype=np.float64) + X = self._validate_X(X, accept_sparse='csr', dtype=np.float64) n_samples = X.shape[0] rng = check_random_state(self.random_state) diff --git a/sklearn/pipeline.py b/sklearn/pipeline.py index a58979142ae7c..6be2bbc8e044f 100644 --- a/sklearn/pipeline.py +++ b/sklearn/pipeline.py @@ -623,6 +623,10 @@ def _pairwise(self): # check if first estimator expects pairwise input return getattr(self.steps[0][1], '_pairwise', False) + @property + def n_features_in_(self): + return self.steps[0][1].n_features_in_ + def _name_estimators(estimators): """Generate names for estimators.""" @@ -981,6 +985,11 @@ def _update_transformer_list(self, transformers): else next(transformers)) for name, old in self.transformer_list] + @property + def n_features_in_(self): + # X is passed to all transformers so we just delegate to the first one + return self.transformer_list[0][1].n_features_in_ + def make_union(*transformers, **kwargs): """Construct a FeatureUnion from the given transformers. diff --git a/sklearn/preprocessing/_discretization.py b/sklearn/preprocessing/_discretization.py index 94fcd50f0270b..17c566cb03bf9 100644 --- a/sklearn/preprocessing/_discretization.py +++ b/sklearn/preprocessing/_discretization.py @@ -133,7 +133,7 @@ def fit(self, X, y=None): ------- self """ - X = check_array(X, dtype='numeric') + X = self._validate_X(X, dtype='numeric') valid_encode = ('onehot', 'onehot-dense', 'ordinal') if self.encode not in valid_encode: diff --git a/sklearn/preprocessing/_function_transformer.py b/sklearn/preprocessing/_function_transformer.py index d7ed64b8369bd..3f2c738c1ff7f 100644 --- a/sklearn/preprocessing/_function_transformer.py +++ b/sklearn/preprocessing/_function_transformer.py @@ -83,7 +83,7 @@ def __init__(self, func=None, inverse_func=None, validate=False, def _check_input(self, X): if self.validate: - return check_array(X, accept_sparse=self.accept_sparse) + return self._validate_X(X, accept_sparse=self.accept_sparse) return X def _check_inverse_transform(self, X): diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index 4a2c5a4eedbe9..1e4c421e68ae8 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -354,17 +354,17 @@ def partial_fit(self, X, y=None): raise TypeError("MinMaxScaler does no support sparse input. " "You may consider to use MaxAbsScaler instead.") - X = check_array(X, - estimator=self, dtype=FLOAT_DTYPES, - force_all_finite="allow-nan") + first_pass = not hasattr(self, 'n_samples_seen_') + check_n_features = not first_pass + X = self._validate_X(X, check_n_features=check_n_features, + estimator=self, dtype=FLOAT_DTYPES, + force_all_finite="allow-nan") data_min = np.nanmin(X, axis=0) data_max = np.nanmax(X, axis=0) - # First pass - if not hasattr(self, 'n_samples_seen_'): + if first_pass: self.n_samples_seen_ = X.shape[0] - # Next steps else: data_min = np.minimum(self.data_min_, data_min) data_max = np.maximum(self.data_max_, data_max) @@ -664,9 +664,9 @@ def partial_fit(self, X, y=None): y Ignored """ - X = check_array(X, accept_sparse=('csr', 'csc'), - estimator=self, dtype=FLOAT_DTYPES, - force_all_finite='allow-nan') + X = self._validate_X(X, accept_sparse=('csr', 'csc'), + estimator=self, dtype=FLOAT_DTYPES, + force_all_finite='allow-nan') # Even in the case of `with_mean=False`, we update the mean anyway # This is needed for the incremental computation of the var @@ -759,9 +759,10 @@ def transform(self, X, copy=None): check_is_fitted(self) copy = copy if copy is not None else self.copy - X = check_array(X, accept_sparse='csr', copy=copy, - estimator=self, dtype=FLOAT_DTYPES, - force_all_finite='allow-nan') + X = self._validate_X(X, check_n_features=True, + accept_sparse='csr', copy=copy, + estimator=self, dtype=FLOAT_DTYPES, + force_all_finite='allow-nan') if sparse.issparse(X): if self.with_mean: @@ -927,9 +928,11 @@ def partial_fit(self, X, y=None): y Ignored """ - X = check_array(X, accept_sparse=('csr', 'csc'), - estimator=self, dtype=FLOAT_DTYPES, - force_all_finite='allow-nan') + first_pass = not hasattr(self, 'n_samples_seen_') + check_n_features = not first_pass + X = self._validate_X(X, check_n_features=check_n_features, + accept_sparse=('csr', 'csc'), estimator=self, + dtype=FLOAT_DTYPES, force_all_finite='allow-nan') if sparse.issparse(X): mins, maxs = min_max_axis(X, axis=0, ignore_nan=True) @@ -937,10 +940,8 @@ def partial_fit(self, X, y=None): else: max_abs = np.nanmax(np.abs(X), axis=0) - # First pass - if not hasattr(self, 'n_samples_seen_'): + if first_pass: self.n_samples_seen_ = X.shape[0] - # Next passes else: max_abs = np.maximum(self.max_abs_, max_abs) self.n_samples_seen_ += X.shape[0] @@ -1158,8 +1159,8 @@ def fit(self, X, y=None): """ # at fit, convert sparse matrices to csc for optimized computation of # the quantiles - X = check_array(X, accept_sparse='csc', estimator=self, - dtype=FLOAT_DTYPES, force_all_finite='allow-nan') + X = self._validate_X(X, accept_sparse='csc', estimator=self, + dtype=FLOAT_DTYPES, force_all_finite='allow-nan') q_min, q_max = self.quantile_range if not 0 <= q_min <= q_max <= 100: @@ -1467,7 +1468,7 @@ def fit(self, X, y=None): ------- self : instance """ - n_samples, n_features = check_array(X, accept_sparse=True).shape + n_samples, n_features = self._validate_X(X, accept_sparse=True).shape combinations = self._combinations(n_features, self.degree, self.interaction_only, self.include_bias) @@ -1773,7 +1774,7 @@ def fit(self, X, y=None): ---------- X : array-like """ - check_array(X, accept_sparse='csr') + self._validate_X(X, accept_sparse='csr') return self def transform(self, X, copy=None): @@ -1907,7 +1908,7 @@ def fit(self, X, y=None): ---------- X : array-like """ - check_array(X, accept_sparse='csr') + self._validate_X(X, accept_sparse='csr') return self def transform(self, X, copy=None): @@ -1987,7 +1988,7 @@ def fit(self, K, y=None): self : returns an instance of self. """ - K = check_array(K, dtype=FLOAT_DTYPES) + K = self._validate_X(K, dtype=FLOAT_DTYPES) if K.shape[0] != K.shape[1]: raise ValueError("Kernel matrix must be a square matrix." @@ -2297,7 +2298,7 @@ def fit(self, X, y=None): " and {} samples.".format(self.n_quantiles, self.subsample)) - X = self._check_inputs(X, copy=False) + X = self._check_inputs(X, in_fit=True, copy=False) n_samples = X.shape[0] if self.n_quantiles > n_samples: @@ -2388,11 +2389,19 @@ def _transform_col(self, X_col, quantiles, inverse): return X_col - def _check_inputs(self, X, accept_sparse_negative=False, copy=False): + def _check_inputs(self, X, in_fit, accept_sparse_negative=False, + copy=False): """Check inputs before fit and transform""" - X = check_array(X, accept_sparse='csc', copy=copy, - dtype=FLOAT_DTYPES, - force_all_finite='allow-nan') + # deactivating check for now (specific tests about error message would + # break) + # TODO: uncomment when addressing check_n_features in + # predict/transform/etc. + # check_n_features = not in_fit + check_n_features = False + + X = self._validate_X(X, check_n_features=check_n_features, + accept_sparse='csc', copy=copy, + dtype=FLOAT_DTYPES, force_all_finite='allow-nan') # we only accept positive sparse matrix when ignore_implicit_zeros is # false and that we call fit or transform. with np.errstate(invalid='ignore'): # hide NaN comparison warnings @@ -2468,7 +2477,7 @@ def transform(self, X): Xt : ndarray or sparse matrix, shape (n_samples, n_features) The projected data. """ - X = self._check_inputs(X, copy=self.copy) + X = self._check_inputs(X, in_fit=False, copy=self.copy) self._check_is_fitted(X) return self._transform(X, inverse=False) @@ -2489,7 +2498,8 @@ def inverse_transform(self, X): Xt : ndarray or sparse matrix, shape (n_samples, n_features) The projected data. """ - X = self._check_inputs(X, accept_sparse_negative=True, copy=self.copy) + X = self._check_inputs(X, in_fit=False, accept_sparse_negative=True, + copy=self.copy) self._check_is_fitted(X) return self._transform(X, inverse=True) @@ -2745,7 +2755,8 @@ def fit_transform(self, X, y=None): return self._fit(X, y, force_transform=True) def _fit(self, X, y=None, force_transform=False): - X = self._check_input(X, check_positive=True, check_method=True) + X = self._check_input(X, in_fit=True, check_positive=True, + check_method=True) if not self.copy and not force_transform: # if call from fit() X = X.copy() # force copy so that fit does not change X inplace @@ -2787,7 +2798,8 @@ def transform(self, X): The transformed data. """ check_is_fitted(self) - X = self._check_input(X, check_positive=True, check_shape=True) + X = self._check_input(X, in_fit=False, check_positive=True, + check_shape=True) transform_function = {'box-cox': boxcox, 'yeo-johnson': self._yeo_johnson_transform @@ -2833,7 +2845,7 @@ def inverse_transform(self, X): The original data """ check_is_fitted(self) - X = self._check_input(X, check_shape=True) + X = self._check_input(X, in_fit=False, check_shape=True) if self.standardize: X = self._scaler.inverse_transform(X) @@ -2938,7 +2950,7 @@ def _neg_log_likelihood(lmbda): # choosing bracket -2, 2 like for boxcox return optimize.brent(_neg_log_likelihood, brack=(-2, 2)) - def _check_input(self, X, check_positive=False, check_shape=False, + def _check_input(self, X, in_fit, check_positive=False, check_shape=False, check_method=False): """Validate the input before fit and transform. @@ -2956,8 +2968,15 @@ def _check_input(self, X, check_positive=False, check_shape=False, check_method : bool If True, check that the transformation method is valid. """ - X = check_array(X, ensure_2d=True, dtype=FLOAT_DTYPES, copy=self.copy, - force_all_finite='allow-nan') + # deactivating check for now (specific tests about error message would + # break) + # TODO: uncomment when addressing check_n_features in + # predict/transform/etc. + # check_n_features = not in_fit + check_n_features = False + X = self._validate_X(X, check_n_features=check_n_features, + ensure_2d=True, dtype=FLOAT_DTYPES, + copy=self.copy, force_all_finite='allow-nan') with np.warnings.catch_warnings(): np.warnings.filterwarnings( diff --git a/sklearn/random_projection.py b/sklearn/random_projection.py index 97597dd330e31..7397066a79ef1 100644 --- a/sklearn/random_projection.py +++ b/sklearn/random_projection.py @@ -341,7 +341,7 @@ def fit(self, X, y=None): self """ - X = check_array(X, accept_sparse=['csr', 'csc']) + X = self._validate_X(X, accept_sparse=['csr', 'csc']) n_samples, n_features = X.shape diff --git a/sklearn/semi_supervised/label_propagation.py b/sklearn/semi_supervised/label_propagation.py index 0cbc59e3e69d8..e51df502c3f59 100644 --- a/sklearn/semi_supervised/label_propagation.py +++ b/sklearn/semi_supervised/label_propagation.py @@ -220,7 +220,7 @@ def fit(self, X, y): ------- self : returns an instance of self. """ - X, y = check_X_y(X, y) + X, y = self._validate_X_y(X, y) self.X_ = X check_classification_targets(y) diff --git a/sklearn/svm/base.py b/sklearn/svm/base.py index a0459708f3288..c863a7554ad56 100644 --- a/sklearn/svm/base.py +++ b/sklearn/svm/base.py @@ -142,9 +142,9 @@ def fit(self, X, y, sample_weight=None): raise TypeError("Sparse precomputed kernels are not supported.") self._sparse = sparse and not callable(self.kernel) - X, y = check_X_y(X, y, dtype=np.float64, - order='C', accept_sparse='csr', - accept_large_sparse=False) + X, y = self._validate_X_y(X, y, dtype=np.float64, + order='C', accept_sparse='csr', + accept_large_sparse=False) y = self._validate_targets(y) sample_weight = np.asarray([] diff --git a/sklearn/svm/classes.py b/sklearn/svm/classes.py index 0c98d9ffb5d3e..c7c49666c38ce 100644 --- a/sklearn/svm/classes.py +++ b/sklearn/svm/classes.py @@ -230,9 +230,9 @@ def fit(self, X, y, sample_weight=None): raise ValueError("Penalty term must be positive; got (C=%r)" % self.C) - X, y = check_X_y(X, y, accept_sparse='csr', - dtype=np.float64, order="C", - accept_large_sparse=False) + X, y = self._validate_X_y(X, y, accept_sparse='csr', + dtype=np.float64, order="C", + accept_large_sparse=False) check_classification_targets(y) self.classes_ = np.unique(y) @@ -419,9 +419,9 @@ def fit(self, X, y, sample_weight=None): raise ValueError("Penalty term must be positive; got (C=%r)" % self.C) - X, y = check_X_y(X, y, accept_sparse='csr', - dtype=np.float64, order="C", - accept_large_sparse=False) + X, y = self._validate_X_y(X, y, accept_sparse='csr', + dtype=np.float64, order="C", + accept_large_sparse=False) penalty = 'l2' # SVR only accepts l2 penalty self.coef_, self.intercept_, self.n_iter_ = _fit_liblinear( X, y, self.C, self.fit_intercept, self.intercept_scaling, diff --git a/sklearn/tests/test_base.py b/sklearn/tests/test_base.py index d83c9c99e2105..0d365b9ba882a 100644 --- a/sklearn/tests/test_base.py +++ b/sklearn/tests/test_base.py @@ -513,6 +513,14 @@ def test_regressormixin_score_multioutput(): assert_warns_message(FutureWarning, msg, reg.score, X, y) +def test_validate_X_bad_kwargs(): + + est = BaseEstimator() + with pytest.raises(TypeError, + match="got an unexpected keyword"): + est._validate_X([1], bad_param=4) + + def test_warns_on_get_params_non_attribute(): class MyEstimator(BaseEstimator): def __init__(self, param=5): diff --git a/sklearn/tests/test_dummy.py b/sklearn/tests/test_dummy.py index 88b2d16fba46e..01f48d463a01e 100644 --- a/sklearn/tests/test_dummy.py +++ b/sklearn/tests/test_dummy.py @@ -757,6 +757,16 @@ def test_dtype_of_classifier_probas(strategy): assert probas.dtype == np.float64 +@pytest.mark.parametrize('Dummy', (DummyRegressor, DummyClassifier)) +def test_n_features_in_(Dummy): + X = [[1, 2]] + y = [0] + d = Dummy() + assert not hasattr(d, 'n_features_in_') + d.fit(X, y) + assert d.n_features_in_ is None + + @pytest.mark.parametrize("Dummy", (DummyRegressor, DummyClassifier)) def test_outputs_2d_deprecation(Dummy): X = [[1, 2]] diff --git a/sklearn/tests/test_pipeline.py b/sklearn/tests/test_pipeline.py index e02b5ef96b7b0..4fffcc0a4dc70 100644 --- a/sklearn/tests/test_pipeline.py +++ b/sklearn/tests/test_pipeline.py @@ -34,6 +34,8 @@ from sklearn.datasets import load_iris from sklearn.preprocessing import StandardScaler from sklearn.feature_extraction.text import CountVectorizer +from sklearn.experimental import enable_hist_gradient_boosting # noqa +from sklearn.ensemble import HistGradientBoostingClassifier JUNK_FOOD_DOCS = ( @@ -1161,3 +1163,46 @@ def test_verbose(est, method, pattern, capsys): est.set_params(verbose=True) func(X, y) assert re.match(pattern, capsys.readouterr().out) + + +def test_n_features_in_pipeline(): + # make sure pipelines delegate n_features_in to the first step + + X = [[1, 2], [3, 4], [5, 6]] + y = [0, 1, 2] + + ss = StandardScaler() + gbdt = HistGradientBoostingClassifier() + pipe = make_pipeline(ss, gbdt) + assert not hasattr(pipe, 'n_features_in_') + pipe.fit(X, y) + assert pipe.n_features_in_ == ss.n_features_in_ == 2 + + # if the first step has the n_features_in attribute then the pipeline also + # has it, even though it isn't fitted. + ss = StandardScaler() + gbdt = HistGradientBoostingClassifier() + pipe = make_pipeline(ss, gbdt) + ss.fit(X, y) + assert pipe.n_features_in_ == ss.n_features_in_ == 2 + assert not hasattr(gbdt, 'n_features_in_') + + +def test_n_features_in_feature_union(): + # make sure FeatureUnion delegates n_features_in to the first transformer + + X = [[1, 2], [3, 4], [5, 6]] + y = [0, 1, 2] + + ss = StandardScaler() + fu = make_union(ss) + assert not hasattr(fu, 'n_features_in_') + fu.fit(X, y) + assert fu.n_features_in_ == ss.n_features_in_ == 2 + + # if the first step has the n_features_in attribute then the feature_union + # also has it, even though it isn't fitted. + ss = StandardScaler() + fu = make_union(ss) + ss.fit(X, y) + assert fu.n_features_in_ == ss.n_features_in_ == 2 diff --git a/sklearn/tree/tree.py b/sklearn/tree/tree.py index c862a09d893c6..3a515c4513f81 100644 --- a/sklearn/tree/tree.py +++ b/sklearn/tree/tree.py @@ -135,7 +135,7 @@ def fit(self, X, y, sample_weight=None, check_input=True, raise ValueError("ccp_alpha must be greater than or equal to 0") if check_input: - X = check_array(X, dtype=DTYPE, accept_sparse="csc") + X = self._validate_X(X, dtype=DTYPE, accept_sparse="csc") y = check_array(y, ensure_2d=False, dtype=None) if issparse(X): X.sort_indices() diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 5a96a4260ceb9..c3a498c6f42f8 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -272,6 +272,8 @@ def _yield_all_checks(name, estimator): yield check_dict_unchanged yield check_dont_overwrite_parameters yield check_fit_idempotent + if not tags["no_validation"]: + yield check_n_features_in if tags["requires_positive_X"]: yield check_fit_non_negative @@ -2672,3 +2674,28 @@ def check_fit_idempotent(name, estimator_orig): atol=max(tol, 1e-9), rtol=max(tol, 1e-7), err_msg="Idempotency check failed for method {}".format(method) ) + + +def check_n_features_in(name, estimator_orig): + # Make sure that n_features_in_ attribute doesn't exist until fit is + # called, and that its value is correct. + + rng = np.random.RandomState(0) + + estimator = clone(estimator_orig) + set_random_state(estimator) + if 'warm_start' in estimator.get_params().keys(): + estimator.set_params(warm_start=False) + + n_samples = 100 + X = rng.normal(loc=100, size=(n_samples, 2)) + X = pairwise_estimator_convert_X(X, estimator) + if is_regressor(estimator_orig): + y = rng.normal(size=n_samples) + else: + y = rng.randint(low=0, high=2, size=n_samples) + y = _enforce_estimator_tags_y(estimator, y) + + assert not hasattr(estimator, 'n_features_in_') + estimator.fit(X, y) + assert estimator.n_features_in_ == X.shape[1] diff --git a/sklearn/utils/tests/test_estimator_checks.py b/sklearn/utils/tests/test_estimator_checks.py index e26a508566871..8c3a2a0bd4bf1 100644 --- a/sklearn/utils/tests/test_estimator_checks.py +++ b/sklearn/utils/tests/test_estimator_checks.py @@ -56,7 +56,7 @@ def __init__(self, key=0): self.key = key def fit(self, X, y=None): - X, y = check_X_y(X, y) + X, y = self._validate_X_y(X, y) return self def predict(self, X): @@ -71,7 +71,7 @@ def __init__(self, acceptable_key=0): def fit(self, X, y=None): self.wrong_attribute = 0 - X, y = check_X_y(X, y) + X, y = self._validate_X_y(X, y) return self @@ -81,14 +81,14 @@ def __init__(self, wrong_attribute=0): def fit(self, X, y=None): self.wrong_attribute = 1 - X, y = check_X_y(X, y) + X, y = self._validate_X_y(X, y) return self class ChangesUnderscoreAttribute(BaseEstimator): def fit(self, X, y=None): self._good_attribute = 1 - X, y = check_X_y(X, y) + X, y = self._validate_X_y(X, y) return self @@ -105,7 +105,7 @@ def set_params(self, **kwargs): return super().set_params(**kwargs) def fit(self, X, y=None): - X, y = check_X_y(X, y) + X, y = self._validate_X_y(X, y) return self @@ -122,7 +122,7 @@ def set_params(self, **kwargs): return super().set_params(**kwargs) def fit(self, X, y=None): - X, y = check_X_y(X, y) + X, y = self._validate_X_y(X, y) return self @@ -141,19 +141,19 @@ def set_params(self, **kwargs): return super().set_params(**kwargs) def fit(self, X, y=None): - X, y = check_X_y(X, y) + X, y = self._validate_X_y(X, y) return self class NoCheckinPredict(BaseBadClassifier): def fit(self, X, y): - X, y = check_X_y(X, y) + X, y = self._validate_X_y(X, y) return self class NoSparseClassifier(BaseBadClassifier): def fit(self, X, y): - X, y = check_X_y(X, y, accept_sparse=['csr', 'csc']) + X, y = self._validate_X_y(X, y, accept_sparse=['csr', 'csc']) if sp.issparse(X): raise ValueError("Nonsensical Error") return self @@ -165,7 +165,7 @@ def predict(self, X): class CorrectNotFittedErrorClassifier(BaseBadClassifier): def fit(self, X, y): - X, y = check_X_y(X, y) + X, y = self._validate_X_y(X, y) self.coef_ = np.ones(X.shape[1]) return self @@ -178,10 +178,11 @@ def predict(self, X): class NoSampleWeightPandasSeriesType(BaseEstimator): def fit(self, X, y, sample_weight=None): # Convert data - X, y = check_X_y(X, y, - accept_sparse=("csr", "csc"), - multi_output=True, - y_numeric=True) + X, y = self._validate_X_y( + X, y, + accept_sparse=("csr", "csc"), + multi_output=True, + y_numeric=True) # Function is only called after we verify that pandas is installed from pandas import Series if isinstance(sample_weight, Series): @@ -218,7 +219,7 @@ def fit(self, X, y): class BadTransformerWithoutMixin(BaseEstimator): def fit(self, X, y=None): - X = check_array(X) + X = self._validate_X(X) return self def transform(self, X): @@ -229,10 +230,11 @@ def transform(self, X): class NotInvariantPredict(BaseEstimator): def fit(self, X, y): # Convert data - X, y = check_X_y(X, y, - accept_sparse=("csr", "csc"), - multi_output=True, - y_numeric=True) + X, y = self._validate_X_y( + X, y, + accept_sparse=("csr", "csc"), + multi_output=True, + y_numeric=True) return self def predict(self, X): @@ -245,11 +247,12 @@ def predict(self, X): class LargeSparseNotSupportedClassifier(BaseEstimator): def fit(self, X, y): - X, y = check_X_y(X, y, - accept_sparse=("csr", "csc", "coo"), - accept_large_sparse=True, - multi_output=True, - y_numeric=True) + X, y = self._validate_X_y( + X, y, + accept_sparse=("csr", "csc", "coo"), + accept_large_sparse=True, + multi_output=True, + y_numeric=True) if sp.issparse(X): if X.getformat() == "coo": if X.row.dtype == "int64" or X.col.dtype == "int64": @@ -265,7 +268,7 @@ def fit(self, X, y): class SparseTransformer(BaseEstimator): def fit(self, X, y=None): - self.X_shape_ = check_array(X).shape + self.X_shape_ = self._validate_X(X).shape return self def fit_transform(self, X, y=None): @@ -296,7 +299,7 @@ def _more_tags(self): class RequiresPositiveYRegressor(LinearRegression): def fit(self, X, y): - X, y = check_X_y(X, y) + X, y = self._validate_X_y(X, y) if (y <= 0).any(): raise ValueError('negative y values not supported!') return super().fit(X, y)