diff --git a/sklearn/cluster/birch.py b/sklearn/cluster/birch.py index 188eff02b6f02..8cb054ce84852 100644 --- a/sklearn/cluster/birch.py +++ b/sklearn/cluster/birch.py @@ -14,7 +14,7 @@ from ..externals.six.moves import xrange from ..utils import check_array from ..utils.extmath import row_norms, safe_sparse_dot -from ..utils.validation import check_is_fitted +from ..utils.validation import check_is_fitted, check_partial_fit_n_features from ..exceptions import NotFittedError, ConvergenceWarning from .hierarchical import AgglomerativeClustering @@ -546,11 +546,8 @@ def _check_fit(self, X): # Should raise an error if one does not fit before predicting. if not (is_fitted or has_partial_fit): raise NotFittedError("Fit training data before predicting") - - if is_fitted and X.shape[1] != self.subcluster_centers_.shape[1]: - raise ValueError( - "Training data and predicted data do " - "not have same number of features.") + if is_fitted: + check_partial_fit_n_features(X, self.subcluster_centers_, self) def predict(self, X): """ diff --git a/sklearn/cluster/k_means_.py b/sklearn/cluster/k_means_.py index b79db75e0e720..eba3258875218 100644 --- a/sklearn/cluster/k_means_.py +++ b/sklearn/cluster/k_means_.py @@ -27,7 +27,7 @@ from ..utils import check_array from ..utils import gen_batches from ..utils import check_random_state -from ..utils.validation import check_is_fitted +from ..utils.validation import check_is_fitted, check_partial_fit_n_features from ..utils.validation import FLOAT_DTYPES from ..utils import Parallel from ..utils import delayed @@ -1694,6 +1694,7 @@ def partial_fit(self, X, y=None, sample_weight=None): random_reassign = False distances = None else: + check_partial_fit_n_features(X, self.cluster_centers_, self) # The lower the minimum count is, the more we do random # reassignment, however, we don't want to do random # reassignment too often, to allow for building up counts diff --git a/sklearn/decomposition/dict_learning.py b/sklearn/decomposition/dict_learning.py index f39e26e083cee..f7b86619710e7 100644 --- a/sklearn/decomposition/dict_learning.py +++ b/sklearn/decomposition/dict_learning.py @@ -19,7 +19,7 @@ from ..utils import (check_array, check_random_state, gen_even_slices, gen_batches) from ..utils.extmath import randomized_svd, row_norms -from ..utils.validation import check_is_fitted +from ..utils.validation import check_is_fitted, check_partial_fit_n_features from ..linear_model import Lasso, orthogonal_mp_gram, LassoLars, Lars @@ -1412,6 +1412,7 @@ def partial_fit(self, X, y=None, iter_offset=None): self.random_state_ = check_random_state(self.random_state) X = check_array(X) if hasattr(self, 'components_'): + check_partial_fit_n_features(X, self.components_, self) dict_init = self.components_ else: dict_init = self.dict_init diff --git a/sklearn/decomposition/incremental_pca.py b/sklearn/decomposition/incremental_pca.py index 779ebf42b20f1..a21bf7c078b5a 100644 --- a/sklearn/decomposition/incremental_pca.py +++ b/sklearn/decomposition/incremental_pca.py @@ -11,6 +11,7 @@ from .base import _BasePCA from ..utils import check_array, gen_batches from ..utils.extmath import svd_flip, _incremental_mean_and_var +from ..utils.validation import check_partial_fit_n_features class IncrementalPCA(_BasePCA): @@ -224,6 +225,8 @@ def partial_fit(self, X, y=None, check_input=True): n_samples, n_features = X.shape if not hasattr(self, 'components_'): self.components_ = None + elif self.components_ is not None: + check_partial_fit_n_features(X, self.components_, self) if self.n_components is None: if self.components_ is None: @@ -243,7 +246,7 @@ def partial_fit(self, X, y=None, check_input=True): if (self.components_ is not None) and (self.components_.shape[0] != self.n_components_): - raise ValueError("Number of input features has changed from %i " + raise ValueError("Number of components has changed from %i " "to %i between calls to partial_fit! Try " "setting n_components to a fixed value." % (self.components_.shape[0], self.n_components_)) diff --git a/sklearn/decomposition/online_lda.py b/sklearn/decomposition/online_lda.py index 4c0f8625771c7..0624258d07c41 100644 --- a/sklearn/decomposition/online_lda.py +++ b/sklearn/decomposition/online_lda.py @@ -19,7 +19,7 @@ from ..utils import (check_random_state, check_array, gen_batches, gen_even_slices) from ..utils.fixes import logsumexp -from ..utils.validation import check_non_negative +from ..utils.validation import check_non_negative, check_partial_fit_n_features from ..utils import Parallel, delayed, effective_n_jobs from ..externals.six.moves import xrange from ..exceptions import NotFittedError @@ -493,12 +493,8 @@ def partial_fit(self, X, y=None): # initialize parameters or check if not hasattr(self, 'components_'): self._init_latent_vars(n_features) - - if n_features != self.components_.shape[1]: - raise ValueError( - "The provided data has %d dimensions while " - "the model was trained with feature size %d." % - (n_features, self.components_.shape[1])) + else: + check_partial_fit_n_features(X, self.components_, self) n_jobs = effective_n_jobs(self.n_jobs) with Parallel(n_jobs=n_jobs, verbose=max(0, diff --git a/sklearn/decomposition/tests/test_online_lda.py b/sklearn/decomposition/tests/test_online_lda.py index 0abc2efe75ec2..887f9de0cd15c 100644 --- a/sklearn/decomposition/tests/test_online_lda.py +++ b/sklearn/decomposition/tests/test_online_lda.py @@ -142,21 +142,6 @@ def test_lda_fit_transform(method): assert_array_almost_equal(X_fit, X_trans, 4) -def test_lda_partial_fit_dim_mismatch(): - # test `n_features` mismatch in `partial_fit` - rng = np.random.RandomState(0) - n_components = rng.randint(3, 6) - n_col = rng.randint(6, 10) - X_1 = np.random.randint(4, size=(10, n_col)) - X_2 = np.random.randint(4, size=(10, n_col + 1)) - lda = LatentDirichletAllocation(n_components=n_components, - learning_offset=5., total_samples=20, - random_state=rng) - lda.partial_fit(X_1) - assert_raises_regexp(ValueError, r"^The provided data has", - lda.partial_fit, X_2) - - def test_invalid_params(): # test `_check_params` method X = np.ones((5, 10)) @@ -202,7 +187,7 @@ def test_lda_transform_mismatch(): random_state=rng) lda.partial_fit(X) assert_raises_regexp(ValueError, r"^The provided data has", - lda.partial_fit, X_2) + lda.transform, X_2) @if_safe_multiprocessing_with_blas diff --git a/sklearn/linear_model/stochastic_gradient.py b/sklearn/linear_model/stochastic_gradient.py index 146d9623f22e7..f67a49a131ce6 100644 --- a/sklearn/linear_model/stochastic_gradient.py +++ b/sklearn/linear_model/stochastic_gradient.py @@ -18,7 +18,7 @@ from ..utils import check_array, check_random_state, check_X_y from ..utils.extmath import safe_sparse_dot from ..utils.multiclass import _check_partial_fit_first_call -from ..utils.validation import check_is_fitted +from ..utils.validation import check_is_fitted, check_partial_fit_n_features from ..exceptions import ConvergenceWarning from ..externals import six from ..model_selection import StratifiedShuffleSplit, ShuffleSplit @@ -533,9 +533,8 @@ def _partial_fit(self, X, y, alpha, C, if getattr(self, "coef_", None) is None or coef_init is not None: self._allocate_parameter_mem(n_classes, n_features, coef_init, intercept_init) - elif n_features != self.coef_.shape[-1]: - raise ValueError("Number of features %d does not match previous " - "data %d." % (n_features, self.coef_.shape[-1])) + else: + check_partial_fit_n_features(X, self.coef_, self) self.loss_function_ = self._get_loss_function(loss) if not hasattr(self, "t_"): @@ -1144,9 +1143,9 @@ def _partial_fit(self, X, y, alpha, C, loss, learning_rate, if getattr(self, "coef_", None) is None: self._allocate_parameter_mem(1, n_features, coef_init, intercept_init) - elif n_features != self.coef_.shape[-1]: - raise ValueError("Number of features %d does not match previous " - "data %d." % (n_features, self.coef_.shape[-1])) + else: + check_partial_fit_n_features(X, self.coef_, self) + if self.average > 0 and getattr(self, "average_coef_", None) is None: self.average_coef_ = np.zeros(n_features, dtype=np.float64, diff --git a/sklearn/naive_bayes.py b/sklearn/naive_bayes.py index dced4fbdb3dd2..cb650e6593700 100644 --- a/sklearn/naive_bayes.py +++ b/sklearn/naive_bayes.py @@ -30,7 +30,7 @@ from .utils.extmath import safe_sparse_dot from .utils.fixes import logsumexp from .utils.multiclass import _check_partial_fit_first_call -from .utils.validation import check_is_fitted +from .utils.validation import check_is_fitted, check_partial_fit_n_features from .externals import six __all__ = ['BernoulliNB', 'GaussianNB', 'MultinomialNB', 'ComplementNB'] @@ -382,9 +382,8 @@ def _partial_fit(self, X, y, classes=None, _refit=False, self.class_prior_ = np.zeros(len(self.classes_), dtype=np.float64) else: - if X.shape[1] != self.theta_.shape[1]: - msg = "Number of features %d does not match previous data %d." - raise ValueError(msg % (X.shape[1], self.theta_.shape[1])) + check_partial_fit_n_features(X, self.theta_, self) + # Put epsilon back in each time self.sigma_[:, :] -= self.epsilon_ @@ -527,9 +526,8 @@ def partial_fit(self, X, y, classes=None, sample_weight=None): self.class_count_ = np.zeros(n_effective_classes, dtype=np.float64) self.feature_count_ = np.zeros((n_effective_classes, n_features), dtype=np.float64) - elif n_features != self.coef_.shape[1]: - msg = "Number of features %d does not match previous data %d." - raise ValueError(msg % (n_features, self.coef_.shape[-1])) + else: + check_partial_fit_n_features(X, self.coef_, self) Y = label_binarize(y, classes=self.classes_) if Y.shape[1] == 1: diff --git a/sklearn/neural_network/multilayer_perceptron.py b/sklearn/neural_network/multilayer_perceptron.py index de559dc67e18f..ff41de7600ec5 100644 --- a/sklearn/neural_network/multilayer_perceptron.py +++ b/sklearn/neural_network/multilayer_perceptron.py @@ -24,7 +24,7 @@ from ..utils import check_array, check_X_y, column_or_1d from ..exceptions import ConvergenceWarning from ..utils.extmath import safe_sparse_dot -from ..utils.validation import check_is_fitted +from ..utils.validation import check_is_fitted, check_partial_fit_n_features from ..utils.multiclass import _check_partial_fit_first_call, unique_labels from ..utils.multiclass import type_of_target @@ -340,6 +340,8 @@ def _fit(self, X, y, incremental=False): incremental): # First time training the model self._initialize(y, layer_units) + else: + check_partial_fit_n_features(X, self.coefs_[0].T, self) # lbfgs does not support mini-batches if self.solver == 'lbfgs': diff --git a/sklearn/neural_network/rbm.py b/sklearn/neural_network/rbm.py index 1361bffe0d240..6a76e568938c7 100644 --- a/sklearn/neural_network/rbm.py +++ b/sklearn/neural_network/rbm.py @@ -21,7 +21,7 @@ from ..utils import gen_even_slices from ..utils.extmath import safe_sparse_dot from ..utils.extmath import log_logistic -from ..utils.validation import check_is_fitted +from ..utils.validation import check_is_fitted, check_partial_fit_n_features class BernoulliRBM(BaseEstimator, TransformerMixin): @@ -243,6 +243,8 @@ def partial_fit(self, X, y=None): (self.n_components, X.shape[1]) ), order='F') + else: + check_partial_fit_n_features(X, self.components_, self) if not hasattr(self, 'intercept_hidden_'): self.intercept_hidden_ = np.zeros(self.n_components, ) if not hasattr(self, 'intercept_visible_'): diff --git a/sklearn/preprocessing/data.py b/sklearn/preprocessing/data.py index 48e78302e0594..5482e8bb23f35 100644 --- a/sklearn/preprocessing/data.py +++ b/sklearn/preprocessing/data.py @@ -31,7 +31,7 @@ mean_variance_axis, incr_mean_variance_axis, min_max_axis) from ..utils.validation import (check_is_fitted, check_random_state, - FLOAT_DTYPES) + FLOAT_DTYPES, check_partial_fit_n_features) from ._csr_polynomial_expansion import _csr_polynomial_expansion @@ -358,6 +358,8 @@ def partial_fit(self, X, y=None): self.n_samples_seen_ = X.shape[0] # Next steps else: + check_partial_fit_n_features(X, self.scale_, self) + data_min = np.minimum(self.data_min_, data_min) data_max = np.maximum(self.data_max_, data_max) self.n_samples_seen_ += X.shape[0] @@ -652,6 +654,16 @@ def partial_fit(self, X, y=None): self.n_samples_seen_ = np.repeat(self.n_samples_seen_, X.shape[1]).astype(np.int64) + # if first pass: store number of features + if not hasattr(self, "mean_"): + self._n_features_ = X.shape[1] + + # check number of features consistency for next passes + if hasattr(self, "mean_") and self.mean_ is not None: + check_partial_fit_n_features(X, self.mean_, self) + if hasattr(self, "scale_") and self.scale_ is not None: + check_partial_fit_n_features(X, self.scale_, self) + if sparse.issparse(X): if self.with_mean: raise ValueError( @@ -911,6 +923,8 @@ def partial_fit(self, X, y=None): self.n_samples_seen_ = X.shape[0] # Next passes else: + check_partial_fit_n_features(X, self.scale_, self) + max_abs = np.maximum(self.max_abs_, max_abs) self.n_samples_seen_ += X.shape[0] diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 5c226ac8ba8e7..8bb76692f0c79 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -126,7 +126,6 @@ def _yield_classifier_checks(name, classifier): # test classifiers trained on a single label always return this label yield check_classifiers_one_label yield check_classifiers_classes - yield check_estimators_partial_fit_n_features # basic consistency testing yield check_classifiers_train yield partial(check_classifiers_train, readonly_memmap=True) @@ -179,7 +178,6 @@ def _yield_regressor_checks(name, regressor): yield check_regressors_train yield partial(check_regressors_train, readonly_memmap=True) yield check_regressor_data_not_an_array - yield check_estimators_partial_fit_n_features yield check_regressors_no_decision_function yield check_supervised_y_2d yield check_supervised_y_no_nan @@ -220,7 +218,6 @@ def _yield_clustering_checks(name, clusterer): # let's not test that here. yield check_clustering yield partial(check_clustering, readonly_memmap=True) - yield check_estimators_partial_fit_n_features yield check_non_transformer_estimators_n_iter @@ -268,6 +265,7 @@ def _yield_all_checks(name, estimator): yield check_dict_unchanged yield check_dont_overwrite_parameters yield check_fit_idempotent + yield check_estimators_partial_fit_n_features def check_estimator(Estimator): @@ -1261,11 +1259,13 @@ def check_estimators_partial_fit_n_features(name, estimator_orig): except NotImplementedError: return - with assert_raises(ValueError, - msg="The estimator {} does not raise an" - " error when the number of features" - " changes between calls to " - "partial_fit.".format(name)): + match = ("Number of input features has changed .* between " + "calls to partial_fit") + msg = ("The estimator {} does not raise an appropriate error when " + "the number of features changes between calls to " + "partial_fit.".format(name)) + + with assert_raises_regex(ValueError, match, msg=msg): estimator.partial_fit(X[:, :-1], y) diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index 3ae1b283ccef5..9793241d7f639 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -971,3 +971,26 @@ def check_non_negative(X, whom): if X_min < 0: raise ValueError("Negative values in data passed to %s" % whom) + + +def check_partial_fit_n_features(X, components, estimator): + """ + Check if number of features is preseved between calls to partial_fit. + + Parameters + ---------- + X : array-like + Input data for the new call to partial_fit + + components : array_like + Fitted attribute of an estimator which has the same number of features + as the input data from the first fit. + + estimator : estimator instance. + Estimator instance for which the check is performed. + """ + if X.shape[-1] != components.shape[-1]: + raise ValueError("Number of input features has changed from {0} to {1}" + " between calls to partial_fit of {2}." + "".format(X.shape[-1], components.shape[-1], + type(estimator).__name__))